# Description: 
This notebook contains the functional embedding being trained on multiple subjetcs' real recordings using PSC and MSC methods and how it generalizes over held out time segments and held out channels.

# Import Libraries 

In [None]:
# general libraries

from torch.utils.data import DataLoader
from scipy.io import loadmat
import plotly.io as pio
import pickle


In [2]:
# local libraries
from src.tools.simulation import *
from src.models.CNN_encoder import *
from src.signal.signal_processing import *
from src.pipelines.embedding_data_preparation import *
from src.utils.utils import *
from src.training.PSC_training import *
from src.training.visualize_embedding import *
from src.training.evaluation import *
from src.training.MSC_training import *


In [3]:
# dataset related libraries (have to replace these with your own dataset loading functions), replace with your own dataset source files
import sys
import os

directory_path = "D:/Python Projects/ICLR 2026/Dataset_Related"
sys.path.append(os.path.abspath(directory_path))
from dataset import *  

In [4]:
# Dataset related functions and mappings
# These functions should be customized for each subject and dataset, here example functions are provided


# Define which subjects and sessions to use. subjects are keys and sessions are values
# subject_sessions = {
#     's1': ['p1'], 
#     's2': ['p1'], 
#     's3': ['p4', 'p6'], 
#     's4': ['p1', 'p2'],
#     's5': ['p1', 'p2', 'p3', 'p6', 'p7'],
#     's6': ['p1', 'p2'], 
#     's7': ['p1'], 
#     's8': ['p9'],
#     's9': ['p1', 'p2', 'p3'], 
#     's10': ['p6', 'p7', 'p8', 'p9'],
# }

# def get_true_region(subject, electrode_name, side, electrode_number, verbose=False):
#     """
#     This function defines the mapping from electrode name and contact number to true brain region for a given subject.
#     """
#     if verbose: print(f"using default arrangement for {subject}")
#     if 'GPi'.lower() in electrode_name.lower() and electrode_number in [4,5,6]:
#         return 'GPi'
#     elif 'VPLa'.lower() in electrode_name.lower() and electrode_number in [4,5,6]:
#         return 'VPla'
#     elif 'VoSTN'.lower() in electrode_name.lower() and electrode_number in [4,5,6]:
#         return 'VO'
#     elif 'VoSTN'.lower() in electrode_name.lower() and electrode_number in [1,2,3]:
#         return 'STN'
#     elif 'STN'.lower() in electrode_name.lower() and electrode_number in [4,5,6]: 
#         return 'STN'
#     elif 'VIM'.lower() in electrode_name.lower() and electrode_number in [4,5,6]:
#         return 'VIM'
#     elif 'PPN'.lower() in electrode_name.lower() and electrode_number in [1,2,3]:
#         return 'PPN'
#     elif 'VA'.lower() in electrode_name.lower() and electrode_number in [4,5,6]:
#         return 'VA'
#     elif 'VoaVop'.lower() in electrode_name.lower() and electrode_number in [4,5,6]:
#         return 'VoaVop'
#     elif 'VoaVop'.lower() in electrode_name.lower() and electrode_number in [4,5,6]:
#         return 'VoaVop'
#     else: 
#         return 'unknown'

# Training functional encoder on aggregated dataset including several subjects and sessions Using PSC method
This sections creates a new dataset and adds all recording sessions and trains and evaluates a functional encoder on that dataset, using held out times (test dataset) and held out channels for evaluation. 

In [None]:
# Check GPU
device = get_device()
print(f" Using {device}")

# ----------- Config ------------ 
# training parameters
EPOCHS = 1
DATASET_SIZE = 1000000
LEARNING_RATE = 1e-2
BATCH_SIZE = 1024

# model parameters
EMBEDDING_DIM = 32
VERSION = "v1(PSC)"
DROPOUT = 0.3

# data preparation 
SEGMENT_LENGTH = 10000 # in ms here is 10 seconds

# visualization
show_model = True
#--------------------------------


