In [1]:
# notebook for 4 conds features benchmarking
%matplotlib ipympl
import sys
import numpy as np
import pandas as pd
import os
from sklearn.svm import SVC
from sklearn.metrics import balanced_accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns

from openTSNE import TSNE

# setting path for mv_python_utils
sys.path.append('../helper_functions')
from mv_python_utils import cat_subjs_train_test, cat_subjs_train_test_ParcelScale

#input folder 
infold = '../STRG_computed_features/Mdl_comparison/'


In [2]:
# load & cat files from TimeFeats
list_ExpConds = ['ECEO_TimeFeats', 'VS_TimeFeats']
fullX_train, fullX_test, Y_train_A, Y_test_A, subjID_trials_labels = cat_subjs_train_test(infold, strtsubj=0, endsubj=29, 
                                                                                            ftype=list_ExpConds, tanh_flag=True, 
                                                                                            compress_flag=True)     
                                                                                            


KeyboardInterrupt



In [None]:
# compute 2d correlation of all features
from scipy.spatial import distance

# define function to adapt train and test sets
def adapt_data(dat):
    
    # first delete the "full_set" from the copied dictionary    
    # corr_Xtr = fullX_train.copy()

    del dat['full_set']

    # split fooof into offset and slope
    fooof_aperiodic = dat['fooof_aperiodic']
    dat['fooof_slope'] = fooof_aperiodic[:, 0::2]
    dat['fooof_offset'] = fooof_aperiodic[:, 1::2]
    # and delete fooof aperiodic
    del dat['fooof_aperiodic']

    return dat

# transfor train and test sets
corr_Xtr = adapt_data(fullX_train)
corr_Xte = adapt_data(fullX_test)

print(corr_Xtr.keys())

corr_array, cosine_array = np.zeros((len(corr_Xtr), len(corr_Xtr))), np.zeros((len(corr_Xtr), len(corr_Xtr)))

# start loop to:
# 1: compute classification accuracy of every feature
# 2: compute correlation between features

rbf_svm = SVC(C=10, random_state=42)


acc_row = 0
vect_accuracy = [] # keep is as a list, easier to resort it accoridng to the dendrogram leaves

for rowFeat, FeatArray1 in corr_Xtr.items():
    
    # fit the SVM model
    this_mdl = rbf_svm.fit(FeatArray1, Y_train_A)

    # generate predictions & compute balanced accuracy
    Xte = corr_Xte[rowFeat]
    pred_labels = this_mdl.predict(Xte)
    this_acc = balanced_accuracy_score(Y_test_A, pred_labels)

    vect_accuracy.append(this_acc)
    
    # linearize matrix, to correlate across trials & parcels
    rowFlat = np.hstack(FeatArray1)
    
    acc_col = 0
    for colFeat, FeatArray2 in corr_Xtr.items():
    
        colFlat = np.hstack(FeatArray2)

        corr_array[acc_row, acc_col] = np.corrcoef(rowFlat, colFlat)[0, 1]
        cosine_array[acc_row, acc_col] = distance.cosine(rowFlat, colFlat)
        
        acc_col += 1
        
    acc_row += 1
    print(rowFeat)

# round correlation to 5th element. This avoids asymmetric matrix for values close to machine precision
corr_array = np.round(corr_array, 5)


In [None]:
# Create the heatmap with tilted x-axis labels
plt.figure()
sns.heatmap(corr_array, xticklabels=corr_Xtr.keys(), yticklabels=corr_Xtr.keys(), cmap="RdBu_r")

plt.xticks(rotation=90, ha="right")  # Rotate x-axis labels by 45 degrees for readability
plt.yticks(rotation=0)  # Keep y-axis labels horizontal


#plt.subplot(212)
#sns.heatmap(cosine_array, xticklabels=corr_Xtr.keys(), yticklabels=corr_Xtr.keys(), cmap="RdBu")
#plt.tight_layout()

