In [1]:
import numpy as np
import scipy.linalg as la
import mne.filter as bandpass
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
import geomstats as gs
import geomstats.geometry.spd_matrices as spd
from geomstats.learning.frechet_mean import FrechetMean
from scipy.io import loadmat  

INFO: Using numpy backend


In [14]:
class RCSP_independent:
    def __init__(self, subjects, epochs, nchannels):
        self.subjects=subjects
        self.nchannels = nchannels
        self.epochs=epochs
        self.V=None
      
    def bandpass(self, X, fs, fl, fh):
        X = np.swapaxes(X, 1, 2)
        X = bandpass.filter_data(X, fs, fl, fh, verbose=False)
        X = np.swapaxes(X, 1, 2)
        return X
        
    def calculateCovariances(self, X):
        covlist = np.empty((X.shape[0], X.shape[2], X.shape[2]))
        for i in range(len(X)):
            epoch = X[i]
            covlist[i] = np.dot(epoch.T, epoch)
        return covlist
    
    def estimateMeans(self, classSpecificCOV, metric):
        if metric=="AIRM":
            estimator = FrechetMean(spd.SPDMetricAffine(n=self.nchannels), max_iter=64)
        elif metric=="LEM":
            estimator = FrechetMean(spd.SPDMetricLogEuclidean(n=self.nchannels), max_iter=64)
        elif metric=="BW":
            estimator = FrechetMean(spd.SPDMetricBuresWasserstein(n=self.nchannels), max_iter=64) #doesn't work yet
        else:
            raise Exception("Not implemented metric")
            
        means = []
        
        for COV in classSpecificCOV:
            estimator.fit(COV)
            mean = estimator.estimate_
            means.append(mean)
        return means
    
    
    def separate_classes(self, X, Y):
        classSpecificCOV = []
        for i in range(2): 
            ind = np.where(Y==i)[0]
            classCOV = X[ind]
            classSpecificCOV.append(classCOV)
        return classSpecificCOV
    
    def CSP(self, X, Y, metric, n):
        classSpecificCOV = self.separate_classes(X, Y)
        if metric == "classic":
            class0_avg = sum(classSpecificCOV[0])/len(classSpecificCOV[0])
            class1_avg = sum(classSpecificCOV[1])/len(classSpecificCOV[1])
        else:
            class0_avg, class1_avg = self.estimateMeans(classSpecificCOV, metric)

        _,V = la.eigh(class0_avg, class0_avg+class1_avg)
        V = np.concatenate((V[:, :n], V[:, -n:]), axis=1)
        
        return V
    
    def applyCSP(self, dk, V):
        a = np.dot(np.dot(V.T, dk), V) 
        f = np.log(np.diagonal(a)/np.trace(a)) 
        return f
    
    def collectSubjects(self, session, fs=1000, fl=8, fh=30, n=3):
        channels = [7, 8, 9, 10, 12, 13, 14, 17, 18, 19, 20, 32, 33, 34, 35, 36, 37, 38, 39, 40] #standard 10/20 cap mapping, electrodes responsible for MI
        
        SubjectsCOV = None
        SubjectsY = None
        
        for i in range(1, self.subjects+1):
            
            filename = 'datasets/54subjects/'+ session + '/subj{}_EEG_MI.mat'.format(i)
            data = loadmat(filename)
            
            d_train = data['EEG_MI_train']
            d_test = data['EEG_MI_test']
            
            X_train = d_train['smt'][0][0][1000:3500, :, channels]
            X_train_filtered = self.bandpass(np.swapaxes(X_train, 0, 1), fs, fl, fh)
            Y_train = d_train['y_dec'][0][0][0] - 1
                
            X_test = d_test['smt'][0][0][1000:3500, :, channels]
            X_test_filtered = self.bandpass(np.swapaxes(X_test, 0, 1), fs, fl, fh)
            Y_test = d_test['y_dec'][0][0][0] - 1
            
            X = np.concatenate((X_train_filtered, X_test_filtered))
            covX = self.calculateCovariances(X)
            Y = np.concatenate((Y_train, Y_test))
            
            if i==1:
                SubjectsCOV = covX
                SubjectsY = Y
            else:
                SubjectsCOV = np.concatenate((SubjectsCOV, covX))
                SubjectsY = np.concatenate((SubjectsY, Y))
                    
        return SubjectsCOV, SubjectsY
    
    def LOSO(self, session, metric, n=3):
        
        SubjectsCOV, SubjectsY = self.collectSubjects(session)
        
        results = []
        for i in range(self.subjects):
            
            testCOV = SubjectsCOV[self.epochs*i:self.epochs*(i+1)]
            testY = SubjectsY[self.epochs*i:self.epochs*(i+1)]
            
            trainCOV = np.delete(SubjectsCOV, slice(self.epochs*i, self.epochs*(i+1)), axis=0)
            trainY = np.delete(SubjectsY, slice(self.epochs*i, self.epochs*(i+1)), axis=0)
            
            V = self.CSP(trainCOV, trainY, metric, n)
            
            clf = LDA()
            
            train_features = np.empty((len(trainCOV), 2*n))
            for j in range(len(trainCOV)):
                epoch = trainCOV[j]
                train_features[j] = self.applyCSP(epoch, V)
                
            test_features = np.empty((len(testCOV), 2*n))
            for j in range(len(testCOV)):
                epoch = testCOV[j]
                test_features[j] = self.applyCSP(epoch, V)
                
            clf.fit(train_features, trainY)
            result = clf.score(test_features, testY)
            results.append(result)
            print("Subject " + str(i+1) + ", subject-independent accuracy is: " + str(result))
            
        av = sum(results)/len(results)
        print("Average: " + str(av))
        
        return results