dataPath = "D:\\transfer\\micro voluntary dataset\\lfs data" # path to data folder (replace with your own data path)
experiment = "Multi session training" 
subject_counter = 0

training_segments_dict = {}
validation_segments_dict = {}
testing_segments_dict = {}
held_out_channels_dict = {}

for isubject in subject_sessions.keys():
    subject_counter += 1
    for iperiod in subject_sessions[isubject]:
        print("Loading data from subject: S" + str(subject_counter) + " period: " + iperiod)
        # load microelectrode labels which contains the names of all microelectrodes in this session
        micro_electrode_path = dataPath  + os.sep + isubject + os.sep + iperiod + os.sep
        try:
            microelectrodes = load_micro_labels(micro_electrode_path + "microLabels.mat")
        except:
            mat_contents = loadmat(micro_electrode_path + "microLabels.mat")
            micro_labels = mat_contents['microLabels'] 
            microelectrodes = [str(label[0]) if isinstance(label, np.ndarray) else str(label) for label in micro_labels.squeeze()]

        
        region_electrode_count = {}
        hold_out_check = 4 # holds 1 channel out every n channels

        # Parse and print them
        for microelectrode in microelectrodes:
            region, side, electrode = parse_label(microelectrode)

            # get true region of this electrode
            true_region = get_true_region(isubject, region, side, electrode)
            if true_region != 'unknown':
                print(f"Loading {microelectrode} --> Region: {region}, Side: {side}, Electrode: {electrode} true region: {true_region}")

                # load data
                mat_path = os.path.join(micro_electrode_path, microelectrode + ".mat")
                raw_signal = None 
                try:
                    # First try HDF5 (MATLAB v7.3 format)
                    with h5py.File(mat_path, "r") as f:
                        raw_signal = np.array(f["data"]).squeeze()
                        fs = int(np.array(f["fs"]).squeeze())
                except FileNotFoundError:
                    print(f"⚠️ File not found: {mat_path}. Skipping electrode.")
                    continue  # Skip to next electrode
                except (OSError, KeyError):
                    # Fall back to standard .mat format (pre-v7.3)
                    try:
                        mat = scipy.io.loadmat(mat_path)
                        raw_signal = np.array(mat["data"]).squeeze()
                        fs = int(np.array(mat["fs"]).squeeze())
                    except Exception as e:
                        print(f"⚠️ Failed to load (not a valid .mat file): {mat_path}. Error: {e}")
                        continue  # Skip if even fallback fails

                # Remove big 60Hz artifacts
                filtered_signal = notch_filter(raw_signal,fs)

                # Normalize session recording
                normalized_signal = normalize_signal(filtered_signal)

                # Apply notch filters at 60Hz and its harmonics
                normalized_signal = notch_filter(normalized_signal,fs,freq=60)
                normalized_signal = notch_filter(normalized_signal,fs,freq=120)
                normalized_signal = notch_filter(normalized_signal,fs,freq=180)

                # Segment data into windows
                signal_segments = segment_data(normalized_signal, SEGMENT_LENGTH)

                # Skip sessions with too few segments
                if len(signal_segments) < 10:
                    print("session too short")
                    continue

                # Check if this region already has 3 electrodes, add this channel to held out channels
                region_electrode_count[true_region] = region_electrode_count.get(true_region ,0) + 1
                if region_electrode_count[true_region] % hold_out_check ==0:
                    # add to held out channel dictionary
                    if true_region in held_out_channels_dict:
                        held_out_channels_dict[true_region] = np.concat([held_out_channels_dict[true_region], signal_segments])
                    else:
                        held_out_channels_dict[true_region] = signal_segments
                else: 
                    # Split segments into training, validation, and testing sets
                    num_val = int(0.15 * len(signal_segments))
                    num_test = int(0.15 * len(signal_segments))
                    val_segments = signal_segments[:num_val]
                    test_segments = signal_segments[num_val:num_val+num_test]
                    train_segments = signal_segments[num_val+num_test:]

                    # add to training dictionary
                    if true_region in training_segments_dict:
                        training_segments_dict[true_region] = np.concat([training_segments_dict[true_region], train_segments])
                    else:
                        training_segments_dict[true_region] = train_segments

                    # add to validation dictionary
                    if true_region in validation_segments_dict:
                        validation_segments_dict[true_region] = np.concat([validation_segments_dict[true_region], val_segments])
                    else:
                        validation_segments_dict[true_region] = val_segments

                    # add to testing dictionary
                    if true_region in testing_segments_dict:
                        testing_segments_dict[true_region] = np.concat([testing_segments_dict[true_region], test_segments])
                    else:
                        testing_segments_dict[true_region] = test_segments

        # print held out channels
        print("held out channels : ", region_electrode_count)


