In [None]:
# ===== ViT (3-band from 5) EEG Topomaps — subject-independent split =====


import os
import json
import time
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision.transforms import functional as TF
from torchvision.transforms.functional import InterpolationMode

from transformers import AutoConfig, ViTForImageClassification, TrainingArguments, Trainer, set_seed

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, roc_auc_score, precision_score, recall_score, f1_score,
    matthews_corrcoef, cohen_kappa_score, confusion_matrix, classification_report
)
from sklearn.utils.class_weight import compute_class_weight


DATA_DIR = "/content/gdrive/MyDrive/1segment_topomap_5channel_11classes"
RESULTS_ROOT = "./vit_topomap_results_sweep"
os.makedirs(RESULTS_ROOT, exist_ok=True)

BATCH_SIZE = 32
EPOCHS = 60
NUM_CLASSES = 4
IMAGE_SIZE_FOR_VIT = 224         # vit-base-patch16-224
RANDOM_SEED = 42
MODEL_NAME = "google/vit-base-patch16-224-in21k"  # native 3-channel ViT

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f">> Device: {device}")


# Use ONLY three bands: theta, alpha, beta (indices 1,2,3)
THREE_BAND_IDX = [1, 2, 3]  # 0=delta,1=theta,2=alpha,3=beta,4=gamma

EXCLUDED_SUBJECTS = {"S038", "S088", "S089", "S092", "S100", "S104"}

def get_subject_ids():
    ids = set()
    for fname in os.listdir(DATA_DIR):
        if fname.endswith(".npz") and fname.startswith("sample"):
            # e.g., sample_S001R01_chunk00.npz -> 'S001'
            sid = fname.split('_')[1][:4]
            if sid not in EXCLUDED_SUBJECTS:
                ids.add(sid)
    return sorted(ids)

subject_ids = get_subject_ids()
random.seed(RANDOM_SEED)
random.shuffle(subject_ids)

def _slice(lst, a, b):
    a = min(a, len(lst)); b = min(b, len(lst))
    return lst[a:b]

# Same slicing pattern you used before
train_ids = _slice(subject_ids, 0, 70)
dev_ids   = _slice(subject_ids, 70, 85)
test_ids  = _slice(subject_ids, 85, 101)

print(f">> Subjects total={len(subject_ids)} | train={len(train_ids)} dev={len(dev_ids)} test={len(test_ids)}")
print("   train_ids[:5]:", train_ids[:5])

# Label mapping (kept as-is)
LABEL_MAP = {2: 0, 3: 1, 6: 2, 7: 1}
VALID_CLASSES = set(LABEL_MAP.keys())

def load_chunks_by_subject_ids(subject_ids):
    """
    Load chunk files for the provided subjects ONLY.
    - Select bands [theta, alpha, beta] -> (3,H,W)
    - Per-chunk normalize AFTER selection
    - Keep labels in VALID_CLASSES and map using LABEL_MAP
    Returns: X (N,3,H,W) float32, y (N,) int
    """
    X_all, y_all = [], []
    files = [f for f in os.listdir(DATA_DIR) if f.endswith(".npz") and f.startswith("sample")]

    for sid in subject_ids:
        sid_token = f"_{sid}R"  # matches ..._S001R01_...
        sid_files = [f for f in files if sid_token in f]
        if not sid_files:
            continue
        print(f"   {sid}: {len(sid_files)} files (e.g., {sid_files[:2]})")

        for fname in sid_files:
            path = os.path.join(DATA_DIR, fname)
            data = np.load(path)
            X, y = data["X"], data["y"]

            if isinstance(y, np.ndarray):
                y = y.item()

            if y not in VALID_CLASSES:
                continue

            # Need at least up to band index 3
            if X.ndim != 3 or X.shape[0] <= max(THREE_BAND_IDX):
                continue

            # Select 3 bands, then normalize per-chunk
            X = X[THREE_BAND_IDX, :, :]                 # (3,H,W)
            X = (X - X.mean()) / (X.std() + 1e-8)

            X_all.append(X.astype(np.float32))
            y_all.append(LABEL_MAP[y])

    if not X_all:
        raise RuntimeError("No samples after subject-based loading. Check DATA_DIR and subject IDs.")

    X_all = np.stack(X_all)             # (N, 3, H, W)
    y_all = np.array(y_all, dtype=int)  # (N,)
    return X_all, y_all

