In [None]:
import os
import utils_stream
import utils_prob
import numpy as np
import scipy
import matplotlib.pyplot as plt
from scipy.io import loadmat
%matplotlib widget

In [None]:
def data_per_subj(Subj_ID, data_folder):
    file_path = f'{data_folder}/dataSubject{Subj_ID}.mat'
    data_dict = loadmat(file_path, squeeze_me=True)
    if 'earEEG' in data_folder:
        eeg_trials = [trial[:,:29] for trial in data_dict['eegTrials']]
    else:
        eeg_trials = data_dict['eegTrials']
    feats_trials = data_dict['audioTrials']
    label_trials = data_dict['attSpeaker']
    nb_trials = len(eeg_trials)
    eeg_trials = [eeg_trials[i] for i in range(nb_trials)]
    feats_trials = [feats_trials[i] for i in range(nb_trials)]
    label_trials = [label_trials[i] for i in range(nb_trials)]
    return eeg_trials, feats_trials, label_trials

In [None]:
latent_dimensions = 5
fs = 20
hparadata = [4, 3]
hparafeats = [6, 0]
evalpara = [3, 2]
weightpara = [0.9, 0.9]
pool_size = 19
SEED = 10
PARATRANS = True
trial_len = 60
nb_calibsessions = 1
nb_disconnected = 0
sub_trial_length = 30
CHANGEORDER = False
SHUFFLE = False
UPDATE_STEP = 2
update_len = sub_trial_length * UPDATE_STEP
leave_out_persubj = 12
data_len_persubj = 72 * trial_len # len(eeg_dict[subjects[0]]) * trial_len
nb_update_trials = data_len_persubj // update_len
dataset = 'Neetha'

In [None]:
# latent_dimensions = 5
# fs = 32
# hparadata = [4, 3]
# hparafeats = [9, 0]
# evalpara = [3, 2]
# weightpara = [0.916, 0.916]
# pool_size = 23
# SEED = 1
# PARATRANS = True
# trial_len = 50
# nb_calibsessions = 1
# nb_disconnected = 0
# sub_trial_length = 25
# CHANGEORDER = False
# SHUFFLE = False
# UPDATE_STEP = 2
# update_len = sub_trial_length * UPDATE_STEP
# leave_out_persubj = 12
# data_len_persubj = 60 * trial_len 
# nb_update_trials = data_len_persubj // update_len
# dataset = 'fuglsang2018'

In [None]:
# latent_dimensions = 5
# fs = 20
# hparadata = [4, 3]
# hparafeats = [6, 0]
# evalpara = [3, 2]
# weightpara = [0.9, 0.9]
# pool_size = 19
# SEED = 1
# PARATRANS = True
# trial_len = 600
# nb_calibsessions = 1
# nb_disconnected = 0
# sub_trial_length = 30
# CHANGEORDER = False
# SHUFFLE = False
# UPDATE_STEP = 2
# update_len = sub_trial_length * UPDATE_STEP
# leave_out_persubj = 2
# data_len_persubj = 6 * trial_len # len(eeg_dict[subjects[0]]) * trial_len
# nb_update_trials = data_len_persubj // update_len
# dataset = 'earEEG'

In [None]:
data_folder = f'../../Experiments/Data/{dataset}/'
files = [f for f in os.listdir(data_folder) if f.endswith('.mat')]
subjects = [int(''.join(filter(str.isdigit, f))) for f in files]
# subjects.sort()
rng = np.random.default_rng(SEED)
rng.shuffle(subjects)
eeg_dict = {}
feats_dict = {}
labels_dict = {}
for subject in subjects:
    eeg_trials, feats_trials, label_trials = data_per_subj(subject, data_folder)
    eeg_dict[subject] = eeg_trials
    feats_dict[subject] = feats_trials
    labels_dict[subject] = label_trials