In [None]:
# start clustering
from scipy.cluster.hierarchy import linkage, dendrogram, fcluster
from scipy.spatial.distance import squareform

dissimilarity_mat = 1-np.abs(corr_array)

Z = linkage(squareform(dissimilarity_mat), 'complete')


plt.figure()
R = dendrogram(Z, labels=list(corr_Xtr.keys()), orientation='top', 
           leaf_rotation=90, color_threshold=.75);
plt.tight_layout()

print(R['leaves'])


In [None]:
# Clusterize the data
threshold = .8
labels = fcluster(Z, threshold, criterion='distance')

# Show the cluster
print(labels)
list(corr_Xtr.keys())

sorted_accs = [vect_accuracy[i] for i in R['leaves']] 

print(sorted_accs)

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable

DF_corr = pd.DataFrame(data=1-abs(corr_array), index=corr_Xtr.keys(), columns=corr_Xtr.keys())

# sns.clustermap(DF_corr, method="complete", cmap='RdBu_r', xticklabels=True, yticklabels=True,
#               annot_kws={"size": 7}, vmin=-1, vmax=1);

g = sns.clustermap(DF_corr, method="complete", cmap='Reds_r', xticklabels=True, yticklabels=True,
                   annot_kws={"size": 7}, vmin=0, vmax=1, cbar_kws={'orientation' : 'horizontal'},
                  row_linkage=Z, col_linkage=Z);

x0, _y0, _w, _h = g.cbar_pos
g.ax_cbar.set_position([x0, 0.9, g.ax_row_dendrogram.get_position().width-.02, 0.02])
g.ax_cbar.set_title('distance \n(1-abs(r))')
g.ax_col_dendrogram.set_visible(False)

# hist_data = [histogram_step_points(data[x_i, y_i, :], bins="auto", density=True) 
#         for x_i in row_order for y_i in col_order]
# mean_data = np.reshape(mean_xy_data[row_order][col_order], -1)

# x, y = hist_data[0].T[0], np.zeros(hist_data[0].shape[0]) # initial hist is all at 0



In [None]:
# accuracy plot 

plt.figure()

plt.rcParams['ytick.right'] = plt.rcParams['ytick.labelright'] = True
plt.rcParams['ytick.left'] = plt.rcParams['ytick.labelleft'] = False


plt.bar(range(len(sorted_accs)), sorted_accs)
plt.ylabel('balanced accuracy')
plt.tick_params(labelbottom = False, bottom = False) 



In [None]:
# store previous values, from nonbandpassed signal, and add the features computed in the frequency bands

temp = {'Xtrain' : fullX_train,
        'Xtest' : fullX_test,
        'Ytrain' : Y_train_A,
        'Ytest' : Y_test_A,
        'subjIDs' : subjID_trials_labels}

compareBands = {'no_bandpass' : temp}

# loop to load also the bandpassed signals
freqBands = ['delta', 'theta', 'alpha', 'beta', 'low_gamma', 'high_gamma']
infold_bands = '../STRG_computed_features/TimeFeats_bandpassed/'

for thisBand in freqBands:
    
    this_list_expConds = ['ECEO_' + thisBand + '_TimeFeats', 'VS_' + thisBand + '_TimeFeats']
    Xtr, Xte, Ytr, Yte, subjID_trials_labels = cat_subjs_train_test(infold_bands, strtsubj=0, endsubj=29, 
                                                                    ftype=this_list_expConds, tanh_flag=True, 
                                                                    compress_flag=True)     
    temp = {'Xtrain' : Xtr,
            'Xtest' : Xte,
            'Ytrain' : Ytr,
            'Ytest' : Yte,
            'subjIDs' : subjID_trials_labels}

    temp_dict = {thisBand : temp}
    compareBands.update(temp_dict)
    
    print(thisBand + ' completed')

    
# try on parcel-defined scaling
# X_train, X_test, Y_train, Y_test, subjID_train, subjID_test = cat_subjs_train_test_ParcelScale(infold, strtsubj=0, endsubj=29, 
#                                      