# Shuffle segments within each region key before training
print("Shuffle segments...")
for d in [training_segments_dict, validation_segments_dict, testing_segments_dict, held_out_channels_dict]:
    for k in d:
        np.random.shuffle(d[k])

# save datasets
save_path = os.path.join("data","multi_subject_training","data_splits")
os.makedirs(save_path, exist_ok=True)
with open(os.path.join(save_path, "training_segments.pkl"), "wb") as f:
    pickle.dump(training_segments_dict, f)
with open(os.path.join(save_path, "validation_segments.pkl"), "wb") as f:
    pickle.dump(validation_segments_dict, f)
with open(os.path.join(save_path, "testing_segments.pkl"), "wb") as f:
    pickle.dump(testing_segments_dict, f)
with open(os.path.join(save_path, "heldout_channels.pkl"), "wb") as f:
    pickle.dump(held_out_channels_dict, f)


# Generate training and validation pairs for the Pairwise Siamese Contrastive (PSC) method
print("Generating training and validation pairs...")
train_pairs, train_labels = create_balanced_pairs(training_segments_dict, DATASET_SIZE)
val_pairs, val_labels = create_balanced_pairs(validation_segments_dict, min(int(0.2 * DATASET_SIZE),20000))

# Create DataLoaders
train_ds = LFPDataset(train_pairs, train_labels)
val_ds = LFPDataset(val_pairs, val_labels)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE)
print(f"Train set: {len(train_ds)} pairs")
print(f"Validation set: {len(val_ds)} pairs")
print(f"Batch size: {BATCH_SIZE}")

# Print shapes for sanity check
print("===== Training Input Shape Info =====")
for x1, x2, _ in train_loader:
    print(f"Train input shapes: x1: {x1.shape}, x2: {x2.shape}")
    break
print("===== Validation Input Shape Info =====")
for x1, x2, _ in val_loader:
    print(f"Validation input shapes: x1: {x1.shape}, x2: {x2.shape}\n\n")
    break

# Model setup
print("Training model...")
model = SiameseNetMultiSubject(embedding_dim=EMBEDDING_DIM,dropout=DROPOUT, normalized_output=True).to(device)
if show_model:
    print_model_summary(model, train_loader, device)
    show_model = False

# train model
save_path = os.path.join("results","multi_subject","PSC_model")
os.makedirs(save_path, exist_ok=True)
trained_model = training_validate_loop(model, train_loader, val_loader, device, save_path, EPOCHS, LEARNING_RATE)

# Extract encoder and evaluate
trained_model.eval()
encoder_model = trained_model.encoder
encoder_model.eval()

# Extract embeddings
print("Extracting embeddings...")
train_emb, train_labels_text = embed_segments_dict(training_segments_dict, encoder_model, device)
val_emb, val_labels_text = embed_segments_dict(validation_segments_dict, encoder_model, device)
test_emb, test_labels_text = embed_segments_dict(testing_segments_dict, encoder_model, device)
held_emb, held_labels_text = embed_segments_dict(held_out_channels_dict, encoder_model, device)

