In [None]:
"""
DiT Age Regression – Stratified Group K-Fold Cross-Validation (Colab-ready)
==========================================================================

This script performs **stratified cross validation** for DiT models, matching the
CNN CV style in your V4 notebook:

* **Splitter**: `StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=42)`
  * **Stratify by**: `AgeGroup`
  * **Group by**: `WriterNumber` (prevents writer leakage)
* **Folds**: 5
* **Schedule**: 2-stage training per fold —  (1)  frozen backbone, then (2) unfrozen
* **Inputs**: CSV at `DATA_ROOT/NewAgeSplit.csv` with columns: `File`, `Age`, `Set`, `AgeGroup`, `WriterNumber`
* **Images**: expected at `DATA_ROOT/<Set>/<File>` (e.g., `train/…`, `val/…`, `test/…`)

It trains all four requested models and saves each under its own
folder, with **per-fold subfolders**:

```
/content/drive/MyDrive/Expanded_HHD_AgeSplit/DiT/CV_STRAT_GROUP/
  microsoft__dit-base/
    fold_1/{stage1_best.pt, final_best.pt, metrics.json}
    …
    summary.json
  microsoft__dit-large/
    …
  microsoft__dit-base-finetuned-rvlcdip/
    …
  microsoft__dit-large-finetuned-rvlcdip/
    …
```
"""

import os, json, random
from pathlib import Path
from typing import Dict, List

# ── Colab Drive mount ───────────────────────────────────────────
try:
    import google.colab  # type: ignore
    from google.colab import drive  # type: ignore
    if not os.path.exists("/content/drive"):
        drive.mount("/content/drive")
except ImportError:
    pass

# ── Paths & hyper-parameters ────────────────────────────────────
DATA_ROOT = "/content/drive/MyDrive/Expanded_HHD_AgeSplit"
CSV_PATH  = f"{DATA_ROOT}/NewAgeSplit.csv"
OUTPUT_DIR = Path(f"{DATA_ROOT}/DiT/CV_STRAT_GROUP")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

EPOCHS_STAGE1 = 15   # frozen backbone
EPOCHS_STAGE2 = 30   # unfrozen backbone
BATCH_SIZE    = 64
LR_BASE       = 1e-4
SEED          = 42
N_SPLITS      = 5

# ── Imports after mount ─────────────────────────────────────────
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.model_selection import StratifiedGroupKFold
from transformers import AutoModel, BeitImageProcessor
from PIL import Image

# ── Reproducibility ─────────────────────────────────────────────
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)

# ── Augmentations & Dataset ─────────────────────────────────────
class AddGaussianNoise(nn.Module):
    def __init__(self, mean: float = 0.0, std: float = 0.05):
        super().__init__(); self.mean, self.std = mean, std
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.clamp(x + torch.randn_like(x) * self.std + self.mean, 0.0, 1.0)

class HHDataset(Dataset):
    """Images under DATA_ROOT/<Set>/<File>. CSV columns: File, Age, Set."""
    def __init__(self, df: pd.DataFrame, root: str, processor, augment: bool=False):
        self.df = df.reset_index(drop=True)
        self.root = Path(root)
        self.proc = processor
        self.pre = transforms.Compose([
            transforms.RandomRotation(15) if augment else transforms.Lambda(lambda x: x),
            transforms.RandomResizedCrop((processor.size["height"], processor.size["width"]), scale=(0.9, 1.1))
                if augment else transforms.Resize((processor.size["height"], processor.size["width"])),
            transforms.ColorJitter(brightness=0.1) if augment else transforms.Lambda(lambda x: x),
            transforms.ToTensor(),
            AddGaussianNoise() if augment else transforms.Lambda(lambda x: x),
        ])
    def __len__(self) -> int: return len(self.df)
    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        img_path = self.root / str(row["Set"]).lower() / row["File"]
        if not img_path.exists():
            raise FileNotFoundError(f"Image not found: {img_path}")
        img = Image.open(img_path).convert("RGB")
        img = self.pre(img)  # tensor CxHxW in [0,1]
        # Avoid double-rescale/resize; preprocessing already set size & [0,1]
        px = self.proc(images=img, return_tensors="pt", do_rescale=False, do_resize=False)["pixel_values"].squeeze(0)
        label = torch.tensor(row["Age"], dtype=torch.float32)
        return {"pixel_values": px, "labels": label}

# ── Model ───────────────────────────────────────────────────────
class DiTReg(nn.Module):
    def __init__(self, name: str = "microsoft/dit-base", p: float = 0.1):
        super().__init__()
        self.backbone = AutoModel.from_pretrained(name)
        self.head = nn.Sequential(nn.Dropout(p), nn.Linear(self.backbone.config.hidden_size, 1))
    def forward(self, pixel_values, labels=None):
        cls = self.backbone(pixel_values=pixel_values).last_hidden_state[:, 0]
        pred = self.head(cls).squeeze(1)
        return {"preds": pred} if labels is None else {"loss": nn.functional.l1_loss(pred, labels), "preds": pred}

