In [45]:
# ============================================================
# DINOv3 — Task1 PCA + Qualitative
# - JOINT PCA: fit su [source_patches ; target_patches]
# - Visualizza PCA per Layer 10 e Last layer (Source + Target)
# - Qualitative last layer: SOURCE(kps blu) | TARGET(pred rosso) | TARGET(GT verde)
# - Dataset SPair-71k:
#     PairAnnotation/test/nomefile:categoria.json
#     JPEGImages/<cat>/<filename>.jpg
# ============================================================

from google.colab import drive
drive.mount("/content/drive")

import os, json, math
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt

!pip -q install scikit-learn
from sklearn.decomposition import PCA

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
device: cuda


In [46]:
# -----------------------------
# PATHS
# -----------------------------
SPAIR_ROOT = Path("/content/drive/MyDrive/AMLDataset/SPair-71k")
DINOv3_WEIGHTS = Path("/content/drive/MyDrive/AMLDataset/dinov3_vitb16_pretrain_lvd1689m-73cec8be.pth")

assert SPAIR_ROOT.exists(), f"SPair-71k non trovato: {SPAIR_ROOT}"
assert DINOv3_WEIGHTS.exists(), f"Pesi DINOv3 non trovati: {DINOv3_WEIGHTS}"

PAIR_ANN_ROOT = SPAIR_ROOT / "PairAnnotation" / "test"
IMG_ROOT      = SPAIR_ROOT / "JPEGImages"
assert PAIR_ANN_ROOT.exists(), f"Manca: {PAIR_ANN_ROOT}"
assert IMG_ROOT.exists(), f"Manca: {IMG_ROOT}"

# -----------------------------
# OUTPUT
# -----------------------------
OUT_ROOT = Path("/content/drive/MyDrive/AML_Project_Results/DINOv3_Task1")
PCA_DIR  = OUT_ROOT / "final_analysis_results_dinov3_jointPCA"
QUAL_DIR = OUT_ROOT / "qualitative_results_last_layer_dinov3"
PCA_DIR.mkdir(parents=True, exist_ok=True)
QUAL_DIR.mkdir(parents=True, exist_ok=True)
print("Saving PCA  to:", PCA_DIR)
print("Saving QUAL to:", QUAL_DIR)

# -----------------------------
# SETTINGS
# -----------------------------
CATEGORIES = ["aeroplane", "chair"]
PAIRS_PER_CATEGORY = 1          # quante figure per categoria (per PCA e QUAL)
MAX_SCAN_PER_CAT   = 200        # non serve molto qui (prendiamo le prime N)
LAYER_INTER        = 10         # layer intermedio
PATCH              = 16         # vitb16


Saving PCA  to: /content/drive/MyDrive/AML_Project_Results/DINOv3_Task1/final_analysis_results_dinov3_jointPCA
Saving QUAL to: /content/drive/MyDrive/AML_Project_Results/DINOv3_Task1/qualitative_results_last_layer_dinov3


In [47]:
# ============================================================
# Load DINOv3
# ============================================================
%cd /content
!test -d dinov3 || git clone https://github.com/facebookresearch/dinov3.git
%cd /content/dinov3
!pip -q install einops timm opencv-python torchmetrics fvcore iopath

DINOV3_DIR = "/content/dinov3"

dinov3 = torch.hub.load(
    DINOV3_DIR,
    "dinov3_vitb16",
    source="local",
    weights=str(DINOv3_WEIGHTS),
).eval().to(device)

for p in dinov3.parameters():
    p.requires_grad_(False)

assert hasattr(dinov3, "blocks"), "Il modello non ha .blocks: API diversa dal previsto"
N_BLOCKS = len(dinov3.blocks)
print("DINOv3 blocks:", N_BLOCKS)

if not (0 <= LAYER_INTER < N_BLOCKS):
    raise ValueError(f"LAYER_INTER={LAYER_INTER} fuori range [0,{N_BLOCKS-1}]")


/content
/content/dinov3
DINOv3 blocks: 12


In [48]:
# ============================================================
# Utils: IO + padding + preprocess
# ============================================================
def load_rgb(path: Path):
    return np.array(Image.open(path).convert("RGB"))

def to_chw_float_0_255(img_np: np.ndarray):
    return torch.from_numpy(img_np).permute(2,0,1).float()  # CHW [0..255]

def pad_to_multiple(img_chw: torch.Tensor, k: int):
    C, H, W = img_chw.shape
    Hpad = int(math.ceil(H / k) * k)
    Wpad = int(math.ceil(W / k) * k)
    out = torch.zeros((C, Hpad, Wpad), dtype=img_chw.dtype)
    out[:, :H, :W] = img_chw
    return out, (H, W), (Hpad, Wpad)