# KNN Evaluation
print("Extracting embeddings...")
for k in training_segments_dict:
    print(f"{k}: {len(training_segments_dict[k])} segments")

# make region balanced reference set for knn
ref_emb, ref_labels = make_balanced_reference(train_emb, train_labels_text, per_class=None, seed=42)
results_train = evaluate_knn(ref_emb, ref_labels, train_emb, train_labels_text)
results_val = evaluate_knn(ref_emb, ref_labels, val_emb, val_labels_text)
results_test = evaluate_knn(ref_emb, ref_labels, test_emb, test_labels_text)
results_held = evaluate_knn(ref_emb, ref_labels, held_emb, held_labels_text)

# Chance & clustering
print ("Computing clustering metrics...")
chance_train = compute_chance_levels(train_labels_text)
chance_val = compute_chance_levels(val_labels_text)
chance_test = compute_chance_levels(test_labels_text)    
chance_held = compute_chance_levels(held_labels_text)

metrics_train = evaluate_clustering_metrics(train_emb, train_labels_text)
metrics_val = evaluate_clustering_metrics(val_emb, val_labels_text)
metrics_test = evaluate_clustering_metrics(test_emb, test_labels_text)
metrics_held = evaluate_clustering_metrics(held_emb, held_labels_text)

# Save confusion matrix 
print("Generating and saving confusion matrices...")
plot_and_save_confusion_matrix(val_labels_text, results_val['predictions'], sorted(set(val_labels_text)), "Validation Confusion", os.path.join(save_path, "cm_val.png"))
plot_and_save_confusion_matrix(test_labels_text, results_test['predictions'], sorted(set(test_labels_text)), "Test Confusion", os.path.join(save_path, "cm_test.png"))
plot_and_save_confusion_matrix(held_labels_text, results_held['predictions'], sorted(set(held_labels_text)), "Held-Out Subject Confusion", os.path.join(save_path, "cm_heldout.png"))

# Embedding visualization
test_region_list = [label for label in test_labels_text]
test_side_list = ["" for _ in test_labels_text]
test_electrode_list = [""] * len(test_labels_text)
fig,_ = plot_interactive_embeddings(test_emb, test_region_list, test_electrode_list, test_side_list, dim=2, method='pca', metric='euclidean', show_ellipses=True, verbose=False)
pio.write_image(fig, os.path.join(save_path, f"embedding_testset_pca_2D.png"), format="png", width=950, height=700)
pio.write_json(fig,  os.path.join(save_path, f"embedding_testset_pca_2D.json"), pretty=True)

fig,_ = plot_interactive_embeddings(test_emb, test_region_list, test_electrode_list, test_side_list, dim=3, method='pca', metric='euclidean', show_ellipses=True, verbose=False)
pio.write_image(fig, os.path.join(save_path, f"embedding_testset_pca_3D.png"), format="png", width=950, height=700)
pio.write_json(fig,  os.path.join(save_path, f"embedding_testset_pca_3D.json"), pretty=True)
has_heldout = any(len(v) > 0 for v in held_out_channels_dict.values())

held_out_region_list = [label for label in held_labels_text]
held_out_side_list = ["" for _ in held_labels_text]
held_out_electrode_list = [""] * len(held_labels_text)
fig,_ = plot_interactive_embeddings(held_emb, held_out_region_list, held_out_electrode_list, held_out_side_list, dim=3, method='pca', metric='euclidean', show_ellipses=True, verbose=False)
pio.write_image(fig, os.path.join(save_path, f"embedding_heldout_pca_3D.png"), format="png", width=950, height=700)
pio.write_json(fig,  os.path.join(save_path, f"embedding_heldout_pca_3D.json"), pretty=True)


# Save results
save_evaluation_to_excel(
    subject_id="Multi-subject training",
    model_name=VERSION,
    results_train=results_train,
    results_test=results_test,
    results_heldout=results_held,
    clustering_train=metrics_train,
    clustering_test=metrics_test,
    clustering_heldout=metrics_held,
    chance_train=chance_train,
    chance_test=chance_test,
    chance_heldout=chance_held,
    save_path=os.path.join("results","multi_subject" , "evaluation_summary_multi_subject.xlsx")
)


 Using cuda
