In [20]:
import mat73
import numpy as np
import scipy
import matplotlib.pyplot as plt
import csv
import os.path
import seaborn as sns
import pandas as pd
import time

# self-made functions
from preprocessing_utils import *
from centers_utils import *
# from classification_utils import *

# use to reload external python file
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Process Data

In [2]:
baseline = [-200, 0] # in ms
frame = [0, 600] # in ms
elec_to_keep = ['FP1', 'FP2', 'F3', 'FZ', 'F4', 'T7', 'Cz', 'T8', 'P7', 'P8', 'O1', 'Oz', 'O2']
sub_num = '01'

In [3]:
eeg = mat73.loadmat("data/s" + sub_num + ".mat")
train_target, train_nontarget, test_target, test_nontarget = processSessionsIndiv(eeg, baseline, frame, elec_to_keep, opt_keep_baseline=False)

# Center Functions with Timing

In [18]:
# matches what we expect: https://ieeexplore.ieee.org/document/8013808
def euclidean_mean_timing(P_set):
    start = time.time()
    mean = np.mean(P_set, axis=0)
    end = time.time()
    return mean, (end-start)

In [8]:
# modified from: https://github.com/pyRiemann/pyRiemann/blob/master/pyriemann/utils/mean.py#L22 
# derived from: https://link.springer.com/chapter/10.1007/978-3-642-00826-9_6 
# find riemannian mean using gradient descent 
def riemannian_mean_timing(P_set, max_iter=50, nu=1.0, tol=10e-9, weights=None):
    tau = np.finfo(np.float64).max
    crit = np.finfo(np.float64).max
    timing = []

    mean_curr = euclidean_mean(P_set)
    n_trials, n_cov, _ = P_set.shape

    if weights == None:
        weights = np.ones(n_trials) / n_trials # evenly weigh all trials

    for i in range(max_iter):
        start = time.time()
        mean_sqrt = sqrtm(mean_curr)
        grad = grad_riemann_mean(P_set, mean_curr, weights)
        mean_curr = mean_sqrt @ expm(nu * grad) @ mean_sqrt

        # this is taken directly from the first link
        crit = np.linalg.norm(grad, ord='fro')
        h = nu * crit
        if h < tau:
            nu = 0.95 * nu
            tau = h
        else:
            nu = 0.5 * nu
        end = time.time()
        timing.append(end - start)

        if crit <= tol or nu <= tol:
            break

    return mean_curr, timing

In [17]:
# from https://www.sciencedirect.com/science/article/pii/S0377042711005218
# uses ADMM + proximal update to do the update
def matrix_median_timing(P_set, lam=0, gamma=1, max_iter=200, tol=10e-9):
    n_trials, n_cov, _ = P_set.shape

    mat_rn = np.zeros((n_cov, n_cov))
    mat_rn = mat_rn.T @ mat_rn

    V_curr = np.tile(mat_rn, (n_trials,1,1)) 

    mat_rn = np.zeros((n_cov, n_cov)) 
    mat_rn = mat_rn.T @ mat_rn
    B_curr = np.tile(mat_rn, (n_trials,1,1))
    
    mat_rn = np.zeros((n_cov, n_cov)) 
    mat_rn = mat_rn.T @ mat_rn
    S_curr = np.tile(mat_rn, (n_trials,1,1))

    X_curr = euclidean_mean(P_set) # informed start
    X_prev = X_curr

    timing = []

    for i in range(max_iter):
        start = time.time()
        # X update
        X_prev = X_curr
        X_curr = np.linalg.inv((lam*gamma + n_trials)*np.identity(n_cov)) @ (np.sum(V_curr - B_curr, axis=0))

        # termination condition
        if np.linalg.norm(np.abs(X_curr - X_prev), ord='fro') < tol:
            end = time.time()
            timing.append(end-start)
            break

        # V update through proximal update on Y
        Y_curr = V_curr - P_set
        S_curr = B_curr + np.tile(X_curr, (n_trials,1,1)) - P_set
        S_norm = np.linalg.norm(S_curr, axis=(1,2), ord='fro')
        Y_new = np.zeros(Y_curr.shape)
        prox_mult = np.tile((1 - (gamma/S_norm)), (n_cov, n_cov, 1))
        prox_mult = np.moveaxis(prox_mult,-1,0)
        greater_idx = S_norm >= gamma
        Y_new[greater_idx] = prox_mult[greater_idx] * S_curr[greater_idx]
        Y_curr = Y_new  
        V_curr = Y_curr + P_set # update V_curr

        # B update - dual update
        B_curr = B_curr + np.tile(X_curr, (n_trials,1,1)) - V_curr

        end = time.time()
        timing.append(end-start)

    return X_curr, timing

