In [None]:
"""
DiT Age Regression Training – Expanded HHD Dataset (Colab-ready)

Expects:
- DATA_ROOT/<Set>/<File> where Set ∈ {train, val, test}
- CSV at DATA_ROOT/NewAgeSplit.csv with columns: File, Age, Set

Trains these models and saves each under its own subfolder in DATA_ROOT/DiT/:
- microsoft/dit-base
- microsoft/dit-large
- microsoft/dit-base-finetuned-rvlcdip
- microsoft/dit-large-finetuned-rvlcdip

Per-model outputs:
- <MODEL_DIR>/stage1_best.pt
- <MODEL_DIR>/final_best.pt
- <MODEL_DIR>/metrics.json
"""

import os
import json
import random
from pathlib import Path
from typing import Dict

# Mount Google Drive if running in Colab
try:
    import google.colab  # type: ignore
    from google.colab import drive  # type: ignore
    if not os.path.exists("/content/drive"):
        drive.mount("/content/drive")
except ImportError:
    pass  # Not in Colab

# Paths & hyper-parameters
DATA_ROOT = "/content/drive/MyDrive/Expanded_HHD_AgeSplit"
CSV_PATH  = f"{DATA_ROOT}/NewAgeSplit.csv"
OUTPUT_DIR = Path(f"{DATA_ROOT}/DiT")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

EPOCHS_STAGE1 = 15   # frozen backbone
EPOCHS_STAGE2 = 30   # unfrozen backbone
BATCH_SIZE    = 64
LR_BASE       = 1e-4
SEED          = 42

# Imports after (possible) Drive mount
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from transformers import AutoModel, BeitImageProcessor
from PIL import Image

# Reproducibility
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# Augmentation helpers
class AddGaussianNoise(nn.Module):
    def __init__(self, mean: float = 0.0, std: float = 0.05):
        super().__init__()
        self.mean = mean
        self.std = std

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.clamp(x + torch.randn_like(x) * self.std + self.mean, 0.0, 1.0)

class HHDataset(Dataset):
    """Images stored under DATA_ROOT/<set>/<filename>. CSV columns: File, Age, Set."""
    def __init__(self, df: pd.DataFrame, root: str, processor, augment: bool = False):
        self.df = df.reset_index(drop=True)
        self.root = Path(root)
        self.proc = processor
        self.aug = augment
        self.pre = transforms.Compose([
            transforms.RandomRotation(15) if augment else transforms.Lambda(lambda x: x),
            transforms.RandomResizedCrop(
                (processor.size["height"], processor.size["width"]),
                scale=(0.9, 1.1)
            ) if augment else transforms.Resize(
                (processor.size["height"], processor.size["width"])
            ),
            transforms.ColorJitter(brightness=0.1) if augment else transforms.Lambda(lambda x: x),
            transforms.ToTensor(),                           # -> tensor in [0,1]
            AddGaussianNoise() if augment else transforms.Lambda(lambda x: x),
        ])

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        img_path = self.root / row["Set"].lower() / row["File"]
        if not img_path.exists():
            raise FileNotFoundError(f"Image not found: {img_path}")

        # Load TIFFs via PIL, convert to RGB
        image = Image.open(img_path).convert("RGB")
        image = self.pre(image)  # tensor CxHxW in [0,1]

        # Avoid double rescaling/resizing inside the processor
        pixel_values = self.proc(
            images=image, return_tensors="pt",
            do_rescale=False, do_resize=False
        )["pixel_values"].squeeze(0)

        label = torch.tensor(row["Age"], dtype=torch.float32)
        return {"pixel_values": pixel_values, "labels": label}

# Model
class DiTReg(nn.Module):
    def __init__(self, name: str = "microsoft/dit-base", p: float = 0.1):
        super().__init__()
        self.backbone = AutoModel.from_pretrained(name)
        self.head = nn.Sequential(nn.Dropout(p), nn.Linear(self.backbone.config.hidden_size, 1))

    def forward(self, pixel_values, labels=None):
        cls = self.backbone(pixel_values=pixel_values).last_hidden_state[:, 0]
        pred = self.head(cls).squeeze(1)
        out = {"preds": pred}
        if labels is not None:
            out["loss"] = nn.functional.l1_loss(pred, labels)  # MAE loss
        return out

