# Imports

In [None]:
from Utilities import *
from Models import *
from Pipelines import *
from Comparison import *
import numpy as np

if not hasattr(builtins, "unicode"):
  builtins.unicode = str
  np.unicode_ = np.str_


# Data Loading

Code to load all data and data splits needed for the code. Preprocessing applied in Data_pipeline()

In [None]:
EEG_smoothed, y_smooth_5, y_smooth, good_channels = Data_Pipeline()

In [None]:
bandwise_EEG = split_EEG_by_bandwave(EEG_smoothed, 2000)
band_freqs_keys = ["delta", "theta", "alpha", "beta", "gamma", "High Gamma", "Ripples", "Fast Ripples", "Multi-Unit"]

In [None]:
test_bandpass_filters(EEG_smoothed, 2000)

In [None]:
regionwise_EEG = split_eeg_by_region(EEG_smoothed, good_channels)

In [None]:
regionwise_band_EEG = split_EEG_by_band_region(EEG_smoothed, good_channels, 2000)

# ESN Executions

Code to run and train all ESN models used in the study

## Simple ESN

In [None]:
run_ESN(EEG_smoothed, y_smooth_5, n_classes=5, reservoir_size=500, epochs=150)

In [None]:
run_ESN(EEG_smoothed, y_smooth, n_classes=10, reservoir_size=500, epochs=150)

## Regionwise ESN

In [None]:
for region, data in regionwise_EEG.items():
    print("Region", region)
    run_ESN(data, y_smooth_5, n_classes = 5, reservoir_size = 500, epochs=50, standarize = True)

In [None]:
for region, data in regionwise_EEG.items():
    print("Region", region)
    run_ESN(data, y_smooth, n_classes = 10, reservoir_size = 500, epochs=50, standarize = True)

In [None]:
run_regionwise_ESN(regionwise_EEG, y_smooth_5, n_classes=5, reservoir_size=500)

In [None]:
run_regionwise_ESN(regionwise_EEG, y_smooth, n_classes=10, reservoir_size=500)

## Bandwise ESN

In [None]:
print(bandwise_EEG.shape)

In [None]:
for i in range(len(band_freqs_keys)):
    print("---------------------------------", band_freqs_keys[i], "---------------------------------")
    print()
    # Convert the bandwise EEG data to a 2D array
    run_ESN(bandwise_EEG[:, i, :, :], y_smooth_5, n_classes = 5, reservoir_size = 500, epochs=150, standarize = True)
    print()
    print()

In [None]:
for i in range(len(band_freqs_keys)):
    print("---------------------------------", band_freqs_keys[i], "---------------------------------")
    print()
    # Convert the bandwise EEG data to a 2D array
    run_ESN(bandwise_EEG[:, i, :, :], y_smooth, n_classes = 10, reservoir_size = 500, epochs=150, standarize = True)
    print()
    print()

In [None]:
import Models
importlib.reload(Models)
from Models import run_bandwise_ESN, run_regionwise_ESN, run_band_regionwise_ESN

In [None]:
run_bandwise_ESN(bandwise_EEG, y_smooth_5, n_classes=5, reservoir_size=500)

In [None]:
run_bandwise_ESN(bandwise_EEG, y_smooth, n_classes=10, reservoir_size=500)

## Region-bandwise ESN

In [None]:
run_band_regionwise_ESN(regionwise_band_EEG, y_smooth_5, n_classes=5, reservoir_size=500)


In [None]:
run_band_regionwise_ESN(regionwise_band_EEG, y_smooth, n_classes=10, reservoir_size=500)

# Comparisons

Code to generate the figures seen in Section 5.2, with exception of the PCA, that was generated from the run_band_regionwise_ESN function

## Comparison on 500 Nodes

In [None]:
acc_5, W_Out_5, acc_10, W_Out_10, W = run_ESN_coefs(EEG_smoothed, y_smooth_5, y_smooth, reservoir_size=500, epochs=150)

In [None]:
keys, dist_matrix = compare_wout_sets(W_Out_5, threshold=0.15, method='spectral')

