### Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install gdown -q

In [None]:
!gdown 14vLVZyHhrZp0PIW_2VmcWrhG0Hu0mt2f -O eeg.zip
!unzip -q eeg.zip -d /content

Downloading...
From (original): https://drive.google.com/uc?id=14vLVZyHhrZp0PIW_2VmcWrhG0Hu0mt2f
From (redirected): https://drive.google.com/uc?id=14vLVZyHhrZp0PIW_2VmcWrhG0Hu0mt2f&confirm=t&uuid=82451550-3beb-4808-95d9-c203eebc9cc9
To: /content/eeg.zip
100% 1.24G/1.24G [00:21<00:00, 57.7MB/s]


In [None]:
!apt-get update -qq
!apt-get install tree -y -qq

W: Skipping acquire of configured file 'main/source/Sources' as repository 'https://r2u.stat.illinois.edu/ubuntu jammy InRelease' does not seem to provide it (sources.list entry misspelt?)
Selecting previously unselected package tree.
(Reading database ... 121852 files and directories currently installed.)
Preparing to unpack .../tree_2.0.2-1_amd64.deb ...
Unpacking tree (2.0.2-1) ...
Setting up tree (2.0.2-1) ...
Processing triggers for man-db (2.10.2-1) ...


In [None]:
# ============================================================================
# IMPORTS
# ============================================================================

from __future__ import annotations

import argparse
import json
import random
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms as T
from transformers import AutoImageProcessor, CvtForImageClassification

### Loading and Normalizing Data

In [None]:
!find /content/Data/eeg_2d_outputs -mindepth 1 -maxdepth 1 -type d | wc -l

3070


In [None]:
# ============================================================================
# REPRODUCIBILITY
# ============================================================================

def default_data_root() -> Path:
    candidates = [
        Path("/content/Data/Data/eeg_2d_outputs"),
        Path("/content/Data/eeg_2d_outputs"),
        Path("eeg_2d_outputs"),
    ]
    for p in candidates:
        if p.exists():
            return p
    return candidates[-1]


def default_output_dir() -> Path:
    colab_base = Path("/content/Data")
    if colab_base.exists():
        return colab_base
    return Path("cvt_run")


def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

In [None]:
# ============================================================================
# FRAME NORMALIZATION AND SHAPE HELPERS
# ============================================================================

def normalize_to_uint8(x: np.ndarray) -> np.ndarray:
    x = np.asarray(x, dtype=np.float32)
    x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)

    lo = float(np.percentile(x, 1))
    hi = float(np.percentile(x, 99))
    if hi <= lo:
        lo = float(x.min())
        hi = float(x.max())
    if hi <= lo:
        return np.zeros_like(x, dtype=np.uint8)

    x = (x - lo) / (hi - lo + 1e-8)
    x = np.clip(x, 0.0, 1.0)
    return (x * 255.0).astype(np.uint8)


def to_frame_stack(arr: np.ndarray) -> np.ndarray:
    arr = np.asarray(arr)
    arr = np.nan_to_num(arr.astype(np.float32), nan=0.0, posinf=0.0, neginf=0.0)

    if arr.ndim == 2:
        return arr[None, :, :]

    if arr.ndim == 3:
        a, b, c = arr.shape

        # Single image HWC
        if c in (1, 3, 4) and a > 8 and b > 8:
            return arr.mean(axis=2, keepdims=False)[None, :, :]

        # Single image CHW
        if a in (1, 3, 4) and b > 8 and c > 8:
            return arr.mean(axis=0, keepdims=False)[None, :, :]

        # Ambiguous 3D: treat smallest dimension as frame axis.
        frame_axis = int(np.argmin(arr.shape))
        if frame_axis == 0:
            return arr
        if frame_axis == 1:
            return np.moveaxis(arr, 1, 0)
        return np.moveaxis(arr, 2, 0)

    if arr.ndim == 4:
        # NHWC
        if arr.shape[-1] in (1, 3, 4):
            return arr.mean(axis=-1)
        # NCHW
        if arr.shape[1] in (1, 3, 4):
            return arr.mean(axis=1)
        # Fallback
        return to_frame_stack(arr.mean(axis=-1))

    raise ValueError(f"Unsupported array shape for frames: {arr.shape}")


