In [None]:
import numpy as np
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Function


In [None]:
meta_data = np.load("clinical_metadata.npy")
labels = np.load("labels.npy")

with open("patient_id_encoder (1).pkl", "rb") as f:
    patient_ids = np.array(pickle.load(f))
#Set file path
eeg_files = np.load("eeg_file_paths.npy", allow_pickle=True)

unique_pids = np.unique(patient_ids)


In [None]:
Load dataset

In [None]:
class EEGDataset(Dataset):
    def __init__(self, eeg_files, meta, labels, pids):
        self.eeg_files = eeg_files
        self.meta = meta
        self.labels = labels
        self.pids = pids

    def __len__(self):
        return len(self.eeg_files)

    def __getitem__(self, idx):
        eeg = np.load(self.eeg_files[idx])["eeg"]
        eeg = torch.tensor(eeg, dtype=torch.float32)

        meta = torch.tensor(self.meta[idx], dtype=torch.float32)
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        pid = torch.tensor(self.pids[idx], dtype=torch.long)

        return eeg, meta, label, pid


Gradient reversal for domain adaptation

In [None]:
class GradReverse(Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x

    @staticmethod
    def backward(ctx, grad_output):
        return -ctx.alpha * grad_output, None


CNN feature extractor

In [None]:
class EEGCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 32, (3,7), padding=(1,3)),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d((1,2)),

            nn.Conv2d(32, 64, (3,5), padding=(1,2)),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.AdaptiveAvgPool2d((1,1))
        )

    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.net(x)
        return x.squeeze(-1).squeeze(-1)


Attention Pooling

In [None]:
class AttentionPooling(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.attn = nn.Sequential(
            nn.Linear(dim, 128),
            nn.Tanh(),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        weights = torch.softmax(self.attn(x), dim=1)
        return (x * weights).sum(dim=1)


Compiled CNN-BiLSTM model with attention

In [None]:
class EpilepsyNet(nn.Module):
    def __init__(self, meta_dim, num_domains):
        super().__init__()

        self.cnn = EEGCNN()
        self.lstm = nn.LSTM(64, 128, bidirectional=True, batch_first=True)
        self.attn = AttentionPooling(256)

        self.meta_fc = nn.Sequential(
            nn.Linear(meta_dim, 32),
            nn.ReLU()
        )

        self.classifier = nn.Sequential(
            nn.Linear(256 + 32, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 2)
        )

        self.domain = nn.Sequential(
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, num_domains)
        )

    def forward(self, eeg, meta, alpha=0.0):
        B, W, C, T = eeg.shape

        x = eeg.view(B*W, C, T)
        x = self.cnn(x)
        x = x.view(B, W, -1)

        x, _ = self.lstm(x)
        eeg_embed = self.attn(x)

        meta_embed = self.meta_fc(meta)
        fused = torch.cat([eeg_embed, meta_embed], dim=1)

        cls_out = self.classifier(fused)
        dom_out = self.domain(GradReverse.apply(eeg_embed, alpha))

        return cls_out, dom_out


LOPO training

In [None]:
def run_lopo():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dataset = EEGDataset(eeg_files, meta_data, labels, patient_ids)

    results = []

    for test_pid in unique_pids:

        train_idx = np.where(patient_ids != test_pid)[0]
        test_idx = np.where(patient_ids == test_pid)[0]

        train_ds = torch.utils.data.Subset(dataset, train_idx)
        test_ds = torch.utils.data.Subset(dataset, test_idx)

        train_loader = DataLoader(train_ds, batch_size=4, shuffle=True)
        test_loader = DataLoader(test_ds, batch_size=1)

        model = EpilepsyNet(
            meta_dim=meta_data.shape[1],
            num_domains=len(unique_pids)
        ).to(device)

        optimizer = optim.Adam(model.parameters(), lr=1e-4)
        loss_fn = nn.CrossEntropyLoss()

        for epoch in range(40):
            model.train()
            alpha = min(1.0, epoch / 20)

            for eeg, meta, label, pid in train_loader:
                eeg, meta = eeg.to(device), meta.to(device)
                label, pid = label.to(device), pid.to(device)

                optimizer.zero_grad()
                cls, dom = model(eeg, meta, alpha)
                loss = loss_fn(cls, label) + 0.2 * loss_fn(dom, pid)
                loss.backward()
                optimizer.step()

        model.eval()
        probs = []

        with torch.no_grad():
            for eeg, meta, _, _ in test_loader:
                eeg, meta = eeg.to(device), meta.to(device)
                out, _ = model(eeg, meta)
                probs.append(torch.softmax(out, 1)[0,1].item())

        pred = int(np.mean(probs) > 0.5)
        true = labels[test_idx[0]]

        results.append(pred == true)

    print("Final LOPO Accuracy:", np.mean(results))


Run LOPO

In [None]:
run_lopo()
