<a href="https://colab.research.google.com/github/AhmedToto23/Person-Identification-using-Time-Frequency-CNN-Temporal-RNN-EEG-/blob/main/Cnn_Rnn_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [27]:
from google.colab import drive
drive.mount('/content/drive')


import os
import numpy as np
import pandas as pd
from glob import glob
from tqdm import tqdm


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [28]:
# Paths (adjust if needed)
SEGMENTS_DIR = "/content/drive/MyDrive/files/segments" # where .npy segments are stored
META_CSV = os.path.join(SEGMENTS_DIR, 'metadata.csv')
MODEL_SAVE = '/content/drive/MyDrive/files/best_cnn_rnn.pth'
os.makedirs(SEGMENTS_DIR, exist_ok=True)

In [29]:
# Build metadata if missing
if not os.path.exists(META_CSV):
    records = []
    for f in sorted(glob(os.path.join(SEGMENTS_DIR, '*.npy'))):
        fname = os.path.basename(f)
        # expected pattern: Sxxx_<orig>_segN.npy
        parts = fname.split('_')
        if len(parts) < 2:
            continue
        subj = parts[0]
        session = 1
        # try to detect session in filename
        for p in parts:
            if p.lower().startswith('session'):
                try:
                    session = int(p.replace('session',''))
                except:
                    session = 1
        # subject like S001 -> label 0
        label = int(subj.replace('S','')) - 1
        records.append({'path': f, 'label': label, 'session': session})
    meta = pd.DataFrame.from_records(records)
    meta.to_csv(META_CSV, index=False)
    print('Created metadata.csv with', len(meta), 'entries')
else:
    meta = pd.read_csv(META_CSV)
    print('Loaded metadata.csv with', len(meta), 'entries')

Loaded metadata.csv with 65737 entries


In [30]:
# Dataset: compute spectrogram per segment on-the-fly
class EEGSpectrogramDataset(Dataset):
    def __init__(self, meta_df, sample_rate=160, n_fft=256, hop=64, freq_range=(1,50), transform=None):
        self.meta = meta_df.reset_index(drop=True)
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.hop = hop
        self.freq_range = freq_range
        self.transform = transform

    def __len__(self):
        return len(self.meta)

    def __getitem__(self, idx):
        row = self.meta.iloc[idx]
        arr = np.load(row['path']).astype(np.float32)  # (C, T)
        x = torch.from_numpy(arr)
        # channel-wise STFT -> magnitude
        specs = []
        for ch in range(x.shape[0]):
            stft = torch.stft(x[ch], n_fft=self.n_fft, hop_length=self.hop, win_length=self.n_fft, center=True, return_complex=True)
            mag = stft.abs()
            specs.append(mag.unsqueeze(0))
        spec = torch.cat(specs, dim=0)  # (C, F, Tfr)
        freqs = torch.linspace(0, self.sample_rate / 2, steps=(self.n_fft // 2 + 1))
        fmin, fmax = self.freq_range
        fmask = (freqs >= fmin) & (freqs <= fmax)
        spec = spec[:, fmask, :]
        spec = torch.log1p(spec)
        label = int(row['label'])
        session = int(row['session']) if 'session' in row else 1
        if self.transform:
            spec = self.transform(spec)
        return spec, label, session

In [31]:
# Simple Hybrid Model
# CNN encoder produces timewise features -> GRU processes temporal sequence
class CNNEncoder(nn.Module):
    def __init__(self, in_ch, cnn_channels=(32,64,128)):
        super().__init__()
        layers = []
        prev = in_ch
        for outc in cnn_channels:
            layers += [
                nn.Conv2d(prev, outc, kernel_size=3, padding=1),
                nn.BatchNorm2d(outc),
                nn.ReLU(inplace=True),
                nn.MaxPool2d((2,2))
            ]
            prev = outc
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

class HybridCNNRNN(nn.Module):
    def __init__(self, in_ch, n_freq_bins, cnn_channels=(32,64,128), gru_hidden=256, gru_layers=2, n_classes=109, dropout=0.4):
        super().__init__()
        self.encoder = CNNEncoder(in_ch, cnn_channels=cnn_channels)
        # compute feature dim after cnn pooling for freq dimension
        # assume freq dimension is divisible by 2**len(cnn_channels)
        self.freq_reduction = 2 ** len(cnn_channels)
        feat_dim = cnn_channels[-1] * (n_freq_bins // self.freq_reduction)
        self.gru = nn.GRU(input_size=feat_dim, hidden_size=gru_hidden, num_layers=gru_layers, batch_first=True, bidirectional=True)
        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(gru_hidden*2, 512),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(512, n_classes)
        )

    def forward(self, x):
        # x: (B, C, F, T)
        out = self.encoder(x)  # (B, C', F', T')
        B, C1, Fp, Tp = out.shape
        out = out.permute(0, 3, 1, 2).contiguous()  # (B, Tp, C', F')
        out = out.view(B, Tp, C1 * Fp)  # (B, Tp, feat)
        gru_out, _ = self.gru(out)
        pooled = gru_out.mean(dim=1)
        logits = self.classifier(pooled)
        return logits

In [40]:
# Train / Eval helpers
from sklearn.metrics import accuracy_score, top_k_accuracy_score
import torch.nn.functional as F

def train_epoch(model, loader, optim, device):
    model.train()
    losses = []
    all_preds, all_labels = [], []
    for spec, label, _ in tqdm(loader, desc='train', leave=False):
        spec = spec.to(device)
        label = label.to(device)
        optim.zero_grad()
        out = model(spec)
        loss = F.cross_entropy(out, label)
        loss.backward()
        optim.step()
        losses.append(loss.item())
        preds = out.argmax(dim=1).cpu().numpy()
        all_preds.extend(preds.tolist())
        all_labels.extend(label.cpu().numpy().tolist())
    return np.mean(losses), accuracy_score(all_labels, all_preds)

@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    probs_list = []
    labels = []
    for spec, label, _ in tqdm(loader, desc='eval', leave=False):
        spec = spec.to(device)
        out = model(spec)
        probs = F.softmax(out, dim=1).cpu().numpy()
        probs_list.append(probs)
        labels.extend(label.numpy().tolist())
    probs = np.concatenate(probs_list, axis=0)
    preds = probs.argmax(axis=1)
    top1 = accuracy_score(labels, preds)
    try:
        top5 = top_k_accuracy_score(labels, probs, k=5)
    except Exception:
        top5 = None
    return top1, top5


In [39]:
# Prepare splits, dataloaders
from sklearn.model_selection import train_test_split

# Perform a random train-test split on the entire metadata
train_meta, test_meta = train_test_split(meta, test_size=0.2, random_state=42, stratify=meta['label'])

train_meta = train_meta.reset_index(drop=True)
test_meta = test_meta.reset_index(drop=True)

print(f'Train samples: {len(train_meta)}')
print(f'Test samples: {len(test_meta)}')

train_ds = EEGSpectrogramDataset(train_meta)
test_ds = EEGSpectrogramDataset(test_meta)

# quick dimension probe
spec0, _, _ = train_ds[0]
C, n_freq_bins_val, T = spec0.shape # Renamed F to n_freq_bins_val
print('Spec shape per sample:', spec0.shape)

train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=32, shuffle=False, num_workers=2, pin_memory=True)


Train samples: 52589
Test samples: 13148
Spec shape per sample: torch.Size([64, 79, 6])


In [41]:
# Instantiate model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = HybridCNNRNN(in_ch=C, n_freq_bins=n_freq_bins_val, n_classes=109, cnn_channels=(32,64)).to(device) # Changed F to n_freq_bins_val
optim = AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=8, gamma=0.5)