def resize_2d(frame: np.ndarray, target_hw: Tuple[int, int]) -> np.ndarray:
    t = torch.from_numpy(frame).float()[None, None, :, :]
    out = F.interpolate(t, size=target_hw, mode="bilinear", align_corners=False)
    return out[0, 0].cpu().numpy()


def merge_freq_topo(freq_frame: np.ndarray, topo_frame: np.ndarray) -> np.ndarray:
    target_h, target_w = freq_frame.shape
    topo_resized = resize_2d(topo_frame, (target_h, target_w))
    return np.concatenate([freq_frame, topo_resized], axis=1)

In [None]:
# ============================================================================
# NPZ INSPECTION UTILITIES
# ============================================================================

def pick_key(keys: List[str], include_tokens: List[str]) -> str | None:
    for k in keys:
        lk = k.lower()
        if any(tok in lk for tok in include_tokens):
            return k
    return None


def inspect_npz(npz_path: Path) -> Dict[str, str | int]:
    with np.load(npz_path, allow_pickle=False) as z:
        keys = list(z.files)
        numeric_keys = [k for k in keys if np.issubdtype(z[k].dtype, np.number) and z[k].size > 0]
        if not numeric_keys:
            raise ValueError(f"No numeric arrays in {npz_path}")

        combined_key = pick_key(
            numeric_keys,
            include_tokens=["combined", "frame", "arr_0", "data", "stack"],
        )
        tf_key = pick_key(
            numeric_keys,
            include_tokens=["tf", "freq", "spect", "time_frequency"],
        )
        topo_key = pick_key(
            numeric_keys,
            include_tokens=["topo", "map"],
        )

        if combined_key is not None:
            n_frames = int(to_frame_stack(z[combined_key]).shape[0])
            return {"mode": "combined", "n_frames": n_frames, "combined_key": combined_key}

        if tf_key is not None and topo_key is not None and tf_key != topo_key:
            n_tf = int(to_frame_stack(z[tf_key]).shape[0])
            n_topo = int(to_frame_stack(z[topo_key]).shape[0])
            return {
                "mode": "split",
                "n_frames": min(n_tf, n_topo),
                "tf_key": tf_key,
                "topo_key": topo_key,
            }

        # Fallback: largest numeric array
        largest_key = max(numeric_keys, key=lambda k: z[k].size)
        n_frames = int(to_frame_stack(z[largest_key]).shape[0])
        return {"mode": "combined", "n_frames": n_frames, "combined_key": largest_key}

In [None]:
# ============================================================================
# DATA RECORD TYPES
# ============================================================================

@dataclass
class FrameRecord:
    npz_path: str
    signal_id: str
    label: int
    frame_idx: int
    mode: str
    combined_key: str | None = None
    tf_key: str | None = None
    topo_key: str | None = None

### Utility

In [None]:
# ============================================================================
# DATASET
# ============================================================================

class CombinedNPZDataset(Dataset):
    def __init__(self, records: List[FrameRecord], transform=None):
        self.records = records
        self.transform = transform

    def __len__(self) -> int:
        return len(self.records)

    def _load_frame(self, rec: FrameRecord) -> np.ndarray:
        with np.load(rec.npz_path, allow_pickle=False) as z:
            if rec.mode == "combined":
                if rec.combined_key is None:
                    raise ValueError(f"Missing combined key for {rec.npz_path}")
                stack = to_frame_stack(z[rec.combined_key])
                return stack[rec.frame_idx % len(stack)]

            if rec.mode == "split":
                if rec.tf_key is None or rec.topo_key is None:
                    raise ValueError(f"Missing tf/topo keys for {rec.npz_path}")
                tf_stack = to_frame_stack(z[rec.tf_key])
                topo_stack = to_frame_stack(z[rec.topo_key])
                tf_frame = tf_stack[rec.frame_idx % len(tf_stack)]
                topo_frame = topo_stack[rec.frame_idx % len(topo_stack)]
                return merge_freq_topo(tf_frame, topo_frame)

        raise ValueError(f"Unsupported mode: {rec.mode}")

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        rec = self.records[idx]
        frame = self._load_frame(rec)
        img = normalize_to_uint8(frame)
        img = np.repeat(img[:, :, None], 3, axis=2)
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        label = torch.tensor(rec.label, dtype=torch.long)
        return img, label

In [None]:
# ============================================================================
# DATA MATCHING AND SPLIT
# ============================================================================