In [None]:
est_subs = subjects[:2]
est_corr_att_sum, est_corr_unatt_sum = utils_prob.estimate_distribution_corr(eeg_dict, feats_dict, labels_dict, est_subs, fs, hparadata, hparafeats, leave_out_persubj=leave_out_persubj, trial_len=update_len, range_into_account=evalpara[0], nb_comp_into_account=evalpara[1])
gmm_0, gmm_1 = utils_prob.fit_gmm(est_corr_att_sum, est_corr_unatt_sum, n_components_per_class=1)
subjects = subjects[2:]

In [None]:
est_corr_att_unatt = np.stack([est_corr_att_sum, est_corr_unatt_sum], axis=1)
est_corr_unatt_att = np.stack([est_corr_unatt_sum, est_corr_att_sum], axis=1)
# 3. Create a grid to visualize the decision boundary
x_min = min(est_corr_unatt_att[:,0].min(), est_corr_att_unatt[:,0].min()) - 0.1
x_max = max(est_corr_unatt_att[:,0].max(), est_corr_att_unatt[:,0].max()) + 0.1
y_min = min(est_corr_unatt_att[:,1].min(), est_corr_att_unatt[:,1].min()) - 0.1
y_max = max(est_corr_unatt_att[:,1].max(), est_corr_att_unatt[:,1].max()) + 0.1

xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100),
                     np.linspace(y_min, y_max, 100))
grid_points = np.c_[xx.ravel(), yy.ravel()]

# 4. Predict probabilities for grid points
grid_probs = utils_prob.predict_proba(grid_points, gmm_0, gmm_1)
blue_probs = grid_probs[:, 1].reshape(xx.shape)

# 5. Visualize the results
plt.figure(figsize=(10, 8))

# Plot the data points
plt.scatter(est_corr_unatt_att[:, 0], est_corr_unatt_att[:, 1], color='red', alpha=0.5, s=10, label='Class 0')
plt.scatter(est_corr_att_unatt[:, 0], est_corr_att_unatt[:, 1], color='blue', alpha=0.5, s=10, label='Class 1')

# Plot decision boundary (where probability = 0.5)
plt.contour(xx, yy, blue_probs, levels=[0.5], colors='black', linestyles='-', linewidths=2)

# Plot probability contours
contour = plt.contourf(xx, yy, blue_probs, levels=np.linspace(0, 1, 101), 
                      alpha=0.3)
plt.colorbar(contour, label='Probability of Class 1')


In [None]:
data_subjects_dict = {}
feats_subjects_dict = {}
labels_subjects_dict = {}
rng = np.random.RandomState(SEED)
for subj in subjects:
    data_trials = eeg_dict[subj]
    if nb_disconnected > 0:
        disconnected_channels = rng.choice(data_trials[0].shape[1], nb_disconnected, replace=False)
        for i in range(len(data_trials)):
            data_trials[i][:, disconnected_channels] = 0
    data_trials = [utils_stream.process_data_per_view(d, hparadata[0], hparadata[1], NORMALIZE=True) for d in data_trials]
    feats_trials = feats_dict[subj]
    feats_trials = [utils_stream.process_data_per_view(f, hparafeats[0], hparafeats[1], NORMALIZE=True) for f in feats_trials]
    labels_trials = labels_dict[subj]
    if sub_trial_length is not None:
        data_trials, feats_trials, labels_trials = utils_stream.further_split_and_shuffle(data_trials, feats_trials, labels_trials, sub_trial_length, fs, SHUFFLE=SHUFFLE, SEED=SEED)
    if CHANGEORDER:
        data_trials, feats_trials, labels_trials = utils_stream.change_order_of_trials(data_trials, feats_trials, labels_trials)
    data_subjects_dict[subj] = data_trials
    feats_subjects_dict[subj] = feats_trials
    labels_subjects_dict[subj] = labels_trials

In [None]:
stream = utils_stream.STREAM(data_subjects_dict, feats_subjects_dict, hparadata[0], hparafeats[0], latent_dimensions, SEED, evalpara, nb_update_trials, UPDATE_STEP)
true_labels = np.stack([v[:nb_update_trials*UPDATE_STEP] for v in labels_subjects_dict.values()], axis=0)