In [15]:
csp = RCSP_independent(20, 200, 20)

In [19]:
metrics = ["classic", "AIRM", "LEM"]
ns = [2,3,4,5]

In [23]:
for n in ns:
    print("N is " + str(n) + "\n")
    for m in metrics:
        print("Metric is " + m + "\n")
        results = csp.LOSO("sess01", m, n=n)
        print()

N is 2

Metric is classic

Subject 1, subject-independent accuracy is: 0.585
Subject 2, subject-independent accuracy is: 0.5
Subject 3, subject-independent accuracy is: 0.665
Subject 4, subject-independent accuracy is: 0.56
Subject 5, subject-independent accuracy is: 0.64
Subject 6, subject-independent accuracy is: 0.615
Subject 7, subject-independent accuracy is: 0.475
Subject 8, subject-independent accuracy is: 0.555
Subject 9, subject-independent accuracy is: 0.73
Subject 10, subject-independent accuracy is: 0.64
Subject 11, subject-independent accuracy is: 0.49
Subject 12, subject-independent accuracy is: 0.595
Subject 13, subject-independent accuracy is: 0.665
Subject 14, subject-independent accuracy is: 0.495
Subject 15, subject-independent accuracy is: 0.395
Subject 16, subject-independent accuracy is: 0.73
Subject 17, subject-independent accuracy is: 0.69
Subject 18, subject-independent accuracy is: 0.755
Subject 19, subject-independent accuracy is: 0.74
Subject 20, subject-ind

Subject 19, subject-independent accuracy is: 0.735
Subject 20, subject-independent accuracy is: 0.51
Average: 0.6325000000000001

Metric is LEM

Subject 1, subject-independent accuracy is: 0.745
Subject 2, subject-independent accuracy is: 0.53
Subject 3, subject-independent accuracy is: 0.815
Subject 4, subject-independent accuracy is: 0.65
Subject 5, subject-independent accuracy is: 0.74
Subject 6, subject-independent accuracy is: 0.67
Subject 7, subject-independent accuracy is: 0.45
Subject 8, subject-independent accuracy is: 0.62
Subject 9, subject-independent accuracy is: 0.85
Subject 10, subject-independent accuracy is: 0.615
Subject 11, subject-independent accuracy is: 0.485
Subject 12, subject-independent accuracy is: 0.545
Subject 13, subject-independent accuracy is: 0.525
Subject 14, subject-independent accuracy is: 0.625
Subject 15, subject-independent accuracy is: 0.435
Subject 16, subject-independent accuracy is: 0.655
Subject 17, subject-independent accuracy is: 0.72
Subje