Loading data from subject: S1 period: p7
Loading microSTN_L_1_CommonFiltered_lfs --> Region: STN, Side: L, Electrode: 1 true region: STN
Segments before cleaning: 144 | After cleaning: 118
Loading microSTN_L_2_CommonFiltered_lfs --> Region: STN, Side: L, Electrode: 2 true region: STN
Segments before cleaning: 144 | After cleaning: 141
Loading microSTN_L_3_CommonFiltered_lfs --> Region: STN, Side: L, Electrode: 3 true region: STN
Segments before cleaning: 144 | After cleaning: 78
Loading microVoaVop_L_1_CommonFiltered_lfs --> Region: VoaVop, Side: L, Electrode: 1 true region: VO
Segments before cleaning: 144 | After cleaning: 78
Loading microVoaVop_L_2_CommonFiltered_lfs --> Region: VoaVop, Side: L, Electrode: 2 true region: VO
Segments before cleaning: 144 | After cleaning: 81
Loading microVoaVop_L_3_CommonFiltered_lfs --> Region: VoaVop, Side: L, Electrode: 3 true region: VO
Segments before cleaning: 144 | After cleaning: 113
Loading microSTN_R_1_CommonFiltered_lfs --> Reg

# Training functional encoder on aggregated dataset including several subjects and sessions Using MSC method
This sections creates a new dataset and adds all recording sessions and trains and evaluates a functional encoder on that dataset, using held out times (test dataset) and held out channels for evaluation. 

In [None]:
# Check GPU
device = get_device()
print(f" Using {device}")

# ----------- Config ------------ 
# training parameters
EPOCHS = 200
DATASET_SIZE = 1000000
LEARNING_RATE = 1e-3
BATCH_SIZE = 128
TEMPERATURE = 0.2

# model parameters
EMBEDDING_DIM = 32
PROJ_DIM = 32
VERSION = "v1(MSC)"
DROPOUT = 0.35
RANDOM_SEED=42


# data preparation 
SEGMENT_LENGTH = 10000 # in ms here is 10 seconds

# visualization
show_model = True
#--------------------------------


dataPath = "D:\\transfer\\micro voluntary dataset\\lfs data" # path to data folder (replace with your own data path)
experiment = "Multi session training" 
subject_counter = 0

training_segments_dict = {}
validation_segments_dict = {}
testing_segments_dict = {}
held_out_channels_dict = {}