pred_labels_dict = stream.fixed_supervised(labels_subjects_dict, subjects[:nb_calibsessions], subjects[nb_calibsessions:])
pred_labels_fixed = np.stack([v for v in pred_labels_dict.values()], axis=0)

pred_labels_dict = stream.adaptive_supervised(labels_subjects_dict, weightpara, PARATRANS=PARATRANS)
pred_labels_adapsup = np.stack([v for v in pred_labels_dict.values()], axis=0)

pred_labels_dict = stream.recursive(weightpara, PARATRANS=PARATRANS, SINGLEENC=True)
pred_labels_single = np.stack([v for v in pred_labels_dict.values()], axis=0)

pred_labels_dict = stream.recursive(weightpara, PARATRANS=PARATRANS, SINGLEENC=False)
pred_labels_two = np.stack([v for v in pred_labels_dict.values()], axis=0)

pred_labels_dict = stream.recursive_soft(weightpara, gmm_0, gmm_1, PARATRANS=PARATRANS)
pred_labels_soft = np.stack([v for v in pred_labels_dict.values()], axis=0)

In [None]:
nb_trials_considered = 15 * UPDATE_STEP
print("###########Fixed Supervised###########")
acc_non_calib_fixed, acc_fixed = utils_stream.calc_smooth_acc(pred_labels_fixed[:,:nb_trials_considered], true_labels[:,:nb_trials_considered], nb_trials_considered, nearby=10, nb_calibsessions=nb_calibsessions)
print("###########Adaptive Supervised###########")
acc_non_calib_adapsup, acc_adapsup = utils_stream.calc_smooth_acc(pred_labels_adapsup[:,:nb_trials_considered], true_labels[:,:nb_trials_considered], nb_trials_considered, nearby=10, nb_calibsessions=nb_calibsessions)
print("###########Single-Encoder###########")
acc_non_calib_single_enc, acc_single = utils_stream.calc_smooth_acc(pred_labels_single[:,:nb_trials_considered], true_labels[:,:nb_trials_considered], nb_trials_considered, nearby=10, nb_calibsessions=nb_calibsessions)
print("###########Two-Encoder###########")
acc_non_calib, acc_two = utils_stream.calc_smooth_acc(pred_labels_two[:,:nb_trials_considered], true_labels[:,:nb_trials_considered], nb_trials_considered, nearby=10, nb_calibsessions=nb_calibsessions)
print("###########Soft-Recursive###########")
acc_non_calib_soft, acc_soft = utils_stream.calc_smooth_acc(pred_labels_soft[:,:nb_trials_considered], true_labels[:,:nb_trials_considered], nb_trials_considered, nearby=10, nb_calibsessions=nb_calibsessions)

In [None]:
plt.close('all')
plt.figure(figsize=(10, 3))
x_axis = np.arange(len(true_labels[:,:nb_trials_considered].reshape(-1)))/UPDATE_STEP
# plt.plot(x_axis, acc_two, label='Two-Enc', color='orange')
plt.plot(x_axis, acc_single, label='Single-Enc', color='blue')
# plt.plot(x_axis, acc_adapsup, label='Adaptive Supervised', color='red')
# plt.plot(x_axis[nb_calibsessions*nb_trials*UPDATE_STEP:], acc_fixed, label='Fixed Supervised', color='green')
plt.plot(x_axis, acc_soft, label='Soft-Recursive', color='purple')
# mark recur_acc_se and recur_acc_te in title
plt.title(f"FixSup: {acc_non_calib_fixed:.2f}, Single-Enc: {acc_non_calib_single_enc:.2f}, Two-Enc: {acc_non_calib:.2f}, Soft: {acc_non_calib_soft:.2f}, AdaSup: {acc_non_calib_adapsup:.2f}")
plt.xlabel('time (min)')
plt.ylabel('accuracy')
for i in range(nb_calibsessions, len(subjects)):
    plt.axvline(x=i*nb_trials_considered//UPDATE_STEP, color='grey', linestyle='--')
plt.legend()
plt.show()