In [16]:
# ============================================================
# Task 2 — DINOv3 Light Fine-tuning (Last Blocks) — SPair-71k
# - Supervision: keypoint -> target patch index (CrossEntropy over target patches)
# - Eval: PCK@{0.05,0.10,0.20} with argmax cosine matching (Task1-style)
# - Geometry: pad bottom/right to multiple of PATCH=16
#
# PROTOCOLLO (tuning SMALL → finale LARGE):
# 0) Overfit sanity (SMALL): train per un numero fisso di step, eval su 200 VAL pairs (SMALL)
# 1) LR finder (SMALL): prova alcuni LR, eval su 300 VAL pairs (SMALL)
# 2) Sweep n_last_blocks (SMALL): prova {0,1,2,4}, seleziona il BEST su VAL (SMALL) usando KP@0.10
#    - Nota: durante lo sweep NON si usa il TEST (per evitare leakage e risparmiare tempo)
# 3) Run finale su LARGE (NO sweep):
#    - calcola baseline FROZEN su VAL (LARGE) e TEST (LARGE)
#    - riparte dai pesi pretrained e allena SOLO il BEST setting (n_last_blocks scelto su SMALL)
#    - valuta BEST su VAL (LARGE) e TEST (LARGE)
#    - stampa globale (KP + IMG) e per-categoria (KP + IMG) per baseline e best, su VAL e TEST
#
# CONFIG:
# - DATASET_SIZE_TUNE="small"
# - DATASET_SIZE_FINAL="large"
# - sanity: eval_pairs=200 (VAL SMALL)
# - lr_finder: eval_pairs=300 (VAL SMALL)
# - sweep: val_pairs=None (FULL VAL SMALL)
# - finale: max_pairs=None (FULL VAL/TEST LARGE)
# ============================================================

!nvidia-smi

Sun Jan  4 16:04:33 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          Off |   00000000:00:05.0 Off |                    0 |
| N/A   35C    P0             62W /  400W |    1919MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
from google.colab import drive
drive.mount("/content/drive")

from pathlib import Path
import os, json, time, copy, random
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd

# ----------------------------
# Paths
# ----------------------------
SPAIR_ROOT = Path("/content/drive/MyDrive/AMLDataset/SPair-71k")
PAIR_ANN_PATH = SPAIR_ROOT / "PairAnnotation"
LAYOUT_PATH   = SPAIR_ROOT / "Layout"
IMAGE_PATH    = SPAIR_ROOT / "JPEGImages"
assert SPAIR_ROOT.exists(), f"SPair-71k non trovato: {SPAIR_ROOT}"
assert PAIR_ANN_PATH.exists() and LAYOUT_PATH.exists() and IMAGE_PATH.exists(), "Cartelle SPair mancanti"

# ----------------------------
# Setup
# ----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

PATCH = 16  # vitb16
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)

def set_seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(0)

Mounted at /content/drive
Device: cuda


In [3]:
# ----------------------------
# Utilities
# ----------------------------
def read_img(image_path: str) -> torch.Tensor:
    """CHW float in [0..255]"""
    img = np.array(Image.open(image_path).convert("RGB"))
    return torch.from_numpy(img.transpose(2, 0, 1)).float()

def normalize_img_chw_0_255(img_chw_0_255: torch.Tensor) -> torch.Tensor:
    """CHW float [0..255] -> normalized float (ImageNet)"""
    x = img_chw_0_255 / 255.0
    mean = torch.tensor(IMAGENET_MEAN, device=x.device).view(3,1,1)
    std  = torch.tensor(IMAGENET_STD,  device=x.device).view(3,1,1)
    return (x - mean) / std