X_train, y_train = load_chunks_by_subject_ids(train_ids)
X_val,   y_val   = load_chunks_by_subject_ids(dev_ids)
X_test,  y_test  = load_chunks_by_subject_ids(test_ids)

print("  Data Shapes (3 bands):")
print("   Train:", X_train.shape, y_train.shape)
print("   Val  :", X_val.shape,   y_val.shape)
print("   Test :", X_test.shape,  y_test.shape)
print("   Class counts train:", {int(c): int((y_train==c).sum()) for c in np.unique(y_train)})


def resize_224(x: torch.Tensor) -> torch.Tensor:
    # x: (C,H,W) float tensor
    x = x.unsqueeze(0)  # (1,C,H,W)
    x = F.interpolate(
        x, size=(IMAGE_SIZE_FOR_VIT, IMAGE_SIZE_FOR_VIT),
        mode="bilinear", align_corners=False
    )
    return x.squeeze(0)  # (C,224,224)

class TensorAug:
    """Tensor-native augmentation that works with C=3 (no PIL)."""
    def __init__(self, do_aug=True, max_rotate=15.0, hflip_p=0.5):
        self.do_aug = do_aug
        self.max_rotate = max_rotate
        self.hflip_p = hflip_p

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        x = resize_224(x)
        if not self.do_aug:
            return x

        if torch.rand(1).item() < self.hflip_p:
            x = torch.flip(x, dims=[2])  # flip width (C,H,W) -> W axis

        angle = (torch.rand(1).item() * 2 * self.max_rotate) - self.max_rotate
        x = TF.rotate(x, angle, interpolation=InterpolationMode.BILINEAR, expand=False)
        return x

train_transform = TensorAug(do_aug=True)
eval_transform  = TensorAug(do_aug=False)


class EEGDataset(Dataset):
    def __init__(self, X, y, transform=None):
        self.X = X
        self.y = y
        self.transform = transform

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x = torch.from_numpy(self.X[idx])  # (C,H,W) float32
        y = torch.tensor(self.y[idx]).long()
        if self.transform is not None:
            x = self.transform(x)
        # print on first few calls only
        if idx < 1:
            print(f"[EEGDataset] idx={idx} x.shape={tuple(x.shape)} y={int(y)}")
        return {"pixel_values": x, "labels": y}

train_ds = EEGDataset(X_train, y_train, transform=train_transform)
val_ds   = EEGDataset(X_val,   y_val,   transform=eval_transform)
test_ds  = EEGDataset(X_test,  y_test,  transform=eval_transform)


class_weights_np = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
CLASS_WEIGHTS_TENSOR = torch.tensor(class_weights_np, dtype=torch.float32)
print(">> Class weights (train):", CLASS_WEIGHTS_TENSOR.tolist())

def build_sampler(y_array):
    cw = compute_class_weight('balanced', classes=np.unique(y_array), y=y_array)
    sample_weights = torch.tensor(cw, dtype=torch.float32)[torch.tensor(y_array, dtype=torch.long)]
    sampler = WeightedRandomSampler(weights=sample_weights,
                                    num_samples=len(sample_weights),
                                    replacement=True)
    return sampler


