In [1]:
import scipy.io
import os
import numpy as np
import pandas as pd
import time
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.metrics import cohen_kappa_score
import Butterworth_filtering

from BW_metric import BW_dist
from AI_metric import AI_dist
from manifold_project_toolbox import F_dist
from manifold_project_toolbox import bw_projection_mean
from manifold_project_toolbox import ai_projection_mean
from manifold_project_toolbox import x2corr

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all" # enable multiple output in one cell

In [2]:
def check_weights(weights, n_weights, *, check_positivity=False):
    
    if weights is None:
        weights = np.ones(n_weights)

    else:
        weights = np.asarray(weights)
        if weights.shape != (n_weights,):
            raise ValueError(
                "Weights do not have the good shape. Should be (%d,) but got "
                "%s." % (n_weights, weights.shape,)
            )
        if check_positivity and any(weights <= 0):
            raise ValueError("Weights must be strictly positive.")

    weights /= np.sum(weights)
    return weights
    
def mean_euclid(X, sample_weight=None):
    return np.average(X, axis=0, weights=sample_weight)

def _matrix_operator(C, operator):
    """Matrix function."""
    if not isinstance(C, np.ndarray) or C.ndim < 2:
        raise ValueError("Input must be at least a 2D ndarray")
    if C.dtype.char in np.typecodes['AllFloat'] and (
            np.isinf(C).any() or np.isnan(C).any()):
        raise ValueError(
            "Matrices must be positive definite. Add "
            "regularization to avoid this error.")
    eigvals, eigvecs = np.linalg.eigh(C)
    eigvals = operator(eigvals)
    if C.ndim >= 3:
        eigvals = np.expand_dims(eigvals, -2)
    D = (eigvecs * eigvals) @ np.swapaxes(eigvecs.conj(), -2, -1)
    return D

def sqrtm(C):
    return _matrix_operator(C, np.sqrt)

def invsqrtm(C):
    def isqrt(x): return 1. / np.sqrt(x)
    return _matrix_operator(C, isqrt)

def logm(C):
    return _matrix_operator(C, np.log)

def expm(C):
    return _matrix_operator(C, np.exp)

import warnings
def mean_riemann(X, *, tol=10e-9, maxiter=50, init=None, sample_weight=None):
    
    dists = []
    ssds = []
    
    n_matrices, n, _ = X.shape
    sample_weight = check_weights(sample_weight, n_matrices)
    if init is None:
        M = mean_euclid(X, sample_weight=sample_weight)
    else:
        M = check_init(init, n)
    
    nu = 1.0
    tau = np.finfo(np.float64).max
    crit = np.finfo(np.float64).max
    for i in range(maxiter):
        print(i)
        M12, Mm12 = sqrtm(M), invsqrtm(M)
        J = np.einsum("a,abc->bc", sample_weight, logm(Mm12 @ X @ Mm12))
        M_new = M12 @ expm(nu * J) @ M12

        # dist_mean = AI_dist(M_new, M)
        # dists.append(dist_mean)
        # ssd_new = np.sum(np.array([AI_dist(M_new,X[i])**2 for i in range(X.shape[0])]))
        # ssds.append(ssd_new)

        M = M_new
        
        crit = np.linalg.norm(J, ord="fro")
        h = nu * crit
        if h < tau:
            nu = 0.95 * nu
            tau = h
        else:
            nu = 0.5 * nu

        # print("crit <= tol:", crit <= tol, "nu <= tol:", nu <= tol)
        
        if crit <= tol or nu <= tol:
            break
        else:
            warnings.warn("Convergence not reached")

    # return M, np.array(dists), np.array(ssds), i+1
    return M

In [3]:
def MDM_dist(A, B, method):
    if method == "BW":
        output = BW_dist(A,B)
    elif method == "AI":
        output = AI_dist(A,B)
    elif method == "Euc":
        output = F_dist(A,B)
    return output

def MDM_mean(x, eps, method, verbose=False):
    # input: psd matrices [N,n,n]
    # output: mean matrix [n,n]
    if method == "BW":
        output = bw_projection_mean(x, eps, verbose=verbose)
    elif method == "AI":
        output = mean_riemann(x, tol=eps)
        # output = ai_projection_mean(x, eps, verbose=verbose)
    elif method == "Euc":
        output = np.mean(x, axis=0) # arithmetic mean
    return output