def pad_to_patch_multiple(x_chw: torch.Tensor, patch=16):
    """Pad bottom/right with zeros so H,W multiples of patch."""
    C, H, W = x_chw.shape
    H_pad = ((H + patch - 1) // patch) * patch
    W_pad = ((W + patch - 1) // patch) * patch
    pad_bottom = H_pad - H
    pad_right  = W_pad - W
    x_pad = F.pad(x_chw, (0, pad_right, 0, pad_bottom), value=0.0)
    return x_pad, (H, W), (H_pad, W_pad)

def ensure_kps_k2(kps: torch.Tensor) -> torch.Tensor:
    """
    Coerce keypoints to [K,2] (x,y). Accepts [K,2], [2,K], [K,3], [3,K].
    Drops visibility if present.
    """
    if kps.ndim != 2:
        raise ValueError(f"kps must be 2D, got {kps.shape}")
    if kps.shape[0] in (2,3) and kps.shape[1] not in (2,3):
        kps = kps.t()
    if kps.shape[1] == 3:
        kps = kps[:, :2]
    if kps.shape[1] != 2:
        raise ValueError(f"Cannot convert to [K,2], got {kps.shape}")
    return kps

def kps_to_flat_indices(kps_k2: torch.Tensor, H_pad: int, W_pad: int, patch=16):
    """
    kps_k2: [K,2] pixel coords.
    Returns:
      flat_idx [K] long
      valid [K] bool (inside padded image and non-negative)
      hg,wg
    """
    x = kps_k2[:,0]
    y = kps_k2[:,1]
    valid = (x >= 0) & (y >= 0) & (x < W_pad) & (y < H_pad)

    hg = H_pad // patch
    wg = W_pad // patch

    ix = torch.clamp((x // patch).long(), 0, wg - 1)
    iy = torch.clamp((y // patch).long(), 0, hg - 1)
    flat = iy * wg + ix
    return flat, valid, hg, wg

def find_layout_file(layout_root: Path, dataset_size: str, split: str):
    """
    Robustly find Layout/<size>/<split>.txt, allowing small naming variants.
    """
    folder = layout_root / dataset_size
    candidates = [
        folder / f"{split}.txt",
        folder / f"{split}n.txt",   # e.g., trnn/valn
        folder / f"{split}_.txt"
    ]
    for c in candidates:
        if c.exists():
            return c
    for f in folder.glob("*.txt"):
        if f.stem.lower() == split.lower():
            return f
    raise FileNotFoundError(f"Nessun layout file trovato per split='{split}' in {folder}.")

In [4]:
# ----------------------------
# Dataset
# ----------------------------
class SPairDataset(Dataset):
    def __init__(self, pair_ann_path, layout_path, image_path,
                 dataset_size="large", split="trn", max_retries=3):
        self.split = split
        self.pair_ann_path = Path(pair_ann_path)
        self.layout_path   = Path(layout_path)
        self.image_path    = Path(image_path)
        self.max_retries   = int(max_retries)

        layout_file = find_layout_file(self.layout_path, dataset_size, split)
        with open(layout_file, "r") as f:
            self.ann_files = [x.strip() for x in f.read().splitlines() if len(x.strip()) > 0]

    def __len__(self):
        return len(self.ann_files)

    def __getitem__(self, index):
        pair_id = self.ann_files[index]
        ann_filepath = self.pair_ann_path / self.split / f"{pair_id}.json"

        last_err = None
        for t in range(self.max_retries):
            try:
                with open(ann_filepath, "r") as f:
                    ann = json.load(f)
                break
            except OSError as e:
                last_err = e
                time.sleep(0.2 * (t + 1))
        else:
            raise last_err

        category = ann["category"]
        src_img_path = self.image_path / category / ann["src_imname"]
        trg_img_path = self.image_path / category / ann["trg_imname"]

        src_img = read_img(str(src_img_path))  # CHW float [0..255]
        trg_img = read_img(str(trg_img_path))

        # keypoints -> torch float [K,2]
        src_kps = ensure_kps_k2(torch.tensor(ann["src_kps"], dtype=torch.float32))
        trg_kps = ensure_kps_k2(torch.tensor(ann["trg_kps"], dtype=torch.float32))

        # bboxes -> torch float [4]
        src_bbox = torch.tensor(ann["src_bndbox"], dtype=torch.float32).view(-1)
        trg_bbox = torch.tensor(ann["trg_bndbox"], dtype=torch.float32).view(-1)
        if src_bbox.numel() != 4 or trg_bbox.numel() != 4:
            raise ValueError(f"Bad bbox size for pair_id={pair_id}: "
                             f"src_bbox={src_bbox.tolist()} trg_bbox={trg_bbox.tolist()}")

        return {
            "pair_id": pair_id,
            "category": category,
            "src_bbox": src_bbox,
            "trg_bbox": trg_bbox,
            "src_img": src_img,
            "trg_img": trg_img,
            "src_kps": src_kps,
            "trg_kps": trg_kps,
        }

In [5]:
# ----------------------------
# Choose dataset sizes
# ----------------------------
DATASET_SIZE_TUNE  = "small"
DATASET_SIZE_FINAL = "large"
print("DATASET_SIZE_TUNE :", DATASET_SIZE_TUNE)
print("DATASET_SIZE_FINAL:", DATASET_SIZE_FINAL)

def make_loaders(dataset_size: str):
    train_dataset = SPairDataset(PAIR_ANN_PATH, LAYOUT_PATH, IMAGE_PATH, dataset_size=dataset_size, split="trn")
    val_dataset   = SPairDataset(PAIR_ANN_PATH, LAYOUT_PATH, IMAGE_PATH, dataset_size=dataset_size, split="val")
    test_dataset  = SPairDataset(PAIR_ANN_PATH, LAYOUT_PATH, IMAGE_PATH, dataset_size=dataset_size, split="test")

    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True,  num_workers=0, pin_memory=False)
    val_loader   = DataLoader(val_dataset,   batch_size=1, shuffle=False, num_workers=0, pin_memory=False)
    test_loader  = DataLoader(test_dataset,  batch_size=1, shuffle=False, num_workers=0, pin_memory=False)

    print(f"[{dataset_size}] Train pairs:", len(train_dataset))
    print(f"[{dataset_size}] Val   pairs:", len(val_dataset))
    print(f"[{dataset_size}] Test  pairs:", len(test_dataset))
    return train_loader, val_loader, test_loader

train_loader_s, val_loader_s, test_loader_s = make_loaders(DATASET_SIZE_TUNE)
train_loader_l, val_loader_l, test_loader_l = make_loaders(DATASET_SIZE_FINAL)


DATASET_SIZE_TUNE : small
DATASET_SIZE_FINAL: large
[small] Train pairs: 10652
[small] Val   pairs: 1070
[small] Test  pairs: 2438
[large] Train pairs: 53340
[large] Val   pairs: 5384
[large] Test  pairs: 12234


In [6]:
# ----------------------------
# Load DINOv3
# ----------------------------
%cd /content
!test -d dinov3 || git clone https://github.com/facebookresearch/dinov3.git
%cd /content/dinov3
!pip -q install einops timm opencv-python torchmetrics fvcore iopath

DINOV3_DIR = "/content/dinov3"
DINOV3_WEIGHTS = "/content/drive/MyDrive/AMLDataset/dinov3_vitb16_pretrain_lvd1689m-73cec8be.pth"
assert os.path.exists(DINOV3_WEIGHTS), f"Pesi DINOv3 non trovati: {DINOV3_WEIGHTS}"

dinov3 = torch.hub.load(
    DINOV3_DIR,
    "dinov3_vitb16",
    source="local",
    weights=DINOV3_WEIGHTS,
).to(device)

assert hasattr(dinov3, "blocks"), "Model has no attribute 'blocks'—cannot unfreeze last blocks."
print("DINOv3 blocks:", len(dinov3.blocks))

# Save pretrained snapshot (fair compare)
pretrained_state = copy.deepcopy(dinov3.state_dict())

def restore_pretrained():
    dinov3.load_state_dict(pretrained_state, strict=True)
    dinov3.to(device)

/content
Cloning into 'dinov3'...
remote: Enumerating objects: 538, done.[K
remote: Counting objects: 100% (362/362), done.[K
remote: Compressing objects: 100% (263/263), done.[K
remote: Total 538 (delta 199), reused 99 (delta 99), pack-reused 176 (from 1)[K
Receiving objects: 100% (538/538), 9.88 MiB | 10.63 MiB/s, done.
Resolving deltas: 100% (222/222), done.
/content/dinov3
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.2/50.2 kB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m60.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for fvcore (setup.py) ... [?25l[?25hdone
  Building wheel for iopath (setup.py) ... [?25l[?25hdone
Downloading

100%|██████████| 327M/327M [00:08<00:00, 39.7MB/s]


DINOv3 blocks: 12


In [7]:
# ----------------------------
# Trainability control
# ----------------------------
def set_trainable_last_blocks(model, n_last_blocks: int, train_final_norm=True):
    for p in model.parameters():
        p.requires_grad_(False)

    if n_last_blocks > 0:
        for b in model.blocks[-n_last_blocks:]:
            for p in b.parameters():
                p.requires_grad_(True)

    if train_final_norm and hasattr(model, "norm"):
        for p in model.norm.parameters():
            p.requires_grad_(True)

In [8]:
# ----------------------------
# Patch tokens extraction (robust)
# ----------------------------
def get_patch_tokens(model, x_bchw: torch.Tensor, expected_n: int | None = None) -> torch.Tensor:
    """
    Returns patch tokens [B, N, C] (no CLS).
    Nota: non verificato che le chiavi siano sempre le stesse tra versioni repo.
    """
    def strip_cls_if_needed(t: torch.Tensor) -> torch.Tensor:
        if t.ndim != 3:
            return t
        if expected_n is None:
            return t[:, 1:, :] if t.shape[1] > 1 else t
        if t.shape[1] == expected_n:
            return t
        if t.shape[1] == expected_n + 1:
            return t[:, 1:, :]
        return t

    if hasattr(model, "forward_features"):
        out = model.forward_features(x_bchw)
        if isinstance(out, dict):
            for key in ["x_patchtokens", "patchtokens", "x_norm_patchtokens", "x_norm", "x"]:
                if key in out and isinstance(out[key], torch.Tensor):
                    return strip_cls_if_needed(out[key])
        elif torch.is_tensor(out):
            return strip_cls_if_needed(out)

    out = model(x_bchw)
    if torch.is_tensor(out):
        return strip_cls_if_needed(out)

    raise RuntimeError("Could not extract patch tokens from DINOv3 output.")

In [9]:
# ----------------------------
# Training loss (Keypoint CE)
# ----------------------------
def kp_ce_loss(src_tok_bnc, trg_tok_bnc, src_flat_idx, trg_flat_idx, valid_mask, temp=0.1):
    src = F.normalize(src_tok_bnc[0], dim=-1)  # [Ns,C]
    trg = F.normalize(trg_tok_bnc[0], dim=-1)  # [Nt,C]

    if valid_mask.sum() == 0:
        return src.new_tensor(0.0), 0

    src_kp = src[src_flat_idx]              # [K,C]
    logits = (src_kp @ trg.t()) / temp      # [K,Nt]

    logits_v = logits[valid_mask]
    gt_v     = trg_flat_idx[valid_mask]
    loss = F.cross_entropy(logits_v, gt_v)
    return loss, int(valid_mask.sum().item())

In [10]:
# ----------------------------
# Evaluation helpers
# ----------------------------
@torch.no_grad()
def tokens_to_featuremap(tok_bnc, hg, wg):
    t = tok_bnc[0]  # [N,C]
    if t.shape[0] != hg * wg:
        raise RuntimeError(f"N={t.shape[0]} != hg*wg={hg*wg}")
    return F.normalize(t.view(hg, wg, -1), dim=-1)

@torch.no_grad()
def argmax_cosine(Ft_flat: torch.Tensor, f_src: torch.Tensor, chunk: int = 4096) -> int:
    best_val = None
    best_idx = 0
    N = Ft_flat.shape[0]
    for s in range(0, N, chunk):
        part = Ft_flat[s:s+chunk]
        sim = (part * f_src).sum(dim=-1)
        v, i = sim.max(dim=0)
        v = float(v.item())
        i = int(i.item()) + s
        if (best_val is None) or (v > best_val):
            best_val, best_idx = v, i
    return best_idx

@torch.no_grad()
def match_one_pair(sample, sim_chunk=4096):
    dinov3.eval()

    src_img = sample["src_img"].to(device)
    trg_img = sample["trg_img"].to(device)
    src_kps = sample["src_kps"].to(device)
    trg_kps = sample["trg_kps"].to(device)

    src_x = normalize_img_chw_0_255(src_img)
    trg_x = normalize_img_chw_0_255(trg_img)

    src_pad, (Hs, Ws), (Hs_pad, Ws_pad) = pad_to_patch_multiple(src_x, PATCH)
    trg_pad, (Ht, Wt), (Ht_pad, Wt_pad) = pad_to_patch_multiple(trg_x, PATCH)

    hg_s, wg_s = Hs_pad // PATCH, Ws_pad // PATCH
    hg_t, wg_t = Ht_pad // PATCH, Wt_pad // PATCH

    Ns = hg_s * wg_s
    Nt = hg_t * wg_t

    src_tok = get_patch_tokens(dinov3, src_pad.unsqueeze(0), expected_n=Ns)
    trg_tok = get_patch_tokens(dinov3, trg_pad.unsqueeze(0), expected_n=Nt)

    Fs = tokens_to_featuremap(src_tok, hg_s, wg_s)
    Ft = tokens_to_featuremap(trg_tok, hg_t, wg_t)
    Ft_flat = Ft.view(-1, Ft.shape[-1])

    preds, gts = [], []
    for sp, gt in zip(src_kps, trg_kps):
        if (sp[0] < 0) or (sp[1] < 0) or (gt[0] < 0) or (gt[1] < 0):
            continue

        x_src, y_src = float(sp[0].item()), float(sp[1].item())
        x_gt,  y_gt  = float(gt[0].item()), float(gt[1].item())

        if not (0.0 <= x_src < Ws and 0.0 <= y_src < Hs):
            continue
        if not (0.0 <= x_gt  < Wt and 0.0 <= y_gt  < Ht):
            continue

        jsrc = min(int(x_src) // PATCH, wg_s - 1)
        isrc = min(int(y_src) // PATCH, hg_s - 1)
        f_src = Fs[isrc, jsrc]

        best = argmax_cosine(Ft_flat, f_src, chunk=sim_chunk)
        it = best // wg_t
        jt = best %  wg_t

        x_pred = jt * PATCH + (PATCH / 2.0)
        y_pred = it * PATCH + (PATCH / 2.0)

        preds.append((x_pred, y_pred))
        gts.append((x_gt, y_gt))

    if len(preds) == 0:
        return torch.zeros((0,2)), torch.zeros((0,2))
    return torch.tensor(preds, dtype=torch.float32), torch.tensor(gts, dtype=torch.float32)

@torch.no_grad()
def evaluate_pck_report(loader, name="EVAL", max_pairs=None, sim_chunk=4096, print_per_category=True):
    dinov3.eval()
    thresholds = (0.05, 0.10, 0.20)

    global_correct = {T: 0.0 for T in thresholds}
    global_total_kp = 0.0
    all_pck_img = {T: [] for T in thresholds}

    cat_correct = {}
    cat_total   = {}
    cat_pck_img = {}

    t0 = time.time()
    pairs_seen = 0
    pairs_used = 0

    for batch in loader:
        if max_pairs is not None and pairs_seen >= int(max_pairs):
            break
        pairs_seen += 1

        sample = {
            "src_img": batch["src_img"][0],
            "trg_img": batch["trg_img"][0],
            "src_kps": batch["src_kps"][0],
            "trg_kps": batch["trg_kps"][0],
            "category": batch["category"][0],
            "trg_bbox": batch["trg_bbox"][0],
        }

        pred, gt = match_one_pair(sample, sim_chunk=sim_chunk)
        if pred.shape[0] == 0:
            continue

        x0, y0, x1, y1 = sample["trg_bbox"].detach().cpu().view(-1).tolist()
        norm = max(float(x1 - x0), float(y1 - y0))
        if norm <= 1e-6:
            continue

        dists = torch.linalg.norm(pred - gt, dim=1)
        N = float(dists.numel())
        if N <= 0:
            continue

        pairs_used += 1

        global_total_kp += N
        cat = sample["category"]
        cat_correct.setdefault(cat, {T: 0.0 for T in thresholds})
        cat_total.setdefault(cat, 0.0)
        cat_pck_img.setdefault(cat, {T: [] for T in thresholds})
        cat_total[cat] += N

        for T in thresholds:
            thr = T * norm
            correct = float((dists <= thr).float().sum().item())
            pck_img = correct / N

            global_correct[T] += correct
            all_pck_img[T].append(pck_img)

            cat_correct[cat][T] += correct
            cat_pck_img[cat][T].append(pck_img)

        if pairs_seen % 200 == 0:
            print(f"[{name}] seen={pairs_seen} used={pairs_used}")

    minutes = (time.time() - t0) / 60.0

    mean_pck_img = {T: float(np.mean(all_pck_img[T])) if len(all_pck_img[T]) else 0.0 for T in thresholds}
    global_pck_kp = {T: float(global_correct[T] / max(global_total_kp, 1.0)) for T in thresholds}

    print("\n" + "="*18 + f" {name} REPORT " + "="*18)
    print(f"Pairs run: {pairs_used} (seen: {pairs_seen})")
    print("\nGlobal PCK (per-image mean):")
    for T in thresholds:
        print(f"  PCK@{T:.2f}: {100.0*mean_pck_img[T]:.2f}%")
    print("\nGlobal PCK (per-keypoint):")
    for T in thresholds:
        print(f"  PCK@{T:.2f}: {100.0*global_pck_kp[T]:.2f}%")

    per_cat_rows = []
    if print_per_category and len(cat_total) > 0:
        for cat in sorted(cat_total.keys()):
            row = {"Category": cat}
            for T in thresholds:
                kp = float(cat_correct[cat][T] / max(cat_total[cat], 1.0))
                im = float(np.mean(cat_pck_img[cat][T])) if len(cat_pck_img[cat][T]) else 0.0
                row[f"KP@{T:.2f}"]  = 100.0 * kp
                row[f"IMG@{T:.2f}"] = 100.0 * im
            per_cat_rows.append(row)

        df_cat = pd.DataFrame(per_cat_rows)
        print("\n" + "="*16 + " PER-CATEGORY RESULTS " + "="*16)
        cols = ["Category","KP@0.05","KP@0.10","KP@0.20","IMG@0.05","IMG@0.10","IMG@0.20"]
        print(df_cat[cols].to_string(index=False))

    print(f"\nMinutes: {minutes:.4f}")

    return {
        "pairs_seen": pairs_seen,
        "pairs_run": pairs_used,
        "minutes": minutes,
        "global_img": mean_pck_img,
        "global_kp": global_pck_kp,
        "per_category": per_cat_rows
    }

In [11]:
# ----------------------------
# Training (steps-based) + AMP
# ----------------------------
from torch.amp import autocast, GradScaler

def train_steps(
    loader,
    n_last_blocks=1,
    lr=1e-5,
    weight_decay=0.05,
    temp=0.1,
    max_steps=2000,
    log_every=200,
    use_amp=True
):
    set_trainable_last_blocks(dinov3, n_last_blocks=n_last_blocks, train_final_norm=True)
    dinov3.train()

    params = [p for p in dinov3.parameters() if p.requires_grad]
    n_trainable = sum(p.numel() for p in params)
    print(f"[train] n_last_blocks={n_last_blocks} | trainable params: {n_trainable:,}")
    assert n_trainable > 0, "No trainable parameters!"

    opt = torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay)
    scaler = GradScaler("cuda", enabled=(use_amp and device.type == "cuda"))

    running_loss = 0.0
    running_kps  = 0
    t0 = time.time()

    it = iter(loader)
    for step in range(int(max_steps)):
        try:
            batch = next(it)
        except StopIteration:
            it = iter(loader)
            batch = next(it)

        src_img = batch["src_img"][0].to(device)
        trg_img = batch["trg_img"][0].to(device)
        src_kps = batch["src_kps"][0].to(device)
        trg_kps = batch["trg_kps"][0].to(device)

        src_x = normalize_img_chw_0_255(src_img)
        trg_x = normalize_img_chw_0_255(trg_img)
        src_pad, _, (Hs_pad, Ws_pad) = pad_to_patch_multiple(src_x, PATCH)
        trg_pad, _, (Ht_pad, Wt_pad) = pad_to_patch_multiple(trg_x, PATCH)

        src_flat, src_valid, hg_s, wg_s = kps_to_flat_indices(src_kps, Hs_pad, Ws_pad, PATCH)
        trg_flat, trg_valid, hg_t, wg_t = kps_to_flat_indices(trg_kps, Ht_pad, Wt_pad, PATCH)
        valid = src_valid & trg_valid

        Ns = hg_s * wg_s
        Nt = hg_t * wg_t

        opt.zero_grad(set_to_none=True)

        with autocast("cuda", enabled=scaler.is_enabled()):
            src_tok = get_patch_tokens(dinov3, src_pad.unsqueeze(0), expected_n=Ns)
            trg_tok = get_patch_tokens(dinov3, trg_pad.unsqueeze(0), expected_n=Nt)
            loss, nvalid = kp_ce_loss(src_tok, trg_tok, src_flat, trg_flat, valid, temp=temp)

        if nvalid == 0:
            continue

        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()

        running_loss += float(loss.item()) * nvalid
        running_kps  += nvalid

        if (step + 1) % log_every == 0:
            avg = running_loss / max(running_kps, 1)
            dt = time.time() - t0
            print(f"[train] step {step+1}/{max_steps} | avg_loss {avg:.4f} | seen_kps {running_kps} | {dt:.1f}s")

    return running_loss / max(running_kps, 1)

In [12]:
# ----------------------------
# Sanity: overfit quick
# ----------------------------
def overfit_sanity(
    train_loader, val_loader,
    n_last_blocks=2,
    lr=5e-5,
    weight_decay=0.05,
    temp=0.1,
    steps=800,
    eval_pairs=200
):
    print("\n" + "="*70)
    print("SANITY CHECK: OVERFIT")
    print("="*70)

    # Baseline frozen
    restore_pretrained()
    set_trainable_last_blocks(dinov3, 0)
    dinov3.eval()
    base = evaluate_pck_report(val_loader, name="VAL (baseline, frozen)", max_pairs=eval_pairs, print_per_category=False)

    # Train
    restore_pretrained()
    avg_loss = train_steps(
        train_loader,
        n_last_blocks=n_last_blocks,
        lr=lr,
        weight_decay=weight_decay,
        temp=temp,
        max_steps=steps,
        log_every=200
    )

    # Evaluate after training
    dinov3.eval()
    fin = evaluate_pck_report(val_loader, name="VAL (after overfit)", max_pairs=eval_pairs, print_per_category=False)

    print(f"\n[overfit] avg_train_loss: {avg_loss:.4f}")
    print(f"[overfit] baseline VAL KP@0.10: {100*base['global_kp'][0.10]:.2f}% -> after: {100*fin['global_kp'][0.10]:.2f}%")


# ----------------------------
# LR finder quick
# ----------------------------
def lr_finder(
    train_loader, val_loader,
    n_last_blocks=2,
    lrs=(1e-6, 3e-6, 1e-5, 3e-5, 1e-4),
    weight_decay=0.05,
    temp=0.1,
    train_steps_each=200,
    eval_pairs=300
):
    print("\n" + "="*70)
    print("LR FINDER")
    print("="*70)

    rows = []
    for lr in lrs:
        print("\n" + "-"*60)
        print(f"[lr_finder] lr={lr:g} | blocks={n_last_blocks}")
        print("-"*60)

        restore_pretrained()
        avg_loss = train_steps(
            train_loader,
            n_last_blocks=n_last_blocks,
            lr=lr,
            weight_decay=weight_decay,
            temp=temp,
            max_steps=train_steps_each,
            log_every=200
        )

        dinov3.eval()
        val_rep = evaluate_pck_report(val_loader, name=f"VAL (lr={lr:g})", max_pairs=eval_pairs, print_per_category=False)

        rows.append({
            "lr": lr,
            "n_last_blocks": n_last_blocks,
            "avg_train_loss": float(avg_loss),
            "val_kp_PCK@0.05": 100*val_rep["global_kp"][0.05],
            "val_kp_PCK@0.10": 100*val_rep["global_kp"][0.10],
            "val_kp_PCK@0.20": 100*val_rep["global_kp"][0.20],
        })

    df = pd.DataFrame(rows).sort_values("lr").reset_index(drop=True)
    print("\n=== LR FINDER SUMMARY ===")
    print(df.to_string(index=False, float_format=lambda x: f"{x:.4f}" if x < 1 else f"{x:.2f}"))
    return df


# ----------------------------
# Sweep on SMALL (VAL selection only) — NO TEST HERE
# ----------------------------
def sweep_blocks_on_val(
    train_loader, val_loader,
    settings=(0,1,2,4),
    lr=3e-5,
    temp=0.1,
    weight_decay=0.05,
    train_steps_per_setting=2000,
    val_pairs=None,
    select_on="KP@0.10"  # either "KP@0.10" or "IMG@0.10"
):
    print("\n" + "="*80)
    print("SWEEP (VAL ONLY) — selecting best on VAL")
    print("="*80)

    # Baseline frozen on VAL
    restore_pretrained()
    set_trainable_last_blocks(dinov3, 0)
    dinov3.eval()
    base_val = evaluate_pck_report(val_loader, name="VAL BASELINE (frozen)", max_pairs=val_pairs, print_per_category=False)

    rows = []
    for n_last_blocks in settings:
        print("\n" + "="*70)
        print(f"RUN setting: n_last_blocks={n_last_blocks}")
        print("="*70)

        restore_pretrained()
        avg_loss = None

        if n_last_blocks > 0:
            avg_loss = train_steps(
                train_loader,
                n_last_blocks=n_last_blocks,
                lr=lr,
                weight_decay=weight_decay,
                temp=temp,
                max_steps=train_steps_per_setting,
                log_every=200
            )

        dinov3.eval()
        val_rep = evaluate_pck_report(
            val_loader,
            name=f"VAL FINETUNED (blocks={n_last_blocks})",
            max_pairs=val_pairs,
            print_per_category=False
        )

        rows.append({
            "n_last_blocks": n_last_blocks,
            "lr": lr,
            "temp": temp,
            "train_steps": int(train_steps_per_setting) if n_last_blocks > 0 else 0,
            "avg_train_loss": float(avg_loss) if avg_loss is not None else None,
            "val_kp_PCK@0.05": 100 * val_rep["global_kp"][0.05],
            "val_kp_PCK@0.10": 100 * val_rep["global_kp"][0.10],
            "val_kp_PCK@0.20": 100 * val_rep["global_kp"][0.20],
            "val_img_PCK@0.10": 100 * val_rep["global_img"][0.10],
        })

    df = pd.DataFrame(rows).sort_values("n_last_blocks").reset_index(drop=True)
    print("\n=== SWEEP SUMMARY (VAL) ===")
    print(df.to_string(index=False, float_format=lambda x: f"{x:.4f}" if abs(x) < 1 else f"{x:.2f}"))

    # Select best
    if select_on == "IMG@0.10":
        metric_col = "val_img_PCK@0.10"
    else:
        metric_col = "val_kp_PCK@0.10"

    best_idx = int(df[metric_col].astype(float).idxmax())
    best_blocks = int(df.loc[best_idx, "n_last_blocks"])
    print("\n" + "="*70)
    print(f"BEST on VAL by {select_on}: n_last_blocks={best_blocks}")
    print("="*70)

    return df, best_blocks, base_val

In [13]:
def final_large_run_no_sweep(
    train_loader_l, val_loader_l, test_loader_l,
    best_blocks: int,
    lr: float,
    temp=0.1,
    weight_decay=0.05,
    train_steps_best=2000,
):
    print("\n" + "="*80)
    print("FINAL LARGE RUN (NO SWEEP)")
    print("="*80)
    print(f"Using best_blocks={best_blocks} | lr={lr:g} | temp={temp} | wd={weight_decay} | steps={train_steps_best}")

    # ---- BASELINE (FROZEN) ----
    restore_pretrained()
    set_trainable_last_blocks(dinov3, 0)
    dinov3.eval()

    base_val_l = evaluate_pck_report(val_loader_l,  name="VAL LARGE BASELINE (frozen)",  max_pairs=None, print_per_category=True)
    base_test_l = evaluate_pck_report(test_loader_l, name="TEST LARGE BASELINE (frozen)", max_pairs=None, print_per_category=True)

    # ---- TRAIN BEST ON TRAIN LARGE ----
    restore_pretrained()
    if best_blocks > 0:
        _ = train_steps(
            train_loader_l,
            n_last_blocks=best_blocks,
            lr=lr,
            weight_decay=weight_decay,
            temp=temp,
            max_steps=train_steps_best,
            log_every=200
        )

    # ---- EVAL BEST ----
    dinov3.eval()
    best_val_l = evaluate_pck_report(val_loader_l,  name=f"VAL LARGE BEST (blocks={best_blocks})",  max_pairs=None, print_per_category=True)
    best_test_l = evaluate_pck_report(test_loader_l, name=f"TEST LARGE BEST (blocks={best_blocks})", max_pairs=None, print_per_category=True)

    return base_val_l, best_val_l, base_test_l, best_test_l

In [14]:
# ----------------------------
# Final summary table (KP only) + deltas
# ----------------------------
def make_final_kp_table(base_val, best_val, base_test, best_test, best_blocks):
    def row(split, model_name, rep):
        return {
            "Split": split,
            "Model": model_name,
            "KP@0.05": 100 * rep["global_kp"][0.05],
            "KP@0.10": 100 * rep["global_kp"][0.10],
            "KP@0.20": 100 * rep["global_kp"][0.20],
        }

    rows = []
    rows.append(row("VAL", "Frozen", base_val))
    rows.append(row("VAL", f"Best (blocks={best_blocks})", best_val))
    rows.append({
        "Split": "VAL",
        "Model": "Δ Best - Frozen",
        "KP@0.05": 100*best_val["global_kp"][0.05] - 100*base_val["global_kp"][0.05],
        "KP@0.10": 100*best_val["global_kp"][0.10] - 100*base_val["global_kp"][0.10],
        "KP@0.20": 100*best_val["global_kp"][0.20] - 100*base_val["global_kp"][0.20],
    })

    if base_test is not None:
        rows.append(row("TEST", "Frozen", base_test))
        rows.append(row("TEST", f"Best (blocks={best_blocks})", best_test))
        rows.append({
            "Split": "TEST",
            "Model": "Δ Best - Frozen",
            "KP@0.05": 100*best_test["global_kp"][0.05] - 100*base_test["global_kp"][0.05],
            "KP@0.10": 100*best_test["global_kp"][0.10] - 100*base_test["global_kp"][0.10],
            "KP@0.20": 100*best_test["global_kp"][0.20] - 100*base_test["global_kp"][0.20],
        })

    return pd.DataFrame(rows)

def make_final_img_table(base_val, best_val, base_test, best_test, best_blocks):
    def row(split, model_name, rep):
        return {
            "Split": split,
            "Model": model_name,
            "IMG@0.05": 100 * rep["global_img"][0.05],
            "IMG@0.10": 100 * rep["global_img"][0.10],
            "IMG@0.20": 100 * rep["global_img"][0.20],
        }

    rows = []
    rows.append(row("VAL", "Frozen", base_val))
    rows.append(row("VAL", f"Best (blocks={best_blocks})", best_val))
    rows.append({
        "Split": "VAL",
        "Model": "Δ Best - Frozen",
        "IMG@0.05": 100*best_val["global_img"][0.05] - 100*base_val["global_img"][0.05],
        "IMG@0.10": 100*best_val["global_img"][0.10] - 100*base_val["global_img"][0.10],
        "IMG@0.20": 100*best_val["global_img"][0.20] - 100*base_val["global_img"][0.20],
    })

    if base_test is not None:
        rows.append(row("TEST", "Frozen", base_test))
        rows.append(row("TEST", f"Best (blocks={best_blocks})", best_test))
        rows.append({
            "Split": "TEST",
            "Model": "Δ Best - Frozen",
            "IMG@0.05": 100*best_test["global_img"][0.05] - 100*base_test["global_img"][0.05],
            "IMG@0.10": 100*best_test["global_img"][0.10] - 100*base_test["global_img"][0.10],
            "IMG@0.20": 100*best_test["global_img"][0.20] - 100*base_test["global_img"][0.20],
        })

    return pd.DataFrame(rows)


In [33]:
# ============================================================
# RUN SECTION (FULL: val_pairs=None, test_pairs=None)
# ============================================================

In [34]:
# --- 0) Overfit sanity on SMALL ---
overfit_sanity(
    train_loader_s, val_loader_s,
    n_last_blocks=2,
    lr=5e-5,
    weight_decay=0.05,
    temp=0.1,
    steps=800,
    eval_pairs=200
)

# --- 1) LR finder on SMALL ---
df_lr = lr_finder(
    train_loader_s, val_loader_s,
    n_last_blocks=2,
    lrs=(1e-6, 3e-6, 1e-5, 3e-5, 1e-4),
    weight_decay=0.05,
    temp=0.1,
    train_steps_each=200,
    eval_pairs=300
)


SANITY CHECK: OVERFIT
[VAL (baseline, frozen)] seen=200 used=200

Pairs run: 200 (seen: 200)

Global PCK (per-image mean):
  PCK@0.05: 34.60%
  PCK@0.10: 52.54%
  PCK@0.20: 63.43%

Global PCK (per-keypoint):
  PCK@0.05: 37.14%
  PCK@0.10: 55.20%
  PCK@0.20: 67.32%

Minutes: 0.1622
[train] n_last_blocks=2 | trainable params: 14,180,352
[train] step 200/800 | avg_loss 3.7585 | seen_kps 1428 | 172.6s
[train] step 400/800 | avg_loss 3.3729 | seen_kps 2868 | 348.5s
[train] step 600/800 | avg_loss 3.1745 | seen_kps 4391 | 521.3s
[train] step 800/800 | avg_loss 3.0239 | seen_kps 5859 | 695.0s
[VAL (after overfit)] seen=200 used=200

Pairs run: 200 (seen: 200)

Global PCK (per-image mean):
  PCK@0.05: 54.29%
  PCK@0.10: 73.54%
  PCK@0.20: 84.75%

Global PCK (per-keypoint):
  PCK@0.05: 55.04%
  PCK@0.10: 73.49%
  PCK@0.20: 85.54%

Minutes: 0.1528

[overfit] avg_train_loss: 3.0239
[overfit] baseline VAL KP@0.10: 55.20% -> after: 73.49%

LR FINDER

-----------------------------------------------

In [15]:
# Pick LR_FINAL manually based on df_lr
LR_FINAL = 3e-5

# --- 2) Sweep blocks on SMALL (VAL only) ---
df_sweep_s, best_blocks_s, base_val_s = sweep_blocks_on_val(
    train_loader_s, val_loader_s,
    settings=(0,1,2,4),
    lr=LR_FINAL,
    temp=0.1,
    weight_decay=0.05,
    train_steps_per_setting=2000,
    val_pairs=None,         # full VAL small
    select_on="KP@0.10"
)

print("\n[TUNING DONE] best_blocks_s =", best_blocks_s, "| LR_FINAL =", LR_FINAL)


SWEEP (VAL ONLY) — selecting best on VAL
[VAL BASELINE (frozen)] seen=200 used=200
[VAL BASELINE (frozen)] seen=400 used=400
[VAL BASELINE (frozen)] seen=600 used=600
[VAL BASELINE (frozen)] seen=800 used=800
[VAL BASELINE (frozen)] seen=1000 used=1000

Pairs run: 1070 (seen: 1070)

Global PCK (per-image mean):
  PCK@0.05: 28.82%
  PCK@0.10: 46.25%
  PCK@0.20: 61.13%

Global PCK (per-keypoint):
  PCK@0.05: 31.76%
  PCK@0.10: 50.23%
  PCK@0.20: 66.28%

Minutes: 20.8490

RUN setting: n_last_blocks=0
[VAL FINETUNED (blocks=0)] seen=200 used=200
[VAL FINETUNED (blocks=0)] seen=400 used=400
[VAL FINETUNED (blocks=0)] seen=600 used=600
[VAL FINETUNED (blocks=0)] seen=800 used=800
[VAL FINETUNED (blocks=0)] seen=1000 used=1000

Pairs run: 1070 (seen: 1070)

Global PCK (per-image mean):
  PCK@0.05: 28.82%
  PCK@0.10: 46.25%
  PCK@0.20: 61.13%

Global PCK (per-keypoint):
  PCK@0.05: 31.76%
  PCK@0.10: 50.23%
  PCK@0.20: 66.28%

Minutes: 0.8384

RUN setting: n_last_blocks=1
[train] n_last_block

In [17]:
# --- 3) FINAL on LARGE (baseline + best only, no sweep) ---
base_val_l, best_val_l, base_test_l, best_test_l = final_large_run_no_sweep(
    train_loader_l, val_loader_l, test_loader_l,
    best_blocks=int(best_blocks_s),
    lr=float(LR_FINAL),
    temp=0.1,
    weight_decay=0.05,
    train_steps_best=2000
)

print("\n=== FINAL KP SUMMARY (VAL + TEST) ===")
df_final_kp = make_final_kp_table(base_val_l, best_val_l, base_test_l, best_test_l, int(best_blocks_s))
print(df_final_kp.to_string(index=False, float_format=lambda x: f"{x:.2f}"))

print("\n=== FINAL IMG SUMMARY (VAL + TEST) ===")
df_final_img = make_final_img_table(base_val_l, best_val_l, base_test_l, best_test_l, int(best_blocks_s))
print(df_final_img.to_string(index=False, float_format=lambda x: f"{x:.2f}"))

print("\n=== DONE ===")


FINAL LARGE RUN (NO SWEEP)
Using best_blocks=4 | lr=3e-05 | temp=0.1 | wd=0.05 | steps=2000
[VAL LARGE BASELINE (frozen)] seen=200 used=200
[VAL LARGE BASELINE (frozen)] seen=400 used=400
[VAL LARGE BASELINE (frozen)] seen=600 used=600
[VAL LARGE BASELINE (frozen)] seen=800 used=800
[VAL LARGE BASELINE (frozen)] seen=1000 used=1000
[VAL LARGE BASELINE (frozen)] seen=1200 used=1200
[VAL LARGE BASELINE (frozen)] seen=1400 used=1400
[VAL LARGE BASELINE (frozen)] seen=1600 used=1600
[VAL LARGE BASELINE (frozen)] seen=1800 used=1800
[VAL LARGE BASELINE (frozen)] seen=2000 used=2000
[VAL LARGE BASELINE (frozen)] seen=2200 used=2200
[VAL LARGE BASELINE (frozen)] seen=2400 used=2400
[VAL LARGE BASELINE (frozen)] seen=2600 used=2600
[VAL LARGE BASELINE (frozen)] seen=2800 used=2800
[VAL LARGE BASELINE (frozen)] seen=3000 used=3000
[VAL LARGE BASELINE (frozen)] seen=3200 used=3200
[VAL LARGE BASELINE (frozen)] seen=3400 used=3400
[VAL LARGE BASELINE (frozen)] seen=3600 used=3600
[VAL LARGE BASE