def grid_hw(Hpad: int, Wpad: int, patch: int):
    return Hpad // patch, Wpad // patch

mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
std  = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)

@torch.no_grad()
def preprocess_for_model(img_chw_0_255: torch.Tensor) -> torch.Tensor:
    x = (img_chw_0_255 / 255.0).unsqueeze(0).to(device)
    return (x - mean) / std

In [49]:
# ============================================================
# Dataset specifics
# ============================================================
def list_pair_jsons_for_category(cat: str):
    # PairAnnotation/test/nomefile:categoria.json
    return sorted(PAIR_ANN_ROOT.glob(f"*:{cat.lower()}.json"))

def extract_pair_fields(ann: dict):
    # campi standard SPair
    src_name = ann.get("src_imname")
    trg_name = ann.get("trg_imname")
    if src_name is None or trg_name is None:
        raise KeyError("Nel json mancano src_imname/trg_imname")
    src_kps  = ann.get("src_kps")
    trg_kps  = ann.get("trg_kps")
    trg_bbox = ann.get("trg_bbox")  # a volte None, non serve qui
    pair_id  = ann.get("pair_id") or ann.get("id")
    return src_name, trg_name, src_kps, trg_kps, trg_bbox, pair_id

def load_pair_images(cat: str, src_name: str, trg_name: str):
    src_path = IMG_ROOT / cat / src_name
    trg_path = IMG_ROOT / cat / trg_name
    if not src_path.exists():
        raise FileNotFoundError(f"Source image not found: {src_path}")
    if not trg_path.exists():
        raise FileNotFoundError(f"Target image not found: {trg_path}")
    return load_rgb(src_path), load_rgb(trg_path)

In [50]:
# ============================================================
# Feature extraction: patch tokens -> grid
# ============================================================
def safe_tokens(out):
    if isinstance(out, (tuple, list)):
        out = out[0]
    if (not torch.is_tensor(out)) or out.ndim != 3:
        raise RuntimeError(f"Unexpected tokens output: {type(out)} shape={getattr(out,'shape',None)}")
    return out

def tokens_to_patchgrid(tokens_bnc: torch.Tensor, hg: int, wg: int) -> torch.Tensor:
    tok = tokens_bnc.squeeze(0)  # (N,C)
    Npatch = hg * wg
    if tok.shape[0] < Npatch:
        raise RuntimeError(f"Ntok={tok.shape[0]} < Npatch={Npatch}")
    patch_tok = tok[-Npatch:]  # drop CLS/register robustly
    Fmap = patch_tok.view(hg, wg, -1)
    return F.normalize(Fmap, dim=-1)

@torch.no_grad()
def feat_last(img_pad_chw: torch.Tensor, hg: int, wg: int) -> torch.Tensor:
    out = dinov3.forward_features(preprocess_for_model(img_pad_chw))
    patch = out["x_norm_patchtokens"].squeeze(0)  # (Npatch, C)
    if patch.shape[0] != hg*wg:
        raise RuntimeError(f"Last patch tokens {patch.shape[0]} != hg*wg {hg*wg}")
    return F.normalize(patch.view(hg, wg, -1), dim=-1)

@torch.no_grad()
def feat_inter(img_pad_chw: torch.Tensor, layer_id: int, hg: int, wg: int) -> torch.Tensor:
    captured = {}
    def hook_fn(m, inp, out):
        captured["t"] = safe_tokens(out).detach()

    h = dinov3.blocks[layer_id].register_forward_hook(hook_fn)
    _ = dinov3.forward_features(preprocess_for_model(img_pad_chw))
    h.remove()

    if "t" not in captured:
        raise RuntimeError("Hook non ha catturato token (layer_id errato o API diversa)")
    return tokens_to_patchgrid(captured["t"], hg, wg)

In [51]:
# ============================================================
# JOINT PCA
# ============================================================
def joint_pca_rgb(Fs_hwC: torch.Tensor, Ft_hwC: torch.Tensor, n_components=3):
    hg, wg, C = Fs_hwC.shape
    assert Ft_hwC.shape[:2] == (hg, wg) and Ft_hwC.shape[2] == C, "source/target grid mismatch"

    Xs = Fs_hwC.reshape(-1, C).detach().cpu().numpy()
    Xt = Ft_hwC.reshape(-1, C).detach().cpu().numpy()
    X  = np.concatenate([Xs, Xt], axis=0)

    pca = PCA(n_components=n_components)
    _ = pca.fit_transform(X)

    Ys = pca.transform(Xs).reshape(hg, wg, n_components)
    Yt = pca.transform(Xt).reshape(hg, wg, n_components)

    Y_all = np.concatenate([Ys.reshape(-1,n_components), Yt.reshape(-1,n_components)], axis=0)
    mn = Y_all.min(axis=0, keepdims=True)
    mx = Y_all.max(axis=0, keepdims=True)

    rgb_s = (Ys - mn) / (mx - mn + 1e-8)
    rgb_t = (Yt - mn) / (mx - mn + 1e-8)
    return rgb_s, rgb_t

