In [None]:
%env GEOMSTATS_BACKEND=numpy
%env NUMEXPR_MAX_THREADS=12 

import pandas as pd
import numpy as np
import scipy.linalg as la
import pickle


import geomstats as gs
import geomstats.geometry.spd_matrices as spd
from geomstats.learning.frechet_mean import FrechetMean

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.linear_model import RidgeClassifier
from sklearn.svm import SVC

import logging
logger = logging.getLogger()
logger.disabled = True

In [None]:
class RiemannCSP:
    def __init__(self, channels=62):
        self.channels=channels
        
    def separate_classes(self, d, l):
        classSpecificSPD=[]
        for i in range(2): #two classes
            indecies = [j for j,val in enumerate(l) if val==i]
            classSPD = [d[j] for j in indecies]
            classSpecificSPD.append(classSPD)
        return classSpecificSPD
    
    def estimateMeans(self, classSpecificSPD, metric):
        if metric=="AIRM":
            estimator = FrechetMean(spd.SPDMetricAffine(n=self.channels), max_iter=64)
        elif metric=="LEM":
            estimator = FrechetMean(spd.SPDMetricLogEuclidean(n=self.channels), max_iter=64)
        else:
            raise Exception("Not implemented metric")
            
        means = []
        
        for SPD in classSpecificSPD:
            estimator.fit(SPD)
            mean = estimator.estimate_
            means.append(mean)
        return means
        
    
    def getSpatialFilter(self, d, l, n, metric):
        classSpecificSPD = self.separate_classes(d, l)
        
        if metric == "classic":
            left_avg = sum(classSpecificSPD[0])/len(classSpecificSPD[0])
            right_avg = sum(classSpecificSPD[1])/len(classSpecificSPD[1])
        else:
            left_avg, right_avg = self.estimateMeans(classSpecificSPD, metric)
        
        _,V = la.eigh(left_avg, left_avg+right_avg)
    
        V = np.concatenate((V[:, :n], V[:, -n:]), axis=1)
        
        return V
    
    def applySpatialFilter(self, dk, V):
        a = np.dot(np.dot(V.T, dk), V)
        f = np.log(np.diagonal(a))
        return f
    
    def extractFeatures(self, d_train, l_train, d_test, l_test, n, metric):   
        V = self.getSpatialFilter(d_train, l_train, n, metric)
        
        f_train = []
        f_test = []
            
        for i in range(len(d_train)):
            f_train.append(self.applySpatialFilter(d_train[i], V))
   
        for i in range(len(d_test)):
            f_test.append(self.applySpatialFilter(d_test[i], V))
            
        return [f_train, l_train, f_test, l_test]

def crossvalidate(d, l, channels, n, metric, classifier, kfolds = 5):
    assert len(d)==len(l)
    
    
    RCSP = RiemannCSP(channels)
    
    segment = int(len(d)/kfolds)
    
    k_accuracies = []
    
    for i in range(kfolds):
        d_test = d[i*segment:(i+1)*segment]
        l_test = l[i*segment:(i+1)*segment]
        d_train = d[:i*segment] + d[(i+1)*segment:]
        l_train = l[:i*segment] + l[(i+1)*segment:]
        
        [X_train, Y_train, X_test, Y_test] = RCSP.extractFeatures(d_train, l_train, d_test, l_test, n, metric)
        
        classifier.fit(X_train, Y_train)        
        acc = classifier.score(X_test, Y_test)
        
        k_accuracies.append(acc)
    
    av = sum(k_accuracies)/len(k_accuracies)
    
    return av, k_accuracies
    
    
def classification(SPDDataset, subjects, epochs, channels, n, metrics, classifier, verbose=False):
    score = []
    for metric in metrics:
        if verbose:
            print("Metric is " + metric)
        aver = [] 
        for i in range(subjects):
            d = SPDDataset[0][i*epochs:(i+1)*epochs]
            l = SPDDataset[1][i*epochs:(i+1)*epochs]

            av, _ = crossvalidate(d, l, channels, n, metric, classifier)
            aver.append(av)
            if verbose:
                print("Subject "+str(i+1)+". Accuracy = "+str(av))

        a = sum(aver)/len(aver)
        if verbose:
            print("Average is = "+str(a)+"\n")
        score.append(a)
    return sum(score)/len(score)
        



In [None]:
filename='datasets/SPDDataset54.pickle'
infile=open(filename,'rb')
SPDDataset=pickle.load(infile)

subjects = 10
channels = 62
epochs = 400

# Throughout this code, left condition is marked 0, right condition is marked 1
# Same applies for some arrays, 0th position is smh corresponding to left, 1th position is smh corresponding to right 


metrics = ["AIRM", "LEM"]

n_range = [i for i in range(1,9)]
c_range = [i/2 for i in range(1, 7)]

In [None]:
print("LDA: ")
print("Peforming grid search for optimal hyperparameters")

best_acc = 0
best_params = 0


clf = LDA()
for n in n_range:
    average_acc = classification(SPDDataset, subjects, epochs, channels, n, metrics, clf) 
    if average_acc>best_acc:
        best_acc = average_acc
        best_params = n

print("Best n is " + str(best_params))

clf = LDA()
_ = classification(SPDDataset, subjects, epochs, channels, best_params, metrics, clf)