def compute_full_metrics(p):
    preds = np.argmax(p.predictions, axis=1)
    probs = torch.nn.functional.softmax(torch.tensor(p.predictions), dim=-1).numpy()

    y_true = p.label_ids
    if y_true.ndim > 1:
        y_true = np.argmax(y_true, axis=1)

    try:
        if probs.shape[1] == 2:
            auc = roc_auc_score(y_true, probs[:, 1])
        else:
            auc = roc_auc_score(y_true, probs, multi_class='ovr', average='weighted')
    except Exception:
        auc = float('nan')

    return {
        "accuracy": accuracy_score(y_true, preds),
        "auc": auc,
        "precision_weighted": precision_score(y_true, preds, average="weighted", zero_division=0),
        "recall_weighted": recall_score(y_true, preds, average="weighted", zero_division=0),
        "f1_weighted": f1_score(y_true, preds, average="weighted", zero_division=0),
        "precision_macro": precision_score(y_true, preds, average="macro", zero_division=0),
        "recall_macro": recall_score(y_true, preds, average="macro", zero_division=0),
        "f1_macro": f1_score(y_true, preds, average="macro", zero_division=0),
        "matthews_corrcoef": matthews_corrcoef(y_true, preds),
        "cohen_kappa": cohen_kappa_score(y_true, preds),
    }

class WeightedTrainer(Trainer):
    def __init__(self, *args, class_weights=None, use_weighted_sampler=False, train_labels=None, train_batch_size=32, **kwargs):
        super().__init__(*args, **kwargs)
        self.class_weights = class_weights
        self.use_weighted_sampler = use_weighted_sampler
        self.train_labels = train_labels
        self.train_batch_size = train_batch_size

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.get("labels")
        pixel_values = inputs.get("pixel_values")
        if pixel_values.ndim == 3:
            # (C,H,W) -> (B=1,C,H,W) for safety
            pixel_values = pixel_values.unsqueeze(0)
        # debug occasionally
        if torch.rand(1).item() < 0.0005:
            print(f"[compute_loss] pixel_values batch shape: {tuple(pixel_values.shape)} (expect B,3,224,224)")
        outputs = model(pixel_values=pixel_values, labels=labels)
        logits = outputs.get("logits")
        if self.class_weights is not None:
            loss_fct = torch.nn.CrossEntropyLoss(weight=self.class_weights.to(model.device))
        else:
            loss_fct = torch.nn.CrossEntropyLoss()
        loss = loss_fct(logits, labels)
        return (loss, outputs) if return_outputs else loss

    def get_train_dataloader(self):
        if not self.use_weighted_sampler:
            return super().get_train_dataloader()

        sampler = build_sampler(self.train_labels)
        return DataLoader(
            self.train_dataset,
            batch_size=self.train_batch_size,
            sampler=sampler,
            collate_fn=self.data_collator,
            drop_last=False,
            num_workers=2,
            pin_memory=True
        )

def load_vit_3ch(num_labels):
    """
    Use the standard ImageNet-21k ViT (expects 3-channel inputs).
    No patch-embed surgery needed since we now feed 3 bands.
    """
    model = ViTForImageClassification.from_pretrained(
        MODEL_NAME,
        num_labels=num_labels
    )
    print(f" model.config.num_channels: {model.config.num_channels} (expect 3)")
    print(f"Noura hidden_size              : {model.config.hidden_size}")
    return model

