EEG Seizure Detection Pipeline (CHB-MIT or TUH) - PyTorch implementation
Single-file end-to-end example that:
- Loads EDF EEG files (CHB-MIT assumed) using mne
- Windowing & label extraction (supports annotations if available)
- Computes spectrograms (mel-spectrograms with librosa)
- Trains a convolutional autoencoder to learn compact representations
- Uses the encoder + a CNN classifier head to detect seizures


Requirements:
- Python 3.8+
- pip install mne librosa numpy scipy matplotlib torch torchvision tqdm sklearn


Notes:
- Place CHB-MIT EDF files in data/chb-mit/ (subfolders allowed). If you have TUH, adapt reading.
- If seizure annotations are present as MNE annotations they will be used. Otherwise you can provide a CSV
with columns: filename,start_sec,end_sec (relative to file start) and the loader will use that.
- This script is intended as a readable, extendable starting point, not a production-ready training suite.

In [None]:
import os
import glob
import argparse
from typing import List, Tuple, Dict


import numpy as np
import mne
import librosa
from scipy import signal
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
from tqdm import tqdm
import matplotlib.pyplot as plt


# ----------------------------- Configuration & Utilities -----------------------------


DEFAULT_SAMPLE_RATE = 256 # target resampling rate (Hz)
WINDOW_SEC = 10
WINDOW_STEP_SEC = 5
N_MELS = 64
HOP_LENGTH = 128


DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')




def find_edf_files(data_dir: str) -> List[str]:
patterns = [os.path.join(data_dir, '**', '*.edf'), os.path.join(data_dir, '**', '*.EDF')]
files = []
for p in patterns:
files.extend(glob.glob(p, recursive=True))
files = sorted(list(set(files)))
return files




def load_annotations_csv(csv_path: str) -> Dict[str, List[Tuple[float,float]]]:
"""CSV expected: filename,start_sec,end_sec (relative seconds). Returns dict keyed by basename."""
ann = {}
if not os.path.exists(csv_path):
return ann
with open(csv_path, 'r') as f:
for line in f:
line = line.strip()
if not line or line.startswith('#'):
continue
parts = line.split(',')
if len(parts) < 3:
continue
fn = os.path.basename(parts[0])
s = float(parts[1]); e = float(parts[2])
ann.setdefault(fn, []).append((s,e))
return ann



In [None]:
# ----------------------------- EDF reading, windowing, spectrograms -----------------------------




def read_raw_edf(path: str, picks: List[str] = None, resample_sfreq: int = DEFAULT_SAMPLE_RATE):
raw = mne.io.read_raw_edf(path, preload=True, verbose=False)
if picks is not None:
picks_present = [ch for ch in picks if ch in raw.ch_names]
if picks_present:
raw.pick_channels(picks_present)
if resample_sfreq is not None and raw.info['sfreq'] != resample_sfreq:
raw.resample(resample_sfreq)
return raw




def windows_from_raw(raw: mne.io.BaseRaw, window_sec=WINDOW_SEC, step_sec=WINDOW_STEP_SEC):
sf = int(raw.info['sfreq'])
window_samples = int(window_sec * sf)
step_samples = int(step_sec * sf)
data = raw.get_data()
n_channels, n_samples = data.shape
windows = []
start = 0
while start + window_samples <= n_samples:
segment = data[:, start:start+window_samples]
t0 = start / sf
t1 = (start + window_samples) / sf
windows.append((segment, t0, t1))
start += step_samples
return windows




def label_window_from_annotations(t0: float, t1: float, annotations: List[Tuple[float,float]]) -> int:
# label 1 if any overlap with annotation intervals
for (s,e) in annotations:
if s < t1 and e > t0:
return 1
return 0




def compute_mel_spectrogram(segment: np.ndarray, sr: int = DEFAULT_SAMPLE_RATE, n_mels=N_MELS, hop_length=HOP_LENGTH):
# segment shape: (n_channels, n_samples). We'll compute mel for each channel and stack as channels.
specs = []
for ch in segment:
# librosa expects float32
ch = ch.astype(np.float32)
S = librosa.feature.melspectrogram(y=ch, sr=sr, n_mels=n_mels, hop_length=hop_length)
S_db = librosa.power_to_db(S, ref=np.max)
specs.append(S_db)
# stack -> (n_channels, n_mels, time_frames)
spec = np.stack(specs, axis=0)
return spec



