In [1]:
import numpy as np
import matplotlib.pyplot as plt
from pyriemann.spatialfilters import CSP
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import confusion_matrix

In [2]:
nsplits = 10
nfilters = 1
nchannels = 139
nsamples = 191
nsubs = 163

signals = np.zeros([nsubs, nsamples, nchannels])
covs = np.zeros([nsubs, nchannels, nchannels])
labels = np.zeros(nsubs)
for i in range(1,98):
    name = 'filtered_data/HC_'+str(i)+'_det.dat'
    signal = np.loadtxt(name)
    signals[i-1] = signal
    covs[i-1] = np.cov(signal.T)
    labels[i-1] = 0
for i in range(98,164):
    name = 'filtered_data/mdd_'+str(i)+'_det.dat'
    signal = np.loadtxt(name)
    signals[i-1] = signal
    covs[i-1] = np.cov(signal.T)
    labels[i-1] = 1

confusion_matrices = np.zeros([nsplits, 2, 2])
split_i = 0
model = CSP(nfilter=nfilters, metric='euclid', log=True)
crossval = StratifiedShuffleSplit(n_splits=nsplits, test_size=0.2, random_state=0)
for train_index, test_index in crossval.split(X=np.zeros(nsubs), y=labels):
    model = model.fit(X=covs[train_index],y=labels[train_index])
    filtered_signals = model.transform(covs)
    lda = LinearDiscriminantAnalysis(solver="svd", store_covariance=True) #svd lsqr eigen
    clf = lda.fit(filtered_signals[train_index], labels[train_index])
    truths = labels[test_index]
    predictions = clf.predict(filtered_signals[test_index])
    confusion_matrices[split_i] = confusion_matrix(y_true=truths, y_pred=predictions)
    split_i += 1

In [3]:
tn = confusion_matrices[:, 0, 0]
fn = confusion_matrices[:, 1, 0]
tp = confusion_matrices[:, 1, 1]
fp = confusion_matrices[:, 0, 1]
accuracy = (tp+tn) / (tp+tn+fp+fn)
recall = tp / (tp+fn)
precision = tp / (tp+fp)
f1 = 2 * (
    (precision*recall) / (precision+recall)
)

In [4]:
print(f"Accuracy: {np.mean(accuracy):.2f} ± {np.std(accuracy):.2f}")
print(f"Recall: {np.mean(recall):.2f} ± {np.std(recall):.2f}")
print(f"Precision: {np.mean(precision):.2f} ± {np.std(precision):.2f}")
print(f"F1-score: {np.mean(f1):.2f} ± {np.std(f1):.2f}")

Accuracy: 0.60 ± 0.08
Recall: 0.37 ± 0.12
Precision: 0.49 ± 0.14
F1-score: 0.42 ± 0.13
