In [None]:
import os
import utils_stream
import numpy as np
import scipy
import matplotlib.pyplot as plt
from scipy.io import loadmat
%matplotlib widget

In [None]:
def data_per_subj(Subj_ID, data_folder):
    file_path = f'{data_folder}/dataSubject{Subj_ID}.mat'
    data_dict = loadmat(file_path, squeeze_me=True)
    if 'earEEG' in data_folder:
        eeg_trials = [trial[:,:29] for trial in data_dict['eegTrials']]
    else:
        eeg_trials = data_dict['eegTrials']
    feats_trials = data_dict['audioTrials']
    label_trials = data_dict['attSpeaker']
    nb_trials = len(eeg_trials)
    eeg_trials = [eeg_trials[i] for i in range(nb_trials)]
    feats_trials = [feats_trials[i] for i in range(nb_trials)]
    label_trials = [label_trials[i] for i in range(nb_trials)]
    return eeg_trials, feats_trials, label_trials

In [None]:
# latent_dimensions = 5
# fs = 20
# hparadata = [4, 3]
# hparafeats = [6, 0]
# evalpara = [3, 2]
# weightpara = [0.9, 0.9]
# pool_size = 19
# SEED = 2
# PARATRANS = True
# ori_trial_len = 60
# trial_len = 60
# nb_calibsessions = 1
# nb_disconnected = 0
# sub_trial_length = 15
# SHUFFLE = True
# UPDATE_STEP = trial_len // sub_trial_length if sub_trial_length else 1
# dataset = 'Neetha'

In [None]:
# latent_dimensions = 5
# fs = 32
# hparadata = [4, 3]
# hparafeats = [9, 0]
# evalpara = [3, 2]
# weightpara = [0.916, 0.916]
# pool_size = 23
# SEED = 2
# PARATRANS = True
# ori_trial_len = 50
# trial_len = 50
# nb_calibsessions = 1
# nb_disconnected = 0
# sub_trial_length = 25
# SHUFFLE = False
# UPDATE_STEP = trial_len // sub_trial_length if sub_trial_length else 1
# dataset = 'fuglsang2018'

In [None]:
latent_dimensions = 5
fs = 20
hparadata = [4, 3]
hparafeats = [6, 0]
evalpara = [3, 2]
weightpara = [0.9, 0.9]
pool_size = 19
SEED = 2
PARATRANS = True
ori_trial_len = 600
trial_len = 60
nb_calibsessions = 1
nb_disconnected = 0
sub_trial_length = 30
SHUFFLE = False
UPDATE_STEP = trial_len // sub_trial_length if sub_trial_length else 1
dataset = 'earEEG'

In [None]:
data_folder = f'../../Experiments/Data/{dataset}/'
files = [f for f in os.listdir(data_folder) if f.endswith('.mat')]
subjects = [int(''.join(filter(str.isdigit, f))) for f in files]
subjects.sort()
eeg_dict = {}
feats_dict = {}
labels_dict = {}
for subject in subjects:
    eeg_trials, feats_trials, label_trials = data_per_subj(subject, data_folder)
    if trial_len != ori_trial_len:
        eeg_trials, feats_trials, label_trials = utils_stream.further_split_and_shuffle(eeg_trials, feats_trials, label_trials, trial_len, fs)
    eeg_dict[subject] = eeg_trials
    feats_dict[subject] = feats_trials
    labels_dict[subject] = label_trials
nb_trials = 24 # len(eeg_dict[subjects[0]])

