In [30]:
from typing import TypedDict, Tuple

import torch


class SPairSample(TypedDict):
    pair_id: int
    filename: str
    src_imname: str
    trg_imname: str
    src_imsize: Tuple[int, int]
    trg_imsize: Tuple[int, int]
    src_bbox: Tuple[int, int, int, int]
    trg_bbox: Tuple[int, int, int, int]
    category: str
    src_pose: str
    trg_pose: str
    src_img: torch.Tensor
    trg_img: torch.Tensor
    src_kps: torch.Tensor
    trg_kps: torch.Tensor
    mirror: int
    vp_var: int
    sc_var: int
    truncn: int
    occlsn: int
    pck_threshold_0_05: float
    pck_threshold_0_1: float
    pck_threshold_0_2: float

In [31]:
import numpy as np
from typing import List
import torch
from PIL import Image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import os
import glob
import json

using_colab = False
base_dir = os.path.abspath(os.path.curdir)

if using_colab:
    !wget -P ./AML-polito/dataset/ "https://cvlab.postech.ac.kr/research/SPair-71k/data/SPair-71k.tar.gz"
    !tar -xvzf ./AML-polito/dataset/SPair-71k.tar.gz -C ./AML-polito/dataset/
    base_dir = os.path.join(os.path.abspath(os.path.curdir), 'AML-polito')


def read_img(path: str) -> torch.Tensor:
    img = np.array(Image.open(path).convert('RGB'))
    return torch.tensor(img.transpose(2, 0, 1).astype(np.float32))


def collate_single(batch_list: List[SPairSample]) -> SPairSample:
    # batch_size deve essere 1
    return batch_list[0]


class SPairDataset(Dataset):
    def __init__(self, pair_ann_path, layout_path, image_path, dataset_size, datatype):
        self.datatype = datatype
        self.ann_files = open(os.path.join(layout_path, dataset_size, datatype + '.txt'), "r").read().split('\n')
        self.ann_files = self.ann_files[:len(self.ann_files) - 1]
        self.pair_ann_path = pair_ann_path
        self.image_path = image_path
        self.categories = list(map(lambda x: os.path.basename(x), glob.glob('%s/*' % image_path)))
        self.categories.sort()

    def __len__(self):
        return len(self.ann_files)

    def __getitem__(self, idx) -> SPairSample:
        ann_filename = self.ann_files[idx]
        ann_file = ann_filename + '.json'
        json_path = os.path.join(self.pair_ann_path, self.datatype, ann_file)

        with open(json_path) as f:
            annotation = json.load(f)

        category = annotation['category']
        src_img = read_img(str(os.path.join(self.image_path, category, annotation['src_imname'])))
        trg_img = read_img(str(os.path.join(self.image_path, category, annotation['trg_imname'])))

        sx1, sy1, sx2, sy2 = annotation["src_bndbox"]
        tx1, ty1, tx2, ty2 = annotation["trg_bndbox"]

        sample: SPairSample = {'pair_id': int(annotation['pair_id']),
                               'filename': str(annotation['filename']),
                               'src_imname': str(annotation['src_imname']),
                               'trg_imname': str(annotation['trg_imname']),
                               'src_imsize': (int(src_img.shape[1]), int(src_img.shape[2])),  # height, width
                               'trg_imsize': (int(trg_img.shape[1]), int(trg_img.shape[2])),  # height, width

                               'src_bbox': (int(sx1), int(sy1), int(sx2), int(sy2)),
                               'trg_bbox': (int(tx1), int(ty1), int(tx2), int(ty2)),
                               'category': str(annotation['category']),

                               'src_pose': str(annotation['src_pose']),
                               'trg_pose': str(annotation['trg_pose']),

                               'src_img': src_img,
                               'trg_img': trg_img,
                               'src_kps': torch.tensor(annotation['src_kps']).float(),
                               'trg_kps': torch.tensor(annotation['trg_kps']).float(),

                               'mirror': int(annotation['mirror']),
                               'vp_var': int(annotation['viewpoint_variation']),
                               'sc_var': int(annotation['scale_variation']),
                               'truncn': int(annotation['truncation']),
                               'occlsn': int(annotation['occlusion']),

                               'pck_threshold_0_05': float(max(tx2 - tx1, ty2 - ty1) * 0.05),
                               'pck_threshold_0_1': float(max(tx2 - tx1, ty2 - ty1) * 0.1),
                               'pck_threshold_0_2': float(max(tx2 - tx1, ty2 - ty1) * 0.2)
                               }

        return sample