# Metrics
@torch.no_grad()
def compute_metrics(y: torch.Tensor, yhat: torch.Tensor) -> Dict[str, float]:
    # Convert to 1D numpy arrays for sklearn compatibility across versions
    y_np = y.detach().cpu().view(-1).numpy()
    yhat_np = yhat.detach().cpu().view(-1).numpy()

    # Core errors
    ae_t = (y - yhat).abs()
    try:
        mae = float(mean_absolute_error(y_np, yhat_np))
    except Exception:
        mae = float(torch.mean(ae_t).item())

    # RMSE: some sklearn versions don't accept 'squared' kwarg
    try:
        rmse = float(mean_squared_error(y_np, yhat_np, squared=False))
    except TypeError:
        import numpy as _np
        rmse = float(_np.sqrt(mean_squared_error(y_np, yhat_np)))

    # R^2 can fail for constant vectors; guard it
    try:
        r2 = float(r2_score(y_np, yhat_np))
    except Exception:
        r2 = float('nan')

    # MAPE guard against zeros
    mape = float((ae_t / y.clamp(min=1e-6)).mean().item() * 100.0)

    within2 = float((ae_t <= 2).float().mean().item() * 100.0)
    within5 = float((ae_t <= 5).float().mean().item() * 100.0)
    within10 = float((ae_t <= 10).float().mean().item() * 100.0)
    max_err = float(ae_t.max().item())
    med_err = float(ae_t.median().item())
    min_err = float(ae_t.min().item())

    return {
        "MAE": mae,
        "RMSE": rmse,
        "R2": r2,
        "MAPE": mape,
        "Within2": within2,
        "Within5": within5,
        "Within10": within10,
        "MaxErr": max_err,
        "MedianErr": med_err,
        "MinErr": min_err,
    }

# Train / eval helpers

def train_one_epoch(model: nn.Module, loader: DataLoader, opt: torch.optim.Optimizer, device) -> float:
    model.train()
    total = 0.0
    for b in loader:
        b = {k: v.to(device) for k, v in b.items()}
        opt.zero_grad()
        out = model(**b)
        out["loss"].backward()
        opt.step()
        total += out["loss"].item() * b["labels"].size(0)
    return total / len(loader.dataset)

@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader, device) -> Dict[str, float]:
    model.eval()
    preds, gts = [], []
    for b in loader:
        preds.append(model(b["pixel_values"].to(device))["preds"].cpu())
        gts.append(b["labels"].cpu())
    return compute_metrics(torch.cat(gts), torch.cat(preds))

# Per-model training routine

def train_model(model_id: str):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Image processor specific to the model (ensures correct size & normalization)
    proc = BeitImageProcessor.from_pretrained(model_id)

    # Data
    df = pd.read_csv(CSV_PATH)
    print(f"\n===== Building dataloaders for {model_id} … =====")
    loaders = {
        split: DataLoader(
            HHDataset(df[df["Set"].str.lower() == split], DATA_ROOT, proc, augment=(split == "train")),
            batch_size=(BATCH_SIZE if split == "train" else BATCH_SIZE * 2),
            shuffle=(split == "train"),
            num_workers=2,
            pin_memory=True,
        )
        for split in ("train", "val", "test")
    }
    print({k: len(v.dataset) for k, v in loaders.items()})

    # Model & per-model output dir
    model = DiTReg(name=model_id).to(device)
    safe_name = model_id.replace("/", "__")
    model_dir = OUTPUT_DIR / safe_name
    model_dir.mkdir(parents=True, exist_ok=True)

    # Stage 1 – freeze backbone
    for p in model.backbone.parameters():
        p.requires_grad = False
    optimizer = torch.optim.AdamW(model.head.parameters(), lr=LR_BASE, weight_decay=1e-4)
    best_mae = float("inf")

    for ep in range(1, EPOCHS_STAGE1 + 1):
        train_one_epoch(model, loaders["train"], optimizer, device)
        m_val = evaluate(model, loaders["val"], device)
        print(f"[{safe_name} | Frozen {ep}/{EPOCHS_STAGE1}] MAE {m_val['MAE']:.3f}")
        if m_val["MAE"] < best_mae:
            best_mae = m_val["MAE"]
            torch.save(model.state_dict(), model_dir / "stage1_best.pt")

    # Stage 2 – unfreeze backbone
    print(f"\n[{safe_name}] Unfreezing backbone …")
    model.load_state_dict(torch.load(model_dir / "stage1_best.pt", map_location=device))
    for p in model.backbone.parameters():
        p.requires_grad = True
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR_BASE / 10, weight_decay=1e-5)

    for ep in range(1, EPOCHS_STAGE2 + 1):
        train_one_epoch(model, loaders["train"], optimizer, device)
        m_val = evaluate(model, loaders["val"], device)
        print(f"[{safe_name} | FT {ep}/{EPOCHS_STAGE2}] MAE {m_val['MAE']:.3f}")
        if m_val["MAE"] < best_mae:
            best_mae = m_val["MAE"]
            torch.save(model.state_dict(), model_dir / "final_best.pt")

    # Final test evaluation
    print(f"\n[{safe_name}] Testing best model …")
    model.load_state_dict(torch.load(model_dir / "final_best.pt", map_location=device))
    m_test = evaluate(model, loaders["test"], device)
    for k, v in m_test.items():
        print(f"[{safe_name}] {k}: {v:.3f}")
    (model_dir / "metrics.json").write_text(json.dumps(m_test, indent=2))

    # Free memory between models
    del model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# Main loop over requested models