In [21]:
# 4,5,6,8,9,10,11,12,13,15,17,18,20,21,23,24

# Lab Data
subject = 24
data_dir = "D:/Projects/MDRM 2.0/Lab Data/Epoched Data"
data = scipy.io.loadmat(os.path.join(data_dir, "p"+str(subject)+"c1.mat"))
data_filtered = Butterworth_filtering.filter_all(data['events'], order=5, lowcut=1, highcut=30, fs=500)

sample = x2corr(data_filtered)
label = data['labels']

sample.shape
label.shape

# from sklearn.covariance import LedoitWolf

# shrunk_matrices = []
# shrinkage_values = []
# for mat in sample:
#     lw = LedoitWolf()
#     lw.fit(mat)
#     shrunk_matrices.append(lw.covariance_)
#     shrinkage_values.append(lw.shrinkage_)

# sample = np.array(shrunk_matrices)
# sample.shape

(17, 64, 64)

(17,)

In [5]:
accuracy_all = {'AI': {}, 'BW': {}, 'Euc': {}}
execution_time_all = {'AI': {}, 'BW': {}, 'Euc': {}}

for method in ['AI','BW','Euc']:

    # parameter setting
    repetition = 100 # number of reps for cv
    
    accuracy = np.zeros(repetition, dtype = np.float32)
    execution_time = np.zeros(repetition, dtype = np.float32)
    
    for rep_idx in range(repetition):
        
        # stratified 80/20 train-test-split
        x_train, x_test, y_train, y_test = train_test_split(sample, label, test_size=0.2, stratify=label, random_state=rep_idx+22)
        ##### MDM
        # compute mean of each class
        start_time = time.time() # timer starts
        mean1 = MDM_mean(x_train[np.where(y_train==0)], 0.0001, method)
        mean2 = MDM_mean(x_train[np.where(y_train==1)], 0.0001, method)
        # assign each testing sample to the nearest class mean
        y_pred = np.zeros(x_test.shape[0], dtype = np.float32)
        for test_idx in range(x_test.shape[0]):
            dist1 = MDM_dist(mean1, x_test[test_idx], method)
            dist2 = MDM_dist(mean2, x_test[test_idx], method)
            y_pred[test_idx] = 0 if dist1<dist2 else 1
        end_time = time.time() # timer stops
        # record accuracy and execution time
        accuracy[rep_idx] = sum(y_pred==y_test)/len(y_test)
        execution_time[rep_idx] = end_time - start_time
        
        print("Repetition " + str(rep_idx+1) + " completed.")

    accuracy_all[method] = accuracy
    execution_time_all[method] = execution_time

0
1
2
3
4
5
0
1
2
3
4
Repetition 1 completed.
0
1
2
3
4
5
0
1
2
3
Repetition 2 completed.
0
1
2
3
4
5
0
1
2
3
4
Repetition 3 completed.
0
1
2
3
4
0
1
2
3




Repetition 4 completed.
0
1
2
3
4
0
1
2
3
Repetition 5 completed.
0
1
2
3
4
5
0
1
2
3
4
Repetition 6 completed.
0
1
2
3
4
0
1
2
3
4
Repetition 7 completed.
0
1
2
3
4
5
0
1
2
3
4




Repetition 8 completed.
0
1
2
3
4
0
1
2
3
4
Repetition 9 completed.
0
1
2
3
4
0
1
2
3
Repetition 10 completed.
0
1
2
3
4
5
0
1
2
3
Repetition 11 completed.
0
1
2
3
4
5
0
1
2
3
4
Repetition 12 completed.
0




1
2
3
4
0
1
2
3
4
Repetition 13 completed.
0
1
2
3
4
5
0
1
2
3
4
Repetition 14 completed.
0
1
2
3
4
5
0
1
2
3
4
Repetition 15 completed.
0
1
2
3
4
5
0
1
2
3
4
Repetition 16 completed.
0