In [16]:
# See https://www.sciencedirect.com/science/article/pii/S1053811908012019?via%3Dihub 
# uses steepest descent
# directly from: https://github.com/pyRiemann/pyRiemann/blob/master/pyriemann/utils/median.py
def riemannian_median_timing(P_set, nu=1, max_iter=50, tol=10e-9, weights=None):
    n_trials, n_cov, _ = P_set.shape

    curr_med = euclidean_mean(P_set)

    timing = []

    if weights == None:
        weights = np.ones(n_trials) / n_trials # evenly weigh all trials

    for i in range(max_iter):
        start = time.time()
        distances = np.array([riemannian_distance(P_indiv, curr_med) for P_indiv in P_set])
        is_nonzero = (~(distances == 0))
        nonzero_weights = weights[is_nonzero] / distances[is_nonzero]

        med_sqrt = sqrtm(curr_med)
        med_invsqrt = invsqrtm(curr_med)
        tangent_vecs = logm(med_invsqrt @ P_set[is_nonzero] @ med_invsqrt)
        grad = np.einsum('a,abc->bc', nonzero_weights / np.sum(nonzero_weights), tangent_vecs)
        curr_med = med_sqrt @ expm(nu * grad) @ med_sqrt

        crit = np.linalg.norm(grad, ord='fro')
        end = time.time()
        timing.append(end-start)
        if crit <= tol:
            break

    return curr_med, timing

In [15]:
# From: https://ieeexplore.ieee.org/abstract/document/7523317
def huber_centroid_timing(P_set, alpha=0.25, mu=0.5, nu_init=0.5, max_iter=50, tol=10e-9): 
    n_trials, n_cov, _ = P_set.shape
    curr = euclidean_mean(P_set)
    nu = nu_init

    timing = []

    for i in range(max_iter):
        start = time.time()
        grad = huber_grad(P_set, curr)

        # attempt at Armijo backsearching - need to check update
        # while huber_obj(P_set, exp_map(curr, -nu*grad)) > (huber_obj(P_set, curr) + nu*np.sum(huber_grad(P_set,curr)*exp_map(curr, -grad))):
        #     nu = mu * nu

        curr = exp_map(curr, -nu*grad)
        crit = np.linalg.norm(grad, ord='fro')
        end = time.time()
        timing.append(end-start)
        if crit <= tol:
            break

    return curr, timing

# Get Timing of One Iteration 

In [14]:
P_test = train_nontarget[0]
P_test.shape

(750, 26, 26)

In [32]:
_, euclid_mean_time = euclidean_mean_timing(P_test)
print("Time of one iteration (s):", np.mean(euclid_mean_time))
print("Number of iterations:", 1)


Time of one iteration (s): 0.000997781753540039
Number of iterations: 1


In [33]:
_, riemann_mean_time = riemannian_mean_timing(P_test)
print("Time of one iteration (s):", np.mean(riemann_mean_time))
print("Number of iterations:", len(riemann_mean_time))

Time of one iteration (s): 0.04880280494689941
Number of iterations: 15


In [34]:
_, mat_med_time = matrix_median_timing(P_test)
print("Time of one iteration (s):", np.mean(mat_med_time))
print("Number of iterations:", len(mat_med_time))

Time of one iteration (s): 0.02243238442564664
Number of iterations: 146


In [35]:
_, riemann_med_time = riemannian_median_timing(P_test)
print("Time of one iteration (s):", np.mean(riemann_med_time))
print("Number of iterations:", len(riemann_med_time))

Time of one iteration (s): 0.1380513572692871
Number of iterations: 50


In [36]:
_, huber_time = huber_centroid_timing(P_test)
print("Time of one iteration (s):", np.mean(huber_time))
print("Number of iterations:", len(huber_time))

Time of one iteration (s): 0.17463696479797364
Number of iterations: 50
