In [42]:
%env GEOMSTATS_BACKEND=numpy
import nu_smrutils as u
import pandas as pd
import numpy as np

from sklearn.pipeline import make_pipeline
from sklearn.model_selection import cross_validate
from sklearn.linear_model import LogisticRegression

#from tensorflow.keras.wrappers.scikit_learn import KerasClassifier
#from tensorflow.keras import Sequential
#from tensorflow.keras.layers import Dense
#from tensorflow.keras.regularizers import L1L2

import geomstats.backend as gs
import geomstats.geometry.spd_matrices as spd
from geomstats.learning.preprocessing import ToTangentSpace

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

env: GEOMSTATS_BACKEND=numpy


In [43]:
filename='datasets/aBNCI2014001R.pickle'
d = u.loaddat(filename)
df=d[0].to_data_frame()
subjects = len(d)
channels = df.columns.to_list()[3:]
l_channels = len(channels)
epochs = len(df['epoch'].unique())
points = df.loc[df['epoch']==1,:].shape[0]

print('Subjects: '+str(subjects))
print('Channels: '+str(l_channels))
print('Total epochs: '+str(epochs))
print('Points per epoch: ' + str(points))

Subjects: 9
Channels: 22
Total epochs: 288
Points per epoch: 321


In [44]:
def convert_data(d, subject, epochs):
    df=d[subject].to_data_frame()
    SPD = []
    labels = [] 
    for j in range(epochs):
        df_slice=df.loc[df['epoch']==j, :]
        matrix=df_slice.iloc[:, 3:]
        label=df_slice['condition'].iloc[0]
        
        if label=='left_hand':
            label=0
        elif label=='right_hand':
            label=1
        
        covmat=matrix.cov().to_numpy()
        SPD.append(covmat)
        labels.append(label)
    return SPD, labels

In [45]:
results=[]
for i in range(subjects):
    SPD, labels = convert_data(d, i, epochs)
    manifold = spd.SPDMatrices(n=l_channels)
    metric = spd.SPDMetricLogEuclidean(n=l_channels)
    pipeline = make_pipeline(
        ToTangentSpace(metric), 
        LogisticRegression(C=3, max_iter=10000)
    )
    
    result = cross_validate(pipeline, SPD, labels)
    results.append(result)
    
    print("Subject #"+str(i+1))
    print("Mean Acc: " + str(100*result['test_score'].mean()))
    print("Std deviation: " + str(100*np.std(result['test_score']))+"\n")

Subject #1
Mean Acc: 78.82637628554144
Std deviation: 6.7726421858475385

Subject #2
Mean Acc: 67.35632183908046
Std deviation: 5.521054232756098

Subject #3
Mean Acc: 91.31276467029643
Std deviation: 4.248762490535245

Subject #4
Mean Acc: 73.28493647912886
Std deviation: 7.590156927159835

Subject #5
Mean Acc: 63.5632183908046
Std deviation: 4.330124113585231

Subject #6
Mean Acc: 63.87779794313369
Std deviation: 4.093595065651524

Subject #7
Mean Acc: 74.3254688445251
Std deviation: 3.3866128834657094

Subject #8
Mean Acc: 94.08348457350272
Std deviation: 3.9314346041301005

Subject #9
Mean Acc: 80.23593466424683
Std deviation: 6.168332857937689

