In [None]:
import os
import json
import time
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision.transforms import functional as TF
from torchvision.transforms.functional import InterpolationMode
from torchvision.models import resnet18, ResNet18_Weights

from sklearn.metrics import (
    accuracy_score, roc_auc_score, precision_score, recall_score, f1_score,
    matthews_corrcoef, cohen_kappa_score, confusion_matrix, classification_report
)
from sklearn.utils.class_weight import compute_class_weight

DATA_DIR = "/content/gdrive/MyDrive/1segment_topomap_5channel_11classes"
RESULTS_ROOT = "./resnet_topomap_results_sweep"
os.makedirs(RESULTS_ROOT, exist_ok=True)

BATCH_SIZE = 32
EPOCHS = 60
NUM_CLASSES = 3
IMAGE_SIZE = 224
RANDOM_SEED = 42
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

THREE_BAND_IDX = [1, 2, 3]
EXCLUDED_SUBJECTS = {"S038", "S088", "S089", "S092", "S100", "S104"}
LABEL_MAP = {2: 0, 3: 1, 6: 2, 7: 1}
VALID_CLASSES = set(LABEL_MAP.keys())

def get_subject_ids():
    ids = set()
    for fname in os.listdir(DATA_DIR):
        if fname.endswith(".npz") and fname.startswith("sample"):
            sid = fname.split("_")[1][:4]
            if sid not in EXCLUDED_SUBJECTS:
                ids.add(sid)
    return sorted(ids)

subject_ids = get_subject_ids()
random.seed(RANDOM_SEED)
random.shuffle(subject_ids)

def _slice(lst, a, b):
    a = min(a, len(lst)); b = min(b, len(lst))
    return lst[a:b]

train_ids = _slice(subject_ids, 0, 70)
dev_ids   = _slice(subject_ids, 71, 86)
test_ids  = _slice(subject_ids, 87, 103)

def load_chunks_by_subject_ids(subject_ids):
    X_all, y_all = [], []
    files = [f for f in os.listdir(DATA_DIR) if f.endswith(".npz") and f.startswith("sample")]
    for sid in subject_ids:
        sid_token = f"_{sid}R"
        sid_files = [f for f in files if sid_token in f]
        for fname in sid_files:
            path = os.path.join(DATA_DIR, fname)
            data = np.load(path)
            X, y = data["X"], data["y"]
            if isinstance(y, np.ndarray):
                y = y.item()
            if y not in VALID_CLASSES:
                continue
            if X.ndim != 3 or X.shape[0] <= max(THREE_BAND_IDX):
                continue
            X = X[THREE_BAND_IDX, :, :]
            X = (X - X.mean()) / (X.std() + 1e-8)
            X_all.append(X.astype(np.float32))
            y_all.append(LABEL_MAP[y])
    if not X_all:
        raise RuntimeError("No samples after subject-based loading.")
    X_all = np.stack(X_all)
    y_all = np.array(y_all, dtype=int)
    return X_all, y_all

X_train, y_train = load_chunks_by_subject_ids(train_ids)
X_val,   y_val   = load_chunks_by_subject_ids(dev_ids)
X_test,  y_test  = load_chunks_by_subject_ids(test_ids)

def resize_224(x: torch.Tensor) -> torch.Tensor:
    x = x.unsqueeze(0)
    x = F.interpolate(x, size=(IMAGE_SIZE, IMAGE_SIZE), mode="bilinear", align_corners=False)
    return x.squeeze(0)

class TensorAug:
    def __init__(self, do_aug=True, max_rotate=15.0, hflip_p=0.5):
        self.do_aug = do_aug
        self.max_rotate = max_rotate
        self.hflip_p = hflip_p
    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        x = resize_224(x)
        if not self.do_aug:
            return x
        if torch.rand(1).item() < self.hflip_p:
            x = torch.flip(x, dims=[2])
        angle = (torch.rand(1).item() * 2 * self.max_rotate) - self.max_rotate
        x = TF.rotate(x, angle, interpolation=InterpolationMode.BILINEAR, expand=False)
        return x

train_transform = TensorAug(do_aug=True)
eval_transform  = TensorAug(do_aug=False)

class EEGDataset(Dataset):
    def __init__(self, X, y, transform=None):
        self.X = X
        self.y = y
        self.transform = transform
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        x = torch.from_numpy(self.X[idx])
        y = torch.tensor(self.y[idx]).long()
        if self.transform is not None:
            x = self.transform(x)
        return x, y

