# ViT (5-channel) — Experiment Sweep with Subject-ID Loader

This notebook preserves the original ViT experiment sweep and **replaces the data loading** with a subject-based loader
that reads per-sample `.npz` files from Google Drive, performs segmentation, aggregates segments to single images,
and prepares `(N, 5, 8, 8)` tensors for ViT.


In [None]:
!pip -q install transformers==4.42.4 timm==0.9.16 torchvision --upgrade

In [None]:
import os, json, time, math, random, csv
from datetime import datetime
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from transformers import ViTConfig, ViTForImageClassification, TrainingArguments, Trainer, set_seed
from sklearn.metrics import (
    accuracy_score, roc_auc_score, precision_score, recall_score, f1_score,
    matthews_corrcoef, cohen_kappa_score, confusion_matrix, classification_report
)
from sklearn.utils.class_weight import compute_class_weight

from google.colab import drive
drive.mount('/content/gdrive')

DATA_DIR = '/content/gdrive/MyDrive/preprocessed_per_sample'
RESULTS_ROOT = './vit_experiments_subject_split'
os.makedirs(RESULTS_ROOT, exist_ok=True)

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

NUM_CLASSES = 3
IMAGE_SIZE = 8
PATCH_SIZE = 2
NUM_CHANNELS = 5

EXCLUDED_SUBJECTS = {"S038", "S088", "S089", "S092", "S100", "S104"}


In [None]:
def get_subject_ids():
    ids = set()
    for fname in os.listdir(DATA_DIR):
        if fname.endswith('.npz') and fname.startswith('sample'):
            sid = fname.split('_')[1][:4]
            if sid not in EXCLUDED_SUBJECTS:
                ids.add(sid)
    return sorted(ids)

subject_ids = get_subject_ids()
random.shuffle(subject_ids)

n_total = len(subject_ids)
n_train = min(70, n_total)
n_dev = min(15, max(0, n_total - n_train))
n_test = min(16, max(0, n_total - n_train - n_dev))

train_ids = subject_ids[:n_train]
dev_ids   = subject_ids[n_train:n_train+n_dev]
test_ids  = subject_ids[n_train+n_dev:n_train+n_dev+n_test]

print('Subjects:', len(subject_ids))
print('Train IDs:', len(train_ids))
print('Dev IDs  :', len(dev_ids))
print('Test IDs :', len(test_ids))


In [None]:
def segment_topomaps(X, y, window=16, stride=10):
    segments = []
    labels = []
    num_frames = X.shape[0]
    for start in range(0, num_frames - window + 1, stride):
        end = start + window
        segments.append(X[start:end])
        labels.append(y)
    return np.array(segments), np.array(labels)