def infer_signal_id(npz_path: Path, label_map: Dict[str, int]) -> str | None:
    # Typical case from eeg_j_01 pipeline: root / SIGNAL_ID / combined_*.npz
    parent_name = npz_path.parent.name
    if parent_name in label_map:
        return parent_name

    stem = npz_path.stem
    if stem in label_map:
        return stem

    # Last fallback: find any known signal id as substring.
    for sid in label_map:
        if sid in str(npz_path):
            return sid
    return None


def stratified_signal_split(
    signal_to_label: Dict[str, int],
    val_ratio: float,
    seed: int,
) -> Tuple[set[str], set[str]]:
    rng = random.Random(seed)
    by_label: Dict[int, List[str]] = {}
    for sid, y in signal_to_label.items():
        by_label.setdefault(y, []).append(sid)

    train_ids: set[str] = set()
    val_ids: set[str] = set()
    for y, sids in by_label.items():
        sids = list(sids)
        rng.shuffle(sids)
        n_val = max(1, int(len(sids) * val_ratio))
        val_part = set(sids[:n_val])
        train_part = set(sids[n_val:])
        if not train_part:
            train_part = set([next(iter(val_part))])
            val_part = val_part - train_part
        train_ids |= train_part
        val_ids |= val_part
    return train_ids, val_ids

In [None]:
# ============================================================================
# RECORD BUILDING
# ============================================================================

def build_records(
    data_root: Path,
    labels_csv: Path,
    max_frames_per_signal: int,
) -> Tuple[List[FrameRecord], Dict[str, int]]:
    if not labels_csv.exists():
        raise FileNotFoundError(f"labels.csv not found: {labels_csv}")

    labels_df = pd.read_csv(labels_csv)
    if "signal_id" not in labels_df.columns or "attended" not in labels_df.columns:
        raise ValueError("labels.csv must have columns: signal_id, attended")

    labels_df = labels_df.dropna(subset=["signal_id", "attended"]).copy()
    labels_df["attended"] = labels_df["attended"].astype(int)
    labels_df = labels_df[labels_df["attended"].isin([1, 2])]

    # 1/2 -> 0/1
    label_map: Dict[str, int] = dict(
        zip(labels_df["signal_id"].astype(str), labels_df["attended"].astype(int) - 1)
    )

    # Prefer combined-style npz names first.
    npz_files = sorted(data_root.rglob("*combined*.npz"))
    if not npz_files:
        npz_files = sorted(data_root.rglob("*.npz"))
    if not npz_files:
        raise FileNotFoundError(f"No .npz files found under: {data_root}")

    records: List[FrameRecord] = []
    kept_signals: Dict[str, int] = {}
    for npz_path in npz_files:
        signal_id = infer_signal_id(npz_path, label_map)
        if signal_id is None:
            continue

        meta = inspect_npz(npz_path)
        n_frames = int(meta["n_frames"])  # type: ignore[arg-type]
        if n_frames <= 0:
            continue

        if max_frames_per_signal > 0 and n_frames > max_frames_per_signal:
            frame_indices = np.linspace(
                0,
                n_frames - 1,
                num=max_frames_per_signal,
                dtype=int,
            ).tolist()
        else:
            frame_indices = list(range(n_frames))

        for i in frame_indices:
            records.append(
                FrameRecord(
                    npz_path=str(npz_path),
                    signal_id=signal_id,
                    label=label_map[signal_id],
                    frame_idx=int(i),
                    mode=str(meta["mode"]),
                    combined_key=meta.get("combined_key"),  # type: ignore[arg-type]
                    tf_key=meta.get("tf_key"),  # type: ignore[arg-type]
                    topo_key=meta.get("topo_key"),  # type: ignore[arg-type]
                )
            )
        kept_signals[signal_id] = label_map[signal_id]

    if not records:
        raise RuntimeError(
            "No training records built. Check that npz folders match signal_id values in labels.csv."
        )

    return records, kept_signals

### Config

In [None]:
# ============================================================================
# TRANSFORMS AND EPOCH LOOP
# ============================================================================

def build_transforms(image_size: int, mean: List[float], std: List[float]):
    train_t = T.Compose(
        [
            T.Resize((image_size, image_size)),
            T.ToTensor(),
            T.Normalize(mean=mean, std=std),
        ]
    )
    val_t = T.Compose(
        [
            T.Resize((image_size, image_size)),
            T.ToTensor(),
            T.Normalize(mean=mean, std=std),
        ]
    )
    return train_t, val_t


