In [2]:
import os
import torch
import torchvision
from torchvision import datasets
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor, Lambda
import matplotlib.pyplot as plt
import requests
from zipfile import ZipFile
from io import BytesIO
import numpy as np
import zipfile
import os
from pathlib import Path

# Monta Google Drive
from google.colab import drive
if not os.path.exists('/content/drive'):
    drive.mount("/content/drive")

# Configurazione Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1. Crea la cartella di destinazione
!mkdir -p /content/spair_local

# 2. Estrai il contenuto (cambia il percorso con quello corretto del tuo Drive)
!tar -xf "/content/drive/MyDrive/AMLDataset/SPair-71k.tar" -C /content/spair_local

print("Scompattamento completato!")

# Percorsi Globali
SPAIR_ROOT = Path("/content/spair_local/SPair-71k")
PAIR_ANN_PATH = SPAIR_ROOT / "PairAnnotation"
IMAGE_PATH    = SPAIR_ROOT / "JPEGImages"
LAYOUT_PATH   = SPAIR_ROOT / "Layout"

# Verifica rapida
assert SPAIR_ROOT.exists(), "ERRORE: Dataset locale non trovato in /content/spair_local"

Scompattamento completato!


In [3]:
# ----------------------------
# Load DINOv3
# ----------------------------
%cd /content
!test -d dinov3 || git clone https://github.com/facebookresearch/dinov3.git
%cd /content/dinov3
!pip -q install einops timm opencv-python torchmetrics fvcore iopath

DINOV3_DIR =Path("/content/dinov3")

DINOV3_WEIGHTS = Path("/content/drive/MyDrive/AMLDataset/dinov3_vitb16_pretrain_lvd1689m-73cec8be.pth")

assert os.path.exists(DINOV3_WEIGHTS), f"Pesi DINOv3 non trovati: {DINOV3_WEIGHTS}"

DEFAULT_MODEL_NAME = "dinov3_vitb16"