def main():
    MODEL_IDS = [
        "microsoft/dit-base",
        "microsoft/dit-large",
        "microsoft/dit-base-finetuned-rvlcdip",
        "microsoft/dit-large-finetuned-rvlcdip",
    ]
    for mid in MODEL_IDS:
        train_model(mid)


if __name__ == "__main__":
    main()


  image_processor = cls(**image_processor_dict)



===== Building dataloaders for microsoft/dit-base … =====
{'train': 786, 'val': 146, 'test': 116}


Some weights of BeitModel were not initialized from the model checkpoint at microsoft/dit-base and are newly initialized: ['pooler.layernorm.bias', 'pooler.layernorm.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[microsoft__dit-base | Frozen 1/15] MAE 8.923
[microsoft__dit-base | Frozen 2/15] MAE 8.586
[microsoft__dit-base | Frozen 3/15] MAE 8.947
[microsoft__dit-base | Frozen 4/15] MAE 8.806
[microsoft__dit-base | Frozen 5/15] MAE 8.634
[microsoft__dit-base | Frozen 6/15] MAE 8.679
[microsoft__dit-base | Frozen 7/15] MAE 9.011
[microsoft__dit-base | Frozen 8/15] MAE 8.906
[microsoft__dit-base | Frozen 9/15] MAE 8.863
[microsoft__dit-base | Frozen 10/15] MAE 9.133
[microsoft__dit-base | Frozen 11/15] MAE 9.156
[microsoft__dit-base | Frozen 12/15] MAE 9.249
[microsoft__dit-base | Frozen 13/15] MAE 8.839
[microsoft__dit-base | Frozen 14/15] MAE 9.303
[microsoft__dit-base | Frozen 15/15] MAE 8.833

[microsoft__dit-base] Unfreezing backbone …
[microsoft__dit-base | FT 1/30] MAE 9.379
[microsoft__dit-base | FT 2/30] MAE 7.722
[microsoft__dit-base | FT 3/30] MAE 7.670
[microsoft__dit-base | FT 4/30] MAE 8.962
[microsoft__dit-base | FT 5/30] MAE 8.807
[microsoft__dit-base | FT 6/30] MAE 7.791
[micros

  image_processor = cls(**image_processor_dict)



===== Building dataloaders for microsoft/dit-large … =====
{'train': 786, 'val': 146, 'test': 116}


config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.25G [00:00<?, ?B/s]

Some weights of BeitModel were not initialized from the model checkpoint at microsoft/dit-large and are newly initialized: ['pooler.layernorm.bias', 'pooler.layernorm.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


model.safetensors:   0%|          | 0.00/1.25G [00:00<?, ?B/s]

[microsoft__dit-large | Frozen 1/15] MAE 132.753
[microsoft__dit-large | Frozen 2/15] MAE 92.678
[microsoft__dit-large | Frozen 3/15] MAE 53.926
[microsoft__dit-large | Frozen 4/15] MAE 17.380
[microsoft__dit-large | Frozen 5/15] MAE 17.001
[microsoft__dit-large | Frozen 6/15] MAE 16.020
[microsoft__dit-large | Frozen 7/15] MAE 10.228
[microsoft__dit-large | Frozen 8/15] MAE 11.229
[microsoft__dit-large | Frozen 9/15] MAE 13.834
[microsoft__dit-large | Frozen 10/15] MAE 12.156
[microsoft__dit-large | Frozen 11/15] MAE 11.841
[microsoft__dit-large | Frozen 12/15] MAE 12.471
[microsoft__dit-large | Frozen 13/15] MAE 12.024
[microsoft__dit-large | Frozen 14/15] MAE 11.697
[microsoft__dit-large | Frozen 15/15] MAE 10.666

[microsoft__dit-large] Unfreezing backbone …
[microsoft__dit-large | FT 1/30] MAE 16.237
[microsoft__dit-large | FT 2/30] MAE 11.642
[microsoft__dit-large | FT 3/30] MAE 8.118
[microsoft__dit-large | FT 4/30] MAE 10.278
[microsoft__dit-large | FT 5/30] MAE 9.832
[microsof

  image_processor = cls(**image_processor_dict)



===== Building dataloaders for microsoft/dit-base-finetuned-rvlcdip … =====
{'train': 786, 'val': 146, 'test': 116}


config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/343M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/343M [00:00<?, ?B/s]

[microsoft__dit-base-finetuned-rvlcdip | Frozen 1/15] MAE 8.370
[microsoft__dit-base-finetuned-rvlcdip | Frozen 2/15] MAE 9.076
[microsoft__dit-base-finetuned-rvlcdip | Frozen 3/15] MAE 8.492
[microsoft__dit-base-finetuned-rvlcdip | Frozen 4/15] MAE 8.961
[microsoft__dit-base-finetuned-rvlcdip | Frozen 5/15] MAE 8.468
[microsoft__dit-base-finetuned-rvlcdip | Frozen 6/15] MAE 9.050
[microsoft__dit-base-finetuned-rvlcdip | Frozen 7/15] MAE 8.386
[microsoft__dit-base-finetuned-rvlcdip | Frozen 8/15] MAE 8.532
[microsoft__dit-base-finetuned-rvlcdip | Frozen 9/15] MAE 8.289
[microsoft__dit-base-finetuned-rvlcdip | Frozen 10/15] MAE 8.801
[microsoft__dit-base-finetuned-rvlcdip | Frozen 11/15] MAE 8.508
[microsoft__dit-base-finetuned-rvlcdip | Frozen 12/15] MAE 8.434
[microsoft__dit-base-finetuned-rvlcdip | Frozen 13/15] MAE 8.638
[microsoft__dit-base-finetuned-rvlcdip | Frozen 14/15] MAE 8.175
[microsoft__dit-base-finetuned-rvlcdip | Frozen 15/15] MAE 8.420

[microsoft__dit-base-finetuned-rv

preprocessor_config.json:   0%|          | 0.00/302 [00:00<?, ?B/s]

  image_processor = cls(**image_processor_dict)



===== Building dataloaders for microsoft/dit-large-finetuned-rvlcdip … =====
{'train': 786, 'val': 146, 'test': 116}


config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.21G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.21G [00:00<?, ?B/s]

[microsoft__dit-large-finetuned-rvlcdip | Frozen 1/15] MAE 43.521
[microsoft__dit-large-finetuned-rvlcdip | Frozen 2/15] MAE 24.916
[microsoft__dit-large-finetuned-rvlcdip | Frozen 3/15] MAE 10.950
[microsoft__dit-large-finetuned-rvlcdip | Frozen 4/15] MAE 10.826
[microsoft__dit-large-finetuned-rvlcdip | Frozen 5/15] MAE 9.083
[microsoft__dit-large-finetuned-rvlcdip | Frozen 6/15] MAE 8.576
[microsoft__dit-large-finetuned-rvlcdip | Frozen 7/15] MAE 8.705
[microsoft__dit-large-finetuned-rvlcdip | Frozen 8/15] MAE 8.549
[microsoft__dit-large-finetuned-rvlcdip | Frozen 9/15] MAE 8.342
[microsoft__dit-large-finetuned-rvlcdip | Frozen 10/15] MAE 8.234
[microsoft__dit-large-finetuned-rvlcdip | Frozen 11/15] MAE 8.372
[microsoft__dit-large-finetuned-rvlcdip | Frozen 12/15] MAE 8.646
[microsoft__dit-large-finetuned-rvlcdip | Frozen 13/15] MAE 8.144
[microsoft__dit-large-finetuned-rvlcdip | Frozen 14/15] MAE 8.090
[microsoft__dit-large-finetuned-rvlcdip | Frozen 15/15] MAE 8.854

[microsoft__di