def run_epoch(
    model: nn.Module,
    loader: DataLoader,
    device: torch.device,
    optimizer: torch.optim.Optimizer | None,
) -> Dict[str, float]:
    train_mode = optimizer is not None
    model.train(train_mode)

    criterion = nn.CrossEntropyLoss()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    for images, labels in loader:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        if train_mode:
            optimizer.zero_grad(set_to_none=True)

        with torch.set_grad_enabled(train_mode):
            outputs = model(pixel_values=images)
            logits = outputs.logits
            loss = criterion(logits, labels)

            if train_mode:
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()

        total_loss += float(loss.item()) * labels.size(0)
        preds = torch.argmax(logits, dim=1)
        total_correct += int((preds == labels).sum().item())
        total_samples += int(labels.size(0))

    return {
        "loss": total_loss / max(total_samples, 1),
        "acc": total_correct / max(total_samples, 1),
    }


def collect_predictions(
    model: nn.Module,
    loader: DataLoader,
    device: torch.device,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    model.eval()
    y_true_parts: List[np.ndarray] = []
    y_pred_parts: List[np.ndarray] = []
    prob_parts: List[np.ndarray] = []

    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            logits = model(pixel_values=images).logits
            probs = torch.softmax(logits, dim=1)
            preds = torch.argmax(probs, dim=1)

            y_true_parts.append(labels.cpu().numpy())
            y_pred_parts.append(preds.cpu().numpy())
            prob_parts.append(probs.cpu().numpy())

    if not y_true_parts:
        return (
            np.empty((0,), dtype=np.int64),
            np.empty((0,), dtype=np.int64),
            np.empty((0, 0), dtype=np.float32),
        )

    return (
        np.concatenate(y_true_parts).astype(np.int64),
        np.concatenate(y_pred_parts).astype(np.int64),
        np.concatenate(prob_parts).astype(np.float32),
    )


def make_confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray, n_classes: int) -> np.ndarray:
    cm = np.zeros((n_classes, n_classes), dtype=np.int64)
    for t, p in zip(y_true.tolist(), y_pred.tolist()):
        if 0 <= t < n_classes and 0 <= p < n_classes:
            cm[t, p] += 1
    return cm


def classification_report_from_cm(cm: np.ndarray, class_names: List[str]) -> pd.DataFrame:
    rows: List[Dict[str, float | int | str]] = []
    total = int(cm.sum())

    precisions: List[float] = []
    recalls: List[float] = []
    f1s: List[float] = []
    supports: List[int] = []

    for i, class_name in enumerate(class_names):
        tp = int(cm[i, i])
        fp = int(cm[:, i].sum() - tp)
        fn = int(cm[i, :].sum() - tp)
        support = int(cm[i, :].sum())

        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1 = (2.0 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0

        rows.append(
            {
                "label": class_name,
                "precision": precision,
                "recall": recall,
                "f1": f1,
                "support": support,
            }
        )
        precisions.append(precision)
        recalls.append(recall)
        f1s.append(f1)
        supports.append(support)

    accuracy = float(np.trace(cm) / total) if total > 0 else 0.0
    macro_precision = float(np.mean(precisions)) if precisions else 0.0
    macro_recall = float(np.mean(recalls)) if recalls else 0.0
    macro_f1 = float(np.mean(f1s)) if f1s else 0.0

    rows.append(
        {
            "label": "accuracy",
            "precision": accuracy,
            "recall": accuracy,
            "f1": accuracy,
            "support": total,
        }
    )
    rows.append(
        {
            "label": "macro_avg",
            "precision": macro_precision,
            "recall": macro_recall,
            "f1": macro_f1,
            "support": total,
        }
    )
    rows.append(
        {
            "label": "weighted_avg",
            "precision": float(np.average(np.asarray(precisions), weights=np.asarray(supports)))
            if total > 0
            else 0.0,
            "recall": float(np.average(np.asarray(recalls), weights=np.asarray(supports)))
            if total > 0
            else 0.0,
            "f1": float(np.average(np.asarray(f1s), weights=np.asarray(supports)))
            if total > 0
            else 0.0,
            "support": total,
        }
    )
    return pd.DataFrame(rows)


def save_learning_curves(history: List[Dict[str, float]], output_dir: Path) -> None:
    if not history:
        return
    try:
        import matplotlib.pyplot as plt
    except Exception as exc:
        print(f"Skipping plots (matplotlib unavailable): {exc}")
        return

    hist = pd.DataFrame(history)
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))

    axes[0].plot(hist["epoch"], hist["train_loss"], label="train_loss", linewidth=2)
    axes[0].plot(hist["epoch"], hist["val_loss"], label="val_loss", linewidth=2)
    axes[0].set_title("Loss")
    axes[0].set_xlabel("Epoch")
    axes[0].set_ylabel("Loss")
    axes[0].grid(alpha=0.3)
    axes[0].legend()

    axes[1].plot(hist["epoch"], hist["train_acc"], label="train_acc", linewidth=2)
    axes[1].plot(hist["epoch"], hist["val_acc"], label="val_acc", linewidth=2)
    axes[1].set_title("Accuracy")
    axes[1].set_xlabel("Epoch")
    axes[1].set_ylabel("Accuracy")
    axes[1].set_ylim(0.0, 1.0)
    axes[1].grid(alpha=0.3)
    axes[1].legend()

    fig.tight_layout()
    fig.savefig(output_dir / "learning_curves.png", dpi=180, bbox_inches="tight")
    plt.close(fig)


