In [None]:

import os
os.environ['KERAS_BACKEND'] = 'torch'

import matplotlib.pyplot as plt
from IPython.display import clear_output
from typing import List
import torch
import keras
import xarray as xr
from src.EEGModalNet import WGAN_GP_old
from scipy.signal import butter, sosfiltfilt
import numpy as np
import pandas as pd
from sklearn.model_selection import GroupShuffleSplit
from sklearn.model_selection import ShuffleSplit
from sklearn.utils.class_weight import compute_class_weight
from keras import regularizers, layers

In [None]:
# PARAMS

DATA_PATH = 'data/LEMON_DATA/EC_all_channels_processed_downsampled.nc5'
CHECKPOINT_PATH = 'logs/06022025/06.02.2025_epoch_3000.model.keras'
CHANNELS = ['O1', 'O2', 'F1', 'F2', 'C1', 'C2', 'P1', 'P2']

DO_GROUPED_SHUFFLE = True
DO_BALANCE = True

In [None]:

def load_data(data_path: str,
              channels: List[str],
              n_subjects: int = 202,
              bandpass_filter: float = 1.0,
              time_dim: int = 1024,
              exclude_sub_ids=None) -> tuple:

    xarray = xr.open_dataarray(data_path, engine='h5netcdf')
    demog = pd.read_csv('data/LEMON_DATA/Demographics.csv', index_col="ID")
    demog['is_old'] = demog['Age'].apply(lambda x:
        2 if int(x.split('-')[0]) >= 55 else 1)
    is_old = demog.loc[xarray["subject"].values, "is_old"]
    xarray["is_old"] = xr.DataArray(is_old, dims='subject')

    x = xarray.sel(subject=xarray.subject[:n_subjects], channel=channels)

    if exclude_sub_ids is not None:
        x = x.sel(subject=~x.subject.isin(exclude_sub_ids))

    x = x.to_numpy()
    n_subjects = x.shape[0]

    if bandpass_filter is not None:
        sos = butter(4, bandpass_filter, btype='high', fs=128, output='sos')  # TODO: fs
        x = sosfiltfilt(sos, x, axis=-1)

    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    x = torch.tensor(x.copy()).unfold(2, time_dim, time_dim).permute(0, 2, 3, 1).flatten(0, 1)  # TODO: copy was added because of an error, look into this

    sub = torch.tensor(np.arange(0, n_subjects).repeat(x.shape[0] // n_subjects)[:, np.newaxis])
    labels = xarray.is_old.values - 1
    print(sub.shape, labels.shape)

    y = labels.repeat(x.shape[0] // 202)
    sub_ids_classifier = sub.squeeze().numpy()

    return x, y, sub_ids_classifier


In [None]:
# LOAD DATA

X_input, y, groups = load_data(DATA_PATH,
                                channels=CHANNELS,
                                n_subjects=202,
                                bandpass_filter=0.5,
                                time_dim=512,
                                exclude_sub_ids=None)

class_weights = compute_class_weight('balanced', classes=np.unique(y), y=y)
class_weights = {'0': class_weights[0], '1': class_weights[1]}
class_weights


In [None]:
# LOAD MODEL AND EXTRACT X_EMBEDDING (X_e and X_e_subj)

keras.utils.clear_session(free_memory=True)

model = WGAN_GP_old(time_dim=512, feature_dim=len(CHANNELS),
                    latent_dim=128, n_subjects=202,
                    use_sublayer_generator=True,
                    use_sublayer_critic=True,
                    use_channel_merger_g=False,
                    use_channel_merger_c=False,
                    interpolation='bilinear')

model.load_weights(CHECKPOINT_PATH)

X_e = model.critic.model.get_layer('dis_flatten')(X_input)

X_input_subj = model.critic.get_layer('torch_module_wrapper_1')(
    X_input.float().to('mps'),
    torch.tensor(groups).to('mps'))
X_e_subj = model.critic.model.get_layer('dis_flatten')(X_input_subj)

In [None]:
# BALANCE DATASET and SPLIT

if DO_BALANCE:
    y_df = pd.DataFrame(y, columns=['y'])
    y0_size = y_df.query('y==0').shape[0]
    y1_size = y_df.query('y==1').shape[0]

    minority_size = min(y0_size, y1_size)

    resample_idx = np.concatenate([
        y_df.query('y==0').sample(minority_size, random_state=1).index.values,
        y_df.query('y==1').sample(minority_size, random_state=1).index.values])

    groups_resampled = groups[resample_idx]
    X_e_resampled = X_e_subj[resample_idx].detach().cpu()
    y_resampled = y[resample_idx]
else:
    groups_resampled = groups
    X_e_resampled = X_e_subj.detach().cpu()
    y_resampled = y


# train/test split

if DO_GROUPED_SHUFFLE:
    group_shuffle = GroupShuffleSplit(n_splits=1, test_size=0.3, random_state=2)
    train_idx, val_idx = next(group_shuffle.split(X_e_resampled, y_resampled, groups=groups_resampled))
else:
    random_shuffle = ShuffleSplit(n_splits=1, test_size=0.3, random_state=1)
    train_idx, val_idx = next(random_shuffle.split(X_e_resampled, y_resampled))

# y.mean(), y_resampled[train_idx].mean(), y_resampled[val_idx].mean()

In [None]:
cls_model_path = 'logs/20.02.2025_classifier_v2'
callbacks = [
    keras.callbacks.ModelCheckpoint(f'{cls_model_path}.model.keras', monitor='val_accuracy', save_best_only=True),
    keras.callbacks.CSVLogger(f'{cls_model_path}.csv'),
    keras.callbacks.TerminateOnNaN()
]

cls_model = keras.models.Sequential([   
    layers.Dense(128, activation='relu',
                 kernel_regularizer=regularizers.l2(0.01)),
    layers.Dropout(0.2),
    layers.Dense(32, activation='relu',
                 kernel_regularizer=regularizers.l2(0.01)),
    layers.Dropout(0.2),
    layers.Dense(1, activation='sigmoid')
])

def smoothed_binary_crossentropy(epsilon=0.1): # Epsilon controls smoothing
    def loss(y_true, y_pred):
        y_true_smoothed = y_true * (1.0 - epsilon) + 0.5 * epsilon # Soften labels
        return keras.losses.binary_crossentropy(y_true_smoothed, y_pred)
    return loss


cls_model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.001),
    loss='binary_crossentropy',
    # loss=keras.losses.Hinge(),
    # loss=smoothed_binary_crossentropy(),
    # loss='mse',
    # loss=keras.losses.BinaryFocalCrossentropy(),
    metrics=['accuracy'])

history = cls_model.fit(
    X_e_resampled[train_idx], y_resampled[train_idx],
    epochs=1000,
    batch_size=20000,
    validation_data=(X_e_resampled[val_idx], y_resampled[val_idx]),
    callbacks=callbacks,
    class_weight=class_weights,
    shuffle=True)

history_df = pd.DataFrame(history.history)
history_df.to_csv(f'{cls_model_path}_final.csv')
cls_model.save(f'{cls_model_path}_final.model.keras')

clear_output(wait=True)
print('Training finished!')

In [None]:
# PLOTTING LOSS AND ACCURACY

def plot_history(history_df):
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 2, 1)
    plt.plot(history_df['loss'], label='train')
    plt.plot(history_df['val_loss'], label='val')
    plt.legend()
    plt.xlabel('epoch')
    plt.title('loss')

    plt.subplot(1, 2, 2)
    plt.plot(history_df['accuracy'], label='train')
    plt.plot(history_df['val_accuracy'], label='val')
    plt.legend()
    plt.title('accuracy')
    plt.xlabel('epoch')
    plt.show()

plot_history(history_df)