In [None]:
# ----------------------------- Dataset -----------------------------


class EEGSpectrogramDataset(Dataset):
def __init__(self, records: List[Tuple[str, int, int, np.ndarray]]):
"""records: list of tuples (basename, label, idx, spec_array)
spec_array: np.ndarray (channels, n_mels, t)
"""
self.records = records


def __len__(self):
return len(self.records)


def __getitem__(self, idx):
fn, label, widx, spec = self.records[idx]
# normalize per-example
spec = (spec - spec.mean()) / (spec.std() + 1e-8)
# convert to float32
spec = spec.astype(np.float32)
# if single channel, expand
# Torch conv2d expects (C,H,W) so our spec already in (channels, n_mels, t)
x = torch.from_numpy(spec)
y = torch.tensor(label, dtype=torch.long)
return x, y



In [None]:
# ----------------------------- Models -----------------------------


class ConvAutoencoder(nn.Module):
def __init__(self, in_ch=1, embedding_dim=128):
super().__init__()
# Encoder
self.encoder = nn.Sequential(
nn.Conv2d(in_ch, 16, kernel_size=3, stride=2, padding=1),
nn.ReLU(True),
nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
nn.ReLU(True),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.ReLU(True),
nn.Flatten(),
)
# dummy forward to compute conv output size
self._embedding_dim = embedding_dim
self.fc_enc = nn.Linear(64 * 8 * 8, embedding_dim) # assumes input dims (in_ch, 64, 64) roughly


# Decoder
self.fc_dec = nn.Linear(embedding_dim, 64 * 8 * 8)
self.decoder = nn.Sequential(
nn.Unflatten(1, (64, 8, 8)),
nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
nn.ReLU(True),
nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),
nn.ReLU(True),
nn.ConvTranspose2d(16, in_ch, kernel_size=4, stride=2, padding=1),
)


def forward(self, x):
z = self.encoder(x)
z = self.fc_enc(z)
out = self.fc_dec(z)
out = self.decoder(out)
return out, z




class ClassifierHead(nn.Module):
def __init__(self, encoder: nn.Module, embedding_dim=128, n_classes=2):
super().__init__()
self.encoder = encoder
# freeze encoder layers if desired outside
self.fc = nn.Sequential(
nn.Linear(embedding_dim, 64),
nn.ReLU(True),
nn.Dropout(0.4),
nn.Linear(64, n_classes)
)


def forward(self, x):
# x shape: (B, C, H, W)
z = self.encoder(x)
z = self.fc(z)
return z



In [None]:
# ----------------------------- Training loops -----------------------------
z = self.fc_enc(z)
return z


enc_wrap = EncoderWrapper(encoder_model).to(DEVICE)
clf = ClassifierHead(enc_wrap, embedding_dim=encoder_model._embedding_dim, n_classes=2).to(DEVICE)


# optionally freeze encoder
for p in enc_wrap.encoder.parameters():
p.requires_grad = False
for p in enc_wrap.fc_enc.parameters():
p.requires_grad = True


criterion = nn.CrossEntropyLoss()
opt = optim.Adam(filter(lambda p: p.requires_grad, clf.parameters()), lr=lr)


for epoch in range(epochs):
clf.train()
pbar = tqdm(train_loader, desc=f'CLF Epoch {epoch+1}/{epochs}')
for x,y in pbar:
x = x.to(DEVICE)
x = nn.functional.interpolate(x, size=(64,64), mode='bilinear', align_corners=False)
y = y.to(DEVICE)
logits = clf(x)
loss = criterion(logits, y)
opt.zero_grad()
loss.backward()
opt.step()
pbar.set_postfix(loss=loss.item())


# validation
clf.eval()
ys, ypreds, yprob = [], [], []
with torch.no_grad():
for x,y in val_loader:
x = x.to(DEVICE)
x = nn.functional.interpolate(x, size=(64,64), mode='bilinear', align_corners=False)
y = y.to(DEVICE)
logits = clf(x)
probs = torch.softmax(logits, dim=1)[:,1].cpu().numpy()
preds = logits.argmax(dim=1).cpu().numpy()
ys.extend(y.cpu().numpy().tolist())
ypreds.extend(preds.tolist())
yprob.extend(probs.tolist())
try:
auc = roc_auc_score(ys, yprob)
except Exception:
auc = None
print('\nValidation AUC:', auc)
print(classification_report(ys, ypreds))
if out_path:
torch.save(clf.state_dict(), out_path)
return clf