def save_confusion_matrix_plot(cm: np.ndarray, class_names: List[str], output_path: Path) -> None:
    try:
        import matplotlib.pyplot as plt
    except Exception as exc:
        print(f"Skipping confusion matrix plot (matplotlib unavailable): {exc}")
        return

    fig, ax = plt.subplots(figsize=(5, 4))
    image = ax.imshow(cm, cmap="Blues")
    ax.figure.colorbar(image, ax=ax, fraction=0.046, pad=0.04)
    ax.set_xlabel("Predicted label")
    ax.set_ylabel("True label")
    ax.set_title("Validation Confusion Matrix")
    ax.set_xticks(np.arange(len(class_names)))
    ax.set_yticks(np.arange(len(class_names)))
    ax.set_xticklabels(class_names)
    ax.set_yticklabels(class_names)

    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            value = int(cm[i, j])
            text_color = "white" if value > (cm.max() * 0.5) else "black"
            ax.text(j, i, str(value), ha="center", va="center", color=text_color)

    fig.tight_layout()
    fig.savefig(output_path, dpi=180, bbox_inches="tight")
    plt.close(fig)


def save_eval_artifacts(
    output_dir: Path,
    val_records: List[FrameRecord],
    y_true: np.ndarray,
    y_pred: np.ndarray,
    y_prob: np.ndarray,
    class_names: List[str],
) -> None:
    if len(y_true) == 0:
        return

    n_classes = len(class_names)
    cm = make_confusion_matrix(y_true=y_true, y_pred=y_pred, n_classes=n_classes)
    cm_df = pd.DataFrame(
        cm,
        index=[f"true_{name}" for name in class_names],
        columns=[f"pred_{name}" for name in class_names],
    )
    cm_df.to_csv(output_dir / "val_confusion_matrix.csv")

    row_sums = cm.sum(axis=1, keepdims=True)
    cm_norm = np.divide(cm, row_sums, out=np.zeros_like(cm, dtype=np.float64), where=row_sums != 0)
    cm_norm_df = pd.DataFrame(
        cm_norm,
        index=[f"true_{name}" for name in class_names],
        columns=[f"pred_{name}" for name in class_names],
    )
    cm_norm_df.to_csv(output_dir / "val_confusion_matrix_normalized.csv")
    save_confusion_matrix_plot(
        cm=cm,
        class_names=class_names,
        output_path=output_dir / "val_confusion_matrix.png",
    )

    report_df = classification_report_from_cm(cm=cm, class_names=class_names)
    report_df.to_csv(output_dir / "val_classification_report.csv", index=False)
    with open(output_dir / "val_classification_report.json", "w", encoding="utf-8") as f:
        json.dump(report_df.to_dict(orient="records"), f, indent=2)

    pred_rows: List[Dict[str, str | int | float]] = []
    for i, (yt, yp) in enumerate(zip(y_true.tolist(), y_pred.tolist())):
        rec = val_records[i] if i < len(val_records) else None
        row: Dict[str, str | int | float] = {
            "true_label_id": int(yt),
            "pred_label_id": int(yp),
            "true_label_name": class_names[int(yt)] if 0 <= int(yt) < n_classes else "unknown",
            "pred_label_name": class_names[int(yp)] if 0 <= int(yp) < n_classes else "unknown",
            "correct": int(yt == yp),
        }
        if y_prob.ndim == 2 and y_prob.shape[0] > i:
            probs_i = y_prob[i]
            row["pred_confidence"] = float(probs_i[int(yp)]) if 0 <= int(yp) < probs_i.shape[0] else float(np.max(probs_i))
            for cls_idx in range(min(n_classes, probs_i.shape[0])):
                row[f"prob_{class_names[cls_idx]}"] = float(probs_i[cls_idx])
        if rec is not None:
            row["signal_id"] = rec.signal_id
            row["frame_idx"] = int(rec.frame_idx)
            row["npz_path"] = rec.npz_path
        pred_rows.append(row)
    pd.DataFrame(pred_rows).to_csv(output_dir / "val_predictions.csv", index=False)

