In [None]:
import os
import mne
import numpy as np
import gc
import tensorflow as tf
from sklearn.model_selection import LeaveOneOut
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, Conv1D, MaxPooling1D, TimeDistributed,
    LSTM, Dense, Dropout, BatchNormalization, Reshape
)
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras import backend as K


In [None]:
# CONFIGURATION
DATA_ROOT = r'/content/drive/MyDrive/EEG'
N_SUBJECTS = 109
RUN_ID = '04'
N_CHANNELS = 64
SFREQ = 160.0
DOWNSAMPLE_FREQ = 80.0
EPOCH_DURATION_SEC = 4.0
N_TIMESTEPS = int(EPOCH_DURATION_SEC * DOWNSAMPLE_FREQ)  # 320
N_CLASSES = 2
EVENT_ID = {'T1': 1, 'T2': 2}
CHANNEL_DROPOUT_RATE = 0.2


In [None]:
def create_cnn_rnn_model(input_shape=(N_CHANNELS, N_TIMESTEPS, 1)):
    input_layer = Input(shape=input_shape, name='eeg_input')

    # Channel dropout
    x = Dropout(CHANNEL_DROPOUT_RATE, noise_shape=(None, N_CHANNELS, 1, 1),
                name='channel_dropout')(input_layer)

    # TimeDistributed CNN
    x = TimeDistributed(Conv1D(16, 3, activation='relu', padding='same'),
                        name='td_conv_spatial')(x)
    x = TimeDistributed(BatchNormalization(), name='td_bn_1')(x)
    x = TimeDistributed(MaxPooling1D(2), name='td_maxpool_temporal')(x)

    # Reshape for RNN
    x = Reshape((N_TIMESTEPS // 2, N_CHANNELS * 16), name='reshape_for_rnn')(x)

    # LSTM layers
    x = LSTM(32, return_sequences=True, name='lstm_1')(x)
    x = Dropout(0.5, name='dropout_1')(x)
    x = LSTM(16, return_sequences=False, name='lstm_2')(x)

    x = Dense(16, activation='relu', name='dense_1')(x)
    output_layer = Dense(N_CLASSES, activation='softmax', name='output')(x)

    return Model(inputs=input_layer, outputs=output_layer)


In [None]:
def load_and_preprocess_data():
    all_epochs = []
    print("Loading and preprocessing data...")

    for sub_idx in range(1, N_SUBJECTS + 1):
        sub_str = f'S{sub_idx:03d}'
        file_name = f'{sub_str}R{RUN_ID}.edf'
        file_path = os.path.join(DATA_ROOT, sub_str, file_name)

        try:
            raw = mne.io.read_raw_edf(file_path, preload=True, verbose=False)

            # Filter
            raw.filter(8., 30., fir_design='firwin', skip_by_annotation='edge', verbose=False)

            # Downsample
            raw.resample(DOWNSAMPLE_FREQ, npad="auto")

            # Epoching
            events, _ = mne.events_from_annotations(raw, event_id=EVENT_ID, verbose=False)
            epochs = mne.Epochs(
                raw, events, EVENT_ID,
                tmin=0., tmax=EPOCH_DURATION_SEC,
                baseline=None, preload=True, verbose=False
            )

            del raw
            gc.collect()

            if len(epochs) == 0:
                print(f"No epochs for subject {sub_idx}, skipping.")
                continue

            all_epochs.append(epochs)

        except FileNotFoundError:
            print(f"Missing: {file_path}")
        except Exception as e:
            print(f"Error in subject {sub_str}: {e}")

    print(f"\nLoaded {len(all_epochs)} subjects.")
    return all_epochs


In [None]:
def convert_epochs_to_numpy(all_epochs):
    X_list, Y_list, S_list = [], [], []

    for subject_id, epochs in enumerate(all_epochs, start=1):
        data = epochs.get_data(units='uV')
        current_timesteps = data.shape[2]

        # Fix length
        if current_timesteps > N_TIMESTEPS:
            data = data[:, :, :N_TIMESTEPS]
        elif current_timesteps < N_TIMESTEPS:
            pad = ((0,0),(0,0),(0, N_TIMESTEPS-current_timesteps))
            data = np.pad(data, pad, mode='constant')

        data = np.expand_dims(data, axis=-1)
        X_list.append(data)

        labels = epochs.events[:, 2] - 1
        Y_list.append(labels)

        S_list.append(np.full(len(epochs), subject_id))

    X = np.concatenate(X_list, axis=0)
    Y = np.concatenate(Y_list, axis=0)
    S = np.concatenate(S_list, axis=0)

    # Normalization
    X = (X - np.mean(X)) / np.std(X)

    print(f"Shapes â†’ X={X.shape}, Y={Y.shape}, S={S.shape}")
    return X, Y, S


In [None]:
def train_loso(X, Y, S):
    unique_subjects = np.unique(S)
    loso = LeaveOneOut()
    accuracies = []

    for fold, (train_idx, test_idx) in enumerate(loso.split(unique_subjects)):
        test_subject = unique_subjects[test_idx][0]

        train_mask = (S != test_subject)
        test_mask = (S == test_subject)

        X_train, Y_train = X[train_mask], Y[train_mask]
        X_test, Y_test = X[test_mask], Y[test_mask]

        print(f"\n--- Fold {fold+1}/{len(unique_subjects)} | Test Subject {test_subject} ({len(X_test)} samples) ---")

        model = create_cnn_rnn_model(X.shape[1:])
        model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

        early = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)

        model.fit(
            X_train, Y_train,
            epochs=10,
            batch_size=8,
            validation_split=0.1,
            callbacks=[early],
            verbose=2
        )

        loss, acc = model.evaluate(X_test, Y_test, verbose=0)
        accuracies.append(acc)
        print(f"Fold Accuracy: {acc:.4f}")

        K.clear_session()
        del model
        gc.collect()

    print("\n========================================")
    print(f"FINAL MEAN ACCURACY: {np.mean(accuracies):.4f}")
    print("========================================")


In [None]:
all_epochs = load_and_preprocess_data()

if all_epochs:
    X, Y, S = convert_epochs_to_numpy(all_epochs)
    train_loso(X, Y, S)
else:
    print("No data found.")
