In [None]:
%env GEOMSTATS_BACKEND=numpy
%env NUMEXPR_MAX_THREADS=12 

import pickle
import pandas as pd
import numpy as np

import geomstats.backend as gs
import geomstats.geometry.spd_matrices as spd
from geomstats.learning.frechet_mean import FrechetMean
from geomstats.learning.preprocessing import ToTangentSpace

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

In [None]:
class RMLNM:
    def __init__(self, subjects, epochs, channels, n_classes):
        self.subjects=subjects
        self.epochs=epochs
        self.n_classes=n_classes
        self.LE_metric = spd.SPDMetricLogEuclidean(n=channels)
        self.AI_metric = spd.SPDMetricAffine(n=channels)
        self.LE_mean = FrechetMean(self.LE_metric)
        self.AI_mean = FrechetMean(self.AI_metric)
        self.choices = ["LE", "AI"]
        
    def outer_LOSO(self, d, l):
        subject_acc=[]
        N = self.subjects
        for i in range(N):
            print("Subject: "+str(i+1))
            d_leave = d[i*self.epochs:(i+1)*self.epochs]
            l_leave = l[i*self.epochs:(i+1)*self.epochs]
            d_remain = d[:i*self.epochs] + d[(i+1)*self.epochs:] 
            l_remain = l[:i*self.epochs] + l[(i+1)*self.epochs:]
            choice = self.inner_LOSO(d_remain, l_remain)
            print("Choice for the subject is: " + choice)
            classSpecificSPD = self.separate_classes(d_remain, l_remain)
            SPDk = self.compute_mean(classSpecificSPD, choice)
            accuracy = self.compute_accuracy(d_leave, l_leave, SPDk, choice)  
            print("The subject accuracy is:" + str(accuracy) + "\n")  
        return subject_acc
    
    def inner_LOSO(self, d, l):
        acc = [0 for i in range(len(self.choices))]
        N = self.subjects-1
        for i in range(N):
            d_leave = d[i*self.epochs:(i+1)*self.epochs]
            l_leave = l[i*self.epochs:(i+1)*self.epochs]
            d_remain = d[:i*self.epochs] + d[(i+1)*self.epochs:] 
            l_remain = l[:i*self.epochs] + l[(i+1)*self.epochs:]
            classSpecificSPD = self.separate_classes(d_remain, l_remain)
            for j in range(len(self.choices)):
                SPDk = self.compute_mean(classSpecificSPD, self.choices[j])
                accuracy = self.compute_accuracy(d_leave, l_leave, SPDk, self.choices[j])
                acc[j] = (N-1)*acc[j]/N + accuracy/N #iterative computation of the mean
        print("Inner LOSO round complete!")
        print(self.choices)
        print(acc)
        best = acc.index(max(acc))
        return self.choices[best]
            
    def compute_accuracy(self, d, l, SPDk, choice):
        error = 0
        if choice == "LE":
            method = self.LE_metric
        elif choice == "AI":
            method = self.AI_metric
        else:
            raise Exception("Non implemented metric")
        for i in range(len(d)):
            dist = []
            for j in range(self.n_classes):
                dist.append(method.dist(d[i], SPDk[j]))
            prediction = dist.index(min(dist))
            if prediction!=l[i]:
                error+=1                
        accuracy=1-error/self.epochs
        return accuracy
    
    def compute_mean(self, classSpecificSPD, choice):
        SPDk=[]
        if choice == "LE":
            method = self.LE_mean
        elif choice == "AI":
            method = self.AI_mean
        else:
            raise Exception("Non implemented metric")
        for SPD in classSpecificSPD:
            method.fit(SPD)
            SPDk.append(method.estimate_)    
        return SPDk
    
    def separate_classes(self, d, l):
        classSpecificSPD=[]
        for i in range(self.n_classes):
            indecies = [j for j,val in enumerate(l) if val==i]
            classSPD = [d[j] for j in indecies]
            classSpecificSPD.append(classSPD)
        return classSpecificSPD

In [None]:
subjects = 10
epochs = 400 #total epochs in one subject
channels = 62
conditions=['left','right']

In [None]:
filename='datasets/SPDDataset54.pickle'
infile=open(filename,'rb')
data=pickle.load(infile)
d, l = data[0], data[1]

In [None]:
d=d[:4000]
l=l[:4000]

In [None]:
b=RMLNM(subjects, epochs, channels, len(conditions))
results=b.outer_LOSO(d, l)

In [None]:
results_dump='datasets/resultsRMLNM54.pickle'
outfile = open(results_dump,'w+b')
pickle.dump(results, outfile)
outfile.close()