In [None]:
data_subjects_dict = {}
feats_subjects_dict = {}
labels_subjects_dict = {}
rng = np.random.RandomState(SEED)
for subj in subjects:
    data_trials = eeg_dict[subj]
    if nb_disconnected > 0:
        disconnected_channels = rng.choice(data_trials[0].shape[1], nb_disconnected, replace=False)
        for i in range(len(data_trials)):
            data_trials[i][:, disconnected_channels] = 0
    data_trials = [utils_stream.process_data_per_view(d, hparadata[0], hparadata[1], NORMALIZE=True) for d in data_trials]
    feats_trials = feats_dict[subj]
    feats_trials = [utils_stream.process_data_per_view(f, hparafeats[0], hparafeats[1], NORMALIZE=True) for f in feats_trials]
    labels_trials = labels_dict[subj]
    if sub_trial_length is not None:
        data_trials, feats_trials, labels_trials = utils_stream.further_split_and_shuffle(data_trials, feats_trials, labels_trials, sub_trial_length, fs, SHUFFLE=SHUFFLE, SEED=SEED)
    data_subjects_dict[subj] = data_trials
    feats_subjects_dict[subj] = feats_trials
    labels_subjects_dict[subj] = labels_trials

In [None]:
stream = utils_stream.STREAM(data_subjects_dict, feats_subjects_dict, hparadata[0], hparafeats[0], latent_dimensions, SEED, evalpara, nb_trials, UPDATE_STEP)
true_labels = np.concatenate([v[:nb_trials*UPDATE_STEP] for v in labels_subjects_dict.values()])

pred_labels_dict = stream.fixed_supervised(labels_subjects_dict, subjects[:nb_calibsessions], subjects[nb_calibsessions:])
pred_labels_fixed = np.concatenate([v for v in pred_labels_dict.values()])

pred_labels_dict = stream.adaptive_supervised(labels_subjects_dict, weightpara, PARATRANS=PARATRANS)
pred_labels_adapsup = np.concatenate([v for v in pred_labels_dict.values()])

pred_labels_dict = stream.recursive(weightpara, PARATRANS=PARATRANS, SINGLEENC=True)
pred_labels_single = np.concatenate([v for v in pred_labels_dict.values()])

pred_labels_dict = stream.recursive(weightpara, PARATRANS=PARATRANS, SINGLEENC=False)
pred_labels_two = np.concatenate([v for v in pred_labels_dict.values()])

In [None]:
print("###########Fixed Supervised###########")
acc_non_calib_fixed, acc_fixed = utils_stream.calc_smooth_acc(pred_labels_fixed, true_labels, nb_trials, UPDATE_STEP, nearby=14, nb_calibsessions=nb_calibsessions)
print("###########Adaptive Supervised###########")
acc_non_calib_adapsup, acc_adapsup = utils_stream.calc_smooth_acc(pred_labels_adapsup, true_labels, nb_trials, UPDATE_STEP, nearby=14, nb_calibsessions=nb_calibsessions)
print("###########Single-Encoder###########")
acc_non_calib_single_enc, acc_single = utils_stream.calc_smooth_acc(pred_labels_single, true_labels, nb_trials, UPDATE_STEP, nearby=14, nb_calibsessions=nb_calibsessions)
print("###########Two-Encoder###########")
acc_non_calib, acc_two = utils_stream.calc_smooth_acc(pred_labels_two, true_labels, nb_trials, UPDATE_STEP, nearby=14, nb_calibsessions=nb_calibsessions)

In [None]:
plt.close('all')
plt.figure(figsize=(10, 2))
x_axis = np.arange(len(true_labels))/UPDATE_STEP
plt.plot(x_axis, acc_two, label='Two-Enc', color='orange')
plt.plot(x_axis, acc_single, label='Single-Enc', color='blue')
# plt.plot(x_axis, acc_adapsup, label='Adaptive Supervised', color='red')
# plt.plot(x_axis[nb_calibsessions*nb_trials*UPDATE_STEP:], acc_fixed, label='Fixed Supervised', color='green')
# mark recur_acc_se and recur_acc_te in title
# plt.title(f"Single-Enc: {acc_non_calib_single_enc:.2f}, Two-Enc: {acc_non_calib:.2f}")
plt.xlabel('time (min)')
plt.ylabel('accuracy')
for i in range(nb_calibsessions, len(subjects)):
    plt.axvline(x=i*nb_trials, color='grey', linestyle='--')
plt.legend()
plt.show()