In [None]:
import utils_stream
import numpy as np
import scipy
import matplotlib.pyplot as plt
%matplotlib widget

In [None]:
latent_dimensions = 5
fs = 20
hparadata = [9, 8]
hparafeats = [1, 0]
evalpara = [1, 1]
weightpara = [0.9, 0.9]
pool_size = 19
SEED = 1
PARATRANS = True
nb_trials = 24
nb_disconnected = 0
sub_trial_length = 30
SHUFFLE = False
UPDATE_STEP = 60 // sub_trial_length if sub_trial_length else 1

In [None]:
# read .mat files
Subj_ID = 1
data_path = '../../Experiments/data/Zink/dataSubjectOfficial{}.mat'.format(Subj_ID)
data = scipy.io.loadmat(data_path, squeeze_me=True)
conditions = data['condition']
unique_conditions = np.unique(conditions)
data_conditions_dict = {}
feats_conditions_dict = {}
labels_conditions_dict = {}
rng = np.random.RandomState(SEED)
for cond in unique_conditions:
    data_trials = data['eegTrials'][conditions == cond]
    data_trials = [d for d in data_trials]
    if nb_disconnected > 0:
        disconnected_channels = rng.choice(data_trials[0].shape[1], nb_disconnected, replace=False)
        for i in range(len(data_trials)):
            data_trials[i][:, disconnected_channels] = 0
    data_trials = [utils_stream.process_data_per_view(d, hparadata[0], hparadata[1], NORMALIZE=True) for d in data_trials]
    feats_trials = data['audioTrials'][conditions == cond]
    feats_trials = [f for f in feats_trials]
    feats_trials = [utils_stream.process_data_per_view(f, hparafeats[0], hparafeats[1], NORMALIZE=True) for f in feats_trials]
    labels_trials = data['attSpeaker'][conditions == cond]
    labels_trials = [l for l in labels_trials]
    if sub_trial_length is not None:
        data_trials, feats_trials, labels_trials = utils_stream.further_split_and_shuffle(data_trials, feats_trials, labels_trials, sub_trial_length, fs, SHUFFLE=SHUFFLE, SEED=SEED)
    data_conditions_dict[cond] = data_trials
    feats_conditions_dict[cond] = feats_trials
    labels_conditions_dict[cond] = labels_trials

In [None]:
stream = utils_stream.STREAM(data_conditions_dict, feats_conditions_dict, hparadata[0], hparafeats[0], latent_dimensions, SEED, evalpara, nb_trials, UPDATE_STEP)
true_labels = np.concatenate([v[:nb_trials*UPDATE_STEP] for v in labels_conditions_dict.values()])

pred_labels_dict = stream.fixed_supervised(labels_conditions_dict, ['CS-1', 'CS-2'], ['TS-1', 'TS-2', 'TS-3', 'TS-4', 'FUS-1', 'FUS-2'])
pred_labels_fixed = np.concatenate([v for v in pred_labels_dict.values()])

pred_labels_dict = stream.adaptive_supervised(labels_conditions_dict, weightpara, PARATRANS=PARATRANS)
pred_labels_adapsup = np.concatenate([v for v in pred_labels_dict.values()])

pred_labels_dict = stream.recursive(weightpara, PARATRANS=PARATRANS, SINGLEENC=True, conds_sorted=['CS-1', 'CS-2', 'TS-1', 'TS-2', 'TS-3', 'TS-4', 'FUS-1', 'FUS-2'])
pred_labels_single = np.concatenate([v for v in pred_labels_dict.values()])

pred_labels_dict = stream.recursive(weightpara, PARATRANS=PARATRANS, SINGLEENC=False, conds_sorted=['CS-1', 'CS-2', 'TS-1', 'TS-2', 'TS-3', 'TS-4', 'FUS-1', 'FUS-2'])
pred_labels_two = np.concatenate([v for v in pred_labels_dict.values()])

In [None]:
print("###########Fixed Supervised###########")
acc_non_calib_fixed, acc_fixed = utils_stream.calc_smooth_acc(pred_labels_fixed, true_labels, nb_trials, UPDATE_STEP, nearby=14)
print("###########Adaptive Supervised###########")
acc_non_calib_adapsup, acc_adapsup = utils_stream.calc_smooth_acc(pred_labels_adapsup, true_labels, nb_trials, UPDATE_STEP, nearby=14)
print("###########Single-Encoder###########")
acc_non_calib_single_enc, acc_single = utils_stream.calc_smooth_acc(pred_labels_single, true_labels, nb_trials, UPDATE_STEP, nearby=14)
print("###########Two-Encoder###########")
acc_non_calib, acc_two = utils_stream.calc_smooth_acc(pred_labels_two, true_labels, nb_trials, UPDATE_STEP, nearby=14)

In [None]:
plt.close('all')
plt.figure(figsize=(10, 2))
x_axis = np.arange(len(true_labels))/UPDATE_STEP
plt.plot(x_axis, acc_two, label='Two-Enc', color='orange')
plt.plot(x_axis, acc_single, label='Single-Enc', color='blue')
plt.plot(x_axis, acc_adapsup, label='Adaptive Supervised', color='red')
plt.plot(x_axis[2*nb_trials*UPDATE_STEP:], acc_fixed, label='Fixed Supervised', color='green')
# mark recur_acc_se and recur_acc_te in title
plt.title(f"Subject {Subj_ID}, Single-Enc: {acc_non_calib_single_enc:.2f}, Two-Enc: {acc_non_calib:.2f}")
plt.xlabel('time (min)')
plt.ylabel('accuracy')
for i in range(2, 9):
    plt.axvline(x=i*nb_trials, color='grey', linestyle='--')
plt.legend()
plt.show()