In [None]:
# ============================================================================
# TRAINING CONFIG
# ============================================================================

@dataclass
class TrainConfig:
    data_root: Path = field(default_factory=default_data_root)
    labels_csv: Path | None = None
    output_dir: Path = field(default_factory=default_output_dir)
    model_name: str = "microsoft/cvt-13"  # smallest CvT family checkpoint
    epochs: int = 10
    batch_size: int = 16
    lr: float = 3e-5
    weight_decay: float = 1e-4
    val_ratio: float = 0.2
    max_frames_per_signal: int = 16
    num_workers: int = 2
    seed: int = 42

### Training

In [None]:
# ============================================================================
# TRAINING PIPELINE
# ============================================================================

def train(cfg: TrainConfig) -> None:
    set_seed(cfg.seed)

    data_root = cfg.data_root
    labels_csv = cfg.labels_csv if cfg.labels_csv is not None else (data_root / "labels.csv")
    output_dir = cfg.output_dir
    output_dir.mkdir(parents=True, exist_ok=True)

    # === NEW: All artifacts go into this clean subfolder ===
    training_dir = output_dir / "Training"
    training_dir.mkdir(parents=True, exist_ok=True)

    records, signal_to_label = build_records(
        data_root=data_root,
        labels_csv=labels_csv,
        max_frames_per_signal=cfg.max_frames_per_signal,
    )

    train_signals, val_signals = stratified_signal_split(
        signal_to_label=signal_to_label,
        val_ratio=cfg.val_ratio,
        seed=cfg.seed,
    )

    train_records = [r for r in records if r.signal_id in train_signals]
    val_records = [r for r in records if r.signal_id in val_signals]

    if not train_records or not val_records:
        raise RuntimeError(
            "Train/val split is empty. Try reducing --val-ratio or increasing data."
        )

    processor = AutoImageProcessor.from_pretrained(cfg.model_name)
    size_cfg = processor.size
    if "height" in size_cfg and "width" in size_cfg:
        image_size = int(size_cfg["height"])
    else:
        image_size = int(size_cfg.get("shortest_edge", 224))
    mean = list(processor.image_mean)
    std = list(processor.image_std)
    train_t, val_t = build_transforms(image_size=image_size, mean=mean, std=std)

    train_ds = CombinedNPZDataset(train_records, transform=train_t)
    val_ds = CombinedNPZDataset(val_records, transform=val_t)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    train_loader = DataLoader(
        train_ds,
        batch_size=cfg.batch_size,
        shuffle=True,
        num_workers=cfg.num_workers,
        pin_memory=use_cuda,
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=cfg.batch_size,
        shuffle=False,
        num_workers=cfg.num_workers,
        pin_memory=use_cuda,
    )

    model = CvtForImageClassification.from_pretrained(
        cfg.model_name,
        num_labels=2,
        id2label={0: "attended_1", 1: "attended_2"},
        label2id={"attended_1": 0, "attended_2": 1},
        ignore_mismatched_sizes=True,
    ).to(device)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=cfg.lr,
        weight_decay=cfg.weight_decay,
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=max(cfg.epochs, 1),
    )

    history: List[Dict[str, float]] = []
    best_val_acc = -1.0
    class_names = ["attended_1", "attended_2"]

    print(f"Train records: {len(train_records):,}")
    print(f"Val records  : {len(val_records):,}")
    print(f"Train signals: {len(train_signals)}")
    print(f"Val signals  : {len(val_signals)}")
    print(f"Device       : {device}")
    print(f"Model        : {cfg.model_name}")

    for epoch in range(1, cfg.epochs + 1):
        train_stats = run_epoch(model, train_loader, device, optimizer=optimizer)
        val_stats = run_epoch(model, val_loader, device, optimizer=None)
        scheduler.step()

        row = {
            "epoch": epoch,
            "train_loss": train_stats["loss"],
            "train_acc": train_stats["acc"],
            "val_loss": val_stats["loss"],
            "val_acc": val_stats["acc"],
            "lr": optimizer.param_groups[0]["lr"],
        }
        history.append(row)

        print(
            f"[{epoch:02d}/{cfg.epochs}] "
            f"train_loss={row['train_loss']:.4f} train_acc={row['train_acc']:.4f} "
            f"val_loss={row['val_loss']:.4f} val_acc={row['val_acc']:.4f} "
            f"lr={row['lr']:.2e}"
        )

        if row["val_acc"] > best_val_acc:
            best_val_acc = row["val_acc"]
            best_dir = training_dir / "best_model"          # ← changed
            best_dir.mkdir(parents=True, exist_ok=True)
            model.save_pretrained(best_dir)
            processor.save_pretrained(best_dir)
            torch.save(
                {
                    "epoch": epoch,
                    "best_val_acc": best_val_acc,
                    "state_dict": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                },
                training_dir / "best_state.pt"              # ← changed
            )

    pd.DataFrame(history).to_csv(training_dir / "history.csv", index=False)  # ← changed
    save_learning_curves(history=history, output_dir=training_dir)           # ← changed

    best_state_path = training_dir / "best_state.pt"                        # ← changed
    if best_state_path.exists():
        checkpoint = torch.load(best_state_path, map_location=device)
        model.load_state_dict(checkpoint["state_dict"])

    y_true, y_pred, y_prob = collect_predictions(model=model, loader=val_loader, device=device)
    if len(y_true) == len(val_records):
        save_eval_artifacts(
            output_dir=training_dir,                                        # ← changed
            val_records=val_records,
            y_true=y_true,
            y_pred=y_pred,
            y_prob=y_prob,
            class_names=class_names,
        )
    else:
        print(
            "Skipping val predictions artifact export due to size mismatch: "
            f"preds={len(y_true)} records={len(val_records)}"
        )
    with open(training_dir / "train_config.json", "w", encoding="utf-8") as f:  # ← changed
        json.dump(
            {
                "data_root": str(cfg.data_root),
                "labels_csv": str(labels_csv),
                "output_dir": str(output_dir),
                "training_dir": str(training_dir),          # ← extra useful line
                "model_name": cfg.model_name,
                "epochs": cfg.epochs,
                "batch_size": cfg.batch_size,
                "lr": cfg.lr,
                "weight_decay": cfg.weight_decay,
                "val_ratio": cfg.val_ratio,
                "max_frames_per_signal": cfg.max_frames_per_signal,
                "num_workers": cfg.num_workers,
                "seed": cfg.seed,
                "best_val_acc": best_val_acc,
                "saved_artifacts": [
                    "history.csv",
                    "learning_curves.png",
                    "best_state.pt",
                    "best_model/",
                    "val_predictions.csv",
                    "val_confusion_matrix.csv",
                    "val_confusion_matrix_normalized.csv",
                    "val_confusion_matrix.png",
                    "val_classification_report.csv",
                    "val_classification_report.json",
                ],
            },
            f,
            indent=2,
        )

    print(f"Training complete. Best val_acc={best_val_acc:.4f}")
    print(f"Saved outputs in: {training_dir}")   # ← now points to the training folder

