## Results and Plotting
A notebook to compile results and organise plots and other resources needed for reporting

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# set global plotting params here for consistency
plt.rcParams['axes.titlesize'] = 20
plt.rcParams['axes.labelsize'] = 16
plt.rcParams['xtick.labelsize'] = 14
plt.rcParams['ytick.labelsize'] = 14

In [12]:
stim_freqs = [7,10,12] # stim freqs used
fs = 64 # sampling freq
Ns = 128 # number of sample points to consider
Nh = 1 # number of harmonics for CCA-based algos

index_pos = dict(zip(["Nc", "Ns", "Nt"], range(3)))

## Data Loading
Load data from json log files and arrange by frequency. The compiled data is stored in the dictionary `data` whose keys are the stimulus frequencies used and whose values are the data tensors corresponding to trials at those frequencies. Data tensors will be arranged like `Nc x Ns x Nt` (channels x samples x trials). 

Note that in this project, we only effectively had one channel. Also, all Nt trials would be independent recordings at the same stimulus frequency.


In [153]:
from eeg_lib.utils import read_json
import json

tests = {7: ["test-7hz-pos2"], 
         10: ["test-10hz-pos2"], 
         12: ["test-12hz-pos2"]
        }

all_data = read_json('eeg_lib/log_data.json')
data = {}

for f, test_set in tests.items():
    data[f] = []
    
    for test in test_set:
        values = all_data[test]
        proc_data = np.array([json.loads(values[i]) for i in range(len(values))])
        data[f].append(proc_data[1:, :Ns].reshape((1, Ns, -1))) # exclude first trial
        
# del all_data    

for f, proc_data in data.items():
    if len(proc_data) <= 1:
        data[f] = proc_data[0]
    else:
        data[f] = np.concatenate([*proc_data], axis=-1) # merge data from across trials

## Decoding
Run various decoding algos on gathered data and store results for comparison

### CCA
Vanilla CCA with no historical training data used across evaluations

In [146]:
from eeg_lib.cca import CCA

cca = CCA(stim_freqs, fs, Nh=Nh)

cca_results = {f:[] for f in stim_freqs}
cca_agg_results = {f:{} for f in stim_freqs}

for f in stim_freqs:
    data_f = data[f]
    for trial in range(1, data[f].shape[index_pos["Nt"]]):
        Xi = data_f[:, :, trial]
        result = cca.compute_corr(Xi)
        result = {k:np.round(v[0], 6) for k,v in result.items()}
        
        result['trial'] = f'f{f}_{trial}'
        result['y'] = f
        cca_results[f].append(result)
        
        # compute CCA result using data aggregated across trials
        agg_result = cca.compute_corr(data_f.mean(axis=index_pos["Nt"]))
        agg_result = {k:np.round(v[0], 6) for k,v in agg_result.items()}
        agg_result['y'] = f
        cca_agg_results[f] = agg_result
    
cca_df = pd.concat([pd.DataFrame(result_set) for result_set in cca_results.values()]).set_index('trial')
cca_df['y_hat'] = cca_df[stim_freqs].apply(lambda row: stim_freqs[np.argmax(row)], axis=1)

cca_agg_df = pd.DataFrame(list(cca_agg_results.values()))
cca_agg_df['y_hat'] = cca_agg_df[stim_freqs].apply(lambda row: stim_freqs[np.argmax(row)], axis=1)

cca_df

Unnamed: 0_level_0,7,10,12,y,y_hat
trial,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
f7_1,0.070296,0.034772,0.08788,7,12
f7_2,0.107306,0.095626,0.032961,7,7
f7_3,0.104282,0.133842,0.085106,7,10
f10_1,0.108304,0.060147,0.106589,10,7
f10_2,0.090138,0.16853,0.042788,10,10
f10_3,0.174531,0.166248,0.015555,10,7
f12_1,0.089876,0.205658,0.089087,12,10
f12_2,0.102849,0.168012,0.086462,12,10


## Template-based Algorithms 
This section explores decoding algos that, along with potentially the artificially-generated harmonic reference, include template data based on historical 'training' data. These include GCCA, MsetCCA, TRCA and others.

In [154]:
min_trial_len = np.min([test_set.shape[-1] for test_set in data.values()])

# Nf x Nc x Ns x Nt
data_tensor = np.array([test_set[:,:,:min_trial_len] for test_set in data.values()])

print("Data tensor shape: ", data_tensor.shape)

Data tensor shape:  (3, 1, 128, 6)


In [160]:
from sklearn.model_selection import LeavePOut

N_train = 3
lpo = LeavePOut(p=N_train)

n_trials = data_tensor.shape[-1]
template_idxs = list(lpo.split(range(n_trials)))