def aggregate_segment(seg, method='mean'):
    if method == 'mean':
        return seg.mean(axis=0)
    elif method == 'median':
        return np.median(seg, axis=0)
    else:
        return seg[seg.shape[0]//2]

def load_subjects(subject_ids, window=16, stride=10, aggregate='mean', target_channels=NUM_CHANNELS, target_size=IMAGE_SIZE):
    X_all, y_all = [], []
    for fname in os.listdir(DATA_DIR):
        if not (fname.endswith('.npz') and fname.startswith('sample')):
            continue
        for sid in subject_ids:
            if f'_{sid}R' in fname:
                data = np.load(os.path.join(DATA_DIR, fname), allow_pickle=True)
                X = data['X']
                y = int(data['y'])
                segs, labels = segment_topomaps(X, y, window=window, stride=stride)
                for seg in segs:
                    img = aggregate_segment(seg, method=aggregate)
                    if img.ndim == 3:
                        pass
                    elif img.ndim == 4 and img.shape[0] == 1:
                        img = img[0]
                    else:
                        raise ValueError(f'Unexpected segment shape: {img.shape}')
                    c_in, h, w = img.shape
                    t = torch.tensor(img, dtype=torch.float32).unsqueeze(0)
                    t = F.interpolate(t, size=(target_size, target_size), mode='area')
                    t = t.squeeze(0)
                    if c_in == target_channels:
                        t_out = t
                    elif c_in < target_channels:
                        reps = [t] + [t[-1:].clone() for _ in range(target_channels - c_in)]
                        t_out = torch.cat(reps, dim=0)
                    else:
                        t_out = t[:target_channels]
                    X_all.append(t_out.numpy())
                    y_all.append(y)
    X_all = np.stack(X_all, axis=0) if len(X_all) else np.empty((0, target_channels, target_size, target_size), dtype=np.float32)
    y_all = np.array(y_all, dtype=np.int64)
    return X_all, y_all

def filter_and_map_labels(y, keep=(2,3,6), mapping={2:0, 3:1, 6:2}):
    mask = np.isin(y, list(keep))
    y_f = y[mask]
    y_m = np.vectorize(mapping.get)(y_f)
    return mask, y_m


In [None]:
X_train, y_train = load_subjects(train_ids, window=16, stride=10, aggregate='mean')
X_dev,   y_dev   = load_subjects(dev_ids,   window=16, stride=10, aggregate='mean')
X_test,  y_test  = load_subjects(test_ids,  window=16, stride=10, aggregate='mean')

m_tr, y_train_m = filter_and_map_labels(y_train)
m_de, y_dev_m   = filter_and_map_labels(y_dev)
m_te, y_test_m  = filter_and_map_labels(y_test)

X_train = X_train[m_tr]
X_dev   = X_dev[m_de]
X_test  = X_test[m_te]

print('Train:', X_train.shape, y_train_m.shape)
print('Dev  :', X_dev.shape,   y_dev_m.shape)
print('Test :', X_test.shape,  y_test_m.shape)


In [None]:
class EEGDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.from_numpy(X).float()
        self.y = torch.from_numpy(y).long()
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return {"pixel_values": self.X[idx], "labels": self.y[idx]}

train_ds = EEGDataset(X_train, y_train_m)
val_ds   = EEGDataset(X_dev,   y_dev_m)
test_ds  = EEGDataset(X_test,  y_test_m)

class_weights = torch.tensor(
    compute_class_weight('balanced', classes=np.unique(y_train_m), y=y_train_m),
    dtype=torch.float
)


In [None]:
def compute_full_metrics(p):
    preds = np.argmax(p.predictions, axis=1)
    probs = torch.nn.functional.softmax(torch.tensor(p.predictions), dim=-1).numpy()
    try:
        auc = roc_auc_score(p.label_ids, probs, multi_class='ovr', average='weighted')
    except Exception:
        auc = float('nan')
    return {
        "accuracy": accuracy_score(p.label_ids, preds),
        "auc": auc,
        "precision_weighted": precision_score(p.label_ids, preds, average="weighted", zero_division=0),
        "recall_weighted": recall_score(p.label_ids, preds, average="weighted", zero_division=0),
        "f1_weighted": f1_score(p.label_ids, preds, average="weighted", zero_division=0),
        "precision_macro": precision_score(p.label_ids, preds, average="macro", zero_division=0),
        "recall_macro": recall_score(p.label_ids, preds, average="macro", zero_division=0),
        "f1_macro": f1_score(p.label_ids, preds, average="macro", zero_division=0),
        "matthews_corrcoef": matthews_corrcoef(p.label_ids, preds),
        "cohen_kappa": cohen_kappa_score(p.label_ids, preds),
    }

class WeightedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.get("labels")
        outputs = model(**inputs)
        logits = outputs.get("logits")
        loss_fct = torch.nn.CrossEntropyLoss(weight=class_weights.to(model.device))
        loss = loss_fct(logits.view(-1, model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss


In [None]:
def run_experiment(cfg, run_idx):
    seed = cfg.get("seed", 42)
    set_seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    cfg_tag = f"LR{cfg.get('lr',1e-5)}_E{cfg.get('epochs',60)}_HS{cfg.get('hidden_size',64)}_L{cfg.get('num_hidden_layers',2)}_H{cfg.get('num_attention_heads',4)}_IS{cfg.get('intermediate_size',128)}_WD{cfg.get('weight_decay',0.01)}_DO{cfg.get('dropout',0.1)}_S{seed}"
    ts = datetime.now().strftime("%Y%m%d-%H%M%S")
    run_dir = Path(RESULTS_ROOT) / f"run{run_idx:02d}_{ts}_{cfg_tag}"
    run_dir.mkdir(parents=True, exist_ok=True)

    with open(run_dir / "config.json", "w") as f:
        json.dump(cfg, f, indent=2)

    vit_cfg = ViTConfig(
        image_size=IMAGE_SIZE,
        patch_size=PATCH_SIZE,
        num_channels=NUM_CHANNELS,
        num_labels=NUM_CLASSES,
        hidden_size=cfg.get("hidden_size", 64),
        num_hidden_layers=cfg.get("num_hidden_layers", 2),
        num_attention_heads=cfg.get("num_attention_heads", 4),
        intermediate_size=cfg.get("intermediate_size", 128),
        hidden_act="gelu",
        hidden_dropout_prob=cfg.get("dropout", 0.1),
        attention_probs_dropout_prob=cfg.get("dropout", 0.1),
    )

    model = ViTForImageClassification(vit_cfg)

    fp16_flag = cfg.get("fp16", torch.cuda.is_available())
    training_args = TrainingArguments(
        output_dir=str(run_dir / "hf_out"),
        per_device_train_batch_size=cfg.get("batch_size", 32),
        per_device_eval_batch_size=cfg.get("batch_size", 32),
        num_train_epochs=cfg.get("epochs", 60),
        learning_rate=cfg.get("lr", 1e-5),
        weight_decay=cfg.get("weight_decay", 0.01),
        evaluation_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="eval_f1_weighted",
        greater_is_better=True,
        logging_steps=10,
        report_to="none",
        fp16=fp16_flag
    )

    trainer = WeightedTrainer(
        model=model,
        args=training_args,
        train_dataset=trainer.train_dataset if False else train_ds,
        eval_dataset=trainer.eval_dataset if False else val_ds,
        compute_metrics=compute_full_metrics,
    )

    print(f"===== RUN {run_idx} =====")
    print(f"Config: {cfg}")
    print("Starting training...")
    trainer.train()
    with open(run_dir / "train_state.json", "w") as f:
        json.dump(trainer.state.log_history, f, indent=2)

    def eval_split(dset, name):
        preds = trainer.predict(dset)
        metrics = compute_full_metrics(preds)
        cm = confusion_matrix(preds.label_ids, np.argmax(preds.predictions, axis=1))
        rep = classification_report(preds.label_ids, np.argmax(preds.predictions, axis=1), digits=4, zero_division=0)

        split_dir = run_dir / f"{name}"
        split_dir.mkdir(exist_ok=True, parents=True)
        with open(split_dir / "metrics.json", "w") as f:
            json.dump(metrics, f, indent=2)
        np.save(split_dir / "confusion_matrix.npy", cm)
        np.savetxt(split_dir / "confusion_matrix.csv", cm, fmt="%d", delimiter=",")
        with open(split_dir / "classification_report.txt", "w") as f:
            f.write(rep)
        np.save(split_dir / "predictions.npy", preds.predictions)
        np.save(split_dir / "labels.npy", preds.label_ids)

        print(f"— {name.upper()} METRICS —")
        for k, v in metrics.items():
            print(f"{k:20}: {v:.6f}" if isinstance(v, (int, float, np.floating)) else f"{k:20}: {v}")
        print("Confusion Matrix:")
        print(cm)
        print("Classification Report:")
        print(rep)

        summary_keys = ["accuracy","auc","f1_weighted","f1_macro","precision_macro","recall_macro","matthews_corrcoef","cohen_kappa"]
        return {k: float(metrics[k]) for k in summary_keys}

    summary = {"config": cfg}
    summary["train"] = eval_split(train_ds, "train")
    summary["val"]   = eval_split(val_ds,   "val")
    summary["test"]  = eval_split(test_ds,  "test")

    with open(run_dir / "summary.json", "w") as f:
        json.dump(summary, f, indent=2)

    trainer.save_model(str(run_dir / "best_model"))

    return summary, run_dir


In [None]:
SWEEP = [
    {
        "run_name": "S1_reg_ema_es",
        "lr": 2e-5, "weight_decay": 0.02, "batch_size": 32, "epochs": 100,
        "augment": True, "weighted_sampler": True, "freeze_backbone": False,
        "label_smoothing": 0.10, "lr_schedule": "cosine_with_restarts",
        "num_cycles": 5, "warmup_ratio": 0.15, "ema_decay": 0.999,
        "early_stopping": {"metric": "macro_f1", "patience": 10}
    },
    {
        "run_name": "S2_llrd",
        "lr": 2e-5, "head_lr": 5e-5, "layer_decay": 0.75,
        "weight_decay": 0.02, "batch_size": 32, "epochs": 160,
        "augment": True, "weighted_sampler": True, "freeze_backbone": False,
        "label_smoothing": 0.05, "lr_schedule": "cosine", "warmup_ratio": 0.10
    },
    {"run_name": "A_baseline_lr1e-5_b32_e60_wd0.01_augY_freezeN",
     "lr": 1e-5, "weight_decay": 0.01, "batch_size": 32, "epochs": 60,
     "augment": True, "weighted_sampler": False, "freeze_backbone": False,
     "label_smoothing": 0.0, "lr_schedule": "cosine", "warmup_ratio": 0.0, "seed": 42},

    {"run_name": "S_lr3e-5_b64_e60_wd0.01_augY_freezeN_ls0.05",
     "lr": 3e-5, "weight_decay": 0.01, "batch_size": 64, "epochs": 60,
     "augment": True, "weighted_sampler": False, "freeze_backbone": False,
     "label_smoothing": 0.05, "lr_schedule": "cosine", "warmup_ratio": 0.05, "seed": 42},

    {"run_name": "S_lr2e-5_b32_e80_wd0.01_augY_freezeN_sampler",
     "lr": 2e-5, "weight_decay": 0.01, "batch_size": 32, "epochs": 80,
     "augment": True, "weighted_sampler": True, "freeze_backbone": False,
     "label_smoothing": 0.05, "lr_schedule": "cosine", "warmup_ratio": 0.1, "seed": 123},

    {"run_name": "S_freeze_lr5e-5_b64_e40_wd0.02_augY_freezeY",
     "lr": 5e-5, "weight_decay": 0.02, "batch_size": 64, "epochs": 40,
     "augment": True, "weighted_sampler": False, "freeze_backbone": True,
     "label_smoothing": 0.1, "lr_schedule": "cosine", "warmup_ratio": 0.1, "seed": 7},

    {"run_name": "S_noaug_lr1e-5_b32_e60_wd0.01_freezeN_accum4",
     "lr": 1e-5, "weight_decay": 0.01, "batch_size": 32, "epochs": 60,
     "augment": False, "weighted_sampler": False, "freeze_backbone": False,
     "label_smoothing": 0.0, "lr_schedule": "cosine", "warmup_ratio": 0.0, "grad_accum": 4, "seed": 42},
]

ALL_SUMMARIES = []
for i, cfg in enumerate(SWEEP, start=1):
    s, rdir = run_experiment(cfg, i)
    s["run_dir"] = str(rdir)
    ALL_SUMMARIES.append(s)

table_path = Path(RESULTS_ROOT) / "all_runs_summary.csv"
with open(table_path, "w", newline="") as f:
    writer = csv.writer(f)
    header = ["run", "run_dir",
              "train_acc","train_auc","train_f1w","train_f1m","train_precm","train_recm","train_mcc","train_kappa",
              "val_acc","val_auc","val_f1w","val_f1m","val_precm","val_recm","val_mcc","val_kappa",
              "test_acc","test_auc","test_f1w","test_f1m","test_precm","test_recm","test_mcc","test_kappa"]
    writer.writerow(header)
    for idx, s in enumerate(ALL_SUMMARIES, start=1):
        row = [idx, s["run_dir"]]
        for split in ["train","val","test"]:
            m = s[split]
            row.extend([m["accuracy"], m["auc"], m["f1_weighted"], m["f1_macro"],
                        m["precision_macro"], m["recall_macro"], m["matthews_corrcoef"], m["cohen_kappa"]])
        writer.writerow(row)

print("Done. Per-run folders saved under:", RESULTS_ROOT)
print("Combined table:", table_path)

def fmt(v):
    return f"{v:.4f}" if isinstance(v, (float, int, np.floating)) else str(v)

print("\n===== SUMMARY (key metrics) =====")
for i, s in enumerate(ALL_SUMMARIES, start=1):
    tr, va, te = s["train"], s["val"], s["test"]
    print(f"Run {i}: dir={s['run_dir']}")
    print(f"  Train: acc {fmt(tr['accuracy'])}, f1w {fmt(tr['f1_weighted'])}, auc {fmt(tr['auc'])}")
    print(f"  Val  : acc {fmt(va['accuracy'])}, f1w {fmt(va['f1_weighted'])}, auc {fmt(va['auc'])}")
    print(f"  Test : acc {fmt(te['accuracy'])}, f1w {fmt(te['f1_weighted'])}, auc {fmt(te['auc'])}")