In [None]:
# load & cat files from freqbands set
list_ExpConds = ['ECEO_FreqBands', 'VS_FreqBands']
Xtr, Xte, Ytr, Yte, subjID_trials_labels = cat_subjs_train_test(infold, strtsubj=0, endsubj=29, 
                                                                    ftype=list_ExpConds, tanh_flag=True, 
                                                                    compress_flag=True)     
temp = {'Xtrain' : Xtr,
        'Xtest' : Xte,
        'Ytrain' : Ytr,
        'Ytest' : Yte,
        'subjIDs' : subjID_trials_labels}

compareBands = {'freqBands' : temp}


In [None]:
# Classifier 1
# - standardization, within subject, along repetitions.

outfold = '../STRG_decoding_accuracy/Mdl_comparison/'
rbf_svm = SVC(C=10)

blist = []; accfreq = 0
for bandName, dataset in compareBands.items():
    
    X_train = dataset['Xtrain']
    X_test = dataset['Xtest']
    Y_train = dataset['Ytrain']
    Y_test = dataset['Ytest']
    
    dict_accs = {}
    for key, Xtr in X_train.items():

        # fit the SVM model
        this_mdl = rbf_svm.fit(Xtr, Y_train)

        # generate predictions & compute balanced accuracy
        Xte = X_test[key]
        pred_labels = this_mdl.predict(Xte)
        this_acc = balanced_accuracy_score(Y_test, pred_labels)

        # print some feedback in the CL
        print(bandName + ' ' + key + ': ' + str(round(this_acc, 4)))

        # append 
        dict_accs.update({key:this_acc})

    DFband = pd.DataFrame(dict_accs, index=[bandName])
    DFband.to_csv(outfold + bandName + '_4conds_collapsed_reproduce.csv') # save, just to be sure
    blist.append(DFband)
    
allBandsDF = pd.concat(blist)
allBandsDF.to_csv(outfold + 'freqBands_TS_4conds_collapsed_reproduce.csv')

In [None]:
# how many top features?
nfeats = 20
outfold = '../STRG_decoding_accuracy/Mdl_comparison/'

# fetch data & plot
df_power = pd.read_csv(outfold + 'freqBands_4conds_collapsed.csv', index_col=0)
df_TS = pd.read_csv(outfold + 'freqBands_TS_4conds_collapsed.csv', index_col=0)
# df_power = pd.read_csv('freqBands_4conds_collapsed.csv', index_col=0)
# df_TS = pd.read_csv('freqBands_TS_4conds_collapsed.csv', index_col=0)



# make the last entry (all bands together) as the first entry (to match the "no bandpass" row in the full DF)
array = np.roll(np.asarray(df_power), 1)
df_TS['power'] = array.T

# Convert to Long Format
melted_df = pd.melt(df_TS, var_name='features', value_name='balanced accuracy', ignore_index=False)

# Sort in Descending Order and Rename 'index' to 'frequency'
sorted_df = melted_df.sort_values(by='balanced accuracy', ascending=False).reset_index().rename(columns={'index': 'frequency'})

# merge columns
sorted_df['freq_feat'] = sorted_df['frequency'] + '\n' + sorted_df['features']

# Select the first 20 entries
top_feats = sorted_df.head(nfeats)

# Plotting
plt.figure(figsize=(10, 8))
sns.barplot(x='balanced accuracy', y='freq_feat', data=top_feats, orient='h',
            palette="ch:start=.2,rot=-.3, dark=.4")

plt.xlabel('balanced accuracy')
plt.ylabel('Features')
plt.title('Top ' + str(nfeats) + ' features' + '\n4 conditions classification')
plt.show()
plt.xlim([.25, 1])
plt.tight_layout()
plt.legend(loc='lower right')


print(top_feats)


In [None]:
# Create a heatmap with Seaborn
plt.figure(figsize=(12, 8))  # Set the figure size

