In [24]:
%env GEOMSTATS_BACKEND=numpy
%env NUMEXPR_MAX_THREADS=12 

import pandas as pd
import numpy as np
import scipy.linalg as la
import pickle


import geomstats as gs
import geomstats.geometry.spd_matrices as spd
from geomstats.learning.frechet_mean import FrechetMean

import logging
logger = logging.getLogger()
logger.disabled = True

from sklearn.linear_model import LogisticRegression

env: GEOMSTATS_BACKEND=numpy
env: NUMEXPR_MAX_THREADS=12


In [2]:
subjects = 20
epochs = 400
points = 512
channels = 62

conditions=['left','right']

# Throughout this code, left condition is marked 0, right condition is marked 1
# Same applies for some arrays, 0th position is smh corresponding to left, 1th position is smh corresponding to right 

In [3]:
filename = 'BDproject/SPDDataset.pickle'
infile=open(filename,'rb')
SPDDataset=pickle.load(infile)

In [63]:
class classicCSP:
    def __init__(self, subjects, train_size=0.8):
        self.subjects=subjects
        self.train_size=train_size
        
    def separate_classes(self, d, l):
        classSpecificSPD=[]
        for i in range(2): #two classes
            indecies = [j for j,val in enumerate(l) if val==i]
            classSPD = [d[j] for j in indecies]
            classSpecificSPD.append(classSPD)
        return classSpecificSPD
    
    def traintestsplit(self, d, l):
        N = len(d)
        d_train = d[:int(N*self.train_size)]
        l_train = l[:int(N*self.train_size)]
        
        d_test = d[int(N*self.train_size):]
        l_test = l[int(N*self.train_size):]
        
        return d_train, l_train, d_test, l_test
    
    def getSpatialFilter(self, d, l, n):
        left, right = self.separate_classes(d, l)
        
        left_avg = sum(left)/len(left)
        right_avg = sum(right)/len(right)
        _,V = la.eig(left_avg, left_avg+right_avg)
        V = np.concatenate((V[:, :n], V[:, -n:]), axis=1)
        return V
    
    def extractFeatures(self, dk, V):
        a = np.dot(np.dot(V.T, dk), V)
        f = np.log(np.diagonal(a)/np.trace(a))
        return f
    
    def mainloop(self, SPDDataset, n, C):
        
        results = []
        
        for i in range(self.subjects):
            d,l = SPDDataset[i]
            d_train, l_train, d_test, l_test = self.traintestsplit(d,l)
            
            V = self.getSpatialFilter(d_train, l_train, n)
            
            f_train = []
            f_test = []
            
            for j in range(len(d_train)):
                f_train.append(self.extractFeatures(d_train[j], V))
                
            for j in range(len(d_test)):
                f_test.append(self.extractFeatures(d_test[j], V))
                
            model = LogisticRegression(C=C, max_iter=2000)
            
            model.fit(f_train, l_train)
            accuracy = model.score(f_test, l_test)
            results.append(accuracy)
            print("For subject "+str(i+1) + " mean accuracy is " + str(accuracy))            
        return results

In [64]:
CSP = classicCSP(subjects, 0.6)

In [65]:
results = CSP.mainloop(SPDDataset, 3, 3)

For subject 1 mean accuracy is 0.55
For subject 2 mean accuracy is 0.48125
For subject 3 mean accuracy is 0.4625
For subject 4 mean accuracy is 0.61875
For subject 5 mean accuracy is 0.44375
For subject 6 mean accuracy is 0.5125
For subject 7 mean accuracy is 0.50625
For subject 8 mean accuracy is 0.46875
For subject 9 mean accuracy is 0.45625
For subject 10 mean accuracy is 0.5
For subject 11 mean accuracy is 0.40625
For subject 12 mean accuracy is 0.51875
For subject 13 mean accuracy is 0.48125
For subject 14 mean accuracy is 0.44375
For subject 15 mean accuracy is 0.45625
For subject 16 mean accuracy is 0.63125
For subject 17 mean accuracy is 0.48125
For subject 18 mean accuracy is 0.53125
For subject 19 mean accuracy is 0.46875
For subject 20 mean accuracy is 0.5375