In [None]:
# ============================================================================
# CLI
# ============================================================================

def parse_args(argv: List[str] | None = None) -> TrainConfig:
    p = argparse.ArgumentParser(
        description=(
            "Train smallest Microsoft CvT (microsoft/cvt-13) on combined EEG npz frames."
        )
    )
    p.add_argument("--data-root", type=Path, default=default_data_root())
    p.add_argument("--labels-csv", type=Path, default=None)
    p.add_argument("--output-dir", type=Path, default=default_output_dir())
    p.add_argument("--model-name", type=str, default="microsoft/cvt-13")
    p.add_argument("--epochs", type=int, default=10)
    p.add_argument("--batch-size", type=int, default=16)
    p.add_argument("--lr", type=float, default=3e-5)
    p.add_argument("--weight-decay", type=float, default=1e-4)
    p.add_argument("--val-ratio", type=float, default=0.2)
    p.add_argument("--max-frames-per-signal", type=int, default=16)
    p.add_argument("--num-workers", type=int, default=2)
    p.add_argument("--seed", type=int, default=42)
    a, unknown = p.parse_known_args(argv)
    if unknown:
        if "ipykernel" in sys.modules:
            print(f"Ignoring extra notebook args: {' '.join(unknown)}")
        else:
            p.error(f"unrecognized arguments: {' '.join(unknown)}")
    return TrainConfig(
        data_root=a.data_root,
        labels_csv=a.labels_csv,
        output_dir=a.output_dir,
        model_name=a.model_name,
        epochs=a.epochs,
        batch_size=a.batch_size,
        lr=a.lr,
        weight_decay=a.weight_decay,
        val_ratio=a.val_ratio,
        max_frames_per_signal=a.max_frames_per_signal,
        num_workers=a.num_workers,
        seed=a.seed,
    )


