# EEG Subject Identification with Channel-Split CNN-RNN

This notebook trains a CNN-RNN model to identify the subject from EEG epochs.  
It uses **50 channels for training** and evaluates on **14 unseen channels** to test channel generalization.


In [None]:
import os
import mne
import numpy as np
import gc
import random
import tensorflow as tf
from sklearn.model_selection import LeaveOneOut
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, Conv1D, MaxPooling1D, TimeDistributed,
    LSTM, Dense, Dropout, BatchNormalization, Reshape
)
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras import backend as K

# --- CONFIGURATION ---
DATA_ROOT = r'/content/drive/MyDrive/EEG' 
N_SUBJECTS = 109
RUN_ID = '04'
N_CHANNELS = 64
SFREQ = 160.0
DOWNSAMPLE_FREQ = 80.0 
EPOCH_DURATION_SEC = 4.0
N_TIMESTEPS = int(EPOCH_DURATION_SEC * DOWNSAMPLE_FREQ) # 320

# --- CLASSIFICATION GOAL ---
N_CLASSES = N_SUBJECTS  # Subject ID classification
N_CHANNELS_TRAIN = 50
N_CHANNELS_TEST = 14
CHANNEL_DROPOUT_RATE = 0.2 
EVENT_ID = {'T1': 1, 'T2': 2}

# Reproducibility
random.seed(42)
np.random.seed(42)


## CNN-RNN Model with Channel Dropout

- **Input:** (N_CHANNELS_TRAIN, N_TIMESTEPS, 1)  
- **Channel Dropout:** Improves robustness to missing or noisy channels  
- **TimeDistributed Conv1D + MaxPooling:** Extract temporal features for each channel  
- **LSTM:** Capture temporal dependencies  
- **Dense → Softmax:** Predicts the subject ID among 109 classes