In [33]:
class ReimannMeanCSP:
    def __init__(self, subjects, channels, train_size=0.8):
        self.subjects=subjects
        self.channels=channels
        self.train_size=train_size
        
    def separate_classes(self, d, l):
        classSpecificSPD=[]
        for i in range(2): #two classes
            indecies = [j for j,val in enumerate(l) if val==i]
            classSPD = [d[j] for j in indecies]
            classSpecificSPD.append(classSPD)
        return classSpecificSPD
    
    def traintestsplit(self, d, l):
        N = len(d)
        d_train = d[:int(N*self.train_size)]
        l_train = l[:int(N*self.train_size)]
        
        d_test = d[int(N*self.train_size):]
        l_test = l[int(N*self.train_size):]
        
        return d_train, l_train, d_test, l_test
    
    def estimateMeans(self, classSpecificSPD, metric):
        if metric=="AIRM":
            estimator = FrechetMean(spd.SPDMetricAffine(n=self.channels), max_iter=64)
        elif metric=="LEM":
            estimator = FrechetMean(spd.SPDMetricLogEuclidean(n=self.channels), max_iter=64)
        else:
            raise Exception("Not implemented metric")
            
        means = []
        
        for SPD in classSpecificSPD:
            estimator.fit(SPD)
            means.append(estimator.estimate_)
            
        return means
        
    
    def getSpatialFilter(self, d, l, n, metric):
        classSpecificSPD = self.separate_classes(d, l)
        
        left_avg, right_avg = self.estimateMeans(classSpecificSPD, metric)
        
        _,V = la.eig(left_avg, left_avg+right_avg)
        V = np.concatenate((V[:, :n], V[:, -n:]), axis=1)
        return V
    
    def extractFeatures(self, dk, V):
        a = np.dot(np.dot(V.T, dk), V)
        f = np.log(np.diagonal(a)/np.trace(a))
        return f
    
    def mainloop(self, SPDDataset, n, C, metric):
        
        results = []
        
        for i in range(self.subjects):
            d,l = SPDDataset[i]
            d_train, l_train, d_test, l_test = self.traintestsplit(d,l)
            
            V = self.getSpatialFilter(d_train, l_train, n, metric)
            
            f_train = []
            f_test = []
            
            for j in range(len(d_train)):
                f_train.append(self.extractFeatures(d_train[j], V))
                
            for j in range(len(d_test)):
                f_test.append(self.extractFeatures(d_test[j], V))
                
            model = LogisticRegression(C=C, max_iter=2000)
            
            model.fit(f_train, l_train)
            accuracy = model.score(f_test, l_test)
            results.append(accuracy)
            print("For subject "+str(i+1) + " mean accuracy is " + str(accuracy))            
        return results

In [34]:
RCSP = ReimannMeanCSP(subjects, channels, 0.6)

In [35]:
AIRM_results = RCSP.mainloop(SPDDataset, 3, 3, "AIRM")

For subject 1 mean accuracy is 0.5375
For subject 2 mean accuracy is 0.56875
For subject 3 mean accuracy is 0.53125
For subject 4 mean accuracy is 0.6875
For subject 5 mean accuracy is 0.51875
For subject 6 mean accuracy is 0.54375
For subject 7 mean accuracy is 0.51875
For subject 8 mean accuracy is 0.48125
For subject 9 mean accuracy is 0.5625
For subject 10 mean accuracy is 0.48125
For subject 11 mean accuracy is 0.45625
For subject 12 mean accuracy is 0.61875
For subject 13 mean accuracy is 0.5375
For subject 14 mean accuracy is 0.49375
For subject 15 mean accuracy is 0.55625
For subject 16 mean accuracy is 0.63125
For subject 17 mean accuracy is 0.53125
For subject 18 mean accuracy is 0.50625
For subject 19 mean accuracy is 0.53125
For subject 20 mean accuracy is 0.55


In [36]:
LEM_results = RCSP.mainloop(SPDDataset, 3, 3, "LEM")

For subject 1 mean accuracy is 0.49375
For subject 2 mean accuracy is 0.53125
For subject 3 mean accuracy is 0.53125
For subject 4 mean accuracy is 0.675
For subject 5 mean accuracy is 0.45
For subject 6 mean accuracy is 0.5125
For subject 7 mean accuracy is 0.54375
For subject 8 mean accuracy is 0.4
For subject 9 mean accuracy is 0.54375
For subject 10 mean accuracy is 0.49375
For subject 11 mean accuracy is 0.475
For subject 12 mean accuracy is 0.6
For subject 13 mean accuracy is 0.51875
For subject 14 mean accuracy is 0.5
For subject 15 mean accuracy is 0.51875
For subject 16 mean accuracy is 0.54375
For subject 17 mean accuracy is 0.45
For subject 18 mean accuracy is 0.44375
For subject 19 mean accuracy is 0.475
For subject 20 mean accuracy is 0.5


In [66]:
classicCSPAverage = sum(results)/len(results)

In [38]:
AIRMCSPAverage = sum(AIRM_results)/len(AIRM_results)

In [39]:
LEMCSPAverage = sum(LEM_results)/len(LEM_results)

In [67]:
classicCSPAverage

0.4978124999999999

In [41]:
AIRMCSPAverage

0.5421875

In [42]:
LEMCSPAverage

0.51