# Create the heatmap with tilted x-axis labels
sns.heatmap(df_TS, cmap="Reds", xticklabels=df_TS.columns, yticklabels=df_TS.index)

# Tilt x-axis labels for better readability
plt.xticks(rotation=45, ha="right")  # Rotate x-axis labels by 45 degrees for readability
plt.yticks(rotation=0)  # Keep y-axis labels horizontal
plt.tight_layout()



In [None]:
# save output & plot

outfold = '../STRG_decoding_accuracy/Mdl_comparison/'
acc_table = pd.DataFrame(data=dict_accs, index=['balanced accuracy'])
acc_table.to_csv(outfold + 'test_acc_4conds_collapsed.csv')

plt.figure()
sns.barplot(acc_table.sort_values(by='balanced accuracy', ascending=False, axis=1), orient='h', palette="ch:start=.2,rot=-.3, dark=.4")
plt.tight_layout()


In [None]:
# Classifier 2a-b
# - standardization, within trial, along parcels. with or without shuffled participants labels

rbf_svm = SVC(C=10)

dict_accs_parc = dict()
for key, X in X_train.items():
        
    # fit the SVM model on with ordered subject labels 
    vect_subjorder_train = np.array(subjID_train, ndmin=2).T
    X_parts_ordered = np.concatenate((X, vect_subjorder_train), axis=1)
    this_mdl_ordered = rbf_svm.fit(X_parts_ordered, Y_train)

    # fit the SVM model on shuffled subject labels
    vect_subjshuffled_train = np.copy(vect_subjorder_train)
    np.random.shuffle(vect_subjshuffled_train) # python is weird to me sometimes
    X_parts_shuffled = np.concatenate((X, vect_subjshuffled_train), axis=1)
    this_mdl_shuffled = rbf_svm.fit(X_parts_shuffled, Y_train)

    # create test data with ordered and shuffled subject labels, respectively
    vect_subjorder_test = np.array(subjID_test, ndmin=2).T
    vect_subjshuffled_test = np.copy(vect_subjorder_test)
    np.random.shuffle(vect_subjshuffled_test)
    swapXtest_ordered = np.concatenate((X_test[key], vect_subjorder_test), axis=1)
    swapXtest_shuffled = np.concatenate((X_test[key], vect_subjshuffled_test), axis=1)
    
    
    # generate predictions & compute balanced accuracy
    pred_labels_ord = this_mdl_ordered.predict(swapXtest_ordered)
    this_acc_ord = balanced_accuracy_score(Y_test, pred_labels_ord)

    pred_labels_shffld = this_mdl_shuffled.predict(swapXtest_shuffled)
    this_acc_shffld = balanced_accuracy_score(Y_test, pred_labels_shffld)

    this_acc = [this_acc_ord, this_acc_shffld]
    
    # print some feedback in the CL
    print(key + ' ordered subj: ' + str(round(this_acc[0], 4)))
    print(key + ' shuffled subj: ' + str(round(this_acc[1], 4)))

    # append 
    dict_accs_parc.update({key:this_acc})

    
acc_table_parc = pd.DataFrame(data=dict_accs_parc, index=['ordered subjlabels', 'shuffled subjlabels'])



In [None]:
outfold = '../STRG_decoding_accuracy/Mdl_comparison/'
acc_table_parc.to_csv(outfold + 'test_acc_4conds_collapsed_parc.csv')

In [None]:
# read classification accuracy data
infold = '../STRG_decoding_accuracy/Mdl_comparison/'
acc_table_NormParc = pd.read_csv(infold + 'test_acc_4conds_collapsed_parc.csv', index_col=0).iloc[[1]]
acc_table_NormTrials = pd.read_csv(infold + 'test_acc_4conds_collapsed.csv',  index_col=0)

plt.figure()
sns.barplot(acc_table_NormTrials.sort_values(by='balanced accuracy', ascending=False, axis=1), orient='h', palette="ch:start=.2,rot=-.3, dark=.4")
plt.title('Norm across trials')
plt.tight_layout()


