In [1]:
%env GEOMSTATS_BACKEND=numpy
%env NUMEXPR_MAX_THREADS=12 

import pandas as pd
import numpy as np
import numpy.linalg as la
import pickle

import geomstats.backend as gs
import geomstats.geometry.spd_matrices as spd
from geomstats.learning.frechet_mean import FrechetMean

env: GEOMSTATS_BACKEND=numpy
env: NUMEXPR_MAX_THREADS=12


INFO: Using numpy backend


In [None]:
#filename='BDproject/BCICompetition4_2a_SPD.pickle'
#infile=open(filename,'rb')
#SPDDataset=pickle.load(infile)

In [None]:
#clf0 = LogisticRegression(C=3, max_iter=1000)
#clf1 = SVC(C=3, max_iter=1000)
#clf2 = RidgeClassifier(alpha=3, max_iter=1000)
#clf = [clf0, clf1, clf2]
#clf_names = ["Logistic Regression", "Support Vector Machine", "Ridge Classifier"]

In [2]:
class DataPreparation:
    def __init__(self, directory, conditions, epochs):
        self.conditions=conditions
        self.epochs=epochs
        self.directory=directory
    
    def loadConcat(self, subject):
        train_f=self.directory.format(subject, 'train')
        test_f=self.directory.format(subject, 'test')

        train = pickle.load(open(train_f, 'rb'))
        tr_df=train.to_data_frame()

        test = pickle.load(open(test_f, 'rb'))
        ts_df=test.to_data_frame()  
        ts_df['epoch']+=200

        return pd.concat([tr_df, ts_df])

    def convertToSPD(self, df, normalize=True):
        SPD = [] 
        labels = [] 
        for i in range(self.epochs):
            df_slice=df.loc[df['epoch']==i, :]
            matrix=df_slice.iloc[:, 3:]
            if normalize:
                matrix=(matrix-matrix.mean())/matrix.std()
                
            label=df_slice['condition'].iloc[0]
            for j in range(len(self.conditions)):
                if label==self.conditions[j]:
                    label=j #encoding of conditions to integers
                    break        
            covmat=matrix.cov().to_numpy()
            SPD.append(covmat)
            labels.append(label)
            
        return [SPD, labels]
    
    def generateSPDDataset(self, r, normalize=True):
        SPDDataset=[]
        for i in range(r[0]+1,r[1]+1):
            df=self.loadConcat(i)
            SPD = self.convertToSPD(df, normalize)
            SPDDataset.append(SPD)
        return SPDDataset

In [3]:
subjects = 20
epochs = 400
points = 512
channels = 62
directory = 'datasets/54subjects/Subject{}_{}.pickle'

conditions=['left','right']

# Throughout this code, left condition is marked 0, right condition is marked 1
# Same applies for some arrays, 0th position is smh corresponding to left, 1th position is smh corresponding to right 


In [4]:
dp=DataPreparation(directory, conditions, epochs)
SPDDataset = dp.generateSPDDataset([0,20])

In [1]:
def precomputeMeans(SPDDataset, subjects):
    
    def separate_classes(d, l):
        classSpecificSPD=[]
        for i in range(2): #two classes
            indecies = [j for j,val in enumerate(l) if val==i]
            classSPD = [d[j] for j in indecies]
            classSpecificSPD.append(classSPD)
        return classSpecificSPD
    
    channels = SPDDataset[0][0][0].shape[0]
    
    EM_mean_estimator = FrechetMean(spd.SPDMetricLogEuclidean(n=channels), max_iter=64)
    AIRM_mean_estimator = FrechetMean(spd.SPDMetricAffine(n=channels), max_iter=64)
    
    LEM_means = []
    AIRM_means = []
    
    for i in range(subjects):
        d,l = SPDDataset[i]
        classSpecificSPD = separate_classes(d,l)
        
        LEM_SPDk = []
        AIRM_SPDk = []
        
        for SPD in classSpecificSPD:
            LEM_mean_estimator.fit(SPD)
            LEM_SPDk.append(LEM_mean_estimator.estimate_)
            
            AIRM_mean_estimator.fit(SPD)
            AIRM_SPDk.append(AIRM_mean_estimator.estimate_)
        
        LEM_means.append(LEM_SPDk)
        AIRM_means.append(AIRM_SPDk)
    return LEM_means, AIRM_means

In [17]:
LEM_means, AIRM_means = getMeans(SPDDataset, subjects, channels)



In [18]:
dump='BDproject/LEM_means.pickle'
outfile = open(dump,'w+b')
pickle.dump(LEM_means, outfile)
outfile.close()

In [19]:
dump='BDproject/AIRM_means.pickle'
outfile = open(dump,'w+b')
pickle.dump(AIRM_means, outfile)
outfile.close()