# ==============================
# ---- Single run executor  ----
# ==============================
def run_experiment(run_cfg):
    seed = run_cfg.get("seed", RANDOM_SEED)
    set_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    run_name = run_cfg.get("run_name", f"run_{int(time.time())}")
    out_dir  = os.path.join(RESULTS_ROOT, run_name)
    os.makedirs(out_dir, exist_ok=True)

    print(f"\n========== RUN: {run_name} ==========")
    print("Config:", json.dumps(run_cfg, indent=2))

    # datasets with chosen transforms
    tr_transform = train_transform if run_cfg.get("augment", True) else TensorAug(do_aug=False)
    train_ds = EEGDataset(X_train, y_train, transform=tr_transform)
    val_ds   = EEGDataset(X_val,   y_val,   transform=eval_transform)
    test_ds  = EEGDataset(X_test,  y_test,  transform=eval_transform)

    # -------- Model: ViT (3-channel) ----------
    model = load_vit_3ch(num_labels=NUM_CLASSES).to(device)


    with torch.no_grad():
        tmp = train_ds[0]["pixel_values"].unsqueeze(0).to(device)
        print(f" Dry-run input shape: {tuple(tmp.shape)} (expect (1,3,224,224))")
        out = model(pixel_values=tmp)
        print(f" Dry-run logits shape: {tuple(out.logits.shape)} (expect (1,{NUM_CLASSES}))")

    if run_cfg.get("freeze_backbone", False):
        frozen, trainable = 0, 0
        for name, p in model.named_parameters():
            if name.startswith("vit.encoder") or name.startswith("vit.embeddings"):
                p.requires_grad = False
                frozen += p.numel()
            else:
                trainable += p.numel()
        print(f">> Frozen params: {frozen:,} | Trainable params: {trainable:,}")

    label_smoothing = float(run_cfg.get("label_smoothing", 0.0))

    training_args = TrainingArguments(
        output_dir=out_dir,
        per_device_train_batch_size=run_cfg.get("batch_size", BATCH_SIZE),
        per_device_eval_batch_size=run_cfg.get("batch_size", BATCH_SIZE),
        num_train_epochs=run_cfg.get("epochs", EPOCHS),
        eval_strategy="epoch",
        save_strategy="epoch",
        learning_rate=run_cfg.get("lr", 1e-5),
        weight_decay=run_cfg.get("weight_decay", 0.01),
        load_best_model_at_end=True,
        logging_steps=10,
        report_to="none",
        metric_for_best_model="eval_f1_weighted",
        greater_is_better=True,
        seed=seed,
        label_smoothing_factor=label_smoothing,
        gradient_accumulation_steps=run_cfg.get("grad_accum", 1),
        lr_scheduler_type=run_cfg.get("lr_schedule", "cosine"),
        warmup_ratio=run_cfg.get("warmup_ratio", 0.0),
        dataloader_num_workers=2,
        save_total_limit=2
    )

    trainer = WeightedTrainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=val_ds,
        compute_metrics=compute_full_metrics,
        class_weights=CLASS_WEIGHTS_TENSOR,
        use_weighted_sampler=run_cfg.get("weighted_sampler", False),
        train_labels=y_train,
        train_batch_size=run_cfg.get("batch_size", BATCH_SIZE),
    )

    print(">> Starting training ...")
    trainer.train()
    print(">> Training done.")

    def eval_split(ds, split_name):
        print(f">> Evaluating: {split_name}")
        metrics = trainer.evaluate(eval_dataset=ds)
        clean = {k.replace("eval_", ""): float(v) for k, v in metrics.items() if k.startswith("eval_")}
        preds_out = trainer.predict(ds)
        preds = np.argmax(preds_out.predictions, axis=1)
        y_true = preds_out.label_ids
        cm = confusion_matrix(y_true, preds).tolist()
        cr = classification_report(y_true, preds, zero_division=0)
        clean["confusion_matrix"] = cm
        clean["classification_report"] = cr
        print(f"   {split_name} acc={clean.get('accuracy'):.4f} f1_w={clean.get('f1_weighted'):.4f}")
        return clean

    results = {
        "config": run_cfg,
        "train_metrics": eval_split(train_ds, "train"),
        "val_metrics":   eval_split(val_ds, "val"),
        "test_metrics":  eval_split(test_ds, "test"),
    }

    # save json
    with open(os.path.join(out_dir, "results.json"), "w") as f:
        json.dump(results, f, indent=2)

    print(f"Finished {run_name}. Results saved to {out_dir}")
    return results