In [None]:
# Training loop
best = 0.0
epochs = 30
for ep in range(1, epochs+1):
    loss, train_acc = train_epoch(model, train_loader, optim, device)
    top1, top5 = evaluate(model, test_loader, device)
    scheduler.step()
    print(f'Epoch {ep:02d} loss={loss:.4f} train_acc={train_acc:.3f} test_top1={top1:.3f} test_top5={top5}')
    if top1 > best:
        best = top1
        torch.save(model.state_dict(), MODEL_SAVE)
        print('Saved best model ->', MODEL_SAVE)

print('Training complete. Best test top1:', best)




Epoch 01 loss=1.5594 train_acc=0.550 test_top1=0.207 test_top5=None
Saved best model -> /content/drive/MyDrive/files/best_cnn_rnn.pth




Epoch 02 loss=0.3921 train_acc=0.885 test_top1=0.209 test_top5=None
Saved best model -> /content/drive/MyDrive/files/best_cnn_rnn.pth




Epoch 03 loss=0.2328 train_acc=0.934 test_top1=0.477 test_top5=None
Saved best model -> /content/drive/MyDrive/files/best_cnn_rnn.pth




Epoch 04 loss=0.1758 train_acc=0.950 test_top1=0.651 test_top5=None
Saved best model -> /content/drive/MyDrive/files/best_cnn_rnn.pth




Epoch 05 loss=0.1389 train_acc=0.960 test_top1=0.323 test_top5=None




Epoch 06 loss=0.1260 train_acc=0.965 test_top1=0.759 test_top5=None
Saved best model -> /content/drive/MyDrive/files/best_cnn_rnn.pth




Epoch 07 loss=0.1147 train_acc=0.969 test_top1=0.436 test_top5=None




Epoch 08 loss=0.0958 train_acc=0.974 test_top1=0.669 test_top5=None




Epoch 09 loss=0.0433 train_acc=0.988 test_top1=0.985 test_top5=None
Saved best model -> /content/drive/MyDrive/files/best_cnn_rnn.pth




Epoch 10 loss=0.0371 train_acc=0.990 test_top1=0.865 test_top5=None




Epoch 11 loss=0.0357 train_acc=0.991 test_top1=0.978 test_top5=None




Epoch 12 loss=0.0273 train_acc=0.992 test_top1=0.905 test_top5=None


train:  55%|█████▍    | 1793/3287 [09:01<06:45,  3.69it/s]

In [None]:
# Compute per-subject accuracy and save a simple report
@torch.no_grad()
def per_subject_accuracy(model, loader, device, meta_df):
    model.eval()
    preds = []
    labels = []
    paths = []
    for spec, label, _ in tqdm(loader, desc='per-subject', leave=False):
        spec = spec.to(device)
        out = model(spec)
        p = out.argmax(dim=1).cpu().numpy()
        preds.extend(p.tolist())
        labels.extend(label.numpy().tolist())
    df = pd.DataFrame({'label': labels, 'pred': preds})
    accs = df.groupby('label').apply(lambda d: (d['label']==d['pred']).mean())
    accs = accs.sort_values()
    accs.to_csv(os.path.join(os.path.dirname(MODEL_SAVE), 'per_subject_acc.csv'))
    print('Saved per-subject accuracies.')

In [None]:
# load best model and run per-subject report
model.load_state_dict(torch.load(MODEL_SAVE, map_location=device))
per_subject_accuracy(model, test_loader, device, test_meta)
print('Report saved near model file.')