for isubject in subject_sessions.keys():
    subject_counter += 1
    for iperiod in subject_sessions[isubject]:
        print("Loading data from subject: S" + str(subject_counter) + " period: " + iperiod)
        # load microelectrode labels which contains the names of all microelectrodes in this session
        micro_electrode_path = dataPath  + os.sep + isubject + os.sep + iperiod + os.sep
        try:
            microelectrodes = load_micro_labels(micro_electrode_path + "microLabels.mat")
        except:
            mat_contents = loadmat(micro_electrode_path + "microLabels.mat")
            micro_labels = mat_contents['microLabels'] 
            microelectrodes = [str(label[0]) if isinstance(label, np.ndarray) else str(label) for label in micro_labels.squeeze()]

        
        region_electrode_count = {}
        hold_out_check = 4 # holds 1 channel out every n channels

        # Parse and print them
        for microelectrode in microelectrodes:
            region, side, electrode = parse_label(microelectrode)

            # get true region of this electrode
            true_region = get_true_region(isubject, region, side, electrode)
            if true_region != 'unknown':
                print(f"Loading {microelectrode} --> Region: {region}, Side: {side}, Electrode: {electrode} true region: {true_region}")

                # load data
                mat_path = os.path.join(micro_electrode_path, microelectrode + ".mat")
                raw_signal = None 
                try:
                    # First try HDF5 (MATLAB v7.3 format)
                    with h5py.File(mat_path, "r") as f:
                        raw_signal = np.array(f["data"]).squeeze()
                        fs = int(np.array(f["fs"]).squeeze())
                except FileNotFoundError:
                    print(f"⚠️ File not found: {mat_path}. Skipping electrode.")
                    continue  # Skip to next electrode
                except (OSError, KeyError):
                    # Fall back to standard .mat format (pre-v7.3)
                    try:
                        mat = scipy.io.loadmat(mat_path)
                        raw_signal = np.array(mat["data"]).squeeze()
                        fs = int(np.array(mat["fs"]).squeeze())
                    except Exception as e:
                        print(f"⚠️ Failed to load (not a valid .mat file): {mat_path}. Error: {e}")
                        continue  # Skip if even fallback fails

                # Remove big 60Hz artifacts
                filtered_signal = notch_filter(raw_signal,fs)

                # Normalize session recording
                normalized_signal = normalize_signal(filtered_signal)

                # Apply notch filters at 60Hz and its harmonics
                normalized_signal = notch_filter(normalized_signal,fs,freq=60)
                normalized_signal = notch_filter(normalized_signal,fs,freq=120)
                normalized_signal = notch_filter(normalized_signal,fs,freq=180)

                # Segment data into windows
                signal_segments = segment_data(normalized_signal, SEGMENT_LENGTH)

                # Skip sessions with too few segments
                if len(signal_segments) < 10:
                    print("session too short")
                    continue

                # Check if this region already has 3 electrodes, add this channel to held out channels
                region_electrode_count[true_region] = region_electrode_count.get(true_region ,0) + 1
                if region_electrode_count[true_region] % hold_out_check ==0:
                    # add to held out channel dictionary
                    if true_region in held_out_channels_dict:
                        held_out_channels_dict[true_region] = np.concat([held_out_channels_dict[true_region], signal_segments])
                    else:
                        held_out_channels_dict[true_region] = signal_segments
                else: 
                    # Split segments into training, validation, and testing sets
                    num_val = int(0.15 * len(signal_segments))
                    num_test = int(0.15 * len(signal_segments))
                    val_segments = signal_segments[:num_val]
                    test_segments = signal_segments[num_val:num_val+num_test]
                    train_segments = signal_segments[num_val+num_test:]

                    # add to training dictionary
                    if true_region in training_segments_dict:
                        training_segments_dict[true_region] = np.concat([training_segments_dict[true_region], train_segments])
                    else:
                        training_segments_dict[true_region] = train_segments

                    # add to validation dictionary
                    if true_region in validation_segments_dict:
                        validation_segments_dict[true_region] = np.concat([validation_segments_dict[true_region], val_segments])
                    else:
                        validation_segments_dict[true_region] = val_segments

                    # add to testing dictionary
                    if true_region in testing_segments_dict:
                        testing_segments_dict[true_region] = np.concat([testing_segments_dict[true_region], test_segments])
                    else:
                        testing_segments_dict[true_region] = test_segments

        # print held out channels
        print("held out channels : ", region_electrode_count)


# Shuffle segments within each region key before training
print("Shuffle segments...")
for d in [training_segments_dict, validation_segments_dict, testing_segments_dict, held_out_channels_dict]:
    for k in d:
        np.random.shuffle(d[k])

# save datasets
save_path = os.path.join("data","multi_subject_training","data_splits")
os.makedirs(save_path, exist_ok=True)
with open(os.path.join(save_path, "training_segments.pkl"), "wb") as f:
    pickle.dump(training_segments_dict, f)
with open(os.path.join(save_path, "validation_segments.pkl"), "wb") as f:
    pickle.dump(validation_segments_dict, f)
with open(os.path.join(save_path, "testing_segments.pkl"), "wb") as f:
    pickle.dump(testing_segments_dict, f)