1
2
3
4
5
0
1
2
3
4
Repetition 17 completed.
0
1
2
3
4
0
1
2
3
4
Repetition 18 completed.
0
1
2
3
4
5
0
1
2
3
Repetition 19 completed.
0
1
2
3
4
5
0
1
2
3
4
Repetition 20 completed.
0
1
2
3
4
5
0
1
2
3
4
Repetition 21 completed.
0
1
2
3
4
5
0
1
2
3
4
Repetition 22 completed.
0
1
2
3
4
5
0
1
2
3
4
Repetition 23 completed.
0
1
2
3
4
5
0
1
2
3
4
Repetition 24 completed.
0
1
2
3
4
0
1




2
3
Repetition 25 completed.
0
1
2
3
4
0
1
2
3
4
Repetition 26 completed.
0
1
2
3
4
5
0
1
2
3
Repetition 27 completed.
0
1
2
3
4
5
0
1
2
3
4
Repetition 28 completed.
0
1
2
3
4
0
1
2
3
4




Repetition 29 completed.
0
1
2
3
4
0
1
2
3
4
Repetition 30 completed.
0
1
2
3
4
5
0
1
2
3
4
Repetition 31 completed.
0
1
2
3
4
0
1
2
3
Repetition 32 completed.
0
1
2
3
4
0
1
2
3
4
Repetition 33 completed.
0
1
2
3
4
5
0
1




2
3
4
Repetition 34 completed.
0
1
2
3
4
5
0
1
2
3
4
Repetition 35 completed.
0
1
2
3
4
0
1
2
3
4
Repetition 36 completed.
0
1
2
3
4
5
0
1
2
3
4
Repetition 37 completed.
0
1
2
3
4
5
0
1
2




3
4
Repetition 38 completed.
0
1
2
3
4
0
1
2
3
4
Repetition 39 completed.
0
1
2
3
4
5
0
1
2
3
4
Repetition 40 completed.
0
1
2
3
4
0
1
2
3
4
Repetition 41 completed.
0
1
2
3
4
5
0




1
2
3
4
Repetition 42 completed.
0
1
2
3
4
0
1
2
3
Repetition 43 completed.
0
1
2
3
4
0
1
2
3
4
Repetition 44 completed.
0
1
2
3
4
5
0
1
2
3
4
Repetition 45 completed.
0
1
2
3
4
5
0
1
2
3
4




Repetition 46 completed.
0
1
2
3
4
0
1
2
3
Repetition 47 completed.
0
1
2
3
4
5
0
1
2
3
4
Repetition 48 completed.
0
1
2
3
4
0
1
2
3
Repetition 49 completed.
0
1
2
3
4
5
0
1
2
3
4
Repetition 50 completed.
0
1
2
3
4
5




0
1
2
3
Repetition 51 completed.
0
1
2
3
4
0
1
2
3
4
Repetition 52 completed.
0
1
2
3
4
5
0
1
2
3
Repetition 53 completed.
0
1
2
3
4
0
1
2
3
4
Repetition 54 completed.
0
1
2
3
4
0
1
2
3
Repetition 55 completed.
0




1
2
3
4
5
0
1
2
3
4
Repetition 56 completed.
0
1
2
3
4
0
1
2
3
Repetition 57 completed.
0
1
2
3
4
0
1
2
3
4
Repetition 58 completed.
0
1
2
3
4
5
0
1
2
3
4
Repetition 59 completed.
0
1
2
3
4
5
0
1
2
3
4
Repetition 60 completed.
0
1
2
3
4
5
0
1
2
3
Repetition 61 completed.
0
1
2
3
4
0
1
2
3
Repetition 62 completed.
0
1
2
3
4
0
1
2
3
Repetition 63 completed.
0
1




2
3
4
0
1
2
3
4
Repetition 64 completed.
0
1
2
3
4
5
0
1
2
3
4
Repetition 65 completed.
0
1
2
3
4
0
1
2
3
4
Repetition 66 completed.
0
1
2
3
4
0
1
2
3
4
Repetition 67 completed.
0
1
2




3
4
5
0
1
2
3
Repetition 68 completed.
0
1
2
3
4
5
0
1
2
3
4
Repetition 69 completed.
0
1
2
3
4
0
1
2
3
Repetition 70 completed.
0
1
2
3
4
0
1
2
3
Repetition 71 completed.
0
1
2
3
4
5
0
1
2