def load_dinov3_backbone(
    *,
    dinov3_dir: str | Path,
    weights_path: str | Path,
    device: torch.device | str = "cpu",
    sanity_input_size: int = 512,
    verbose: bool = True,
) -> torch.nn.Module:
    """
    Load DINOv3 ViT-B/16 backbone for Task 1 (training-free).

    Requirements:
      - Official DINOv3 repository cloned locally
      - Pretrained checkpoint downloaded separately (licensed)
    """

    # ------------------------------------------------------------------
    # 2) Load model from local repo via torch.hub (same behavior as notebook)
    # ------------------------------------------------------------------
    model = torch.hub.load(
        str(dinov3_dir),
        DEFAULT_MODEL_NAME,
        source="local",
        weights=str(weights_path),
    )

    model.to(device).eval()

    # ------------------------------------------------------------------
    # 3) Freeze backbone (Task 1 compliant)
    # ------------------------------------------------------------------
    for p in model.parameters():
        p.requires_grad_(False)

    # ------------------------------------------------------------------
    # 4) Sanity checks: architecture + patch size
    # ------------------------------------------------------------------
    if not hasattr(model, "blocks"):
        raise RuntimeError("[DINOv3] Loaded model has no attribute 'blocks' (API mismatch?)")

    if not hasattr(model, "patch_embed") or not hasattr(model.patch_embed, "patch_size"):
        raise RuntimeError("[DINOv3] patch_embed.patch_size not found (API mismatch?)")

    n_blocks = len(model.blocks)
    patch = model.patch_embed.patch_size
    patch_int = patch[0] if isinstance(patch, (tuple, list)) else int(patch)

    if patch_int != 16:
        raise RuntimeError(f"[DINOv3] Expected patch size 16 for ViT-B/16, got {patch}")

    # ------------------------------------------------------------------
    # 5) Token-grid sanity check (important for correspondence)
    # ------------------------------------------------------------------
    if sanity_input_size is not None:
        x = torch.randn(1, 3, sanity_input_size, sanity_input_size, device=device)
        with torch.no_grad():
            feats = model.get_intermediate_layers(x, n=1)[0]  # [B, N, C]

        expected_n = (sanity_input_size // patch_int) ** 2
        if feats.shape[1] != expected_n:
            raise RuntimeError(
                f"[DINOv3] Unexpected token count: {feats.shape[1]} vs expected {expected_n}. "
                "Check input size / patch size / token handling."
            )

    if verbose:
        print(
            f"[DINOv3] loaded ViT-B/16 | blocks={n_blocks} | patch={patch_int} | "
            f"checkpoint={weights_path.name}"
        )

    return model

model = load_dinov3_backbone(
    dinov3_dir=DINOV3_DIR,
    weights_path=DINOV3_WEIGHTS,
    device=device,
    sanity_input_size=512,
    verbose=True,

).eval().to(device)

with torch.no_grad():
    x = model.forward_features(torch.zeros(1, 3, 512, 512, device=device))
    print("x_norm_patchtokens:", x["x_norm_patchtokens"].shape)

/content
Cloning into 'dinov3'...
remote: Enumerating objects: 538, done.[K
remote: Counting objects: 100% (363/363), done.[K
remote: Compressing objects: 100% (264/264), done.[K
remote: Total 538 (delta 201), reused 99 (delta 99), pack-reused 175 (from 1)[K
Receiving objects: 100% (538/538), 9.88 MiB | 16.64 MiB/s, done.
Resolving deltas: 100% (223/223), done.
/content/dinov3
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.2/50.2 kB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m36.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for fvcore (setup.py) ... [?25l[?25hdone
  Building wheel for iopath (setup.py) ... [?25l[?25hdone
Downloading

100%|██████████| 327M/327M [00:06<00:00, 52.4MB/s]


[DINOv3] loaded ViT-B/16 | blocks=12 | patch=16 | checkpoint=dinov3_vitb16_pretrain_lvd1689m-73cec8be.pth
x_norm_patchtokens: torch.Size([1, 1024, 768])


In [4]:
from PIL import Image
import glob
import json
import numpy as np
import torch

class Normalize(object):
    def __init__(self, image_keys):
        self.image_keys = image_keys
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    def __call__(self, image):
        for key in self.image_keys:
            image[key] /= 255.0
            image[key] = self.normalize(image[key])
        return image


def read_img(path):
    img = np.array(Image.open(path).convert('RGB'))

    return torch.tensor(img.transpose(2, 0, 1).astype(np.float32))


class SPairDataset(Dataset):
    def __init__(self, pair_ann_path, layout_path, image_path, dataset_size, pck_alpha, datatype):

        self.datatype = datatype
        self.pck_alpha = pck_alpha
        self.ann_files = open(os.path.join(layout_path, dataset_size, datatype + '.txt'), "r").read().split('\n')
        self.ann_files = self.ann_files[:len(self.ann_files) - 1]
        self.pair_ann_path = pair_ann_path
        self.image_path = image_path
        self.categories = list(map(lambda x: os.path.basename(x), glob.glob('%s/*' % image_path)))
        self.categories.sort()
        self.transform = Normalize(['src_img', 'trg_img'])

    def __len__(self):
        return len(self.ann_files)

    def __getitem__(self, idx):

        raw_line = self.ann_files[idx]
        ann_file = raw_line + '.json'
        json_path = os.path.join(self.pair_ann_path, self.datatype, ann_file)

        with open(json_path) as f:
            annotation = json.load(f)

        category = annotation['category']
        src_img = read_img(os.path.join(self.image_path, category, annotation['src_imname']))
        trg_img = read_img(os.path.join(self.image_path, category, annotation['trg_imname']))

        trg_bbox = annotation['trg_bndbox']
        pck_threshold = max(trg_bbox[2] - trg_bbox[0],  trg_bbox[3] - trg_bbox[1]) * self.pck_alpha

        sample = {'pair_id': annotation['pair_id'],
                  'filename': annotation['filename'],
                  'src_imname': annotation['src_imname'],
                  'trg_imname': annotation['trg_imname'],
                  'src_imsize': src_img.size(),
                  'trg_imsize': trg_img.size(),

                  'src_bbox': annotation['src_bndbox'],
                  'trg_bbox': annotation['trg_bndbox'],
                  'category': annotation['category'],

                  'src_pose': annotation['src_pose'],
                  'trg_pose': annotation['trg_pose'],

                  'src_img': src_img,
                  'trg_img': trg_img,
                  'src_kps': torch.tensor(annotation['src_kps']).float(),
                  'trg_kps': torch.tensor(annotation['trg_kps']).float(),

                  'mirror': annotation['mirror'],
                  'vp_var': annotation['viewpoint_variation'],
                  'sc_var': annotation['scale_variation'],
                  'truncn': annotation['truncation'],
                  'occlsn': annotation['occlusion'],

                  'pck_threshold': pck_threshold}

        if self.transform:
            sample = self.transform(sample)

        return sample


if __name__ == '__main__':
    SPAIR_ROOT = Path("/content/spair_local/SPair-71k")
    pair_ann_path = os.path.join(SPAIR_ROOT, 'PairAnnotation')
    layout_path = os.path.join(SPAIR_ROOT, 'Layout')
    image_path = os.path.join(SPAIR_ROOT, 'JPEGImages')
    dataset_size = 'large'
    pck_alpha = 0.1

    # Verifica che i percorsi esistano prima di creare il dataset
    if os.path.exists(pair_ann_path) and os.path.exists(layout_path) and os.path.exists(image_path):
        trn_dataset = SPairDataset(pair_ann_path, layout_path, image_path, dataset_size, pck_alpha, datatype='trn')
        val_dataset = SPairDataset(pair_ann_path, layout_path, image_path, dataset_size, pck_alpha, datatype='val')
        test_dataset = SPairDataset(pair_ann_path, layout_path, image_path, dataset_size, pck_alpha, datatype='test')

        trn_dataloader = DataLoader(trn_dataset, num_workers=0)
        val_dataloader = DataLoader(val_dataset, num_workers=0)
        test_dataloader = DataLoader(test_dataset, num_workers=0)
        print("Dataset caricati correttamente.")
    else:
        print(f"Errore: Impossibile trovare i percorsi del dataset in '{base_dir}'.\nVerifica l'estrazione e controlla se la struttura delle cartelle corrisponde.")

    # Verifica che i percorsi esistano prima di creare il dataset
    if os.path.exists(pair_ann_path) and os.path.exists(layout_path) and os.path.exists(image_path):
        trn_dataset = SPairDataset(pair_ann_path, layout_path, image_path, dataset_size, pck_alpha, datatype='trn')
        val_dataset = SPairDataset(pair_ann_path, layout_path, image_path, dataset_size, pck_alpha, datatype='val')
        test_dataset = SPairDataset(pair_ann_path, layout_path, image_path, dataset_size, pck_alpha, datatype='test')

        trn_dataloader = DataLoader(trn_dataset, num_workers=0)
        val_dataloader = DataLoader(val_dataset, num_workers=0)
        test_dataloader = DataLoader(test_dataset, num_workers=0)
        print("Dataset caricati correttamente.")
    else:
        print(f"Errore: Impossibile trovare i percorsi del dataset in '{base_dir}'.\nVerifica l'estrazione e controlla se la struttura delle cartelle corrisponde.")

Dataset caricati correttamente.
Dataset caricati correttamente.


In [5]:
import torch
import math
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
import os

# --- CONFIGURAZIONE ---
output_dir = "/content/drive/MyDrive/AMLDataset/qualitative_results_grouped"
os.makedirs(output_dir, exist_ok=True)


TARGET_CLASSES = ['aeroplane', 'chair']
POINTS_NEEDED = 3  # Numero di punti da visualizzare insieme sulla stessa immagine

# Stato: traccia se abbiamo finito una categoria
category_done = {cat: False for cat in TARGET_CLASSES}

# Helper functions
def denormalize_image(tensor):
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img = tensor.cpu().squeeze(0).permute(1, 2, 0).numpy()
    img = (img * std) + mean
    return np.clip(img, 0, 1)

def pad_to_multiple(x, k=16):
    h, w = x.shape[-2:]
    new_h = math.ceil(h / k) * k
    new_w = math.ceil(w / k) * k
    pad_bottom = new_h - h
    pad_right = new_w - w
    if pad_bottom == 0 and pad_right == 0: return x
    return F.pad(x, (0, pad_right, 0, pad_bottom), value=0)

print(f"Searching for 1 valid image per class with at least {POINTS_NEEDED} keypoints...")

with torch.no_grad():
    for i, data in enumerate(tqdm(test_dataloader, desc="Scanning")):

        category = data['category'][0]

        # Salta se categoria non richiesta o già completata
        if category not in TARGET_CLASSES: continue
        if category_done[category]: continue

        # --- PROCESSIAMO L'IMMAGINE ---
        src_img = data['src_img'].to(device)
        trg_img = data['trg_img'].to(device)
        src_img_padded = pad_to_multiple(src_img, 16)
        trg_img_padded = pad_to_multiple(trg_img, 16)

        # Forward Pass
        dict_src = model.forward_features(src_img_padded)
        dict_trg = model.forward_features(trg_img_padded)
        feats_src = dict_src["x_norm_patchtokens"]
        feats_trg = dict_trg["x_norm_patchtokens"]

        # Grid Info
        _, _, H_padded, W_padded = src_img_padded.shape
        _, _, H_orig, W_orig = data['src_img'].shape
        patch_size = 16
        w_grid = W_padded // patch_size
        h_grid = H_padded // patch_size
        kps_list_src = data['src_kps'][0]
        trg_kps_gt = data['trg_kps'][0]

        # --- LISTE PER ACCUMULARE I 3 PUNTI ---
        valid_src_points = []  # (x, y)
        valid_pred_points = [] # (x, y)
        valid_gt_points = []   # (x, y)

        # Scansioniamo tutti i keypoint dell'immagine
        for n_keypoint, keypoint_src in enumerate(kps_list_src):

            # Se ne abbiamo già trovati 3, smettiamo di calcolarne altri
            if len(valid_src_points) >= POINTS_NEEDED:
                break

            # 1. Check Source Point
            x_src = keypoint_src[0].item()
            y_src = keypoint_src[1].item()
            if math.isnan(x_src) or math.isnan(y_src): continue

            x_src, y_src = int(x_src), int(y_src)
            if not (0 <= x_src < W_orig and 0 <= y_src < H_orig): continue

            # 2. Prediction Logic
            x_patch = min(x_src // patch_size, w_grid - 1)
            y_patch = min(y_src // patch_size, h_grid - 1)
            patch_idx = (y_patch * w_grid) + x_patch
            if patch_idx >= feats_src.shape[1]: patch_idx = feats_src.shape[1] - 1

            source_vec = feats_src[0, patch_idx, :]
            sim_map = torch.cosine_similarity(source_vec, feats_trg[0], dim=-1)
            best_idx = torch.argmax(sim_map).item()

            x_pred = (best_idx % w_grid) * patch_size + (patch_size // 2)
            y_pred = (best_idx // w_grid) * patch_size + (patch_size // 2)

            # 3. Check GT Point
            x_gt = trg_kps_gt[n_keypoint, 0].item()
            y_gt = trg_kps_gt[n_keypoint, 1].item()
            if math.isnan(x_gt) or math.isnan(y_gt): continue

            # --- PUNTO VALIDO! AGGIUNGIAMO ALLE LISTE ---
            valid_src_points.append((x_src, y_src))
            valid_pred_points.append((x_pred, y_pred))
            valid_gt_points.append((x_gt, y_gt))

        # --- SE ABBIAMO TROVATO 3 PUNTI, FACCIAMO IL PLOT UNICO ---
        if len(valid_src_points) == POINTS_NEEDED:

            img_s_vis = denormalize_image(src_img)
            img_t_vis = denormalize_image(trg_img)

            fig, ax = plt.subplots(1, 3, figsize=(18, 6))

            # Colori per distinguere i punti 1, 2 e 3 (opzionale, utile per vedere le corrispondenze)
            colors = ['cyan', 'orange', 'lime']

            # PANEL 1: Source Image con 3 punti
            ax[0].imshow(img_s_vis)
            ax[0].set_title(f"SOURCE ({category})\n(Blue Crosses)")
            for idx, (x, y) in enumerate(valid_src_points):
                # Disegna croce blu grande
                ax[0].scatter(x, y, c='blue', s=150, marker='o', linewidth=3)
                # Aggiunge numero piccolo per capire quale punto è quale
                ax[0].text(x+5, y+5, str(idx+1), color='white', fontsize=12, fontweight='bold')

            # PANEL 2: Prediction Image con 3 punti
            ax[1].imshow(img_t_vis)
            ax[1].set_title(f"PREDICTION (DINOv2)\n(Red X)")
            for idx, (x, y) in enumerate(valid_pred_points):
                ax[1].scatter(x, y, c='red', s=150, marker='o', linewidth=3)
                ax[1].text(x+5, y+5, str(idx+1), color='white', fontsize=12, fontweight='bold')

            # PANEL 3: Ground Truth Image con 3 punti
            ax[2].imshow(img_t_vis)
            ax[2].set_title(f"GROUND TRUTH\n(Green Circles)")
            for idx, (x, y) in enumerate(valid_gt_points):
                ax[2].scatter(x, y, c='green', s=150, marker='o', facecolors='none', linewidth=3)
                ax[2].text(x+5, y+5, str(idx+1), color='white', fontsize=12, fontweight='bold')

            # Cleanup plot
            for a in ax: a.axis('off')

            # Salvataggio
            save_path = os.path.join(output_dir, f"Comparison_{category}_ID{i}.png")
            plt.tight_layout()
            plt.savefig(save_path)
            plt.close(fig)

            print(f"--> [SUCCESS] Saved comparison for {category} (Image ID: {i})")

            # Segna categoria come completata
            category_done[category] = True

        # Check if all classes are done
        if all(category_done.values()):
            print("\nGenerated all requested images. Exiting.")
            break

Searching for 1 valid image per class with at least 3 keypoints...


Scanning:   0%|          | 5/12234 [00:02<1:32:16,  2.21it/s]

--> [SUCCESS] Saved comparison for aeroplane (Image ID: 0)


Scanning:  44%|████▍     | 5422/12234 [01:12<01:31, 74.80it/s]

--> [SUCCESS] Saved comparison for chair (Image ID: 5422)

Generated all requested images. Exiting.