In [None]:
plt.figure()
sns.barplot(acc_table_NormParc.sort_values(by='shuffled subjlabels', ascending=False, axis=1), orient='h', palette="ch:start=.2,rot=-.3, dark=.4")
plt.title('Norm across parcels')
plt.tight_layout()


In [None]:
# TSNE testing

example_feat = 'MCL'

tsne_across, tsne_parcel = (TSNE(n_components=2,
                                perplexity=30,
                                metric="euclidean",
                                n_jobs=8,
                                random_state=42,
                                verbose=True),)*2

across_X_train, across_X_test = fullX_train[example_feat], fullX_test[example_feat]
%time embed_train_across = tsne_across.fit(across_X_train)
%time embed_test_across = embed_train_across.transform(across_X_test)

parcel_X_train, parcel_X_test = X_train[example_feat], X_test[example_feat]
%time embed_train_parcel = tsne_parcel.fit(parcel_X_train)
%time embed_test_parcel = embed_train_parcel.transform(parcel_X_test)


In [None]:
# create DF for easy plotting  

dict_plot_across = {'x': embed_test_across[:, 0],
                     'y': embed_test_across[:, 1],
                     'labels': Y_test_A}

DF_tsne_across = pd.DataFrame.from_dict(dict_plot_across)

plt.figure()
plt.subplot(121)
sns.scatterplot(data=DF_tsne_across, x='x', y='y', hue='labels')
plt.subplot(122)
sns.kdeplot(data=DF_tsne_across, x='x', y='y', hue='labels')


In [None]:
dict_plot_parc = {'x': embed_test_parcel[:, 0],
                 'y': embed_test_parcel[:, 1],
                 'labels': Y_test,
                 'subjIDs' : subjID_test}

DF_tsne_parcel = pd.DataFrame.from_dict(dict_plot_parc)

plt.figure()
plt.subplot(121)
sns.scatterplot(data=DF_tsne_parcel, x='x', y='y', hue='labels')
plt.subplot(122)
sns.kdeplot(data=DF_tsne_parcel, x='x', y='y', hue='labels')


In [None]:
# Create a new column 'scalarSubj' with unique integers for each unique value in 'subjIDs'
DF_tsne_parcel['scalarSubj'] = pd.factorize(DF_tsne_parcel['subjIDs'])[0]

# Create a 3D scatterplot
sns.set(style="darkgrid")
fig = plt.figure(figsize=(10, 8))

# 3D Scatterplot
ax1 = fig.add_subplot(projection='3d')
scatter = ax1.scatter(DF_tsne_parcel['x'], DF_tsne_parcel['y'], DF_tsne_parcel['scalarSubj'], c=DF_tsne_parcel['labels'])
ax1.set_xlabel('x')
ax1.set_ylabel('y')
ax1.set_zlabel('scalarSubj')
legend1 = ax1.legend(*scatter.legend_elements(), loc="upper right")
ax1.add_artist(legend1)
ax1.set_title('3D Scatterplot')





In [None]:
fig = plt.figure(figsize=(10, 8))

# KDE Plot for x
ax2 = fig.add_subplot(221)
sns.kdeplot(data=DF_tsne_parcel, x='x', y='y', hue='labels', ax=ax2)
ax2.set_xlabel('x')
ax2.set_title('KDE Projection collapsing subjects')

# KDE Plot for y
ax3 = fig.add_subplot(222)
sns.kdeplot(data=DF_tsne_parcel, x='y', y='scalarSubj', hue='labels', ax=ax3)
ax3.set_xlabel('y')
ax3.set_title('KDE Projection collapsing x-axis')

# KDE Plot for scalarSubj
ax4 = fig.add_subplot(223)
sns.kdeplot(data=DF_tsne_parcel, x='scalarSubj', y='x', hue='labels', ax=ax4)
ax4.set_xlabel('scalarSubj')
ax4.set_title('KDE Projection collapsing y-axis')

plt.tight_layout()
plt.show()