In [None]:
# ----------------------------- Full pipeline orchestration -----------------------------




def build_records_from_edf(files: List[str], ann_csv: str = None, picks: List[str] = None, max_files=None):
ann_map = load_annotations_csv(ann_csv) if ann_csv else {}
records = []
for i, path in enumerate(files):
if max_files and i >= max_files:
break
basename = os.path.basename(path)
print('Reading', basename)
try:
raw = read_raw_edf(path, picks=picks, resample_sfreq=DEFAULT_SAMPLE_RATE)
except Exception as e:
print('Failed to read', path, e)
continue
windows = windows_from_raw(raw)
anns = ann_map.get(basename, [])
# if MNE annotations available, convert
if hasattr(raw, 'annotations') and raw.annotations is not None and len(raw.annotations) > 0:
anns = [(ann['onset'], ann['onset']+ann['duration']) for ann in raw.annotations]


for widx, (segment, t0, t1) in enumerate(windows):
label = label_window_from_annotations(t0, t1, anns) if anns else 0
spec = compute_mel_spectrogram(segment, sr=DEFAULT_SAMPLE_RATE)
# optionally reduce channels to 1 by averaging
# here we keep all channels as separate channels
records.append((basename, label, widx, spec))
return records



In [None]:
# ----------------------------- Example entrypoint -----------------------------




def main(args):
files = find_edf_files(args.data_dir)
if len(files) == 0:
raise RuntimeError('No EDF files found in ' + args.data_dir)
print(f'Found {len(files)} EDF files')


records = build_records_from_edf(files, ann_csv=args.ann_csv, picks=None, max_files=args.max_files)
print(f'Built {len(records)} windows')


# Split dataset
train_recs, test_recs = train_test_split(records, test_size=0.2, random_state=42, stratify=[r[1] for r in records])
train_recs, val_recs = train_test_split(train_recs, test_size=0.2, random_state=42, stratify=[r[1] for r in train_recs])


train_ds = EEGSpectrogramDataset(train_recs)
val_ds = EEGSpectrogramDataset(val_recs)
test_ds = EEGSpectrogramDataset(test_recs)


train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, num_workers=2)


# Autoencoder pretraining on entire training set
in_ch = train_recs[0][3].shape[0]
ae = ConvAutoencoder(in_ch=in_ch, embedding_dim=128)
print('Training autoencoder...')
ae = train_autoencoder(ae, train_loader, epochs=args.ae_epochs, lr=args.ae_lr, out_path=args.ae_out)


# Classifier training
print('Training classifier...')
clf = train_classifier(ae, train_loader, val_loader, epochs=args.clf_epochs, lr=args.clf_lr, out_path=args.clf_out)


# Evaluate on test set
clf.eval()
ys, ypreds, yprob = [], [], []
with torch.no_grad():
for x,y in test_loader:
x = x.to(DEVICE)
x = nn.functional.interpolate(x, size=(64,64), mode='bilinear', align_corners=False)
logits = clf(x)
probs = torch.softmax(logits, dim=1)[:,1].cpu().numpy()
preds = logits.argmax(dim=1).cpu().numpy()
ys.extend(y.numpy().tolist())
ypreds.extend(preds.tolist())
yprob.extend(probs.tolist())
print('\nTest results:')
try:
auc = roc_auc_score(ys, yprob)
except Exception:
auc = None
print('AUC:', auc)
print(classification_report(ys, ypreds))




if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default='data/chb-mit')
parser.add_argument('--ann_csv', type=str, default='')
parser.add_argument('--out_dir', type=str, default='runs')
parser.add_argument('--max_files', type=int, default=None)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--ae_epochs', type=int, default=10)
parser.add_argument('--clf_epochs', type=int, default=20)
parser.add_argument('--ae_lr', type=float, default=1e-3)
parser.add_argument('--clf_lr', type=float, default=1e-4)
parser.add_argument('--ae_out', type=str, default='ae.pth')
parser.add_argument('--clf_out', type=str, default='clf.pth')
args = parser.parse_args()
os.makedirs(args.out_dir, exist_ok=True)
args.ae_out = os.path.join(args.out_dir, args.ae_out)
args.clf_out = os.path.join(args.out_dir, args.clf_out)
main(args)