with open(os.path.join(save_path, "heldout_channels.pkl"), "wb") as f:
    pickle.dump(held_out_channels_dict, f)


# Generate training and validation pairs for the Pairwise Siamese Contrastive (PSC) method
print("Generating training and validation datasets...")
train_supcon_ds = SupConDataset(training_segments_dict, transform=zscore_transform)
val_supcon_ds   = SupConDataset(validation_segments_dict, transform=zscore_transform)

train_supcon_loader = DataLoader(train_supcon_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
val_supcon_loader   = DataLoader(val_supcon_ds,   batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
print(f"Train set: {len(train_supcon_ds)} pairs")
print(f"Validation set: {len(val_supcon_ds)} pairs")
print(f"Batch size: {BATCH_SIZE}")

# --- keep  pair dataset for validation after each epoch to compare with PSC ---
val_pairs, val_labels = create_balanced_pairs(validation_segments_dict, max_dataset_size=min(20000, len(val_supcon_ds)*2))
val_pair_ds = LFPDataset(val_pairs, val_labels)
val_pair_loader = DataLoader(val_pair_ds, batch_size=BATCH_SIZE)


# Model setup
print("Training model...")
model = SupConNetMultiSubject(embedding_dim=EMBEDDING_DIM, proj_dim=PROJ_DIM, dropout=DROPOUT).to(device)
if show_model:
    print_model_summary(model, train_supcon_loader, device)
    show_model = False

save_path = os.path.join("results","multi_subject","MSC_model")
os.makedirs(save_path, exist_ok=True)
trained_model = training_validate_loop_supcon(
model, train_supcon_loader, val_pair_loader, device, save_path,
epochs=EPOCHS, lr=LEARNING_RATE,             # lower LR
temperature=TEMPERATURE,                    # a bit larger temp often helps early
add_nn=True                        # disable NN positives on single-subject
)


# Extract encoder and evaluate
trained_model.eval()
encoder_model = trained_model.encoder
encoder_model.eval()

# Extract embeddings
print("Extracting embeddings...")
train_emb, train_labels_text = embed_segments_dict(training_segments_dict, encoder_model, device)
val_emb, val_labels_text = embed_segments_dict(validation_segments_dict, encoder_model, device)
test_emb, test_labels_text = embed_segments_dict(testing_segments_dict, encoder_model, device)
held_emb, held_labels_text = embed_segments_dict(held_out_channels_dict, encoder_model, device)

# KNN Evaluation
print("Extracting embeddings...")
for k in training_segments_dict:
    print(f"{k}: {len(training_segments_dict[k])} segments")

# make region balanced reference set for knn
ref_emb, ref_labels = make_balanced_reference(train_emb, train_labels_text, per_class=None, seed=42)
results_train = evaluate_knn(ref_emb, ref_labels, train_emb, train_labels_text)
results_val = evaluate_knn(ref_emb, ref_labels, val_emb, val_labels_text)
results_test = evaluate_knn(ref_emb, ref_labels, test_emb, test_labels_text)
results_held = evaluate_knn(ref_emb, ref_labels, held_emb, held_labels_text)

# Chance & clustering
print ("Computing clustering metrics...")
chance_train = compute_chance_levels(train_labels_text)
chance_val = compute_chance_levels(val_labels_text)
chance_test = compute_chance_levels(test_labels_text)    
chance_held = compute_chance_levels(held_labels_text)

metrics_train = evaluate_clustering_metrics(train_emb, train_labels_text)
metrics_val = evaluate_clustering_metrics(val_emb, val_labels_text)
metrics_test = evaluate_clustering_metrics(test_emb, test_labels_text)
metrics_held = evaluate_clustering_metrics(held_emb, held_labels_text)

# Save confusion matrix 
print("Generating and saving confusion matrices...")
plot_and_save_confusion_matrix(val_labels_text, results_val['predictions'], sorted(set(val_labels_text)), "Validation Confusion", os.path.join(save_path, "cm_val.png"))
plot_and_save_confusion_matrix(test_labels_text, results_test['predictions'], sorted(set(test_labels_text)), "Test Confusion", os.path.join(save_path, "cm_test.png"))
plot_and_save_confusion_matrix(held_labels_text, results_held['predictions'], sorted(set(held_labels_text)), "Held-Out channel Confusion", os.path.join(save_path, "cm_heldout.png"))

# Embedding visualization
test_region_list = [label for label in test_labels_text]
test_side_list = ["" for _ in test_labels_text]
test_electrode_list = [""] * len(test_labels_text)
fig,_ = plot_interactive_embeddings(test_emb, test_region_list, test_electrode_list, test_side_list, dim=2, method='pca', metric='euclidean', show_ellipses=True, verbose=False)
pio.write_image(fig, os.path.join(save_path, f"embedding_testset_pca_2D.png"), format="png", width=950, height=700)
pio.write_json(fig,  os.path.join(save_path, f"embedding_testset_pca_2D.json"), pretty=True)

fig,_ = plot_interactive_embeddings(test_emb, test_region_list, test_electrode_list, test_side_list, dim=3, method='pca', metric='euclidean', show_ellipses=True, verbose=False)
pio.write_image(fig, os.path.join(save_path, f"embedding_testset_pca_3D.png"), format="png", width=950, height=700)
pio.write_json(fig,  os.path.join(save_path, f"embedding_testset_pca_3D.json"), pretty=True)
has_heldout = any(len(v) > 0 for v in held_out_channels_dict.values())

held_out_region_list = [label for label in held_labels_text]
held_out_side_list = ["" for _ in held_labels_text]
held_out_electrode_list = [""] * len(held_labels_text)
fig,_ = plot_interactive_embeddings(held_emb, held_out_region_list, held_out_electrode_list, held_out_side_list, dim=3, method='pca', metric='euclidean', show_ellipses=True, verbose=False)
pio.write_image(fig, os.path.join(save_path, f"embedding_heldout_pca_3D.png"), format="png", width=950, height=700)
pio.write_json(fig,  os.path.join(save_path, f"embedding_heldout_pca_3D.json"), pretty=True)


# Save results
save_evaluation_to_excel(
    subject_id="Multi-subject training",
    model_name=VERSION,
    results_train=results_train,
    results_test=results_test,
    results_heldout=results_held,
    clustering_train=metrics_train,
    clustering_test=metrics_test,
    clustering_heldout=metrics_held,
    chance_train=chance_train,
    chance_test=chance_test,
    chance_heldout=chance_held,
    save_path=os.path.join("results","multi_subject" , "evaluation_summary_multi_subject.xlsx")
)


 Using cuda
Loading data from subject: S1 period: p7
Loading microSTN_L_1_CommonFiltered_lfs --> Region: STN, Side: L, Electrode: 1 true region: STN
Segments before cleaning: 144 | After cleaning: 118
Loading microSTN_L_2_CommonFiltered_lfs --> Region: STN, Side: L, Electrode: 2 true region: STN
Segments before cleaning: 144 | After cleaning: 141
Loading microSTN_L_3_CommonFiltered_lfs --> Region: STN, Side: L, Electrode: 3 true region: STN
Segments before cleaning: 144 | After cleaning: 78
Loading microVoaVop_L_1_CommonFiltered_lfs --> Region: VoaVop, Side: L, Electrode: 1 true region: VO
Segments before cleaning: 144 | After cleaning: 78
Loading microVoaVop_L_2_CommonFiltered_lfs --> Region: VoaVop, Side: L, Electrode: 2 true region: VO
Segments before cleaning: 144 | After cleaning: 81
Loading microVoaVop_L_3_CommonFiltered_lfs --> Region: VoaVop, Side: L, Electrode: 3 true region: VO
Segments before cleaning: 144 | After cleaning: 113
Loading microSTN_R_1_CommonFiltered_lfs --> Reg