def compute_full_metrics(y_true, logits):
    preds = np.argmax(logits, axis=1)
    probs = torch.softmax(torch.tensor(logits), dim=-1).numpy()
    try:
        if probs.shape[1] == 2:
            auc = roc_auc_score(y_true, probs[:, 1])
        else:
            auc = roc_auc_score(y_true, probs, multi_class="ovr", average="weighted")
    except Exception:
        auc = float("nan")
    return {
        "accuracy": accuracy_score(y_true, preds),
        "auc": float(auc),
        "precision_weighted": precision_score(y_true, preds, average="weighted", zero_division=0),
        "recall_weighted": recall_score(y_true, preds, average="weighted", zero_division=0),
        "f1_weighted": f1_score(y_true, preds, average="weighted", zero_division=0),
        "precision_macro": precision_score(y_true, preds, average="macro", zero_division=0),
        "recall_macro": recall_score(y_true, preds, average="macro", zero_division=0),
        "f1_macro": f1_score(y_true, preds, average="macro", zero_division=0),
        "matthews_corrcoef": matthews_corrcoef(y_true, preds),
        "cohen_kappa": cohen_kappa_score(y_true, preds),
        "confusion_matrix": confusion_matrix(y_true, preds).tolist(),
        "classification_report": classification_report(y_true, preds, zero_division=0)
    }

def build_sampler(y_array):
    cw = compute_class_weight("balanced", classes=np.unique(y_array), y=y_array)
    sample_weights = torch.tensor(cw, dtype=torch.float32)[torch.tensor(y_array, dtype=torch.long)]
    sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
    return sampler

def load_resnet(num_classes, freeze_backbone=False):
    model = resnet18(weights=ResNet18_Weights.DEFAULT)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    if freeze_backbone:
        for name, p in model.named_parameters():
            if not name.startswith("fc"):
                p.requires_grad = False
    return model

def run_experiment(cfg):
    seed = cfg.get("seed", RANDOM_SEED)
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    run_name = cfg.get("run_name", f"run_{int(time.time())}")
    out_dir = os.path.join(RESULTS_ROOT, run_name)
    os.makedirs(out_dir, exist_ok=True)

    tr_transform = train_transform if cfg.get("augment", True) else TensorAug(do_aug=False)
    train_ds = EEGDataset(X_train, y_train, transform=tr_transform)
    val_ds   = EEGDataset(X_val,   y_val,   transform=eval_transform)
    test_ds  = EEGDataset(X_test,  y_test,  transform=eval_transform)

    if cfg.get("weighted_sampler", False):
        sampler = build_sampler(y_train)
        train_loader = DataLoader(train_ds, batch_size=cfg.get("batch_size", BATCH_SIZE), sampler=sampler, num_workers=2, pin_memory=True)
    else:
        train_loader = DataLoader(train_ds, batch_size=cfg.get("batch_size", BATCH_SIZE), shuffle=True, num_workers=2, pin_memory=True)
    val_loader  = DataLoader(val_ds,   batch_size=cfg.get("batch_size", BATCH_SIZE), shuffle=False, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_ds,  batch_size=cfg.get("batch_size", BATCH_SIZE), shuffle=False, num_workers=2, pin_memory=True)

    model = load_resnet(NUM_CLASSES, freeze_backbone=cfg.get("freeze_backbone", False)).to(DEVICE)

    class_weights = compute_class_weight("balanced", classes=np.unique(y_train), y=y_train)
    class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(DEVICE)
    loss_fn = nn.CrossEntropyLoss(weight=class_weights_tensor, label_smoothing=float(cfg.get("label_smoothing", 0.0)))
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.get("lr", 1e-4), weight_decay=cfg.get("weight_decay", 0.0))
    if cfg.get("lr_schedule", "cosine") == "cosine":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.get("epochs", EPOCHS))
    else:
        scheduler = None

    best_f1 = -1.0
    best_state = None
    logs = []

    grad_accum = int(cfg.get("grad_accum", 1))
    epochs = int(cfg.get("epochs", EPOCHS))

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        steps = 0
        optimizer.zero_grad(set_to_none=True)
        for i, (xb, yb) in enumerate(train_loader):
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            out = model(xb)
            loss = loss_fn(out, yb) / grad_accum
            loss.backward()
            if (i + 1) % grad_accum == 0:
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)
            running_loss += loss.item()
            steps += 1
        if scheduler is not None:
            scheduler.step()

        model.eval()
        with torch.no_grad():
            val_logits = []
            val_labels = []
            for xb, yb in val_loader:
                xb = xb.to(DEVICE)
                out = model(xb)
                val_logits.append(out.cpu().numpy())
                val_labels.append(yb.numpy())
            val_logits = np.concatenate(val_logits, axis=0)
            val_labels = np.concatenate(val_labels, axis=0)
            val_metrics = compute_full_metrics(val_labels, val_logits)
        logs.append({"epoch": epoch + 1, "train_loss": float(running_loss / max(1, steps)), "val_f1_weighted": float(val_metrics["f1_weighted"]), "val_accuracy": float(val_metrics["accuracy"])})
        if val_metrics["f1_weighted"] > best_f1:
            best_f1 = val_metrics["f1_weighted"]
            best_state = {k: v.cpu() for k, v in model.state_dict().items()}

    if best_state is not None:
        model.load_state_dict(best_state)

    def eval_split(loader):
        model.eval()
        logits = []
        labels = []
        with torch.no_grad():
            for xb, yb in loader:
                xb = xb.to(DEVICE)
                out = model(xb)
                logits.append(out.cpu().numpy())
                labels.append(yb.numpy())
        logits = np.concatenate(logits, axis=0)
        labels = np.concatenate(labels, axis=0)
        return compute_full_metrics(labels, logits)

    train_metrics = eval_split(train_loader)
    val_metrics   = eval_split(val_loader)
    test_metrics  = eval_split(test_loader)

    with open(os.path.join(out_dir, "train_state.json"), "w") as f:
        json.dump(logs, f, indent=2)
    with open(os.path.join(out_dir, "results.json"), "w") as f:
        json.dump({"config": cfg, "train_metrics": train_metrics, "val_metrics": val_metrics, "test_metrics": test_metrics}, f, indent=2)
    torch.save(model.state_dict(), os.path.join(out_dir, "best_model.pth"))

    summary_keys = ["accuracy","auc","f1_weighted","f1_macro","precision_macro","recall_macro","matthews_corrcoef","cohen_kappa"]
    return {
        "config": cfg,
        "train": {k: float(train_metrics[k]) for k in summary_keys},
        "val":   {k: float(val_metrics[k]) for k in summary_keys},
        "test":  {k: float(test_metrics[k]) for k in summary_keys},
        "run_dir": out_dir
    }