In [52]:
# ============================================================
# PCA Figure
# ============================================================
def save_pca_pair_figure_full(cat: str, pair_stem: str,
                              src_img: np.ndarray, trg_img: np.ndarray,
                              rgb_s_inter, rgb_t_inter,
                              rgb_s_last,  rgb_t_last):
    fig = plt.figure(figsize=(16, 9))

    ax1 = plt.subplot(2,4,1); ax1.imshow(src_img); ax1.set_title(f"SOURCE ({cat})"); ax1.axis("off")
    ax2 = plt.subplot(2,4,2); ax2.imshow(trg_img); ax2.set_title("TARGET"); ax2.axis("off")

    ax3 = plt.subplot(2,4,5); ax3.imshow(rgb_s_inter); ax3.set_title(f"JOINT PCA L{LAYER_INTER} (S)"); ax3.axis("off")
    ax4 = plt.subplot(2,4,6); ax4.imshow(rgb_t_inter); ax4.set_title(f"JOINT PCA L{LAYER_INTER} (T)"); ax4.axis("off")

    ax5 = plt.subplot(2,4,7); ax5.imshow(rgb_s_last); ax5.set_title("JOINT PCA Last (S)"); ax5.axis("off")
    ax6 = plt.subplot(2,4,8); ax6.imshow(rgb_t_last); ax6.set_title("JOINT PCA Last (T)"); ax6.axis("off")

    fig.suptitle(f"DINOv3 — JOINT PCA — {cat} — {pair_stem}", y=0.98)
    out = PCA_DIR / f"PCA_JOINT_FULL_{cat}_ID{pair_stem}.png"
    plt.tight_layout()
    plt.savefig(out, dpi=200)
    plt.close(fig)


In [53]:
# ============================================================
# Qualitative last-layer matching (Task1-style)
# ============================================================
def patch_center(ix, iy, patch=16):
    return (ix*patch + patch/2.0, iy*patch + patch/2.0)