print(dist_matrix)

In [None]:
mean_wouts = {key: W.mean(axis=1) for key, W in W_Out_10.items()}

combined_Wouts = {f"5_class_{i}": W_Out_5[f"class_{i}"] for i in range(1, 6)}
combined_Wouts.update({f"10_class_{i}": W_Out_10[f"class_{i}"] for i in range(1, 6)})

class_pairs = [(f"5_class_{i}", f"10_class_{i}") for i in range(1, 6)]

combined_Wouts = {}

# 5-class model
for i in range(1, 6):
    combined_Wouts[f"5_class_{i}"] = W_Out_5[f"class_{i}"]

# 10-class model, first 5
for i in range(1, 6):
    combined_Wouts[f"10_class_{i}"] = W_Out_10[f"class_{i}"]

# 10-class model, next 5
for i in range(6, 11):
    combined_Wouts[f"10_class_extra_{i}"] = W_Out_10[f"class_{i}"]

class_triplets = [
    (f"5_class_{i}", f"10_class_{i}", f"10_class_extra_{i+5}") for i in range(1, 6)
]


In [None]:
plot_reservoir_graph_comparisons(combined_Wouts, W, class_pairs, save_dir='reservoir_graphs', show=True)


In [None]:
plot_reservoir_graph_comparisons_3way(combined_Wouts, W, class_triplets, save_dir='reservoir_graphs_3way', show=True)

In [None]:
plot_reservoir_heatmap_comparisons_binned(combined_Wouts, class_triplets, save_dir='reservoir_heatmaps_3way', show=True)

In [None]:
plot_class_similarity_heatmaps_compressed(W_Out_5, W_Out_10)

plot_class_similarity_heatmaps(W_Out_5, W_Out_10, save_dir='reservoir_class_diff_heatmaps')

## Comparison on 50 Nodes

In [None]:
acc_5, W_Out_5, acc_10, W_Out_10, W = run_ESN_coefs(EEG_smoothed, y_smooth_5, y_smooth, reservoir_size=50, epochs=25)

In [None]:
keys, dist_matrix = compare_wout_sets(W_Out_5, threshold=0.15, method='spectral')

print(dist_matrix)

In [None]:
mean_wouts = {key: W.mean(axis=1) for key, W in W_Out_10.items()}

combined_Wouts = {f"5_class_{i}": W_Out_5[f"class_{i}"] for i in range(1, 6)}
combined_Wouts.update({f"10_class_{i}": W_Out_10[f"class_{i}"] for i in range(1, 6)})

class_pairs = [(f"5_class_{i}", f"10_class_{i}") for i in range(1, 6)]

combined_Wouts = {}

# 5-class model
for i in range(1, 6):
    combined_Wouts[f"5_class_{i}"] = W_Out_5[f"class_{i}"]

# 10-class model, first 5
for i in range(1, 6):
    combined_Wouts[f"10_class_{i}"] = W_Out_10[f"class_{i}"]

# 10-class model, next 5
for i in range(6, 11):
    combined_Wouts[f"10_class_extra_{i}"] = W_Out_10[f"class_{i}"]

class_triplets = [
    (f"5_class_{i}", f"10_class_{i}", f"10_class_extra_{i+5}") for i in range(1, 6)
]


In [None]:
plot_reservoir_graph_comparisons(combined_Wouts, W, class_pairs, save_dir='reservoir_graphs', show=True)

In [None]:
plot_reservoir_graph_comparisons_3way(combined_Wouts, W, class_triplets, save_dir='reservoir_graphs_3way', show=True)

In [None]:
plot_reservoir_heatmap_comparisons_binned(combined_Wouts, class_triplets, save_dir='reservoir_heatmaps_3way', show=True)

In [None]:
plot_class_similarity_heatmaps_compressed(W_Out_5, W_Out_10)

plot_class_similarity_heatmaps(W_Out_5, W_Out_10, save_dir='reservoir_class_diff_heatmaps')