dataset_dir = os.path.join(base_dir, 'dataset', "SPair-71k")
pair_ann_path = os.path.join(dataset_dir, 'PairAnnotation')
layout_path = os.path.join(dataset_dir, 'Layout')
image_path = os.path.join(dataset_dir, 'JPEGImages')
dataset_size = 'large'
#pck_alpha = 0.05

# Verifica che i percorsi esistano prima di creare il dataset
if os.path.exists(pair_ann_path) and os.path.exists(layout_path) and os.path.exists(image_path):
    trn_dataset = SPairDataset(pair_ann_path, layout_path, image_path, dataset_size, datatype='trn')
    val_dataset = SPairDataset(pair_ann_path, layout_path, image_path, dataset_size, datatype='val')
    test_dataset = SPairDataset(pair_ann_path, layout_path, image_path, dataset_size, datatype='test')

    trn_dataloader = DataLoader(trn_dataset, num_workers=0)
    val_dataloader = DataLoader(val_dataset, num_workers=0)
    test_dataloader = DataLoader(test_dataset, num_workers=0, batch_size=1, collate_fn=collate_single)
    print("Dataset caricati correttamente.")
else:
    raise RuntimeError(
        f"Errore: Impossibile trovare i percorsi del dataset in '{base_dir}'.\nVerifica l'estrazione e controlla se la struttura delle cartelle corrisponde.")

Dataset caricati correttamente.


In [32]:
model_size = {
    "vit_b": ("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", "sam_vit_b_01ec64.pth"),
    "vit_l": ("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", "sam_vit_l_0b3195.pth"),
    "vit_h": ("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", "sam_vit_h_4b8939.pth")
}

selected_model = "vit_b"

if using_colab:
    !pip install git+https://github.com/facebookresearch/segment-anything.git -q
    !wget -P ./AML-polito/models/ {model_size[selected_model][0]}

import torch
from segment_anything import SamPredictor, sam_model_registry

# inizializzazione SAM

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# === 1) Carico SAM ===
sam_checkpoint_path = os.path.join(base_dir, 'models', model_size[selected_model][1])
sam = sam_model_registry["vit_b"](checkpoint=sam_checkpoint_path)
sam.to(device)
sam.eval()
predictor = SamPredictor(sam)
transform = predictor.transform

# parametri utili
IMG_SIZE = predictor.model.image_encoder.img_size  # 1024
PATCH = IMG_SIZE // 64  # 16

Device: cuda


In [33]:
def get_scale_factor(img_size: tuple) -> float:
    img_h, img_w = img_size
    resized_h, resized_w = predictor.transform.get_preprocess_shape(img_h, img_w, IMG_SIZE)

    # lato lungo originale → lato lungo resized
    if img_w >= img_h:
        return resized_w / img_w
    else:
        return resized_h / img_h