# ── Metrics ─────────────────────────────────────────────────────
@torch.no_grad()
def compute_metrics(y: torch.Tensor, yhat: torch.Tensor) -> Dict[str, float]:
    y_np, yhat_np = y.detach().cpu().view(-1).numpy(), yhat.detach().cpu().view(-1).numpy()
    ae_t = (y - yhat).abs()
    try: mae  = float(mean_absolute_error(y_np, yhat_np))
    except Exception: mae = float(torch.mean(ae_t).item())
    try: rmse = float(mean_squared_error(y_np, yhat_np, squared=False))
    except TypeError:
        import numpy as _np
        rmse = float(_np.sqrt(mean_squared_error(y_np, yhat_np)))
    try: r2 = float(r2_score(y_np, yhat_np))
    except Exception: r2 = float('nan')
    mape = float((ae_t / y.clamp(min=1e-6)).mean().item() * 100.0)
    within2, within5, within10 = (float((ae_t <= t).float().mean().item() * 100.0) for t in (2,5,10))
    return {
        "MAE": mae, "RMSE": rmse, "R2": r2, "MAPE": mape,
        "Within2": within2, "Within5": within5, "Within10": within10,
        "MaxErr": float(ae_t.max().item()),
        "MedianErr": float(ae_t.median().item()),
        "MinErr": float(ae_t.min().item()),
    }

# ── Train / Eval helpers ────────────────────────────────────────

def train_one_epoch(model: nn.Module, loader: DataLoader, opt: torch.optim.Optimizer, device) -> float:
    model.train(); total = 0.0
    for b in loader:
        b = {k: v.to(device) for k, v in b.items()}
        opt.zero_grad(); out = model(**b)
        out["loss"].backward(); opt.step()
        total += out["loss"].item() * b["labels"].size(0)
    return total / len(loader.dataset)

@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader, device) -> Dict[str, float]:
    model.eval(); preds, gts = [], []
    for b in loader:
        preds.append(model(b["pixel_values"].to(device))["preds"].cpu()); gts.append(b["labels"].cpu())
    return compute_metrics(torch.cat(gts), torch.cat(preds))

# ── CV routine per model ────────────────────────────────────────

