In [None]:
# In this file, we train and test all of the data, and we visualise it in several graphs

# run preprocessing and classification to use functions
%run Preprocessing.ipynb
%run Classification.ipynb

# create list of worded columns and list of columns to drop
classes = ["Compound", "A", "B", "In literature", "Lowest distortion"]
drop_list = ["Compound", "In literature", "A", "B", "τ"]

file_directory = "Crystal_structure.csv"

# get preprocessed data
clean = Preprocessor(file_directory, drop_list, classes, replaceNAN = True, 
                     useScaler = True, scaletype = "standard").Process()

# split data using conventional data
split_data = SplitData(clean)

# split data using stratified k folds and run classification k times
k_folds = KFoldSplit(clean, 10)

# make tuple of classifiers to use (if defined above) 
classifiers = (RandomForest, KNN, Logistic)

print(f"VISUAL DATA FOR SINGLE FOLD \n")

# we first run only one fold of data using conventional train/test split. This split will be visualised in a graph
for classifier in classifiers:
    classifier(split_data)
print(f"STRATIFIED CROSS VALIDATION USING {len(k_folds)} FOLDS ... \n")   

# run cross validation and store dictionary of classifiers with their f1 scores after k folds have been averaged
score_dictionary = CrossValidate(k_folds, classifiers)

print(score_dictionary)


In [None]:
import matplotlib.widgets

# we have now found that random forest yields the best results, so we will use it from here on out. We will now study the
# effect on the scores when one of the dependant variables are removed

# read the csv file and grab columns we need
data = pd.read_csv(file_directory)

# remove dependant column
independant_columns = list(data.columns)
independant_columns.remove("Lowest distortion")

# create list of columns excluding the dependant column that are not dropped
used_columns = [i for i in independant_columns if i not in drop_list]

# create dictionary with key as column removed and f1 scores as value
scores_dict = {}

# iterate through these columns and run 5 fold cross validation for random forest algorithm with column removed

for removed_column in used_columns:
    print(f"removing {removed_column}...")
    
    # add column to drop_list
    drop_list.append(removed_column)
    
    # get preprocessed data:
    clean = Preprocessor(file_directory, drop_list, classes, replaceNAN = True, \
                             useScaler = True, scaletype = "standard").Process()
    
    # get 5 fold split
    k_split = KFoldSplit(clean, 3)
    
    # have one tuple which includes only random forest
    classifiers = (RandomForest, )
    
    # run classification.ipynb's cross-validation function and return dict with one key-value pair
    report = CrossValidate(k_split, classifiers)
    
    # add random forest score to dictionary with the removed column
    scores_dict[f"{removed_column}"] = report["RandomForest"]
    
    # at the end, remove column from drop list and continue
    drop_list.remove(removed_column)
    continue

print(scores_dict)
# for score dictionary, each key is a classifier with its values being a tuple where:
# 0: average score, 1: macro average score, 2: weighted average score

In [None]:
def callback(label: str):
    # clear the plot
    ax.clear()
    # find index of text in category list
    index = categories.index(label)
    # use index to get name of column removed
    removed_column = used_column[index]
    
    # plot data on scores_dict for said column
    ax.bar(x, scores_dict[removed_column])
    
    return



# create categories
categories = [f"remove {removed_column}" for removed_column in used_columns]
x = ["average", "macro", "weighted"]
#colours = ('blue', 'green', 'orange')
fig, ax = plt.subplots(figsize = (9, 9))
plt.subplots_adjust(left=0.45)

print(categories)
ax.bar(x ,scores_dict["v(A)"])

# create radio box and remove tick labels
box = plt.axes([0.0, 0.2, 0.3, 0.6], facecolor='#FFDDAA')
box.tick_params(bottom=False, labelbottom=False, left=False, labelleft=False)

# add buttons
buttons = matplotlib.widgets.RadioButtons(box, categories)

buttons.on_clicked(callback)
plt.show()