def kp_src_to_featmap(kp_src_coordinates: torch.Tensor, img_src_size: tuple):
    img_src_h, img_src_w = img_src_size

    # Compute coordinates in resized image (without padding)
    x_prepad, y_prepad = predictor.transform.apply_coords_torch(kp_src_coordinates, img_src_size)[0]

    # dimensioni resized reali
    img_resized_h_prepad, img_resized_w_prepad = predictor.transform.get_preprocess_shape(img_src_h, img_src_w,
                                                                                          IMG_SIZE)

    # resized -> feature map
    xf = int(x_prepad // PATCH)
    yf = int(y_prepad // PATCH)

    # regione valida (no padding)
    wv = math.ceil(img_resized_w_prepad / PATCH)
    hv = math.ceil(img_resized_h_prepad / PATCH)

    xf = min(max(xf, 0), wv - 1)
    yf = min(max(yf, 0), hv - 1)

    return xf, yf


def kp_featmap_to_trg(featmap_coords: tuple, trg_img_size: tuple):
    xf, yf = featmap_coords
    xr = (xf + 0.5) * PATCH
    yr = (yf + 0.5) * PATCH

    scale = get_scale_factor(trg_img_size)
    xo = xr / scale
    yo = yr / scale

    return xo, yo

In [34]:
from tqdm import tqdm
from itertools import islice
import math
import torch

results = []
qualitative_result = {
    "src_img": None,
    "trg_img": None,
    "src_kps": [],
    "trg_gt_kps": [],
    "trg_pred_kps": []
}

max_images = 1
data = islice(test_dataloader, max_images)
size = max_images
if using_colab:
    data = test_dataloader
    size = len(test_dataloader)

with torch.no_grad():
    for iter, batch in enumerate(tqdm(data, total=size, desc="Elaborazione con SAM")):

        category = batch["category"]
        src_img = batch["src_img"].to(device).unsqueeze(0)  # [1,3,Hs,Ws]
        trg_img = batch["trg_img"].to(device).unsqueeze(0)  # [1,3,Ht,Wt]
        orig_size_src = batch["src_imsize"]  # (Hs, Ws)
        orig_size_trg = batch["trg_imsize"]  # (Ht, Wt)

        src_resized = predictor.transform.apply_image_torch(src_img)  # [1,3,Hs',Ws']
        trg_resized = predictor.transform.apply_image_torch(trg_img)  # [1,3,Ht',Wt']

        # --- 1) Embedding SAM sorgente ---
        predictor.set_torch_image(src_resized, orig_size_src)
        src_emb = predictor.get_image_embedding()[0]  # [C,64,64]

        # --- 2) Embedding SAM target ---
        predictor.set_torch_image(trg_resized, orig_size_trg)
        trg_emb = predictor.get_image_embedding()[0]  # [C,64,64]
        C_ft = trg_emb.shape[0]

        # --- 3) Keypoints & metadata ---
        src_kps = batch["src_kps"].to(device)  # [N,2]
        trg_kps = batch["trg_kps"].to(device)  # [N,2]
        pck_thr_0_05 = batch["pck_threshold_0_05"]
        pck_thr_0_1 = batch["pck_threshold_0_1"]
        pck_thr_0_2 = batch["pck_threshold_0_2"]

        pair_id = batch["pair_id"]
        filename = batch["filename"]

        # --- regione valida target (per evitare padding) ---
        # resized (prima padding)
        H_prime, W_prime = trg_resized.shape[-2:]

        # regione valida feature map
        hv_t = math.ceil(H_prime / PATCH)
        wv_t = math.ceil(W_prime / PATCH)

        # Get valid part of trg embedding
        trg_valid = trg_emb[:, :hv_t, :wv_t]  # [C,hv,wv]
        # Flatten valid part
        trg_flat = trg_valid.permute(1, 2, 0).reshape(-1, C_ft)  # [Pvalid,C]

        # --- 4) Loop sui keypoint sorgente ---
        N_kps = src_kps.shape[0]
        distances_this_image = []

        if iter == 0:
            qualitative_result["src_img"] = batch["src_img"]
            qualitative_result["trg_img"] = batch["trg_img"]

        for i in range(N_kps):
            src_keypoint = src_kps[i].unsqueeze(0)  # [x, y]
            trg_keypoint = trg_kps[i]  # [x, y]

            # salta keypoint mancanti
            if torch.isnan(src_keypoint).any() or torch.isnan(trg_keypoint).any():
                continue

            # 4.1: originale src -> feature src (coerente con SAM)
            x_idx, y_idx = kp_src_to_featmap(src_keypoint, orig_size_src)

            # 4.2: vettore feature sorgente in quel punto
            src_vec = src_emb[:, y_idx, x_idx]  # [C]

            # 4.3: cosine similarity con tutte le posizioni valide del target (senza normalizzare a mano)
            sim = torch.cosine_similarity(trg_flat, src_vec.unsqueeze(0), dim=1)  # [P]

            # 4.4: argmax su regione valida
            max_idx = torch.argmax(sim).item()
            y_idx_t = max_idx // wv_t
            x_idx_t = max_idx % wv_t

            # 4.5: feature target -> pixel originali target (passando per resized)
            x_pred, y_pred = kp_featmap_to_trg((x_idx_t, y_idx_t), orig_size_trg)

            if iter == 0:
                qualitative_result["src_kps"].append(src_keypoint.squeeze(0).tolist())
                qualitative_result["trg_gt_kps"].append(trg_keypoint.tolist())
                qualitative_result["trg_pred_kps"].append([x_pred, y_pred])

            # 4.6: distanza in pixel nello spazio originale target
            dist = math.sqrt((x_pred - trg_keypoint[0]) ** 2 + (y_pred - trg_keypoint[1]) ** 2)
            distances_this_image.append(dist)

        results.append({
            "pair_id": pair_id,
            "filename": filename,
            "category": category,
            "pck_threshold_0_05": pck_thr_0_05,
            "pck_threshold_0_1": pck_thr_0_1,
            "pck_threshold_0_2": pck_thr_0_2,
            "distances": distances_this_image
        })

        # --- cleanup GPU ---
        predictor.reset_image()
        for name in ("src_emb", "trg_emb", "trg_flat", "src_vec", "sim"):
            if name in locals():
                del locals()[name]
        torch.cuda.empty_cache()

print("Elaborazione completata. Numero di coppie elaborate:", len(results))
print(qualitative_result)

Elaborazione con SAM: 100%|██████████| 1/1 [00:01<00:00,  1.67s/it]

Elaborazione completata. Numero di coppie elaborate: 1
{'src_img': tensor([[[129., 129., 130.,  ..., 210., 210., 210.],
         [129., 129., 129.,  ..., 210., 210., 210.],
         [128., 128., 128.,  ..., 210., 210., 210.],
         ...,
         [124., 129., 135.,  ..., 169., 170., 171.],
         [129., 134., 140.,  ..., 171., 172., 172.],
         [136., 139., 143.,  ..., 172., 173., 173.]],

        [[146., 146., 147.,  ..., 211., 211., 211.],
         [146., 146., 146.,  ..., 211., 211., 211.],
         [145., 145., 145.,  ..., 211., 211., 211.],
         ...,
         [ 86.,  91.,  97.,  ..., 174., 175., 176.],
         [ 91.,  96., 102.,  ..., 176., 177., 177.],
         [ 98., 101., 105.,  ..., 177., 178., 178.]],

        [[162., 162., 163.,  ..., 213., 213., 213.],
         [162., 162., 162.,  ..., 213., 213., 213.],
         [163., 163., 163.,  ..., 213., 213., 213.],
         ...,
         [ 63.,  68.,  74.,  ..., 180., 181., 182.],
         [ 68.,  73.,  79.,  ..., 182.,




In [35]:
# Cella 4: calcolo PCK
import pandas as pd
import numpy as np
import torch

tot_keypoints = 0
tot_0_05 = 0
tot_0_1 = 0
tot_0_2 = 0
image_results = []

for res in results:
    dists_list = res["distances"]
    num_keypoints = len(dists_list)

    # Se non ci sono keypoint validi, salva comunque info ma non fare divisioni
    if num_keypoints == 0:
        image_results.append({
            "filename": res["filename"],
            "category": res["category"],
            "correct_0_05": 0,
            "correct_0_1": 0,
            "correct_0_2": 0,
            "num_keypoints": 0
        })
        continue

    dists = torch.tensor(dists_list)
    thr_0_05 = res["pck_threshold_0_05"]
    thr_0_1 = res["pck_threshold_0_1"]
    thr_0_2 = res["pck_threshold_0_2"]

    tot_keypoints += num_keypoints

    correct_0_05 = (dists <= thr_0_05).sum().item()
    correct_0_1 = (dists <= thr_0_1).sum().item()
    correct_0_2 = (dists <= thr_0_2).sum().item()

    tot_0_05 += correct_0_05
    tot_0_1 += correct_0_1
    tot_0_2 += correct_0_2

    image_results.append({
        "filename": res["filename"],
        "category": res["category"],
        "correct_0_05": correct_0_05,
        "correct_0_1": correct_0_1,
        "correct_0_2": correct_0_2,
        "num_keypoints": num_keypoints
    })

# --- Per keypoints (micro-average) ---
print("PCK Results per keypoints:")
df_keypoints = pd.DataFrame({
    "PCK 0.05": [tot_0_05 / tot_keypoints if tot_keypoints > 0 else np.nan],
    "PCK 0.1": [tot_0_1 / tot_keypoints if tot_keypoints > 0 else np.nan],
    "PCK 0.2": [tot_0_2 / tot_keypoints if tot_keypoints > 0 else np.nan],
})
print(df_keypoints)

# --- Per image (macro-average) SOLO su immagini con num_keypoints > 0 ---
valid_imgs = [r for r in image_results if r["num_keypoints"] > 0]

print("PCK Results per image:")
df_image = pd.DataFrame({
    "PCK 0.05": [np.mean([r["correct_0_05"] / r["num_keypoints"] for r in valid_imgs]) if valid_imgs else np.nan],
    "PCK 0.1": [np.mean([r["correct_0_1"] / r["num_keypoints"] for r in valid_imgs]) if valid_imgs else np.nan],
    "PCK 0.2": [np.mean([r["correct_0_2"] / r["num_keypoints"] for r in valid_imgs]) if valid_imgs else np.nan],
})
print(df_image)

PCK Results per keypoints:
   PCK 0.05   PCK 0.1  PCK 0.2
0  0.333333  0.666667      1.0
PCK Results per image:
   PCK 0.05   PCK 0.1  PCK 0.2
0  0.333333  0.666667      1.0


In [36]:
import pandas as pd
import numpy as np

category_results = {}

for r in image_results:
    cat = r["category"]
    if cat not in category_results:
        category_results[cat] = {
            "correct_0_05": 0,
            "correct_0_1": 0,
            "correct_0_2": 0,
            "total_keypoints": 0
        }

    category_results[cat]["correct_0_05"] += r["correct_0_05"]
    category_results[cat]["correct_0_1"] += r["correct_0_1"]
    category_results[cat]["correct_0_2"] += r["correct_0_2"]
    category_results[cat]["total_keypoints"] += r["num_keypoints"]

rows = []
for cat, stats in category_results.items():
    total = stats["total_keypoints"]

    rows.append({
        "Category": cat,
        "Total keypoints": total,
        "PCK 0.05": (stats["correct_0_05"] / total) if total > 0 else np.nan,
        "PCK 0.1": (stats["correct_0_1"] / total) if total > 0 else np.nan,
        "PCK 0.2": (stats["correct_0_2"] / total) if total > 0 else np.nan,
    })

df_cat = pd.DataFrame(rows).sort_values("Category")

print("PCK Results per category:")
print(df_cat)

PCK Results per category:
    Category  Total keypoints  PCK 0.05   PCK 0.1  PCK 0.2
0  aeroplane                3  0.333333  0.666667      1.0


In [43]:
# Visualization and saving of qualitative results
import matplotlib.pyplot as plt


def tensor_to_image(img: torch.Tensor) -> np.ndarray:
    img = img.detach().cpu()  # sicurezza
    img = img.permute(1, 2, 0)  # [H,W,3]
    img = img.numpy()
    img = img.astype(np.uint8)
    return img


def draw_keypoints(ax, image, keypoints, color, label=None, marker='o'):
    ax.imshow(image)
    if len(keypoints) > 0:
        xs = [kp[0] for kp in keypoints]
        ys = [kp[1] for kp in keypoints]
        ax.scatter(xs, ys, c=color, s=40, marker=marker, label=label)
    ax.axis("off")


def plot_keypoints_save(
        src_img_chw: torch.Tensor,
        trg_img_chw: torch.Tensor,
        qualitative_result: dict,
        dpi: int = 200
) -> None:
    # --- cartella output ---
    save_dir = os.path.join(base_dir, "qualitative-results")
    os.makedirs(save_dir, exist_ok=True)

    src_path = os.path.join(save_dir, f"{selected_model}_src.png")
    gt_path = os.path.join(save_dir, f"{selected_model}_trg_gt.png")
    pr_path = os.path.join(save_dir, f"{selected_model}_trg_pred.png")

    # --- immagini ---
    src_img = tensor_to_image(src_img_chw)
    trg_img = tensor_to_image(trg_img_chw)

    src_kps = qualitative_result["src_kps"]  # [[x,y],...]
    trg_gt = qualitative_result["trg_gt_kps"]  # [[x,y],...]
    trg_pr = qualitative_result["trg_pred_kps"]  # [[x,y],...]

    n = min(len(src_kps), len(trg_gt), len(trg_pr))
    cmap = plt.get_cmap("tab10", max(n, 1))

    # ---------- 1) SOURCE ----------
    fig, ax = plt.subplots(figsize=(6, 6))
    ax.imshow(src_img)
    ax.set_title("Source")
    ax.axis("off")
    for i in range(n):
        c = cmap(i)
        xs, ys = src_kps[i]
        ax.scatter(xs, ys, s=60, color=c, marker="o")
        ax.text(xs + 6, ys + 6, str(i), color=c, fontsize=10)
    plt.tight_layout()
    plt.savefig(src_path, dpi=dpi, bbox_inches="tight")
    plt.close(fig)

    # ---------- 2) TARGET GT ----------
    fig, ax = plt.subplots(figsize=(6, 6))
    ax.imshow(trg_img)
    ax.set_title("Target GT")
    ax.axis("off")
    for i in range(n):
        c = cmap(i)
        xg, yg = trg_gt[i]
        ax.scatter(xg, yg, s=60, color=c, marker="o")
        ax.text(xg + 6, yg + 6, str(i), color=c, fontsize=10)
    plt.tight_layout()
    plt.savefig(gt_path, dpi=dpi, bbox_inches="tight")
    plt.close(fig)

    # ---------- 3) TARGET PRED ----------
    fig, ax = plt.subplots(figsize=(6, 6))
    ax.imshow(trg_img)
    ax.set_title("Target Pred")
    ax.axis("off")
    for i in range(n):
        c = cmap(i)
        xg, yg = trg_gt[i]
        xp, yp = trg_pr[i]

        # pred: pallino vuoto
        ax.scatter(xp, yp, s=60, color=c,  linewidths=2, marker="o")
        ax.text(xp + 6, yp + 6, str(i), color=c, fontsize=10)

        # (opzionale) linea di errore GT->Pred
        ax.plot([xg, xp], [yg, yp], color=c, linewidth=1, linestyle="--")

    plt.tight_layout()
    plt.savefig(pr_path, dpi=dpi, bbox_inches="tight")
    plt.close(fig)


plot_keypoints_save(
    src_img_chw=batch["src_img"],
    trg_img_chw=batch["trg_img"],
    qualitative_result=qualitative_result
)