3
Repetition 72 completed.
0
1
2
3
4
5
0
1
2
3
4
Repetition 73 completed.
0
1
2
3
4
0
1
2
3
Repetition 74 completed.
0
1
2
3
4
5
0
1
2
3
4
Repetition 75 completed.
0
1
2
3
4
5
0
1
2
3
Repetition 76 completed.
0
1
2
3




4
0
1
2
3
Repetition 77 completed.
0
1
2
3
4
5
0
1
2
3
Repetition 78 completed.
0
1
2
3
4
0
1
2
3
4
Repetition 79 completed.
0
1
2
3
4
5
0
1
2
3
Repetition 80 completed.
0
1
2
3
4
5
0
1




2
3
4
Repetition 81 completed.
0
1
2
3
4
0
1
2
3
Repetition 82 completed.
0
1
2
3
4
0
1
2
3
Repetition 83 completed.
0
1
2
3
4
0
1
2
3
Repetition 84 completed.
0
1
2
3
4
0
1
2
3
4
Repetition 85 completed.
0




1
2
3
4
0
1
2
3
4
Repetition 86 completed.
0
1
2
3
4
5
0
1
2
3
4
Repetition 87 completed.
0
1
2
3
4
0
1
2
3
Repetition 88 completed.
0
1
2
3
4
5
0
1
2
3
4
Repetition 89 completed.
0
1
2
3
4
5
0
1
2
3
4
Repetition 90 completed.
0
1
2
3
4
5
0
1
2
3
Repetition 91 completed.
0
1
2
3
4
5
0
1
2
3
Repetition 92 completed.
0
1
2
3
4
5
0
1
2
3
4
Repetition 93 completed.
0
1
2
3
4
0
1
2
3
4




Repetition 94 completed.
0
1
2
3
4
5
0
1
2
3
4
Repetition 95 completed.
0
1
2
3
4
0
1
2
3
4
Repetition 96 completed.
0
1
2
3
4
5
0
1
2
3
4
Repetition 97 completed.
0
1
2
3
4
0
1
2
3
Repetition 98 completed.
0
1
2
3




4
0
1
2
3
4
Repetition 99 completed.
0
1
2
3
4
5
0
1
2
3
Repetition 100 completed.
Repetition 1 completed.
Repetition 2 completed.
Repetition 3 completed.
Repetition 4 completed.
Repetition 5 completed.
Repetition 6 completed.
Repetition 7 completed.
Repetition 8 completed.
Repetition 9 completed.
Repetition 10 completed.
Repetition 11 completed.
Repetition 12 completed.
Repetition 13 completed.
Repetition 14 completed.
Repetition 15 completed.
Repetition 16 completed.
Repetition 17 completed.
Repetition 18 completed.
Repetition 19 completed.
Repetition 20 completed.
Repetition 21 completed.
Repetition 22 completed.
Repetition 23 completed.
Repetition 24 completed.
Repetition 25 completed.
Repetition 26 completed.
Repetition 27 completed.
Repetition 28 completed.
Repetition 29 completed.
Repetition 30 completed.
Repetition 31 completed.
Repetition 32 completed.
Repetition 33 completed.
Repetition 34 completed.
Repetition 35 completed.
Repetition 36 completed.
Repetition 37 completed.
R

In [6]:
np.mean(accuracy_all['AI'])
np.mean(accuracy_all['BW'])
np.mean(accuracy_all['Euc'])

np.float32(1.0)

np.float32(1.0)

np.float32(1.0)

In [7]:
np.mean(execution_time_all['AI'])
np.mean(execution_time_all['BW'])
np.mean(execution_time_all['Euc'])

np.float32(0.048292443)

np.float32(0.12964416)

np.float32(0.00020023584)

In [8]:
save_dir = "D:/Projects/Scientific Report/MDRM with LedoitWolf/Lab Data/"
rows = ['Rep{}'.format(i+1) for i in list(range(repetition))]

for method in ['AI','BW','Euc']:
    
    accuracy_df = pd.DataFrame(accuracy_all[method], index=rows)
    execution_time_df = pd.DataFrame(execution_time_all[method], index=rows)

    accuracy_df.to_csv(save_dir+"p"+str(subject)+"c1"+"_accuracy_"+method+".csv")
    execution_time_df.to_csv(save_dir+"Execution Time/"+"p"+str(subject)+"c1"+"_execution_time_"+method+".csv")