# ============================================================================
# MAIN
# ============================================================================

if __name__ == "__main__":
    train(parse_args())

Ignoring extra notebook args: -f /root/.local/share/jupyter/runtime/kernel-a359ae4c-18c4-49be-a266-7fb88f7a8306.json


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/266 [00:00<?, ?B/s]

The image processor of type `ConvNextImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. 


config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/80.2M [00:00<?, ?B/s]

Loading weights:   0%|          | 0/459 [00:00<?, ?it/s]

CvtForImageClassification LOAD REPORT from: microsoft/cvt-13
Key               | Status   |                                                                                        
------------------+----------+----------------------------------------------------------------------------------------
classifier.weight | MISMATCH | Reinit due to size mismatch ckpt: torch.Size([1000, 384]) vs model:torch.Size([2, 384])
classifier.bias   | MISMATCH | Reinit due to size mismatch ckpt: torch.Size([1000]) vs model:torch.Size([2])          

Notes:
- MISMATCH	:ckpt weights were loaded, but they did not match the original empty weight shapes.


Train records: 4,912
Val records  : 1,226
Train signals: 2456
Val signals  : 613
Device       : cuda
Model        : microsoft/cvt-13
[01/10] train_loss=0.7030 train_acc=0.5063 val_loss=0.6975 val_acc=0.5065 lr=2.93e-05


Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

[02/10] train_loss=0.7019 train_acc=0.5026 val_loss=0.6988 val_acc=0.4967 lr=2.71e-05
[03/10] train_loss=0.6936 train_acc=0.5311 val_loss=0.6880 val_acc=0.5481 lr=2.38e-05


Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

[04/10] train_loss=0.6897 train_acc=0.5326 val_loss=0.6867 val_acc=0.5277 lr=1.96e-05
[05/10] train_loss=0.6665 train_acc=0.5770 val_loss=0.6850 val_acc=0.5677 lr=1.50e-05


Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

[06/10] train_loss=0.6505 train_acc=0.5947 val_loss=0.6748 val_acc=0.5767 lr=1.04e-05


Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

[07/10] train_loss=0.6360 train_acc=0.6067 val_loss=0.6995 val_acc=0.5767 lr=6.18e-06
[08/10] train_loss=0.6133 train_acc=0.6356 val_loss=0.6890 val_acc=0.5979 lr=2.86e-06


Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

[09/10] train_loss=0.6041 train_acc=0.6466 val_loss=0.6904 val_acc=0.5987 lr=7.34e-07


Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

[10/10] train_loss=0.5918 train_acc=0.6610 val_loss=0.6829 val_acc=0.6085 lr=0.00e+00


Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Training complete. Best val_acc=0.6085
Saved outputs in: /content/Data/Training


### CLI

In [None]:
!ls /content/Data

best_model	     val_classification_report.csv
best_state.pt	     val_classification_report.json
eeg_2d_outputs	     val_confusion_matrix.csv
history.csv	     val_confusion_matrix_normalized.csv
learning_curves.png  val_confusion_matrix.png
train_config.json    val_predictions.csv


In [None]:
!ls "/content/drive/MyDrive/Colab Notebooks/EEG "

 Data   Data.zip  'EEG J 01.ipynb'  'EEG J 02.ipynb'   pyz


In [None]:
!cp -r /content/Data/training "/content/drive/MyDrive/Colab Notebooks/EEG "

In [None]:
# !cd /content/Data && mkdir -p training && mv history.csv learning_curves.png best_state.pt train_config.json val_* best_model/ training/ 2>/dev/null || true