@torch.no_grad()
def match_keypoints_lastlayer(src_img_chw, trg_img_chw, src_kps):
    # pad + grid
    src_pad, (Hs,Ws), (Hsp,Wsp) = pad_to_multiple(src_img_chw, PATCH)
    trg_pad, (Ht,Wt), (Htp,Wtp) = pad_to_multiple(trg_img_chw, PATCH)
    hg_s, wg_s = grid_hw(Hsp, Wsp, PATCH)
    hg_t, wg_t = grid_hw(Htp, Wtp, PATCH)

    Fs = feat_last(src_pad, hg_s, wg_s).reshape(hg_s*wg_s, -1)  # (Ps,C)
    Ft = feat_last(trg_pad, hg_t, wg_t).reshape(hg_t*wg_t, -1)  # (Pt,C)

    pred_xy = []
    for kp in src_kps:
        if kp is None:
            pred_xy.append((np.nan, np.nan))
            continue
        # kp = [x,y] oppure [x,y,vis]
        if len(kp) >= 3 and kp[2] == 0:
            pred_xy.append((np.nan, np.nan))
            continue

        x,y = kp[0], kp[1]
        ix = int(np.clip(x//PATCH, 0, wg_s-1))
        iy = int(np.clip(y//PATCH, 0, hg_s-1))
        src_patch_id = iy*wg_s + ix

        sims = Ft @ Fs[src_patch_id]  # (Pt,) cosine sim (features normalized)
        j = int(torch.argmax(sims).item())
        jy = j // wg_t
        jx = j %  wg_t

        px,py = patch_center(jx, jy, PATCH)
        px = float(np.clip(px, 0, Wt-1))
        py = float(np.clip(py, 0, Ht-1))
        pred_xy.append((px,py))

    return pred_xy

def save_qual_figure(cat, pair_stem, src_np, trg_np, src_kps, trg_kps, pred_xy):
    fig = plt.figure(figsize=(16,5))

    # SOURCE (blue)
    ax1 = plt.subplot(1,3,1)
    ax1.imshow(src_np); ax1.set_title(f"SOURCE ({cat})"); ax1.axis("off")
    for kp in src_kps:
        if kp is None:
            continue
        if len(kp) >= 3 and kp[2] == 0:
            continue
        ax1.scatter(kp[0], kp[1], s=120, c="blue")

    # PRED (red X) on TARGET
    ax2 = plt.subplot(1,3,2)
    ax2.imshow(trg_np); ax2.set_title("PREDICTION (Last Layer)"); ax2.axis("off")
    for p in pred_xy:
        if np.isnan(p[0]) or np.isnan(p[1]):
            continue
        ax2.scatter(p[0], p[1], s=120, c="red", marker="x")

    # GT (green) on TARGET
    ax3 = plt.subplot(1,3,3)
    ax3.imshow(trg_np); ax3.set_title("GROUND TRUTH"); ax3.axis("off")
    for kp in trg_kps:
        if kp is None:
            continue
        if len(kp) >= 3 and kp[2] == 0:
            continue
        ax3.scatter(kp[0], kp[1], s=120, c="lime")

    out = QUAL_DIR / f"Qual_{cat}_ID{pair_stem}.png"
    plt.tight_layout()
    plt.savefig(out, dpi=200)
    plt.close(fig)

In [54]:
# ============================================================
# Main: per categoria prendi le prime N coppie e salva PCA + QUAL
# ============================================================
@torch.no_grad()
def process_category(cat: str):
    jsons = list_pair_jsons_for_category(cat)
    if len(jsons) == 0:
        print(f"[WARN] Nessun json per {cat}")
        return

    picked = jsons[:min(MAX_SCAN_PER_CAT, len(jsons))][:PAIRS_PER_CATEGORY]
    print(f"{cat}: selected {len(picked)} pair(s)")

    for jp in picked:
        with open(jp, "r") as f:
            ann = json.load(f)

        src_name, trg_name, src_kps, trg_kps, _, _ = extract_pair_fields(ann)
        if src_kps is None or trg_kps is None:
            raise KeyError("Nel json mancano src_kps / trg_kps")

        pair_stem = jp.stem  # utile per confronto v2/v3

        # load images
        src_np, trg_np = load_pair_images(cat, src_name, trg_name)
        src_chw = to_chw_float_0_255(src_np)
        trg_chw = to_chw_float_0_255(trg_np)

        # -------- PCA --------
        src_pad, _, (Hsp, Wsp) = pad_to_multiple(src_chw, PATCH)
        trg_pad, _, (Htp, Wtp) = pad_to_multiple(trg_chw, PATCH)
        hg_s, wg_s = grid_hw(Hsp, Wsp, PATCH)
        hg_t, wg_t = grid_hw(Htp, Wtp, PATCH)

        hg = min(hg_s, hg_t)
        wg = min(wg_s, wg_t)

        Fs_inter_full = feat_inter(src_pad, LAYER_INTER, hg_s, wg_s)
        Ft_inter_full = feat_inter(trg_pad, LAYER_INTER, hg_t, wg_t)
        Fs_inter = Fs_inter_full[:hg, :wg, :]
        Ft_inter = Ft_inter_full[:hg, :wg, :]
        rgb_s_inter, rgb_t_inter = joint_pca_rgb(Fs_inter, Ft_inter)

        Fs_last_full = feat_last(src_pad, hg_s, wg_s)
        Ft_last_full = feat_last(trg_pad, hg_t, wg_t)
        Fs_last = Fs_last_full[:hg, :wg, :]
        Ft_last = Ft_last_full[:hg, :wg, :]
        rgb_s_last, rgb_t_last = joint_pca_rgb(Fs_last, Ft_last)

        save_pca_pair_figure_full(
            cat, pair_stem,
            src_np, trg_np,
            rgb_s_inter, rgb_t_inter,
            rgb_s_last,  rgb_t_last
        )

        # -------- QUAL (Last layer) --------
        pred_xy = match_keypoints_lastlayer(src_chw, trg_chw, src_kps)
        save_qual_figure(cat, pair_stem, src_np, trg_np, src_kps, trg_kps, pred_xy)

        print("  saved PCA :", f"PCA_JOINT_FULL_{cat}_ID{pair_stem}.png")
        print("  saved QUAL:", f"Qual_{cat}_ID{pair_stem}.png")

for cat in CATEGORIES:
    process_category(cat)

print("DONE.")
print("PCA_DIR :", PCA_DIR)
print("QUAL_DIR:", QUAL_DIR)

aeroplane: selected 1 pair(s)
  saved PCA : PCA_JOINT_FULL_aeroplane_ID000001-2008_002719-2008_004100:aeroplane.png
  saved QUAL: Qual_aeroplane_ID000001-2008_002719-2008_004100:aeroplane.png
chair: selected 1 pair(s)
  saved PCA : PCA_JOINT_FULL_chair_ID005637-2008_000089-2008_001467:chair.png
  saved QUAL: Qual_chair_ID005637-2008_000089-2008_001467:chair.png
DONE.
PCA_DIR : /content/drive/MyDrive/AML_Project_Results/DINOv3_Task1/final_analysis_results_dinov3_jointPCA
QUAL_DIR: /content/drive/MyDrive/AML_Project_Results/DINOv3_Task1/qualitative_results_last_layer_dinov3