SWEEP = [
    {"run_name": "A_baseline_lr1e-4_b32_e60_wd0.01_augY_freezeN",
     "lr": 1e-4, "weight_decay": 0.01, "batch_size": 32, "epochs": 60,
     "augment": True, "weighted_sampler": False, "freeze_backbone": False,
     "label_smoothing": 0.0, "lr_schedule": "cosine", "warmup_ratio": 0.0, "seed": 42},

    {"run_name": "S_lr3e-4_b64_e60_wd0.01_augY_freezeN_ls0.05",
     "lr": 3e-4, "weight_decay": 0.01, "batch_size": 64, "epochs": 60,
     "augment": True, "weighted_sampler": False, "freeze_backbone": False,
     "label_smoothing": 0.05, "lr_schedule": "cosine", "warmup_ratio": 0.05, "seed": 42},

    {"run_name": "S_lr2e-4_b32_e80_wd0.01_augY_freezeN_sampler",
     "lr": 2e-4, "weight_decay": 0.01, "batch_size": 32, "epochs": 80,
     "augment": True, "weighted_sampler": True, "freeze_backbone": False,
     "label_smoothing": 0.05, "lr_schedule": "cosine", "warmup_ratio": 0.1, "seed": 123},

    {"run_name": "S_freeze_lr5e-4_b64_e40_wd0.02_augY_freezeY",
     "lr": 5e-4, "weight_decay": 0.02, "batch_size": 64, "epochs": 40,
     "augment": True, "weighted_sampler": False, "freeze_backbone": True,
     "label_smoothing": 0.1, "lr_schedule": "cosine", "warmup_ratio": 0.1, "seed": 7},

    {"run_name": "S_noaug_lr1e-4_b32_e60_wd0.01_freezeN_accum4",
     "lr": 1e-4, "weight_decay": 0.01, "batch_size": 32, "epochs": 60,
     "augment": False, "weighted_sampler": False, "freeze_backbone": False,
     "label_smoothing": 0.0, "lr_schedule": "cosine", "warmup_ratio": 0.0, "grad_accum": 4, "seed": 42},
]

if __name__ == "__main__":
    all_summaries = []
    for i, cfg in enumerate(SWEEP, start=1):
        s = run_experiment(cfg)
        all_summaries.append(s)
    import csv
    table_path = os.path.join(RESULTS_ROOT, "all_runs_summary.csv")
    with open(table_path, "w", newline="") as f:
        writer = csv.writer(f)
        header = ["run", "run_dir",
                  "train_acc","train_auc","train_f1w","train_f1m","train_precm","train_recm","train_mcc","train_kappa",
                  "val_acc","val_auc","val_f1w","val_f1m","val_precm","val_recm","val_mcc","val_kappa",
                  "test_acc","test_auc","test_f1w","test_f1m","test_precm","test_recm","test_mcc","test_kappa"]
        writer.writerow(header)
        for idx, s in enumerate(all_summaries, start=1):
            row = [idx, s["run_dir"]]
            for split in ["train","val","test"]:
                m = s[split]
                row.extend([m["accuracy"], m["auc"], m["f1_weighted"], m["f1_macro"],
                            m["precision_macro"], m["recall_macro"], m["matthews_corrcoef"], m["cohen_kappa"]])
            writer.writerow(row)
    with open(os.path.join(RESULTS_ROOT, "sweep_summary.json"), "w") as f:
        json.dump(all_summaries, f, indent=2)