In [None]:
def create_cnn_rnn_model(input_shape):
    """CNN-RNN model with Channel Dropout for EEG subject ID classification."""
    input_layer = Input(shape=input_shape, name='eeg_input')
    
    x = Dropout(CHANNEL_DROPOUT_RATE, noise_shape=(None, input_shape[0], 1, 1), 
                name='channel_dropout')(input_layer)

    x = TimeDistributed(Conv1D(16, 3, activation='relu', padding='same'), name='td_conv_spatial')(x)
    x = TimeDistributed(BatchNormalization(), name='td_bn_1')(x)
    x = TimeDistributed(MaxPooling1D(2), name='td_maxpool_temporal')(x)

    n_features_combined = input_shape[0] * 16
    x = Reshape((N_TIMESTEPS // 2, n_features_combined), name='reshape_for_rnn')(x)

    x = LSTM(32, return_sequences=True, name='lstm_1')(x)
    x = Dropout(0.5, name='dropout_1')(x)
    x = LSTM(16, return_sequences=False, name='lstm_2')(x)

    x = Dense(16, activation='relu', name='dense_1')(x)
    output_layer = Dense(N_CLASSES, activation='softmax', name='output')(x)

    model = Model(inputs=input_layer, outputs=output_layer)
    return model


## Load and Preprocess EEG Data

- Load each subject's EDF file using **MNE**
- Apply bandpass filter (8–30 Hz)
- Downsample to 80 Hz
- Epoch into 4-second windows
- Handle missing or empty files gracefully


In [None]:
def load_and_preprocess_data():
    all_epochs = []
    print("Loading and preprocessing data...")

    for sub_idx in range(1, N_SUBJECTS + 1):
        sub_str = f'S{sub_idx:03d}'
        file_name = f'{sub_str}R{RUN_ID}.edf'
        file_path = os.path.join(DATA_ROOT, sub_str, file_name)

        try:
            raw = mne.io.read_raw_edf(file_path, preload=True, verbose=False)
            raw.filter(8., 30., fir_design='firwin', skip_by_annotation='edge', verbose=False)
            raw.resample(DOWNSAMPLE_FREQ, npad="auto")

            events, _ = mne.events_from_annotations(raw, event_id=EVENT_ID, verbose=False)
            epochs = mne.Epochs(raw, events, EVENT_ID, tmin=0., tmax=EPOCH_DURATION_SEC,
                                baseline=None, preload=True, verbose=False)
            
            del raw
            gc.collect()

            if len(epochs) == 0:
                print(f"Warning: No epochs for subject {sub_idx}. Skipping.")
                continue
            
            all_epochs.append(epochs)
            
        except FileNotFoundError:
            print(f"File not found: {file_path}, skipping subject {sub_idx}.")
        except Exception as e:
            print(f"Error {sub_str}: {e}, skipping.")

    print(f"\nEpoching complete. Loaded {len(all_epochs)} subjects.")
    return all_epochs


## Convert Epochs to NumPy Arrays

- Concatenate all epochs across subjects
- Normalize globally (Z-score)
- Labels correspond to **subject ID**


In [None]:
def load_and_preprocess_data():
    all_epochs = []
    print("Loading and preprocessing data...")

    for sub_idx in range(1, N_SUBJECTS + 1):
        sub_str = f'S{sub_idx:03d}'
        file_name = f'{sub_str}R{RUN_ID}.edf'
        file_path = os.path.join(DATA_ROOT, sub_str, file_name)

        try:
            raw = mne.io.read_raw_edf(file_path, preload=True, verbose=False)
            raw.filter(8., 30., fir_design='firwin', skip_by_annotation='edge', verbose=False)
            raw.resample(DOWNSAMPLE_FREQ, npad="auto")

            events, _ = mne.events_from_annotations(raw, event_id=EVENT_ID, verbose=False)
            epochs = mne.Epochs(raw, events, EVENT_ID, tmin=0., tmax=EPOCH_DURATION_SEC,
                                baseline=None, preload=True, verbose=False)
            
            del raw
            gc.collect()

            if len(epochs) == 0:
                print(f"Warning: No epochs for subject {sub_idx}. Skipping.")
                continue
            
            all_epochs.append(epochs)
            
        except FileNotFoundError:
            print(f"File not found: {file_path}, skipping subject {sub_idx}.")
        except Exception as e:
            print(f"Error {sub_str}: {e}, skipping.")

    print(f"\nEpoching complete. Loaded {len(all_epochs)} subjects.")
    return all_epochs


## Channel-Split Training and Evaluation

- Randomly select **50 channels for training**  
- Reserve **14 channels for testing**
- Test data is **padded to match model input**  
- Evaluate **subject identification accuracy** on unseen channels


In [None]:
def train_channel_split_subject_id(X, Y):
    N_TOTAL_CHANNELS = X.shape[1]
    
    if N_TOTAL_CHANNELS != N_CHANNELS:
        print(f"Error: Expected {N_CHANNELS} channels, found {N_TOTAL_CHANNELS}. Aborting.")
        return 0.0
    
    channel_indices = np.arange(N_TOTAL_CHANNELS)
    np.random.shuffle(channel_indices)
    
    train_channels_idx = channel_indices[:N_CHANNELS_TRAIN]
    test_channels_idx = channel_indices[N_CHANNELS_TRAIN:]
    
    X_train_data = X[:, train_channels_idx, :, :]
    X_test_data = X[:, test_channels_idx, :, :]
    
    input_shape = X_train_data.shape[1:] 
    model = create_cnn_rnn_model(input_shape)
    
    print(f"\n--- Training Model on {N_CHANNELS_TRAIN} Channels ---")
    model.compile(optimizer='adam', 
                  loss='sparse_categorical_crossentropy', 
                  metrics=['accuracy'])
    
    early_stop = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)

    history = model.fit(
        X_train_data, Y,
        epochs=10,
        batch_size=8,
        validation_split=0.1,
        callbacks=[early_stop],
        verbose=2
    )
    
    N_EPOCHS = X.shape[0]
    X_test_padded = np.zeros(X_train_data.shape, dtype=X.dtype)
    X_test_padded[:, :N_CHANNELS_TEST, :, :] = X_test_data

    print(f"\n--- Evaluating Model on {N_CHANNELS_TEST} Unseen Channels ---")
    loss_test, acc_test = model.evaluate(X_test_padded, Y, verbose=0)
    
    print(f"Classification Accuracy (14 Unseen Channels): {acc_test:.4f}")
    
    K.clear_session()
    del model
    gc.collect()

    return acc_test


## Main Execution

1. Load EEG data  
2. Convert to NumPy arrays  
3. Train model on 50 channels  
4. Evaluate on 14 unseen channels


In [None]:
# --- MAIN ---
all_epochs = load_and_preprocess_data()

if all_epochs:
    X, Y, S = convert_epochs_to_numpy(all_epochs)
    
    final_acc = train_channel_split_subject_id(X, Y)
    
    print("\n" + "="*50)
    print(f"FINAL SUBJECT ID ACCURACY (Train {N_CHANNELS_TRAIN} / Test {N_CHANNELS_TEST}): {final_acc:.4f}")
    print("="*50)
else:
    print("No data available for training.")