#### GCCA
Generalised CCA aims to simultaneously maximise correlation between three sets of data: historical observations, measured signals in a new sample and the pre-constructed sinusoidal reference. As interpreted by the authors (Wong et al), the optimal spatial filters obtained through GCCA perform SSVEP signal denoising.

#### MsetCCA
MsetCCA is one extension of standard CCA that takes into account historical data instead of performing inference purely on new observations. Zhang et al propose that this is one of the reasons that standard CCA performs poorly on short time windows; it effectively over fits to localised dynamics. Furthermore, the authors suggest that exclusively using the pre-constructed sinusoidal reference set is not optimal since this artificial reference does not exclude other features from real EEG data. To circumvent this, MsetCCA seeks to optimise the reference signals used in the CCA algorithm by learning multiple linear transforms to maximise overall correlation between canonical variables over many sets of EEG data at each candidate frequency fk ∈ F. This optimisation effectively finds optimal joint spatial filters w1, . . . , wNt (over Nt trials) using only historical observations (‘training’ data). The authors claim that MsetCCA outperforms similar techniques, especially in cases with few channels and short time windows.

In [161]:
from eeg_lib.cca import GCCA_SSVEP
from eeg_lib.cca import MsetCCA_SSVEP

gcca = GCCA_SSVEP(stim_freqs, fs, Nh=Nh)
gcca_results = {f:[] for f in stim_freqs}

mset_cca = MsetCCA_SSVEP(stim_freqs)
mset_cca_results = {f:[] for f in stim_freqs}

for f in stim_freqs:
    for split_idx, (test_idxs, train_idxs) in enumerate(template_idxs):
        chi_train = data_tensor[:, :, :, train_idxs]
        
        # train models on current train-test split
        gcca.fit(chi_train)
        mset_cca.fit(chi_train)
        
        # extract test matrices from all test indices and compute result
        for test_idx in test_idxs:
            if test_idx in train_idxs:
                raise ValueError("Found intersection between train and test indices")
                
            X_test = data[f][:, :, test_idx]
            
            # GCCA
            result = {k: abs(np.round(v,4)) for k,v in gcca.classify(X_test).items()}
            result['idx'] = f'f{f}_split{split_idx+1}_test{test_idx+1}'
            result['y'] = f
            gcca_results[f].append(result)
            
            # MsetCCA
            result = {k: abs(np.round(v,4)) for k,v in mset_cca.classify(X_test).items()}
            result['idx'] = f'f{f}_split{split_idx+1}_test{test_idx+1}'
            result['y'] = f
            mset_cca_results[f].append(result)
            
gcca_df = pd.concat([pd.DataFrame(result_set) for result_set in gcca_results.values()]).set_index('idx')
gcca_df['y_hat'] = gcca_df[stim_freqs].apply(lambda row: stim_freqs[np.argmax(row)], axis=1)

mset_cca_df = pd.concat([pd.DataFrame(result_set) for result_set in mset_cca_results.values()]).set_index('idx')
mset_cca_df['y_hat'] = mset_cca_df[stim_freqs].apply(lambda row: stim_freqs[np.argmax(row)], axis=1)

# first analysis of accuracy
for f in stim_freqs:
    tmp_df = gcca_df[gcca_df['y']==f]
    gcca_acc_f = 100*len(tmp_df[tmp_df['y'] == tmp_df['y_hat']])/len(tmp_df)
    
    tmp_df = mset_cca_df[gcca_df['y']==f]
    mset_cca_acc_f = 100*len(tmp_df[tmp_df['y'] == tmp_df['y_hat']])/len(tmp_df)
    
    print(f"Results for {f}Hz: GCCA {round(gcca_acc_f, 3)}%, MsetCCA {round(mset_cca_acc_f, 3)}%")

Results for 7Hz: GCCA 90.0%, MsetCCA 91.667%
Results for 10Hz: GCCA 88.333%, MsetCCA 98.333%
Results for 12Hz: GCCA 80.0%, MsetCCA 90.0%


In [157]:
mset_cca_df[mset_cca_df['y'] == 10]

Unnamed: 0_level_0,7,10,12,y,y_hat
idx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
f10_split1_test3,0.1674,0.3053,0.1574,10,10
f10_split1_test4,0.0547,0.1958,0.0781,10,10
f10_split1_test5,0.1375,0.3009,0.2361,10,10
f10_split1_test6,0.1466,0.2036,0.1905,10,10
f10_split2_test2,0.0695,0.5176,0.1979,10,10
f10_split2_test4,0.0706,0.3843,0.0691,10,10
f10_split2_test5,0.2158,0.2623,0.2563,10,10
f10_split2_test6,0.1439,0.2634,0.1705,10,10
f10_split3_test2,0.0656,0.4629,0.1989,10,10
f10_split3_test3,0.0932,0.3707,0.178,10,10
