In [1]:
# ============================================================
# Task 2 — DINOv3 Light Fine-tuning (Last Blocks) — SPair-71k
# - Training supervision: keypoint -> target patch index (CrossEntropy)
# - Evaluation: PCK@{0.05,0.10,0.20} using argmax matching (Task1-style)
# Geometry: padding bottom/right to multiple of PATCH=16
# - Task2-complete: LR finder + overfit sanity + sweep blocks (select on VAL), test only best
# ============================================================

!nvidia-smi

Fri Jan  2 19:20:36 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   60C    P8             10W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
from google.colab import drive
drive.mount("/content/drive")

from pathlib import Path
import os, json, time, copy, random
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd

# ----------------------------
# Paths
# ----------------------------
SPAIR_ROOT = Path("/content/drive/MyDrive/AMLDataset/SPair-71k")
PAIR_ANN_PATH = SPAIR_ROOT / "PairAnnotation"
LAYOUT_PATH   = SPAIR_ROOT / "Layout"
IMAGE_PATH    = SPAIR_ROOT / "JPEGImages"
assert SPAIR_ROOT.exists(), f"SPair-71k non trovato: {SPAIR_ROOT}"
assert PAIR_ANN_PATH.exists() and LAYOUT_PATH.exists() and IMAGE_PATH.exists(), "Cartelle SPair mancanti"

# ----------------------------
# Setup
# ----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

PATCH = 16  # vitb16
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)

def set_seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(0)

Mounted at /content/drive
Device: cuda


In [3]:
# ----------------------------
# Utilities
# ----------------------------
def read_img(image_path: str) -> torch.Tensor:
    """CHW float in [0..255]"""
    img = np.array(Image.open(image_path).convert("RGB"))
    return torch.from_numpy(img.transpose(2, 0, 1)).float()

def normalize_img_chw_0_255(img_chw_0_255: torch.Tensor) -> torch.Tensor:
    """CHW float [0..255] -> normalized float (ImageNet)"""
    x = img_chw_0_255 / 255.0
    mean = torch.tensor(IMAGENET_MEAN, device=x.device).view(3,1,1)
    std  = torch.tensor(IMAGENET_STD,  device=x.device).view(3,1,1)
    return (x - mean) / std

