In [None]:
import mne
import scipy
import numpy as np
import matplotlib.pyplot as plt
import utils
import os
import glob
import copy
from findpeaks import findpeaks
from numpy import linalg as LA
from scipy.stats import zscore, pearsonr
from scipy.io import savemat, loadmat
from scipy import signal
%matplotlib widget

In [None]:
def clean_features(feats, smooth=True, objflow=True):
    y = copy.deepcopy(feats)
    if not objflow:
        # discard the coordinate of the center if it is not object detection based optical flow 
        y = y[:,:-2]
    else:
        # recover unnormalized histogram
        # TODO: remove in the future version
        y[:,0:8] = y[:,0:8] * np.expand_dims(y[:,8], axis=1)
    # interpolate NaN values (linealy)
    T, nb_feature = y.shape
    for i in range(nb_feature):
        # interpolate NaN values
        nans, x= np.isnan(y[:,i]), lambda z: z.nonzero()[0]
        if any(nans):
            f1 = scipy.interpolate.interp1d(x(~nans), y[:,i][~nans], fill_value='extrapolate')
            y[:,i][nans] = f1(x(nans))
        if smooth and i < (nb_feature-2): # don't smooth coordinates of the center
            # extract envelope by finding peaks and interpolating peaks with spline
            idx_peaks = scipy.signal.find_peaks(y[:,i])[0]
            idx_rest = np.setdiff1d(np.array(range(T)), idx_peaks)
            f2 = scipy.interpolate.interp1d(idx_peaks, y[:,i][idx_peaks], kind='cubic', fill_value='extrapolate')
            y[:,i][idx_rest] = f2(idx_rest)
    return y

## Data Loading

In [None]:
def multisub_data_org(subjects, video_id, folder, fsStim, bads=[], eog=False, regression=False, normalize=True, smooth=True, objflow=True):
    feats_path_folder = '../Feature extraction/features/'
    if objflow:
        feats_path = feats_path_folder + video_id + '_feats.npy'
    else:
        feats_path = feats_path_folder + video_id + '_flow.npy'
    feats = np.load(feats_path)
    feats = clean_features(feats, smooth=smooth, objflow=objflow)
    T = feats.shape[0]
    eeg_list = []
    for sub in subjects:
        eeg_path = '../../Experiments/data/'+ sub +'/' + folder + '/' + video_id + '_eeg.set'
        eeg_prepro, fs = utils.preprocessing(eeg_path, HP_cutoff = 0.5, AC_freqs=50, resamp_freqs=fsStim, bads=bads, eog=eog, regression=regression, normalize=normalize)
        eeg_channel_indices = mne.pick_types(eeg_prepro.info, eeg=True)
        eeg_downsampled, _ = eeg_prepro[eeg_channel_indices]
        eeg_downsampled = eeg_downsampled.T
        eeg_list.append(eeg_downsampled)
        if eeg_downsampled.shape[0] < T:
            T = eeg_downsampled.shape[0]
    # Clip data
    feats = feats[2*fsStim:T, :]
    eeg_list = [np.expand_dims(eeg[2*fsStim:T,:], axis=2) for eeg in eeg_list]
    eeg_multisub = np.concatenate(tuple(eeg_list), axis=2)
    times = np.array(range(T))/fs
    return feats, eeg_multisub, fs, times

In [None]:
%%capture
subjects = ['AS', 'YY']
folder = 'Single_obj'
eeg_path_folder = "../../Experiments/data/AS/Single_obj/"
video_ids = [dataset[0:2] for dataset in os.listdir(eeg_path_folder) if dataset.endswith('.set')]
# video_ids = ['01', '02', '04', '05', '06', '08', '09', '14', '16']
# video_ids = ['16']
features_list = []
eeg_multisub_list = []
for video_id in video_ids:
    features, eeg_multisub, fs, _ = multisub_data_org(subjects, video_id, folder, fsStim=30, bads=['B25'], eog=True, regression=True, normalize=True, smooth=True, objflow=True)
    # Or do a normalization here
    # features[:,8] = features[:,8]/LA.norm(features[:,8])
    features_list.append(features) 
    eeg_multisub_list.append(eeg_multisub)
feature_concat = np.concatenate(tuple(features_list), axis=0)
# feature_concat = feature_concat/LA.norm(feature_concat)
eeg_multisub_concat = np.concatenate(tuple(eeg_multisub_list), axis=0)
T = feature_concat.shape[0]
times = np.array(range(T))/fs

In [None]:
hist = feature_concat[:,0:8]
mag_avg = np.expand_dims(feature_concat[:,8], axis=1)
mag_up = np.expand_dims(feature_concat[:,9], axis=1)
mag_down = np.expand_dims(feature_concat[:,10], axis=1)
mag_left = np.expand_dims(feature_concat[:,11], axis=1)
mag_right = np.expand_dims(feature_concat[:,12], axis=1)
mag_all = feature_concat[:,8:13]
center = feature_concat[:,13:15]

In [None]:
plt.close()
plt.plot(mag_avg)
plt.show()

## CCA

In [None]:
# features: non-causal temporal filter 
n_components = 5
fold = 10
eeg_onesub = eeg_multisub_concat[:,:,0]
corr_train, corr_test, V_A_train, V_B_train = utils.cross_val_CCA(eeg_onesub, mag_avg, fs, L_EEG=1, L_feat=fs+1, causalx=False, causaly=False, fold=10, n_components=5, regularization='lwcov', K_regu=None, message=True, signifi_level=True)


In [None]:
# features: causal temporal filter 
# Note: GCCA-one subject + stimulus = CCA
datalist = [eeg_multisub_concat[:,:,0], mag_avg]
Llist = [1, fs+1]
causal_list = [False, False]
rhos= [1, 1]
corr_train, corr_test, Wlist_train, Flist_train = utils.cross_val_GCCA_multi_mod(datalist, Llist, causal_list, rhos, fs, fold=10, n_components=5, regularization='lwcov', message=True, signifi_level=True, ISC=True)

In [None]:
# Visualization:
forward_model = utils.forward_model(eeg_multisub_concat[:,:,1], V_A_train)
biosemi_layout = mne.channels.read_layout('biosemi')
create_info = mne.create_info(biosemi_layout.names, ch_types='eeg', sfreq=30)
create_info.set_montage('biosemi64')
plt.close()
plt.figure()
# plt.figure(figsize=(20, 20))
for i in range(5):
    ax = plt.subplot(2, 3, i + 1)
    mne.viz.plot_topomap(forward_model[:,i], create_info, ch_type='eeg', axes=ax)
    ax.set_title('CC '+str(i+1))
plt.show()

## GCCA

In [None]:
# GCCA-all subjects
datalist = [eeg_multisub_concat]
Llist = [1]
causal_list = [False]
rhos = [1]
corr_train, corr_test, Wlist_train, Flist_train = utils.cross_val_GCCA_multi_mod(datalist, Llist, causal_list, rhos, fs, fold=10, n_components=5, regularization='lwcov', message=True, signifi_level=True, ISC=True)

In [None]:
# GCCA-all subjects + stimulus
datalist = [eeg_multisub_concat, mag_avg]
Llist = [1, fs+1]
causal_list = [False, False]
rhos = utils.rho_sweep(datalist, np.linspace(-2,3,11), Llist, causal_list, fs, fold=10, n_components=5, message=True)
corr_train, corr_test, Wlist_train, Flist_train = utils.cross_val_GCCA_multi_mod(datalist, Llist, causal_list, rhos, fs, fold=10, n_components=5, regularization='lwcov', message=True, signifi_level=True, ISC=False)