# Machine Learning Classifier Evaluation

Fit a range of ML classifiers to each subject in the  study, and create a table of metrics, including mean accuracy from 10-fold stratified cross-validation, and generalization accuracy of the model to a test set (20% of original data).

### Import modules 

In [None]:
import numpy as np
import pandas as pd
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
import seaborn as sns
from glob import glob
from pathlib import Path
import yaml
from yaml import CLoader as Loader
import os.path as op
# MNE
import mne
from mne import io, EvokedArray
from mne.decoding import Vectorizer, get_coef
from mne.decoding import LinearModel
# sklearn
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split, StratifiedKFold, cross_validate
from sklearn.metrics import precision_recall_curve, precision_score, recall_score, accuracy_score, roc_auc_score, f1_score
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import LinearSVC

mne.set_log_level(verbose='Warning')

### Yaml + Pathing

In [None]:
## YAML
with open('config.yml', 'r') as f:
    config = yaml.load(f, Loader=Loader)

study_name = config['study_name']
task = config['task']
data_type = config['data_type']
eog = config['eog']
montage_fname = config['montage_fname']
n_jobs = -1

epoch_p =  {k: v for d in config['preprocessing_settings']['epoch'] for k, v in d.items()}

cl_p = {k: v for d in config['classification'] for k, v in d.items()}

## Pathing
results_path = op.join('results', 'classification_test_' + str(cl_p['test_size'])[-1] + '0_pct')
if Path(results_path).exists() == False:
    Path(results_path).mkdir(parents=True)
    
report_path = op.join(results_path, 'reports')
if Path(report_path).exists() == False:
    Path(report_path).mkdir(parents=True)

fig_path = op.join(results_path, 'figures')
if Path(fig_path).exists() == False:
    Path(fig_path).mkdir(parents=True) 

tab_path = op.join(results_path, 'tables')
if Path(tab_path).exists() == False:
    Path(tab_path).mkdir(parents=True) 
    
epochs_suffix = '-epo.fif'

# Output files
out_file = op.join(tab_path, 'classification_overall_results.csv')
summary_file =  op.join(tab_path, 'classification_accuracy_summary.csv')
plot_stem = op.join(fig_path, 'plot_')
fig_format = 'pdf'

## Define conditions and labels

In [None]:
conditions = ['Angry/Grey/target', 'Angry/Grey/nontarget',
              'Angry/Red/target', 'Angry/Red/nontarget',
              'Neutral/Grey/target', 'Neutral/Grey/nontarget',
              'Neutral/Red/target', 'Neutral/Red/nontarget',
              'target', 'nontarget'
             ]

coi = ['target', 'nontarget']

contrasts = {'Angry/Grey':['Angry/Grey/target', 'Angry/Grey/nontarget'],
             'Angry/Red':['Angry/Red/target', 'Angry/Red/nontarget'],
             'Neutral/Grey':['Neutral/Grey/target', 'Neutral/Grey/nontarget'],
             'Neutral/Red':['Neutral/Red/target', 'Neutral/Red/nontarget'],
             'Target-Nontarget':['target', 'nontarget']
            }

## Instantiating classifiers, parameter grids, and scoring metrics

In [None]:
rng = np.random.RandomState(seed=42)

scaler = StandardScaler()
vectorizer = Vectorizer()

logreg = LinearModel(LogisticRegression(solver='lbfgs', max_iter=1000, n_jobs=n_jobs, verbose=False, random_state=rng))
lda = LinearDiscriminantAnalysis()
svc = LinearSVC(C=8, max_iter=10000, verbose=False, random_state=rng) 

classifiers = {'LR':logreg, 'LDA':lda, 'SVM':svc}

# For cross-validation
k = 10
cv = StratifiedKFold(n_splits=k, shuffle=True, random_state=42)

## Define Subjects & Load Data

In [None]:
# For Running Individual participants
subjects = ['sub-004', 'sub-005']
print("n subjects = ", len(subjects))

## Reading in data
epochs = {}
print('Loading Subjects:', subjects)
for subject in subjects:
    raw_path = op.join('./', 'data')
    raw_subj = glob(op.join(raw_path + '/' + '*-epo.fif'))
    epochs[subject] = mne.read_epochs(raw_subj.pop(), proj=False, verbose=False, preload=True)
    
    # Correcting for presentation delay
    epochs[subject]._raw_times = epochs[subject]._raw_times - epoch_p['tshift']
    epochs[subject]._times_readonly = epochs[subject]._times_readonly - epoch_p['tshift']


## Classification loop
For each subject and condition, fit each classifier and score

In [None]:
%%time

acc_tab = pd.DataFrame()
acc_tab_list = []

