In [29]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [59]:
%cd /content/drive/MyDrive/EEG/Dataset/
%ls

/content/drive/MyDrive/EEG/Dataset
[0m[01;34mA[0m/  [01;34mB[0m/  [01;34mC[0m/  [01;34mD[0m/  [01;34mE[0m/


In [40]:
!pip install captum

Collecting captum
  Downloading captum-0.8.0-py3-none-any.whl.metadata (26 kB)
Collecting numpy<2.0 (from captum)
  Downloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
Downloading captum-0.8.0-py3-none-any.whl (1.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m57.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.0/18.0 MB[0m [31m125.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy, captum
  Attempting uninstall: numpy
    Found existing installation: numpy 2.0.2
    Uninstalling numpy-2.0.2:
      Successfully uninstalled numpy-2.0.2
[31mERROR: pip's dependency resolver does not currently take into account a

In [83]:
import os
import glob
import random
from dataclasses import dataclass
from typing import List, Tuple, Dict

import numpy as np
import matplotlib.pyplot as plt

from scipy.signal import butter, filtfilt

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, confusion_matrix, silhouette_score
from sklearn.manifold import TSNE
from sklearn.linear_model import LogisticRegression

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from captum.attr import IntegratedGradients

print("Imports OK")

Imports OK


In [84]:
c_dir = "/content/drive/MyDrive/EEG/Dataset/C"

files = glob.glob(os.path.join(c_dir, "*.TXT"))

print("Files to rename:", len(files))

for f in files:
    new_name = f[:-4] + ".txt"   # replace .TXT with .txt
    os.rename(f, new_name)

print("Renaming completed.")

Files to rename: 0
Renaming completed.


In [85]:
class Config:
    data_root: str = "/content/drive/MyDrive/EEG/Dataset"
    fs: float = 173.61
    segment_len: int = 4096

    # Preprocessing
    bandpass: Tuple[float, float] = (0.5, 40.0)
    filter_order: int = 4
    standardize_per_segment: bool = True

    # Windowing (keep as full segment by default)
    window_len: int = 4096
    window_stride: int = 4096

    # Training
    seed: int = 42
    batch_size: int = 32
    lr: float = 1e-3
    weight_decay: float = 1e-4
    epochs: int = 30

    # Splits
    test_size: float = 0.2
    val_size: float = 0.2

    # Best model selection metric on test: "f1" | "auc" | "accuracy"
    selection_metric: str = "f1"

    # Base output directory (task subfolders will be created inside)
    out_base: str = "/content/drive/MyDrive/EEG/Outputs_Bonn_Project_Combined"

cfg = Config()
cfg


<__main__.Config at 0x7f82c03d34d0>

In [86]:
for s in ["A","B","C","D","E"]:
    folder = os.path.join(cfg.data_root, s)
    n_txt = len(glob.glob(os.path.join(folder, "*.txt")))
    n_TXT = len(glob.glob(os.path.join(folder, "*.TXT")))
    print(s, "exists:", os.path.isdir(folder), "| .txt:", n_txt, "| .TXT:", n_TXT)


A exists: True | .txt: 100 | .TXT: 0
B exists: True | .txt: 100 | .TXT: 0
C exists: True | .txt: 100 | .TXT: 0
D exists: True | .txt: 100 | .TXT: 0
E exists: True | .txt: 100 | .TXT: 0


In [87]:
def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def ensure_dir(path: str) -> None:
    os.makedirs(path, exist_ok=True)

def compute_metrics(y_true: np.ndarray, y_prob: np.ndarray, threshold: float = 0.5) -> Dict:
    y_pred = (y_prob >= threshold).astype(int)
    acc = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    auc = roc_auc_score(y_true, y_prob) if len(np.unique(y_true)) == 2 else float("nan")
    cm = confusion_matrix(y_true, y_pred)
    return {"accuracy": acc, "f1": f1, "auc": auc, "confusion_matrix": cm}

set_seed(cfg.seed)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

ensure_dir(cfg.out_base)
print("Base output:", cfg.out_base)

Device: cuda
Base output: /content/drive/MyDrive/EEG/Outputs_Bonn_Project_Combined


In [88]:
def bandpass_filter(x: np.ndarray, fs: float, low: float, high: float, order: int = 4) -> np.ndarray:
    nyq = 0.5 * fs
    b, a = butter(order, [low / nyq, high / nyq], btype="band")
    return filtfilt(b, a, x).astype(np.float32)

def zscore(x: np.ndarray, eps: float = 1e-8) -> np.ndarray:
    return ((x - x.mean()) / (x.std() + eps)).astype(np.float32)

def make_windows(x: np.ndarray, window_len: int, stride: int) -> np.ndarray:
    if window_len == len(x) and stride == len(x):
        return x[None, :]
    windows = []
    for start in range(0, len(x) - window_len + 1, stride):
        windows.append(x[start:start + window_len])
    return np.stack(windows, axis=0).astype(np.float32)

print("Preprocessing OK")

Preprocessing OK


In [89]:
def gather_files(data_root: str, sets: List[str]) -> List[str]:
    paths = []
    for s in sets:
        paths.extend(sorted(glob.glob(os.path.join(data_root, s, "*.txt"))))
        paths.extend(sorted(glob.glob(os.path.join(data_root, s, "*.TXT"))))
    return paths

def load_txt_signal(path: str) -> np.ndarray:
    return np.loadtxt(path).astype(np.float32)

class BonnEEGDataset(Dataset):
    def __init__(self, cfg: Config, file_list: List[Tuple[str, int]]):
        self.cfg = cfg
        self.files = file_list
        self._cache = {}
        self.index = []

        for i, (path, _) in enumerate(self.files):
            seg = self._load_and_preprocess(path)
            wins = make_windows(seg, cfg.window_len, cfg.window_stride)
            for w_idx in range(len(wins)):
                self.index.append((i, w_idx))

    def _load_and_preprocess(self, path: str) -> np.ndarray:
        if path in self._cache:
            return self._cache[path]

        x = load_txt_signal(path)

        low, high = self.cfg.bandpass
        x = bandpass_filter(x, self.cfg.fs, low, high, self.cfg.filter_order)

        if self.cfg.standardize_per_segment:
            x = zscore(x)

        self._cache[path] = x
        return x

    def __len__(self):
        return len(self.index)

    def __getitem__(self, idx: int):
        file_idx, w_idx = self.index[idx]
        path, label = self.files[file_idx]

        seg = self._load_and_preprocess(path)
        wins = make_windows(seg, self.cfg.window_len, self.cfg.window_stride)
        w = wins[w_idx]

        x_t = torch.tensor(w, dtype=torch.float32).unsqueeze(0)  # (1,L)
        y_t = torch.tensor(label, dtype=torch.long)
        return x_t, y_t, path

print("Dataset OK")


Dataset OK


In [90]:
class CNN1D(nn.Module):
    def __init__(self, in_ch=1, num_classes=2, feat_dim=128):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Conv1d(in_ch, 32, 7, padding=3),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.MaxPool1d(2),

            nn.Conv1d(32, 64, 5, padding=2),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(2),

            nn.Conv1d(64, 128, 3, padding=1),
            nn.BatchNorm1d(128),
            nn.ReLU(),
        )
        self.gap = nn.AdaptiveAvgPool1d(1)
        self.fc_feat = nn.Linear(128, feat_dim)
        self.fc_out = nn.Linear(feat_dim, num_classes)

    def extract_features(self, x):
        h = self.backbone(x)
        h = self.gap(h).squeeze(-1)
        f = self.fc_feat(h)
        return f

    def forward(self, x):
        f = self.extract_features(x)
        return self.fc_out(F.relu(f))


class CNN_BiLSTM(nn.Module):
    def __init__(self, in_ch=1, num_classes=2, feat_dim=128, lstm_hidden=64):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv1d(in_ch, 32, 7, padding=3),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.MaxPool1d(4),

            nn.Conv1d(32, 64, 5, padding=2),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(4),
        )
        self.lstm = nn.LSTM(
            input_size=64,
            hidden_size=lstm_hidden,
            num_layers=1,
            batch_first=True,
            bidirectional=True
        )
        self.fc_feat = nn.Linear(2 * lstm_hidden, feat_dim)
        self.fc_out = nn.Linear(feat_dim, num_classes)

    def extract_features(self, x):
        h = self.cnn(x)
        h = h.transpose(1, 2)
        out, _ = self.lstm(h)
        last = out[:, -1, :]
        f = self.fc_feat(last)
        return f

    def forward(self, x):
        f = self.extract_features(x)
        return self.fc_out(F.relu(f))


class Transformer1D(nn.Module):
    def __init__(self, in_ch=1, num_classes=2, feat_dim=128, d_model=128, nhead=4, num_layers=2):
        super().__init__()
        self.patch = nn.Conv1d(in_ch, d_model, kernel_size=16, stride=8, padding=8)
        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=4 * d_model,
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)
        self.fc_feat = nn.Linear(d_model, feat_dim)
        self.fc_out = nn.Linear(feat_dim, num_classes)

    def extract_features(self, x):
        h = self.patch(x)
        h = h.transpose(1, 2)
        z = self.encoder(h)
        pooled = z.mean(dim=1)
        f = self.fc_feat(pooled)
        return f

    def forward(self, x):
        f = self.extract_features(x)
        return self.fc_out(F.relu(f))

print("Models OK")

Models OK


In [91]:
@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader, device: str) -> Dict:
    model.eval()
    y_true, y_prob = [], []
    for x, y, _ in loader:
        x = x.to(device)
        logits = model(x)
        prob = torch.softmax(logits, dim=1)[:, 1].detach().cpu().numpy()
        y_true.append(y.numpy())
        y_prob.append(prob)
    y_true = np.concatenate(y_true)
    y_prob = np.concatenate(y_prob)
    return compute_metrics(y_true, y_prob)

def train_one_model(model, train_loader, val_loader, device, cfg, class_weights):
    model.to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    loss_fn = nn.CrossEntropyLoss(weight=class_weights.to(device))

    best_val = -1.0
    best_state = None

    for epoch in range(1, cfg.epochs + 1):
        model.train()
        for x, y, _ in train_loader:
            x, y = x.to(device), y.to(device)
            opt.zero_grad()
            logits = model(x)
            loss = loss_fn(logits, y)
            loss.backward()
            opt.step()

        val_metrics = evaluate(model, val_loader, device)
        score = val_metrics[cfg.selection_metric]
        print(f"Epoch {epoch:02d} | val_{cfg.selection_metric}={score:.4f} | val_acc={val_metrics['accuracy']:.4f} | val_auc={val_metrics['auc']:.4f}")

        if score > best_val:
            best_val = score
            best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}

    model.load_state_dict(best_state)
    return model, best_val


# ---------- Feature extraction ----------
@torch.no_grad()
def extract_embeddings(model: nn.Module, loader: DataLoader, device: str):
    model.eval()
    feats, labels = [], []
    for x, y, _ in loader:
        x = x.to(device)
        f = model.extract_features(x).detach().cpu().numpy()
        feats.append(f)
        labels.append(y.numpy())
    return np.concatenate(feats), np.concatenate(labels)

def linear_probe(feats: np.ndarray, labels: np.ndarray, seed: int = 42):
    Xtr, Xte, ytr, yte = train_test_split(feats, labels, test_size=0.3, random_state=seed, stratify=labels)
    clf = LogisticRegression(max_iter=2000)
    clf.fit(Xtr, ytr)
    pred = clf.predict(Xte)
    return {"probe_acc": accuracy_score(yte, pred), "probe_f1": f1_score(yte, pred)}

def save_tsne_plot(z2: np.ndarray, labels: np.ndarray, title: str, out_path: str):
    plt.figure()
    plt.scatter(z2[:, 0], z2[:, 1], c=labels)
    plt.title(title)
    plt.savefig(out_path, dpi=200, bbox_inches="tight")
    plt.close()


# ---------- XAI: IG + Occlusion ----------
@torch.no_grad()
def model_prob_class1(model: nn.Module, x: torch.Tensor) -> float:
    return torch.softmax(model(x), dim=1)[0, 1].item()

def occlusion_attribution_1d(model: nn.Module, x: torch.Tensor, window: int = 128, stride: int = 64, baseline: float = 0.0) -> np.ndarray:
    model.eval()
    x = x.clone()
    L = x.shape[-1]
    base_p = model_prob_class1(model, x)

    attr = np.zeros(L, dtype=np.float32)
    counts = np.zeros(L, dtype=np.float32)

    for start in range(0, L - window + 1, stride):
        x_occ = x.clone()
        x_occ[..., start:start+window] = baseline
        p_occ = model_prob_class1(model, x_occ)
        drop = base_p - p_occ
        attr[start:start+window] += drop
        counts[start:start+window] += 1.0

    counts[counts == 0] = 1.0
    attr = attr / counts
    attr = np.maximum(attr, 0)
    attr = attr / (attr.max() + 1e-8)
    return attr

def deletion_insertion_curve(model: nn.Module, x: torch.Tensor, attr: np.ndarray, target_class: int = 1, steps: int = 20):
    model.eval()
    L = x.shape[-1]
    idx = np.argsort(-attr)
    baseline = torch.zeros_like(x)

    @torch.no_grad()
    def prob(inp):
        return torch.softmax(model(inp), dim=1)[0, target_class].item()

    del_probs = [prob(x)]
    ins_probs = [prob(baseline)]

    k = max(1, L // steps)
    x_del = x.clone()
    x_ins = baseline.clone()

    for s in range(steps):
        sel = idx[s*k:(s+1)*k]
        x_del[..., sel] = 0.0
        x_ins[..., sel] = x[..., sel]
        del_probs.append(prob(x_del))
        ins_probs.append(prob(x_ins))

    return np.array(del_probs), np.array(ins_probs)

print("Core helpers OK")


Core helpers OK


In [92]:
TASKS = {
    "AB_vs_E": {"non_seizure": ("A","B"), "seizure": ("E",)},
    "CD_vs_E": {"non_seizure": ("C","D"), "seizure": ("E",)},
}

final_summary_rows = []  # will store task-level summary for final comparison

for task_name, task_def in TASKS.items():
    print("\n" + "="*70)
    print("RUNNING TASK:", task_name)
    print("="*70)

    # task-specific out dir
    out_dir = os.path.join(cfg.out_base, task_name)
    ensure_dir(out_dir)

    # Build file list
    non_sets = list(task_def["non_seizure"])
    seiz_sets = list(task_def["seizure"])

    files_0 = [(p, 0) for p in gather_files(cfg.data_root, non_sets)]
    files_1 = [(p, 1) for p in gather_files(cfg.data_root, seiz_sets)]
    files = files_0 + files_1

    if len(files) == 0:
        raise RuntimeError(f"No files found for task {task_name}. Check folders/extensions.")

    labels = [y for _, y in files]

    # Split
    train_files, test_files = train_test_split(files, test_size=cfg.test_size, random_state=cfg.seed, stratify=labels)
    train_labels = [y for _, y in train_files]
    train_files, val_files = train_test_split(train_files, test_size=cfg.val_size, random_state=cfg.seed, stratify=train_labels)

    print("Counts:", {"train": len(train_files), "val": len(val_files), "test": len(test_files)})

    # Datasets/loaders
    train_ds = BonnEEGDataset(cfg, train_files)
    val_ds   = BonnEEGDataset(cfg, val_files)
    test_ds  = BonnEEGDataset(cfg, test_files)

    train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True)
    val_loader   = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False)
    test_loader  = DataLoader(test_ds, batch_size=cfg.batch_size, shuffle=False)

    # Class weights
    y_train = np.array([y for _, y in train_files])
    n0 = int((y_train == 0).sum()); n1 = int((y_train == 1).sum())
    w0 = (n0 + n1) / (2.0 * max(n0, 1))
    w1 = (n0 + n1) / (2.0 * max(n1, 1))
    class_weights = torch.tensor([w0, w1], dtype=torch.float32)
    print("Class weights:", class_weights.tolist(), f"(n0={n0}, n1={n1})")

    # Train 3 models
    candidates = {
        "CNN1D": CNN1D(),
        "CNN_BiLSTM": CNN_BiLSTM(),
        "Transformer1D": Transformer1D(),
    }

    results = {}

    for model_name, model in candidates.items():
        print("\n--- Training model:", model_name, "| task:", task_name)
        model, best_val = train_one_model(model, train_loader, val_loader, device, cfg, class_weights)
        test_metrics = evaluate(model, test_loader, device)
        print("TEST:", test_metrics)

        ckpt_path = os.path.join(out_dir, f"{model_name}_best.pt")
        torch.save(model.state_dict(), ckpt_path)

        results[model_name] = {"best_val": best_val, "test": test_metrics, "ckpt": ckpt_path}

    # Best model by selection metric on TEST
    best_model_name = max(results.keys(), key=lambda k: results[k]["test"][cfg.selection_metric])
    best_ckpt = results[best_model_name]["ckpt"]

    print("\nBEST MODEL for", task_name, "=", best_model_name)
    print("Best test metrics:", results[best_model_name]["test"])

    # Save summary.txt
    summary_path = os.path.join(out_dir, "summary.txt")
    with open(summary_path, "w") as f:
        for m, r in results.items():
            f.write(f"{m}\n  ckpt: {r['ckpt']}\n  best_val: {r['best_val']}\n  test: {r['test']}\n\n")
        f.write(f"BEST: {best_model_name}\n")
    print("Saved:", summary_path)

    # -----------------------
    # Feature extraction & comparison
    # -----------------------
    print("\nFEATURE EXTRACTION for", task_name)

    full_ds = BonnEEGDataset(cfg, files)
    full_loader = DataLoader(full_ds, batch_size=cfg.batch_size, shuffle=False)

    feature_report_path = os.path.join(out_dir, "feature_report.txt")
    with open(feature_report_path, "w") as fr:
        for m in candidates.keys():
            # reload model
            if m == "CNN1D":
                model = CNN1D()
            elif m == "CNN_BiLSTM":
                model = CNN_BiLSTM()
            else:
                model = Transformer1D()

            model.load_state_dict(torch.load(os.path.join(out_dir, f"{m}_best.pt"), map_location="cpu"))
            model.to(device).eval()

            feats, y_np = extract_embeddings(model, full_loader, device)

            sil = silhouette_score(feats, y_np) if len(np.unique(y_np)) == 2 else float("nan")
            probe = linear_probe(feats, y_np, cfg.seed)

            msg = f"{m}\n  emb_shape={feats.shape}\n  silhouette={sil}\n  linear_probe={probe}\n"
            print(msg)
            fr.write(msg + "\n")

            tsne = TSNE(n_components=2, perplexity=20, random_state=cfg.seed)
            z2 = tsne.fit_transform(feats)
            tsne_path = os.path.join(out_dir, f"tsne_{m}.png")
            save_tsne_plot(z2, y_np, f"t-SNE Embeddings: {m} ({task_name})", tsne_path)
            print("Saved:", tsne_path)

    print("Saved feature report:", feature_report_path)

    # -----------------------
    # XAI on best model: IG + Occlusion
    # -----------------------
    print("\nXAI for best model (IG + Occlusion):", best_model_name, "| task:", task_name)

    # load best model
    if best_model_name == "CNN1D":
        best_model = CNN1D()
    elif best_model_name == "CNN_BiLSTM":
        best_model = CNN_BiLSTM()
    else:
        best_model = Transformer1D()

    best_model.load_state_dict(torch.load(best_ckpt, map_location="cpu"))
    best_model.to(device).eval()

    ig = IntegratedGradients(best_model)

    # sample explanations from TEST
    xai_ds = BonnEEGDataset(cfg, test_files)
    xai_loader = DataLoader(xai_ds, batch_size=1, shuffle=True)

    samples = []
    for x, y, path in xai_loader:
        samples.append((x.to(device), int(y.item()), path[0]))
        if len(samples) >= 20:   # more stable than 6
            break

    xai_rows = []
    for i, (x, y, p) in enumerate(samples):
        # IG
        ig_attr = ig.attribute(x, target=1, baselines=torch.zeros_like(x))
        ig_attr = ig_attr.squeeze().detach().cpu().numpy()
        ig_attr = np.abs(ig_attr)
        ig_attr = ig_attr / (ig_attr.max() + 1e-8)

        del_ig, ins_ig = deletion_insertion_curve(best_model, x, ig_attr, target_class=1, steps=20)

        # Occlusion
        occ_attr = occlusion_attribution_1d(best_model, x, window=128, stride=64, baseline=0.0)
        del_occ, ins_occ = deletion_insertion_curve(best_model, x, occ_attr, target_class=1, steps=20)

        xai_rows.append({
            "sample": i, "label": y, "path": p,
            "IG_deletion_auc": float(np.trapz(del_ig)),
            "IG_insertion_auc": float(np.trapz(ins_ig)),
            "OCC_deletion_auc": float(np.trapz(del_occ)),
            "OCC_insertion_auc": float(np.trapz(ins_occ)),
        })

        # Save attribution plots (optional but good for report)
        sig = x.squeeze().detach().cpu().numpy()

        plt.figure(); plt.plot(sig); plt.title(f"Signal sample={i} label={y} ({task_name})")
        plt.savefig(os.path.join(out_dir, f"xai_signal_{i}.png"), dpi=200, bbox_inches="tight"); plt.close()

        plt.figure(); plt.plot(ig_attr); plt.title(f"Integrated Gradients sample={i} ({task_name})")
        plt.savefig(os.path.join(out_dir, f"xai_ig_attr_{i}.png"), dpi=200, bbox_inches="tight"); plt.close()

        plt.figure(); plt.plot(occ_attr); plt.title(f"Occlusion Attribution sample={i} ({task_name})")
        plt.savefig(os.path.join(out_dir, f"xai_occ_attr_{i}.png"), dpi=200, bbox_inches="tight"); plt.close()

    # Aggregate XAI
    ig_del = float(np.mean([r["IG_deletion_auc"] for r in xai_rows]))
    ig_ins = float(np.mean([r["IG_insertion_auc"] for r in xai_rows]))
    occ_del = float(np.mean([r["OCC_deletion_auc"] for r in xai_rows]))
    occ_ins = float(np.mean([r["OCC_insertion_auc"] for r in xai_rows]))

    if (occ_del < ig_del) and (occ_ins > ig_ins):
        xai_winner = "Occlusion"
    elif (ig_del < occ_del) and (ig_ins > occ_ins):
        xai_winner = "Integrated Gradients"
    else:
        xai_winner = "Mixed"

    print("Aggregate faithfulness:")
    print("IG   del_auc (lower better):", ig_del, "| ins_auc (higher better):", ig_ins)
    print("OCC  del_auc (lower better):", occ_del, "| ins_auc (higher better):", occ_ins)
    print("Winner XAI:", xai_winner)

    # Save XAI report
    xai_report_path = os.path.join(out_dir, "xai_report.txt")
    with open(xai_report_path, "w") as f:
        for r in xai_rows:
            f.write(str(r) + "\n")
        f.write("\nAGGREGATE\n")
        f.write(f"IG_del={ig_del}, IG_ins={ig_ins}\n")
        f.write(f"OCC_del={occ_del}, OCC_ins={occ_ins}\n")
        f.write(f"WINNER_XAI={xai_winner}\n")
    print("Saved:", xai_report_path)

    # -----------------------
    # Collect task-level summary row for final comparison table
    # -----------------------
    best_test = results[best_model_name]["test"]
    final_summary_rows.append({
        "task": task_name,
        "non_seizure_sets": str(tuple(non_sets)),
        "seizure_sets": str(tuple(seiz_sets)),
        "best_model": best_model_name,
        "best_test_accuracy": best_test["accuracy"],
        "best_test_f1": best_test["f1"],
        "best_test_auc": best_test["auc"],
        "xai_winner": xai_winner,
        "IG_del_auc": ig_del,
        "IG_ins_auc": ig_ins,
        "OCC_del_auc": occ_del,
        "OCC_ins_auc": occ_ins,
        "out_dir": out_dir
    })

print("\n" + "="*70)
print("ALL TASKS COMPLETED")
print("="*70)


RUNNING TASK: AB_vs_E
Counts: {'train': 192, 'val': 48, 'test': 60}
Class weights: [0.75, 1.5] (n0=128, n1=64)

--- Training model: CNN1D | task: AB_vs_E
Epoch 01 | val_f1=0.5079 | val_acc=0.3542 | val_auc=0.8789
Epoch 02 | val_f1=0.5000 | val_acc=0.3333 | val_auc=0.8926
Epoch 03 | val_f1=0.5000 | val_acc=0.3333 | val_auc=0.9160
Epoch 04 | val_f1=0.5000 | val_acc=0.3333 | val_auc=0.9453
Epoch 05 | val_f1=0.5000 | val_acc=0.3333 | val_auc=0.9629
Epoch 06 | val_f1=0.5000 | val_acc=0.3333 | val_auc=0.9844
Epoch 07 | val_f1=0.5246 | val_acc=0.3958 | val_auc=0.9961
Epoch 08 | val_f1=0.5246 | val_acc=0.3958 | val_auc=1.0000
Epoch 09 | val_f1=0.7111 | val_acc=0.7292 | val_auc=1.0000
Epoch 10 | val_f1=0.9412 | val_acc=0.9583 | val_auc=1.0000
Epoch 11 | val_f1=1.0000 | val_acc=1.0000 | val_auc=1.0000
Epoch 12 | val_f1=1.0000 | val_acc=1.0000 | val_auc=1.0000
Epoch 13 | val_f1=0.8649 | val_acc=0.8958 | val_auc=1.0000
Epoch 14 | val_f1=1.0000 | val_acc=1.0000 | val_auc=1.0000
Epoch 15 | val_f1=1

In [93]:
import pandas as pd

df = pd.DataFrame(final_summary_rows)
df = df.sort_values(by=["task"]).reset_index(drop=True)

display(df)

csv_path = os.path.join(cfg.out_base, "FINAL_COMPARISON_TABLE.csv")
df.to_csv(csv_path, index=False)
print("Saved final comparison table to:", csv_path)

Unnamed: 0,task,non_seizure_sets,seizure_sets,best_model,best_test_accuracy,best_test_f1,best_test_auc,xai_winner,IG_del_auc,IG_ins_auc,OCC_del_auc,OCC_ins_auc,out_dir
0,AB_vs_E,"('A', 'B')","('E',)",Transformer1D,0.983333,0.974359,1.0,Occlusion,19.356909,9.340502,18.908969,18.963056,/content/drive/MyDrive/EEG/Outputs_Bonn_Projec...
1,CD_vs_E,"('C', 'D')","('E',)",CNN1D,0.966667,0.947368,0.9825,Mixed,7.475885,14.42531,1.493621,2.520964,/content/drive/MyDrive/EEG/Outputs_Bonn_Projec...


Saved final comparison table to: /content/drive/MyDrive/EEG/Outputs_Bonn_Project_Combined/FINAL_COMPARISON_TABLE.csv