SWEEP = [
      # {
      #   "run_name": "S1_reg_ema_es",
      #   "lr": 2e-5, "weight_decay": 0.02, "batch_size": 32, "epochs": 100,
      #   "augment": True, "weighted_sampler": True, "freeze_backbone": False,
      #   "label_smoothing": 0.10, "lr_schedule": "cosine_with_restarts",
      #   "num_cycles": 5, "warmup_ratio": 0.15, "ema_decay": 0.999,
      #   "early_stopping": {"metric": "macro_f1", "patience": 10}
      # },
      # {
      #   "run_name": "S2_llrd",
      #   "lr": 2e-5, "head_lr": 5e-5, "layer_decay": 0.75,
      #   "True": 0.02, "batch_size": 32, "epochs": 80,
      #   "augment": True, "weighted_sampler": True, "freeze_backbone": False,
      #   "label_smoothing": 0.05, "lr_schedule": "cosine", "warmup_ratio": 0.10
      # },
      {"run_name": "A_baseline_lr1e-5_b32_e60_wd0.01_augY_freezeN",
      "lr": 1e-5, "weight_decay": 0.01, "batch_size": 32, "epochs": 60,
      "augment": True, "weighted_sampler": False, "freeze_backbone": False,
      "label_smoothing": 0.0, "lr_schedule": "cosine", "warmup_ratio": 0.0, "seed": 42},

      {"run_name": "S_lr3e-5_b64_e60_wd0.01_augY_freezeN_ls0.05",
      "lr": 3e-5, "weight_decay": 0.01, "batch_size": 64, "epochs": 60,
      "augment": True, "weighted_sampler": False, "freeze_backbone": False,
      "label_smoothing": 0.05, "lr_schedule": "cosine", "warmup_ratio": 0.05, "seed": 42},

      {"run_name": "S_lr2e-5_b32_e80_wd0.01_augY_freezeN_sampler",
      "lr": 2e-5, "weight_decay": 0.01, "batch_size": 32, "epochs": 80,
      "augment": True, "weighted_sampler": True, "freeze_backbone": False,
      "label_smoothing": 0.05, "lr_schedule": "cosine", "warmup_ratio": 0.1, "seed": 123},

      {"run_name": "S_freeze_lr5e-5_b64_e40_wd0.02_augY_freezeY",
      "lr": 5e-5, "weight_decay": 0.02, "batch_size": 64, "epochs": 40,
      "augment": True, "weighted_sampler": False, "freeze_backbone": True,
      "label_smoothing": 0.1, "lr_schedule": "cosine", "warmup_ratio": 0.1, "seed": 7},

      {"run_name": "S_noaug_lr1e-5_b32_e60_wd0.01_freezeN_accum4",
      "lr": 1e-5, "weight_decay": 0.01, "batch_size": 32, "epochs": 60,
      "augment": False, "weighted_sampler": False, "freeze_backbone": False,
      "label_smoothing": 0.0, "lr_schedule": "cosine", "warmup_ratio": 0.0, "grad_accum": 4, "seed": 42},
]


if __name__ == "__main__":
    ALL_RESULTS = []
    for cfg in SWEEP:
        res = run_experiment(cfg)
        ALL_RESULTS.append(res)

    summary = []
    for r in ALL_RESULTS:
        cfg = r["config"]
        test = r["test_metrics"]
        summary.append({
            "run_name": cfg["run_name"],
            "lr": cfg["lr"],
            "batch_size": cfg["batch_size"],
            "epochs": cfg["epochs"],
            "weight_decay": cfg["weight_decay"],
            "augment": cfg["augment"],
            "freeze_backbone": cfg["freeze_backbone"],
            "weighted_sampler": cfg.get("weighted_sampler", False),
            "label_smoothing": cfg.get("label_smoothing", 0.0),
            "seed": cfg.get("seed", RANDOM_SEED),
            "test_acc": test.get("accuracy"),
            "test_auc": test.get("auc"),
            "test_f1_w": test.get("f1_weighted"),
            "test_f1_m": test.get("f1_macro"),
            "test_prec_m": test.get("precision_macro"),
            "test_rec_m": test.get("recall_macro"),
        })

    with open(os.path.join(RESULTS_ROOT, "sweep_summary.json"), "w") as f:
        json.dump(summary, f, indent=2)

    print("\n Sweep complete. Summary saved to", os.path.join(RESULTS_ROOT, "sweep_summary.json"))