def pad_to_patch_multiple(x_chw: torch.Tensor, patch=16):
    """Pad bottom/right with zeros so H,W multiples of patch."""
    C, H, W = x_chw.shape
    H_pad = ((H + patch - 1) // patch) * patch
    W_pad = ((W + patch - 1) // patch) * patch
    pad_bottom = H_pad - H
    pad_right  = W_pad - W
    x_pad = F.pad(x_chw, (0, pad_right, 0, pad_bottom), value=0.0)
    return x_pad, (H, W), (H_pad, W_pad)

def ensure_kps_k2(kps: torch.Tensor) -> torch.Tensor:
    """
    Coerce keypoints to [K,2] (x,y). Accepts [K,2], [2,K], [K,3], [3,K].
    Drops visibility if present.
    """
    if kps.ndim != 2:
        raise ValueError(f"kps must be 2D, got {kps.shape}")
    if kps.shape[0] in (2,3) and kps.shape[1] not in (2,3):
        kps = kps.t()
    if kps.shape[1] == 3:
        kps = kps[:, :2]
    if kps.shape[1] != 2:
        raise ValueError(f"Cannot convert to [K,2], got {kps.shape}")
    return kps

def kps_to_flat_indices(kps_k2: torch.Tensor, H_pad: int, W_pad: int, patch=16):
    """
    kps_k2: [K,2] pixel coords.
    Returns:
      flat_idx [K] long
      valid [K] bool (inside padded image and non-negative)
      hg,wg
    """
    x = kps_k2[:,0]
    y = kps_k2[:,1]
    valid = (x >= 0) & (y >= 0) & (x < W_pad) & (y < H_pad)

    hg = H_pad // patch
    wg = W_pad // patch

    ix = torch.clamp((x // patch).long(), 0, wg - 1)
    iy = torch.clamp((y // patch).long(), 0, hg - 1)
    flat = iy * wg + ix
    return flat, valid, hg, wg

def find_layout_file(layout_root: Path, dataset_size: str, split: str):
    """
    Robustly find Layout/<size>/<split>.txt, allowing small naming variants.
    """
    folder = layout_root / dataset_size
    candidates = [
        folder / f"{split}.txt",
        folder / f"{split}n.txt",   # e.g., trnn/valn
        folder / f"{split}_.txt"
    ]
    for c in candidates:
        if c.exists():
            return c
    for f in folder.glob("*.txt"):
        if f.stem.lower() == split.lower():
            return f
    raise FileNotFoundError(f"Nessun layout file trovato per split='{split}' in {folder}.")

In [4]:
# ----------------------------
# Dataset
# ----------------------------
class SPairDataset(Dataset):
    def __init__(self, pair_ann_path, layout_path, image_path,
                 dataset_size="large", split="trn", max_retries=3):
        self.split = split
        self.pair_ann_path = Path(pair_ann_path)
        self.layout_path   = Path(layout_path)
        self.image_path    = Path(image_path)
        self.max_retries   = int(max_retries)

        layout_file = find_layout_file(self.layout_path, dataset_size, split)
        with open(layout_file, "r") as f:
            self.ann_files = [x.strip() for x in f.read().splitlines() if len(x.strip()) > 0]

    def __len__(self):
        return len(self.ann_files)

    def __getitem__(self, index):
        pair_id = self.ann_files[index]
        ann_filepath = self.pair_ann_path / self.split / f"{pair_id}.json"

        last_err = None
        for t in range(self.max_retries):
            try:
                with open(ann_filepath, "r") as f:
                    ann = json.load(f)
                break
            except OSError as e:
                last_err = e
                time.sleep(0.2 * (t + 1))
        else:
            raise last_err

        category = ann["category"]
        src_img_path = self.image_path / category / ann["src_imname"]
        trg_img_path = self.image_path / category / ann["trg_imname"]

        src_img = read_img(str(src_img_path))  # CHW float [0..255]
        trg_img = read_img(str(trg_img_path))

        # keypoints -> torch float [K,2]
        src_kps = ensure_kps_k2(torch.tensor(ann["src_kps"], dtype=torch.float32))
        trg_kps = ensure_kps_k2(torch.tensor(ann["trg_kps"], dtype=torch.float32))

        # bboxes -> torch float [4] (so DataLoader collate is always [B,4])
        src_bbox = torch.tensor(ann["src_bndbox"], dtype=torch.float32).view(-1)
        trg_bbox = torch.tensor(ann["trg_bndbox"], dtype=torch.float32).view(-1)
        if src_bbox.numel() != 4 or trg_bbox.numel() != 4:
            raise ValueError(f"Bad bbox size for pair_id={pair_id}: "
                            f"src_bbox={src_bbox.tolist()} trg_bbox={trg_bbox.tolist()}")

        return {
            "pair_id": pair_id,        # string (ok with batch_size=1)
            "category": category,      # string (ok with batch_size=1)
            "src_bbox": src_bbox,      # tensor [4]
            "trg_bbox": trg_bbox,      # tensor [4]
            "src_img": src_img,        # tensor [3,H,W]
            "trg_img": trg_img,        # tensor [3,H,W]
            "src_kps": src_kps,        # tensor [K,2]
            "trg_kps": trg_kps,        # tensor [K,2]
        }

In [5]:
# ----------------------------
# Choose dataset size here
# ----------------------------
DATASET_SIZE = "small"   # <-- set "small" or "large"
print("DATASET_SIZE:", DATASET_SIZE)

train_dataset = SPairDataset(PAIR_ANN_PATH, LAYOUT_PATH, IMAGE_PATH, dataset_size=DATASET_SIZE, split="trn")
val_dataset   = SPairDataset(PAIR_ANN_PATH, LAYOUT_PATH, IMAGE_PATH, dataset_size=DATASET_SIZE, split="val")
test_dataset  = SPairDataset(PAIR_ANN_PATH, LAYOUT_PATH, IMAGE_PATH, dataset_size=DATASET_SIZE, split="test")

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True,  num_workers=0, pin_memory=False)
val_loader   = DataLoader(val_dataset,   batch_size=1, shuffle=False, num_workers=0, pin_memory=False)
test_loader  = DataLoader(test_dataset,  batch_size=1, shuffle=False, num_workers=0, pin_memory=False)

print("Train pairs:", len(train_dataset))
print("Val   pairs:", len(val_dataset))
print("Test  pairs:", len(test_dataset))

DATASET_SIZE: small
Train pairs: 10652
Val   pairs: 1070
Test  pairs: 2438


In [6]:
# ----------------------------
# Load DINOv3
# ----------------------------
%cd /content
!test -d dinov3 || git clone https://github.com/facebookresearch/dinov3.git
%cd /content/dinov3
!pip -q install einops timm opencv-python torchmetrics fvcore iopath

DINOV3_DIR = "/content/dinov3"
DINOV3_WEIGHTS = "/content/drive/MyDrive/AMLDataset/dinov3_vitb16_pretrain_lvd1689m-73cec8be.pth"
assert os.path.exists(DINOV3_WEIGHTS), f"Pesi DINOv3 non trovati: {DINOV3_WEIGHTS}"

dinov3 = torch.hub.load(
    DINOV3_DIR,
    "dinov3_vitb16",
    source="local",
    weights=DINOV3_WEIGHTS,
).to(device)

assert hasattr(dinov3, "blocks"), "Model has no attribute 'blocks'—cannot unfreeze last blocks."
print("DINOv3 blocks:", len(dinov3.blocks))

# Save pretrained snapshot (for fair comparison across settings)
pretrained_state = copy.deepcopy(dinov3.state_dict())

def restore_pretrained():
    dinov3.load_state_dict(pretrained_state, strict=True)
    dinov3.to(device)

/content
Cloning into 'dinov3'...
remote: Enumerating objects: 538, done.[K
remote: Counting objects: 100% (363/363), done.[K
remote: Compressing objects: 100% (264/264), done.[K
remote: Total 538 (delta 201), reused 99 (delta 99), pack-reused 175 (from 1)[K
Receiving objects: 100% (538/538), 9.88 MiB | 19.96 MiB/s, done.
Resolving deltas: 100% (223/223), done.
/content/dinov3
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.2/50.2 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m67.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for fvcore (setup.py) ... [?25l[?25hdone
  Building wheel for iopath (setup.py) ... [?25l[?25hdone
Downloading

100%|██████████| 327M/327M [00:10<00:00, 33.3MB/s]


DINOv3 blocks: 12


In [7]:
# ----------------------------
# Trainability control (unfreeze last blocks + optional final norm)
# ----------------------------
def set_trainable_last_blocks(model, n_last_blocks: int, train_final_norm=True):
    for p in model.parameters():
        p.requires_grad_(False)

    if n_last_blocks > 0:
        for b in model.blocks[-n_last_blocks:]:
            for p in b.parameters():
                p.requires_grad_(True)

    if train_final_norm and hasattr(model, "norm"):
        for p in model.norm.parameters():
            p.requires_grad_(True)

In [8]:
# ----------------------------
# Patch tokens extraction (robust + expected_n safeguard)
# ----------------------------
def get_patch_tokens(model, x_bchw: torch.Tensor, expected_n: int | None = None) -> torch.Tensor:
    """
    Returns patch tokens [B, N, C] (no CLS).
    """
    def strip_cls_if_needed(t: torch.Tensor) -> torch.Tensor:
        if t.ndim != 3:
            return t
        if expected_n is None:
            return t[:, 1:, :] if t.shape[1] > 1 else t
        if t.shape[1] == expected_n:
            return t
        if t.shape[1] == expected_n + 1:
            return t[:, 1:, :]
        return t

    if hasattr(model, "forward_features"):
        out = model.forward_features(x_bchw)
        if isinstance(out, dict):
            for key in ["x_norm_patchtokens", "x_patchtokens", "patchtokens", "x_norm", "x"]:
                if key in out and isinstance(out[key], torch.Tensor):
                    return strip_cls_if_needed(out[key])
        elif torch.is_tensor(out):
            return strip_cls_if_needed(out)

    out = model(x_bchw)
    if torch.is_tensor(out):
        return strip_cls_if_needed(out)

    raise RuntimeError("Could not extract patch tokens from DINOv3 output.")

In [9]:
# ----------------------------
# Training loss (Keypoint CE)
# ----------------------------
def kp_ce_loss(src_tok_bnc, trg_tok_bnc, src_flat_idx, trg_flat_idx, valid_mask, temp=0.1):
    src = F.normalize(src_tok_bnc[0], dim=-1)  # [Ns,C]
    trg = F.normalize(trg_tok_bnc[0], dim=-1)  # [Nt,C]

    if valid_mask.sum() == 0:
        return src.new_tensor(0.0), 0

    src_kp = src[src_flat_idx]              # [K,C]
    logits = (src_kp @ trg.t()) / temp      # [K,Nt]

    logits_v = logits[valid_mask]
    gt_v     = trg_flat_idx[valid_mask]
    loss = F.cross_entropy(logits_v, gt_v)
    return loss, int(valid_mask.sum().item())

In [10]:
# ----------------------------
# Evaluation helpers
# ----------------------------
@torch.no_grad()
def tokens_to_featuremap(tok_bnc, hg, wg):
    t = tok_bnc[0]  # [N,C]
    if t.shape[0] != hg * wg:
        raise RuntimeError(f"N={t.shape[0]} != hg*wg={hg*wg}")
    return F.normalize(t.view(hg, wg, -1), dim=-1)

@torch.no_grad()
def argmax_cosine(Ft_flat: torch.Tensor, f_src: torch.Tensor, chunk: int = 4096) -> int:
    best_val = None
    best_idx = 0
    N = Ft_flat.shape[0]
    for s in range(0, N, chunk):
        part = Ft_flat[s:s+chunk]
        sim = (part * f_src).sum(dim=-1)
        v, i = sim.max(dim=0)
        v = float(v.item())
        i = int(i.item()) + s
        if (best_val is None) or (v > best_val):
            best_val, best_idx = v, i
    return best_idx

@torch.no_grad()
def match_one_pair(sample, sim_chunk=4096):
    dinov3.eval()  # <<< FIX: garantisce no dropout/stochasticity

    src_img = sample["src_img"].to(device)
    trg_img = sample["trg_img"].to(device)
    src_kps = sample["src_kps"].to(device)
    trg_kps = sample["trg_kps"].to(device)

    src_x = normalize_img_chw_0_255(src_img)
    trg_x = normalize_img_chw_0_255(trg_img)

    src_pad, (Hs, Ws), (Hs_pad, Ws_pad) = pad_to_patch_multiple(src_x, PATCH)
    trg_pad, (Ht, Wt), (Ht_pad, Wt_pad) = pad_to_patch_multiple(trg_x, PATCH)

    hg_s, wg_s = Hs_pad // PATCH, Ws_pad // PATCH
    hg_t, wg_t = Ht_pad // PATCH, Wt_pad // PATCH

    Ns = hg_s * wg_s
    Nt = hg_t * wg_t

    src_tok = get_patch_tokens(dinov3, src_pad.unsqueeze(0), expected_n=Ns)
    trg_tok = get_patch_tokens(dinov3, trg_pad.unsqueeze(0), expected_n=Nt)

    Fs = tokens_to_featuremap(src_tok, hg_s, wg_s)
    Ft = tokens_to_featuremap(trg_tok, hg_t, wg_t)
    Ft_flat = Ft.view(-1, Ft.shape[-1])

    preds, gts = [], []
    for sp, gt in zip(src_kps, trg_kps):
        if (sp[0] < 0) or (sp[1] < 0) or (gt[0] < 0) or (gt[1] < 0):
            continue

        x_src, y_src = float(sp[0].item()), float(sp[1].item())
        x_gt,  y_gt  = float(gt[0].item()), float(gt[1].item())

        if not (0.0 <= x_src < Ws and 0.0 <= y_src < Hs):
            continue
        if not (0.0 <= x_gt  < Wt and 0.0 <= y_gt  < Ht):
            continue

        jsrc = min(int(x_src) // PATCH, wg_s - 1)
        isrc = min(int(y_src) // PATCH, hg_s - 1)
        f_src = Fs[isrc, jsrc]

        best = argmax_cosine(Ft_flat, f_src, chunk=sim_chunk)
        it = best // wg_t
        jt = best %  wg_t

        x_pred = jt * PATCH + (PATCH / 2.0)
        y_pred = it * PATCH + (PATCH / 2.0)

        preds.append((x_pred, y_pred))
        gts.append((x_gt, y_gt))

    if len(preds) == 0:
        return torch.zeros((0,2)), torch.zeros((0,2))
    return torch.tensor(preds, dtype=torch.float32), torch.tensor(gts, dtype=torch.float32)


@torch.no_grad()
def evaluate_pck_report(loader, name="EVAL", max_pairs=None, sim_chunk=4096, print_per_category=True):
    dinov3.eval()  # important: always eval during metric

    thresholds = (0.05, 0.10, 0.20)

    global_correct = {T: 0.0 for T in thresholds}
    global_total_kp = 0.0
    all_pck_img = {T: [] for T in thresholds}

    cat_correct = {}
    cat_total   = {}
    cat_pck_img = {}

    t0 = time.time()
    pairs_seen = 0
    pairs_used = 0

    for batch in loader:
        if max_pairs is not None and pairs_seen >= int(max_pairs):
            break
        pairs_seen += 1

        sample = {
            "src_img": batch["src_img"][0],
            "trg_img": batch["trg_img"][0],
            "src_kps": batch["src_kps"][0],
            "trg_kps": batch["trg_kps"][0],
            "category": batch["category"][0],
            "trg_bbox": batch["trg_bbox"][0],   # tensor [4]
        }

        pred, gt = match_one_pair(sample, sim_chunk=sim_chunk)
        if pred.shape[0] == 0:
            continue

        x0, y0, x1, y1 = sample["trg_bbox"].detach().cpu().view(-1).tolist()
        x0, y0, x1, y1 = float(x0), float(y0), float(x1), float(y1)
        norm = max(x1 - x0, y1 - y0)
        if norm <= 1e-6:
            continue

        dists = torch.linalg.norm(pred - gt, dim=1)
        N = float(dists.numel())
        if N <= 0:
            continue

        pairs_used += 1

        global_total_kp += N
        cat = sample["category"]
        cat_correct.setdefault(cat, {T: 0.0 for T in thresholds})
        cat_total.setdefault(cat, 0.0)
        cat_pck_img.setdefault(cat, {T: [] for T in thresholds})
        cat_total[cat] += N

        for T in thresholds:
            thr = T * norm
            correct = float((dists <= thr).float().sum().item())
            pck_img = correct / N

            global_correct[T] += correct
            all_pck_img[T].append(pck_img)

            cat_correct[cat][T] += correct
            cat_pck_img[cat][T].append(pck_img)

        if pairs_seen % 200 == 0:
            print(f"[{name}] seen={pairs_seen} used={pairs_used}")

    minutes = (time.time() - t0) / 60.0

    mean_pck_img = {T: float(np.mean(all_pck_img[T])) if len(all_pck_img[T]) else 0.0 for T in thresholds}
    global_pck_kp = {T: float(global_correct[T] / max(global_total_kp, 1.0)) for T in thresholds}

    print("\n" + "="*18 + f" {name} REPORT " + "="*18)
    print(f"Pairs run: {pairs_used} (seen: {pairs_seen})")
    print("\nGlobal PCK (per-image mean):")
    for T in thresholds:
        print(f"  PCK@{T:.2f}: {100.0*mean_pck_img[T]:.2f}%")
    print("\nGlobal PCK (per-keypoint):")
    for T in thresholds:
        print(f"  PCK@{T:.2f}: {100.0*global_pck_kp[T]:.2f}%")

    per_cat_rows = []
    if print_per_category and len(cat_total) > 0:
        for cat in sorted(cat_total.keys()):
            row = {"Category": cat}
            for T in thresholds:
                kp = float(cat_correct[cat][T] / max(cat_total[cat], 1.0))
                im = float(np.mean(cat_pck_img[cat][T])) if len(cat_pck_img[cat][T]) else 0.0
                row[f"KP@{T:.2f}"]  = 100.0 * kp
                row[f"IMG@{T:.2f}"] = 100.0 * im
            per_cat_rows.append(row)

        df_cat = pd.DataFrame(per_cat_rows)
        print("\n" + "="*16 + " PER-CATEGORY RESULTS " + "="*16)
        cols = ["Category",
                "KP@0.05","KP@0.10","KP@0.20",
                "IMG@0.05","IMG@0.10","IMG@0.20"]
        cols = [c for c in cols if c in df_cat.columns]
        print(df_cat[cols].to_string(index=False))

    print(f"\nMinutes: {minutes:.4f}")

    return {
        "pairs_seen": pairs_seen,
        "pairs_run": pairs_used,
        "minutes": minutes,
        "global_img": mean_pck_img,
        "global_kp": global_pck_kp,
        "per_category": per_cat_rows
    }

In [11]:
# ----------------------------
# Training (steps-based) + AMP
# ----------------------------
from torch.amp import autocast, GradScaler

def train_steps(
    loader,
    n_last_blocks=1,
    lr=1e-5,
    weight_decay=0.05,
    temp=0.1,
    max_steps=2000,
    log_every=200,
    use_amp=True
):
    set_trainable_last_blocks(dinov3, n_last_blocks=n_last_blocks, train_final_norm=True)
    dinov3.train()

    params = [p for p in dinov3.parameters() if p.requires_grad]
    n_trainable = sum(p.numel() for p in params)
    print(f"[train] n_last_blocks={n_last_blocks} | trainable params: {n_trainable:,}")
    assert n_trainable > 0, "No trainable parameters!"

    opt = torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay)
    scaler = GradScaler("cuda", enabled=(use_amp and device.type == "cuda"))

    running_loss = 0.0
    running_kps  = 0
    t0 = time.time()

    it = iter(loader)
    for step in range(int(max_steps)):
        try:
            batch = next(it)
        except StopIteration:
            it = iter(loader)
            batch = next(it)

        src_img = batch["src_img"][0].to(device)
        trg_img = batch["trg_img"][0].to(device)
        src_kps = batch["src_kps"][0].to(device)
        trg_kps = batch["trg_kps"][0].to(device)

        src_x = normalize_img_chw_0_255(src_img)
        trg_x = normalize_img_chw_0_255(trg_img)
        src_pad, _, (Hs_pad, Ws_pad) = pad_to_patch_multiple(src_x, PATCH)
        trg_pad, _, (Ht_pad, Wt_pad) = pad_to_patch_multiple(trg_x, PATCH)

        src_flat, src_valid, hg_s, wg_s = kps_to_flat_indices(src_kps, Hs_pad, Ws_pad, PATCH)
        trg_flat, trg_valid, hg_t, wg_t = kps_to_flat_indices(trg_kps, Ht_pad, Wt_pad, PATCH)
        valid = src_valid & trg_valid

        Ns = hg_s * wg_s
        Nt = hg_t * wg_t

        opt.zero_grad(set_to_none=True)

        with autocast("cuda", enabled=scaler.is_enabled()):
            src_tok = get_patch_tokens(dinov3, src_pad.unsqueeze(0), expected_n=Ns)
            trg_tok = get_patch_tokens(dinov3, trg_pad.unsqueeze(0), expected_n=Nt)
            loss, nvalid = kp_ce_loss(src_tok, trg_tok, src_flat, trg_flat, valid, temp=temp)

        if nvalid == 0:
            continue

        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()

        running_loss += float(loss.item()) * nvalid
        running_kps  += nvalid

        if (step + 1) % log_every == 0:
            avg = running_loss / max(running_kps, 1)
            dt = time.time() - t0
            print(f"[train] step {step+1}/{max_steps} | avg_loss {avg:.4f} | seen_kps {running_kps} | {dt:.1f}s")

    return running_loss / max(running_kps, 1)

In [12]:
# ----------------------------
# Sanity checks (optional but very useful)
# ----------------------------
def overfit_sanity(
    n_last_blocks=2,
    lr=5e-5,
    weight_decay=0.05,   # NEW
    temp=0.1,
    steps=800,
    eval_pairs=200
):
    print("\n" + "="*70)
    print("SANITY CHECK: OVERFIT (quick)")
    print("="*70)

    restore_pretrained()
    set_trainable_last_blocks(dinov3, 0)
    dinov3.eval()
    base = evaluate_pck_report(val_loader, name="VAL (baseline, frozen)", max_pairs=eval_pairs, print_per_category=False)

    restore_pretrained()
    avg_loss = train_steps(
        train_loader,
        n_last_blocks=n_last_blocks,
        lr=lr,
        weight_decay=weight_decay,  # NEW
        temp=temp,
        max_steps=steps,
        log_every=200
    )

    dinov3.eval()
    fin = evaluate_pck_report(val_loader, name="VAL (after overfit)", max_pairs=eval_pairs, print_per_category=False)

    print(f"\n[overfit] avg_train_loss: {avg_loss:.4f}")
    print(f"[overfit] baseline val KP@0.10: {100*base['global_kp'][0.10]:.2f}%  -> after: {100*fin['global_kp'][0.10]:.2f}%")

def lr_finder(
    n_last_blocks=2,
    lrs=(1e-6, 3e-6, 1e-5, 3e-5, 1e-4),
    weight_decay=0.05,   # NEW
    temp=0.1,
    train_steps_each=300,
    eval_pairs=300
):
    print("\n" + "="*70)
    print("LR FINDER")
    print("="*70)

    rows = []
    for lr in lrs:
        print("\n" + "-"*60)
        print(f"[lr_finder] lr={lr:g} | blocks={n_last_blocks}")
        print("-"*60)

        restore_pretrained()
        avg_loss = train_steps(
            train_loader,
            n_last_blocks=n_last_blocks,
            lr=lr,
            weight_decay=weight_decay,  # NEW
            temp=temp,
            max_steps=train_steps_each,
            log_every=200
        )

        dinov3.eval()
        val_rep = evaluate_pck_report(val_loader, name=f"VAL (lr={lr:g})", max_pairs=eval_pairs, print_per_category=False)

        rows.append({
            "lr": lr,
            "n_last_blocks": n_last_blocks,
            "temp": temp,
            "weight_decay": weight_decay,  # NEW (nice for logging)
            "train_steps": train_steps_each,
            "avg_train_loss": avg_loss,
            "val_kp_PCK@0.05": 100*val_rep["global_kp"][0.05],
            "val_kp_PCK@0.10": 100*val_rep["global_kp"][0.10],
            "val_kp_PCK@0.20": 100*val_rep["global_kp"][0.20],
            "val_img_PCK@0.05": 100*val_rep["global_img"][0.05],
            "val_img_PCK@0.10": 100*val_rep["global_img"][0.10],
            "val_img_PCK@0.20": 100*val_rep["global_img"][0.20],
        })

    df = pd.DataFrame(rows).sort_values("lr").reset_index(drop=True)
    print("\n=== LR FINDER SUMMARY (choose on VAL, e.g. KP@0.10) ===")
    print(df.to_string(index=False))
    return df

In [13]:
# ----------------------------
# Task2: sweep blocks, select best on VAL, test only best
# ----------------------------
def run_task2_sweep(
    settings=(0,1,2,4),
    lr=1e-5,
    temp=0.1,
    weight_decay=0.05,
    train_steps_per_setting=2000,
    val_pairs=1000,
    test_pairs=1000,
    compute_test_baseline=False   # <<< IMPORTANTE
):
    print("\n" + "="*80)
    print("TASK2 SWEEP (select best on VAL; test only best)")
    print("="*80)

    # ---------- BASELINE (VAL ONLY) ----------
    restore_pretrained()
    set_trainable_last_blocks(dinov3, 0)
    dinov3.eval()

    base_val = evaluate_pck_report(
        val_loader,
        name="VAL BASELINE (frozen)",
        max_pairs=val_pairs,
        print_per_category=True
    )

    base_test = None
    if compute_test_baseline:
        base_test = evaluate_pck_report(
            test_loader,
            name="TEST BASELINE (frozen)",
            max_pairs=test_pairs,
            print_per_category=False
        )

    # ---------- SWEEP ----------
    rows = []
    for n_last_blocks in settings:
        print("\n" + "="*70)
        print(f"RUN setting: n_last_blocks={n_last_blocks}")
        print("="*70)

        restore_pretrained()

        avg_loss = None
        if n_last_blocks > 0:
            avg_loss = train_steps(
                train_loader,
                n_last_blocks=n_last_blocks,
                lr=lr,
                weight_decay=weight_decay,
                temp=temp,
                max_steps=train_steps_per_setting,
                log_every=200
            )

        dinov3.eval()
        val_rep = evaluate_pck_report(
            val_loader,
            name=f"VAL FINETUNED (blocks={n_last_blocks})",
            max_pairs=val_pairs,
            print_per_category=False
        )

        rows.append({
            "n_last_blocks": n_last_blocks,
            "lr": lr,
            "temp": temp,
            "train_steps": train_steps_per_setting if n_last_blocks > 0 else 0,
            "avg_train_loss": avg_loss,
            "val_kp_PCK@0.10": 100 * val_rep["global_kp"][0.10],
        })

    df = pd.DataFrame(rows).sort_values("n_last_blocks").reset_index(drop=True)
    print("\n=== SWEEP SUMMARY (VAL) ===")
    print(df.to_string(index=False))

    # ---------- SELECT BEST ----------
    best_idx = int(df["val_kp_PCK@0.10"].values.argmax())
    best = df.iloc[best_idx].to_dict()
    best_blocks = int(best["n_last_blocks"])

    print("\n" + "="*70)
    print(f"BEST (by VAL KP@0.10): n_last_blocks={best_blocks}")
    print("="*70)

    # ---------- FINAL TRAIN + TEST ----------
    restore_pretrained()
    if best_blocks > 0:
        _ = train_steps(
            train_loader,
            n_last_blocks=best_blocks,
            lr=lr,
            weight_decay=weight_decay,
            temp=temp,
            max_steps=train_steps_per_setting,
            log_every=200
        )

    dinov3.eval()
    best_test = evaluate_pck_report(
        test_loader,
        name=f"TEST FINETUNED (BEST blocks={best_blocks})",
        max_pairs=test_pairs,
        print_per_category=True
    )

    return df, best, base_val, base_test, best_test

In [14]:
# ----------------------------
# 0) (Opzionale) sanity checks come il collega
# ----------------------------
# overfit_sanity(n_last_blocks=2, lr=5e-5, weight_decay=0.05, steps=800, eval_pairs=200)

# df_lr = lr_finder(
    # n_last_blocks=2,
    # lrs=(1e-6, 3e-6, 1e-5, 3e-5, 1e-4),
    # weight_decay=0.05,
    # train_steps_each=300,
    # eval_pairs=300
# )


SANITY CHECK: OVERFIT (quick)
[VAL (baseline, frozen)] seen=200 used=200

Pairs run: 200 (seen: 200)

Global PCK (per-image mean):
  PCK@0.05: 34.60%
  PCK@0.10: 52.54%
  PCK@0.20: 63.43%

Global PCK (per-keypoint):
  PCK@0.05: 37.14%
  PCK@0.10: 55.20%
  PCK@0.20: 67.32%

Minutes: 4.6610
[train] n_last_blocks=2 | trainable params: 14,180,352
[train] step 200/800 | avg_loss 3.7358 | seen_kps 1422 | 676.0s
[train] step 400/800 | avg_loss 3.3616 | seen_kps 2916 | 1037.6s
[train] step 600/800 | avg_loss 3.1668 | seen_kps 4338 | 1362.1s
[train] step 800/800 | avg_loss 3.0120 | seen_kps 5817 | 1629.1s
[VAL (after overfit)] seen=200 used=200

Pairs run: 200 (seen: 200)

Global PCK (per-image mean):
  PCK@0.05: 54.71%
  PCK@0.10: 73.42%
  PCK@0.20: 85.25%

Global PCK (per-keypoint):
  PCK@0.05: 54.96%
  PCK@0.10: 72.79%
  PCK@0.20: 85.38%

Minutes: 0.3625

[overfit] avg_train_loss: 3.0120
[overfit] baseline val KP@0.10: 55.20%  -> after: 72.79%

LR FINDER

-----------------------------------

In [14]:
# ----------------------------
# 1) Sweep Task2 (coerente): selezione su VAL, test solo sul best
# ----------------------------
settings = (0, 1, 2, 4)
VAL_PAIRS  = None   # set None for full
TEST_PAIRS = None   # set None for full

df_sweep, best_cfg, base_val, base_test, best_test = run_task2_sweep(
    settings=(0,1,2,4),
    lr=3e-5,
    temp=0.1,
    weight_decay=0.05,
    train_steps_per_setting=2000,
    val_pairs=None,
    test_pairs=None,
    compute_test_baseline=False
)

print("\n=== DONE ===")



TASK2 SWEEP (select best on VAL; test only best)
[VAL BASELINE (frozen)] seen=200 used=200
[VAL BASELINE (frozen)] seen=400 used=400
[VAL BASELINE (frozen)] seen=600 used=600
[VAL BASELINE (frozen)] seen=800 used=800
[VAL BASELINE (frozen)] seen=1000 used=1000

Pairs run: 1070 (seen: 1070)

Global PCK (per-image mean):
  PCK@0.05: 28.81%
  PCK@0.10: 46.24%
  PCK@0.20: 61.11%

Global PCK (per-keypoint):
  PCK@0.05: 31.74%
  PCK@0.10: 50.21%
  PCK@0.20: 66.26%

   Category   KP@0.05   KP@0.10   KP@0.20  IMG@0.05  IMG@0.10  IMG@0.20
  aeroplane 39.436620 55.281690 68.485915 36.977169 53.396684 65.001936
    bicycle 34.604106 51.319648 64.222874 35.101065 49.650837 62.893437
       bird 41.275168 68.120805 78.859060 40.067875 68.800438 78.537358
       boat 17.153285 24.817518 39.416058 14.290344 21.128968 33.134921
     bottle 22.676580 41.635688 60.966543 21.303461 39.168320 57.379299
        bus 28.295820 42.443730 53.376206 19.877840 30.819328 40.307345
        car 33.333333 48.888889