for subject in subjects:
    print('\n-------\n' + subject)
    
    for contr, conds in contrasts.items():
        print('-------\n' + contr)
        subj_epochs = epochs[subject][conds]

        # create a list of labels from event codes mapped to event_id
        event_id_rev = dict(zip(subj_epochs.event_id.values(), subj_epochs.event_id.keys()))
        labels_all = [event_id_rev[e] for e in subj_epochs.events[:, 2]]
        labels_all = pd.DataFrame(labels_all)[0].str.split('/', expand=True).rename(columns={0:'Emotion', 1:'Colour', 2:'Status', 3:'Location'})
        label_map = {'target':1, 'nontarget':0}
        labels_all['labels'] = labels_all['Status'].map(label_map)
        labels = labels_all['labels']

        # Extract data from subj_epochs and vectorize 
        X = subj_epochs.get_data()    

        # Create train-test split
        X_train, X_test, y_train, y_test = train_test_split(X, labels,
                                                            stratify=labels,
                                                            test_size=cl_p['test_size'], 
                                                            random_state=42)

        for c_name, c in classifiers.items():
            print('-------\nRunning classifier: ' + c_name)
            clf = Pipeline([('Vectorizer', vectorizer),
                            ('Scaler', scaler),
                            (c_name, c)
                           ])

    #         kf_scores = cross_val_score(clf, X_train, y_train, cv=cv)
            # Fit model then get prediction accuracy on test set
            print('Cross validate...')
            cv_cv = cross_validate(clf, X_train, y_train, 
                                   scoring=['accuracy', 'precision', 'recall', 'f1', 'roc_auc'], 
                                   cv=cv,
                                   n_jobs=n_jobs)      
            print('Training...')
            train_fit = clf.fit(X_train, y_train)
            print('Predicting...')
            y_pred = clf.predict(X_test)

            print('Scoring...')
            # test_score = clf.score(X_test, y_test)

            acc_tab_list.append(pd.DataFrame({'participant_id':subject,
                                              'Condition':contr,
                                              'Classifier':c_name,
                                              'CV_accuracy':cv_cv['test_accuracy'].mean().round(3) * 100,
                                              'Test_accuracy':accuracy_score(y_test, y_pred).round(3) * 100,
                                              
                                              'CV_precision':cv_cv['test_precision'].mean().round(3) * 100,
                                              'Test_precision':precision_score(y_test, y_pred).round(3) * 100,
                                              
                                              'CV_recall':cv_cv['test_recall'].mean().round(3) * 100,
                                              'Test_recall':recall_score(y_test, y_pred).round(3) * 100,
                                              
                                              'CV_f1':cv_cv['test_f1'].mean().round(3) * 100,
                                              'Test_f1':f1_score(y_test, y_pred).round(3) * 100,
                                              
                                              'CV_ROC_AUC':cv_cv['test_roc_auc'].mean().round(3) * 100,
                                              'Test_ROC_AUC':roc_auc_score(y_test, y_pred).round(3) * 100,
                                              
                                              'Fit Time':cv_cv['fit_time'].mean().round(3)
                                             }, index=[0]
                                            )
                               )
# compile accuracy results                            
acc_tab = pd.concat(acc_tab_list)

# save compiled results as CSV file in `results` folder
acc_tab.to_csv(out_file)

## Full results grouped by participant

In [None]:
acc_tab.groupby(['participant_id', 'Condition', 'Classifier']).mean()

## Average across subjects

In [None]:
acc_tab.groupby(['Condition', 'Classifier']).mean()

### Descriptive statistics on above

In [None]:
descr_tab = acc_tab.groupby(['Condition', 'Classifier']).describe()
descr_tab.to_csv(summary_file)
descr_tab

## Visualize results

In [None]:
# accuracy
ax = sns.catplot(kind='strip', 
            data=acc_tab,
            y='Test_accuracy', x='Condition', hue='Classifier', col='participant_id',
            aspect=.5
            )
ax.set_xticklabels(rotation = 30)
ax.savefig(plot_stem + 'accuracy_swarmplot_by_subj' + '.' + fig_format)

In [None]:
# precision
ax = sns.catplot(kind='swarm', 
            data=acc_tab,
            y='Test_precision', x='Condition', hue='Classifier', col='participant_id',
            aspect=.5
            )
ax.set_xticklabels(rotation = 30)
ax.savefig(plot_stem + 'precision_swarmplot_by_subj' + '.' + fig_format)

In [None]:
# recall
ax = sns.catplot(kind='swarm', 
            data=acc_tab,
            y='Test_recall', x='Condition', hue='Classifier', col='participant_id',
            aspect=.5
            )
ax.set_xticklabels(rotation = 30)
ax.savefig(plot_stem + 'precision_swarmplot_by_subj' + '.' + fig_format)

In [None]:
# F-1 score
ax = sns.catplot(kind='swarm', 
            data=acc_tab,
            y='Test_f1', x='Condition', hue='Classifier', col='participant_id',
            aspect=.5
            )
ax.set_xticklabels(rotation = 30)
ax.savefig(plot_stem + 'precision_swarmplot_by_subj' + '.' + fig_format)