In [None]:
import os
import pickle
import numpy as np
from mne import viz
from scipy import io
from matplotlib import colors, pyplot as plt
from sklearn import model_selection, ensemble, svm, discriminant_analysis, neighbors, linear_model

In [None]:
CAF_DOSE = 200
SIGNIFICANT_P = 0.05

CLASSIFIER = 'SVM' # SVM, LDA, QDA, GradientBoosting, KNeighbors, SGD-LOG, SGD-PERC

DATA_PATH = 'C:\\Users\\Philipp\\Documents\\Caffeine\\Features{dose}\\Combined'.format(dose=CAF_DOSE)
PROJECT_PATH = 'E:\\Cafeine_data'
RESULTS_PATH = '..\\results'

STAGES = ['AWA', 'AWSL', 'NREM', 'REM']
BANDS = ['delta', 'theta', 'alpha', 'sigma', 'beta', 'low gamma']

In [None]:
sensor_pos = io.loadmat(os.path.join(PROJECT_PATH, 'Coo_caf'))['Cor'].T
sensor_pos = np.array([sensor_pos[1], sensor_pos[0]]).T

In [None]:
with open(os.path.join(DATA_PATH, 'data.pickle'), 'rb') as file:
    data = pickle.load(file)
with open(os.path.join(DATA_PATH, 'labels.pickle'), 'rb') as file:
    labels = pickle.load(file)
with open(os.path.join(DATA_PATH, 'groups.pickle'), 'rb') as file:
    groups = pickle.load(file)

In [None]:
scores = {}
for stage in data.keys():
    scores[stage] = {}
    print(f'Sleep stage {stage}')
    for feature in data[stage].keys():
        scores[stage][feature] = []
        for electrode in range(20):
            print(f'   Training {CLASSIFIER} for feature {feature} (electrode {electrode + 1:2})...', end='\r')
            
            if CLASSIFIER.lower() == 'svm':
                clf = svm.SVC(gamma='scale')
            elif CLASSIFIER.lower() == 'lda':
                clf = discriminant_analysis.LinearDiscriminantAnalysis()
            elif CLASSIFIER.lower() == 'qda':
                clf = discriminant_analysis.QuadraticDiscriminantAnalysis()
            elif CLASSIFIER.lower() == 'gradientboosting':
                clf = ensemble.GradientBoostingClassifier(n_estimators=100)
            elif CLASSIFIER.lower() == 'kneighbors':
                clf = neighbors.KNeighborsClassifier()
            elif CLASSIFIER.lower() == 'sgd-log':
                clf = linear_model.SGDClassifier(loss='log', max_iter=1000, tol=1e-3)
            elif CLASSIFIER.lower() == 'sgd-perc':
                clf = linear_model.SGDClassifier(loss='perceptron', max_iter=1000, tol=1e-3)
            
            current = model_selection.permutation_test_score(estimator=clf,
                                                             X=data[stage][feature][:,electrode].reshape((-1, 1)),
                                                             y=labels[stage],
                                                             groups=groups[stage],
                                                             cv=10,
                                                             n_permutations=1000,
                                                             n_jobs=-1)
            scores[stage][feature].append(current)
        print()

In [None]:
with open(os.path.join(RESULTS_PATH, f'scores{CAF_DOSE}', f'scores_individual_{CLASSIFIER}.pickle'), 'wb') as file:
    pickle.dump(scores, file)

In [None]:
all_scores = [[[elec[0] for elec in ft] for ft in stage.values()] for stage in scores.values()]
vmin = np.min(all_scores)
vmax = np.max(all_scores)

print(f'Min accuracy: {vmin * 100:.2f}%')
print(f'Max accuracy: {vmax * 100:.2f}%')
print(f'Mean accuracy: {np.mean(all_scores) * 100:.2f}%')

In [None]:
plot_rows = 2
plot_cols = 5
colormap = 'jet'

for stage in STAGES:
    plt.figure(figsize=(18, 5))
    plt.suptitle(stage, y=1.05, fontsize=20)
    
    all_scores = [[elec[0] for elec in ft] for ft in scores[stage].values()]
    vmin = np.min(all_scores)
    vmax = np.max(all_scores)
    
    subplot_index = 1
    axes = []
    for feature in scores[stage].keys():
        curr_acc = np.array([score[0] for score in scores[stage][feature]])
        curr_sig = np.array([score[2] for score in scores[stage][feature]])
        
        ax = plt.subplot(plot_rows, plot_cols, subplot_index)
        axes.append(ax)
        plt.title(feature)
        mask = curr_sig < SIGNIFICANT_P
        viz.plot_topomap(curr_acc, sensor_pos, mask=mask, cmap=colormap, vmin=vmin, vmax=vmax, contours=False, show=False)
        subplot_index += 1
    
    norm = colors.Normalize(vmin=vmin,vmax=vmax)
    sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)
    sm.set_array([])
    plt.colorbar(sm, ax=axes, shrink=0.95, aspect=15)
    plt.savefig(os.path.join(RESULTS_PATH, f'figures{CAF_DOSE}', f'{CLASSIFIER}_DA_individual_{stage}.png'))
    plt.show()