def cross_validate_model(model_id: str, n_splits: int = N_SPLITS) -> Dict[str, Dict[str, float]]:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Processor tied to backbone id
    proc = BeitImageProcessor.from_pretrained(model_id)

    # Read CSV once
    df_full = pd.read_csv(CSV_PATH)

    # Prepare SGKF splits using AgeGroup (stratify) and WriterNumber (group)
    if not set(["AgeGroup", "WriterNumber"]).issubset(df_full.columns):
        raise KeyError("CSV must contain 'AgeGroup' and 'WriterNumber' for stratified group CV.")

    sgkf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=SEED)
    splits = list(sgkf.split(df_full.index.values, df_full["AgeGroup"].values, df_full["WriterNumber"].values))

    safe_name = model_id.replace("/", "__")
    model_root = OUTPUT_DIR / safe_name
    model_root.mkdir(parents=True, exist_ok=True)

    fold_results: List[Dict[str, float]] = []

    for fold, (train_idx, val_idx) in enumerate(splits, start=1):
        print(f"\n===== {safe_name} | Fold {fold}/{n_splits} =====")

        # Slice folds
        train_df = df_full.iloc[train_idx].reset_index(drop=True)
        val_df   = df_full.iloc[val_idx].reset_index(drop=True)

        # Dataloaders
        loaders = {
            "train": DataLoader(HHDataset(train_df, DATA_ROOT, proc, augment=True),  batch_size=BATCH_SIZE,     shuffle=True,  num_workers=2, pin_memory=True),
            "val":   DataLoader(HHDataset(val_df,   DATA_ROOT, proc, augment=False), batch_size=BATCH_SIZE * 2, shuffle=False, num_workers=2, pin_memory=True),
        }

        # Output dir per fold
        fold_dir = model_root / f"fold_{fold}"
        fold_dir.mkdir(parents=True, exist_ok=True)

        # Build model
        model = DiTReg(name=model_id).to(device)

        # Stage 1 – freeze backbone
        for p in model.backbone.parameters(): p.requires_grad = False
        opt = torch.optim.AdamW(model.head.parameters(), lr=LR_BASE, weight_decay=1e-4)
        best_mae = float("inf")

        for ep in range(1, EPOCHS_STAGE1 + 1):
            train_one_epoch(model, loaders["train"], opt, device)
            m_val = evaluate(model, loaders["val"], device)
            print(f"[{safe_name} | Fold {fold} | Frozen {ep}/{EPOCHS_STAGE1}] MAE {m_val['MAE']:.3f}")
            if m_val["MAE"] < best_mae:
                best_mae = m_val["MAE"]
                torch.save(model.state_dict(), fold_dir / "stage1_best.pt")

        # Stage 2 – unfreeze backbone
        print(f"[{safe_name}] Unfreezing backbone …")
        model.load_state_dict(torch.load(fold_dir / "stage1_best.pt", map_location=device))
        for p in model.backbone.parameters(): p.requires_grad = True
        opt = torch.optim.AdamW(model.parameters(), lr=LR_BASE/10, weight_decay=1e-5)

        for ep in range(1, EPOCHS_STAGE2 + 1):
            train_one_epoch(model, loaders["train"], opt, device)
            m_val = evaluate(model, loaders["val"], device)
            print(f"[{safe_name} | Fold {fold} | FT {ep}/{EPOCHS_STAGE2}] MAE {m_val['MAE']:.3f}")
            if m_val["MAE"] < best_mae:
                best_mae = m_val["MAE"]
                torch.save(model.state_dict(), fold_dir / "final_best.pt")

        # Final fold metrics
        print(f"[{safe_name}] Evaluating best checkpoint on fold {fold} …")
        model.load_state_dict(torch.load(fold_dir / "final_best.pt", map_location=device))
        fold_metrics = evaluate(model, loaders["val"], device)
        (fold_dir / "metrics.json").write_text(json.dumps(fold_metrics, indent=2))
        fold_results.append(fold_metrics)

        # Cleanup
        del model, loaders
        if torch.cuda.is_available(): torch.cuda.empty_cache()

    # Aggregate CV metrics
    keys = fold_results[0].keys()
    summary = {
        k: {
            "mean": float(np.mean([fr[k] for fr in fold_results])),
            "std":  float(np.std([fr[k] for fr in fold_results]))
        }
        for k in keys
    }
    (model_root / "summary.json").write_text(json.dumps({"folds": fold_results, "summary": summary}, indent=2))

    print("\n══════ CV SUMMARY ══════")
    for k, v in summary.items():
        print(f"{k:<10}: {v['mean']:.3f} ± {v['std']:.3f}")

    return summary

# ── Main: run CV for all requested models ───────────────────────

def main():
    MODEL_IDS = [
        # "microsoft/dit-base",
        # "microsoft/dit-large",
        # "microsoft/dit-base-finetuned-rvlcdip",
        "microsoft/dit-large-finetuned-rvlcdip",
    ]
    for mid in MODEL_IDS:
        cross_validate_model(mid, n_splits=N_SPLITS)

if __name__ == "__main__":
    main()


Mounted at /content/drive


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

preprocessor_config.json:   0%|          | 0.00/302 [00:00<?, ?B/s]

  image_processor = cls(**image_processor_dict)



===== microsoft__dit-large-finetuned-rvlcdip | Fold 1/5 =====


config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.21G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.21G [00:00<?, ?B/s]

[microsoft__dit-large-finetuned-rvlcdip | Fold 1 | Frozen 1/15] MAE 38.115
[microsoft__dit-large-finetuned-rvlcdip | Fold 1 | Frozen 2/15] MAE 18.483
[microsoft__dit-large-finetuned-rvlcdip | Fold 1 | Frozen 3/15] MAE 9.611
[microsoft__dit-large-finetuned-rvlcdip | Fold 1 | Frozen 4/15] MAE 19.120
[microsoft__dit-large-finetuned-rvlcdip | Fold 1 | Frozen 5/15] MAE 13.740
[microsoft__dit-large-finetuned-rvlcdip | Fold 1 | Frozen 6/15] MAE 12.401
[microsoft__dit-large-finetuned-rvlcdip | Fold 1 | Frozen 7/15] MAE 13.892
[microsoft__dit-large-finetuned-rvlcdip | Fold 1 | Frozen 8/15] MAE 13.804
[microsoft__dit-large-finetuned-rvlcdip | Fold 1 | Frozen 9/15] MAE 13.559
[microsoft__dit-large-finetuned-rvlcdip | Fold 1 | Frozen 10/15] MAE 13.128
[microsoft__dit-large-finetuned-rvlcdip | Fold 1 | Frozen 11/15] MAE 12.615
[microsoft__dit-large-finetuned-rvlcdip | Fold 1 | Frozen 12/15] MAE 13.497
[microsoft__dit-large-finetuned-rvlcdip | Fold 1 | Frozen 13/15] MAE 13.214
[microsoft__dit-large-