In [41]:
from typing import TypedDict, Tuple, Any
import torch


class SPairSample(TypedDict):
    pair_id: int
    filename: str
    src_imname: str
    trg_imname: str
    src_imsize: Tuple[int, int]
    trg_imsize: Tuple[int, int]
    src_bbox: Tuple[int, int, int, int]
    trg_bbox: Tuple[int, int, int, int]
    category: str
    src_pose: str
    trg_pose: str
    src_img: torch.Tensor
    trg_img: torch.Tensor
    src_kps: torch.Tensor
    trg_kps: torch.Tensor
    mirror: int
    vp_var: int
    sc_var: int
    truncn: int
    occlsn: int
    pck_threshold_0_05: float
    pck_threshold_0_1: float
    pck_threshold_0_2: float

In [42]:
import numpy as np
from typing import List
from PIL import Image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import os
import glob
import json

using_colab = False
base_dir = os.path.abspath(os.path.curdir)

if using_colab:
    !wget -P ./AML-polito/dataset/ "https://cvlab.postech.ac.kr/research/SPair-71k/data/SPair-71k.tar.gz"
    !tar -xvzf ./AML-polito/dataset/SPair-71k.tar.gz -C ./AML-polito/dataset/
    base_dir = os.path.join(os.path.abspath(os.path.curdir), 'AML-polito')


def read_img(path: str) -> torch.Tensor:
    img = np.array(Image.open(path).convert('RGB'))
    return torch.tensor(img.transpose(2, 0, 1).astype(np.float32))


def collate_single(batch_list: List[SPairSample]) -> SPairSample:
    # batch_size deve essere 1
    return batch_list[0]


class SPairDataset(Dataset):
    def __init__(self, pair_ann_path, layout_path, image_path, dataset_size, datatype):
        self.datatype = datatype
        self.ann_files = open(os.path.join(layout_path, dataset_size, datatype + '.txt'), "r").read().split('\n')
        self.ann_files = self.ann_files[:len(self.ann_files) - 1]
        self.pair_ann_path = pair_ann_path
        self.image_path = image_path
        self.categories = list(map(lambda x: os.path.basename(x), glob.glob('%s/*' % image_path)))
        self.categories.sort()

    def __len__(self):
        return len(self.ann_files)

    def __getitem__(self, idx) -> SPairSample:
        ann_filename = self.ann_files[idx]
        ann_file = ann_filename + '.json'
        json_path = os.path.join(self.pair_ann_path, self.datatype, ann_file)

        with open(json_path) as f:
            annotation = json.load(f)

        category = annotation['category']
        src_img = read_img(str(os.path.join(self.image_path, category, annotation['src_imname'])))
        trg_img = read_img(str(os.path.join(self.image_path, category, annotation['trg_imname'])))

        sx1, sy1, sx2, sy2 = annotation["src_bndbox"]
        tx1, ty1, tx2, ty2 = annotation["trg_bndbox"]

        sample: SPairSample = {'pair_id': int(annotation['pair_id']),
                               'filename': str(annotation['filename']),
                               'src_imname': str(annotation['src_imname']),
                               'trg_imname': str(annotation['trg_imname']),
                               'src_imsize': (int(src_img.shape[1]), int(src_img.shape[2])),  # height, width
                               'trg_imsize': (int(trg_img.shape[1]), int(trg_img.shape[2])),  # height, width

                               'src_bbox': (int(sx1), int(sy1), int(sx2), int(sy2)),
                               'trg_bbox': (int(tx1), int(ty1), int(tx2), int(ty2)),
                               'category': str(annotation['category']),

                               'src_pose': str(annotation['src_pose']),
                               'trg_pose': str(annotation['trg_pose']),

                               'src_img': src_img,
                               'trg_img': trg_img,
                               'src_kps': torch.tensor(annotation['src_kps']).float(),
                               'trg_kps': torch.tensor(annotation['trg_kps']).float(),

                               'mirror': int(annotation['mirror']),
                               'vp_var': int(annotation['viewpoint_variation']),
                               'sc_var': int(annotation['scale_variation']),
                               'truncn': int(annotation['truncation']),
                               'occlsn': int(annotation['occlusion']),

                               'pck_threshold_0_05': float(max(tx2 - tx1, ty2 - ty1) * 0.05),
                               'pck_threshold_0_1': float(max(tx2 - tx1, ty2 - ty1) * 0.1),
                               'pck_threshold_0_2': float(max(tx2 - tx1, ty2 - ty1) * 0.2)
                               }

        return sample


dataset_dir = os.path.join(base_dir, 'dataset', "SPair-71k")
pair_ann_path = os.path.join(dataset_dir, 'PairAnnotation')
layout_path = os.path.join(dataset_dir, 'Layout')
image_path = os.path.join(dataset_dir, 'JPEGImages')
dataset_size = 'large'
#pck_alpha = 0.05

# Verifica che i percorsi esistano prima di creare il dataset
if os.path.exists(pair_ann_path) and os.path.exists(layout_path) and os.path.exists(image_path):
    trn_dataset = SPairDataset(pair_ann_path, layout_path, image_path, dataset_size, datatype='trn')
    val_dataset = SPairDataset(pair_ann_path, layout_path, image_path, dataset_size, datatype='val')
    test_dataset = SPairDataset(pair_ann_path, layout_path, image_path, dataset_size, datatype='test')

    trn_dataloader = DataLoader(trn_dataset, num_workers=0)
    val_dataloader = DataLoader(val_dataset, num_workers=0)
    test_dataloader = DataLoader(test_dataset, num_workers=0, batch_size=1, collate_fn=collate_single)
    print("Dataset caricati correttamente.")
else:
    raise RuntimeError(
        f"Errore: Impossibile trovare i percorsi del dataset in '{base_dir}'.\nVerifica l'estrazione e controlla se la struttura delle cartelle corrisponde.")

Dataset caricati correttamente.


In [43]:
model_size = {
    "vit_b": ("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", "sam_vit_b_01ec64.pth"),
    "vit_l": ("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", "sam_vit_l_0b3195.pth"),
    "vit_h": ("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", "sam_vit_h_4b8939.pth")
}

selected_model = "vit_b"

if using_colab:
    !pip install git+https://github.com/facebookresearch/segment-anything.git -q
    !wget -P ./AML-polito/models/ {model_size[selected_model][0]}
    !clear

from segment_anything import SamPredictor, sam_model_registry

# inizializzazione SAM

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# === 1) Carico SAM ===
sam_checkpoint_path = os.path.join(base_dir, 'models', model_size[selected_model][1])
sam = sam_model_registry[selected_model](checkpoint=sam_checkpoint_path)
sam.to(device)
sam.eval()
predictor = SamPredictor(sam)
transform = predictor.transform

# parametri utili
IMG_SIZE = predictor.model.image_encoder.img_size  # 1024
#PATCH = IMG_SIZE // 64  # 16
PATCH = int(predictor.model.image_encoder.patch_embed.proj.kernel_size[0])  # 16
print(f"SAM modello '{selected_model}' caricato. IMG_SIZE={IMG_SIZE}, PATCH={PATCH}")

Device: cuda
SAM modello 'vit_b' caricato. IMG_SIZE=1024, PATCH=16


In [44]:
def get_scale_factor(img_size: tuple) -> float:
    img_h, img_w = img_size
    resized_h, resized_w = predictor.transform.get_preprocess_shape(img_h, img_w, IMG_SIZE)

    # lato lungo originale → lato lungo resized
    if img_w >= img_h:
        return resized_w / img_w
    else:
        return resized_h / img_h


def kp_src_to_featmap(kp_src_coordinates: torch.Tensor, img_src_size: tuple):
    img_src_h, img_src_w = img_src_size

    # Compute coordinates in resized image (without padding)
    x_prepad, y_prepad = predictor.transform.apply_coords_torch(kp_src_coordinates, img_src_size)[0]

    # dimensioni resized reali
    img_resized_h_prepad, img_resized_w_prepad = predictor.transform.get_preprocess_shape(img_src_h, img_src_w,
                                                                                          IMG_SIZE)

    # resized -> feature map
    xf = int(x_prepad // PATCH)
    yf = int(y_prepad // PATCH)

    # regione valida (no padding)
    wv = math.ceil(img_resized_w_prepad / PATCH)
    hv = math.ceil(img_resized_h_prepad / PATCH)

    xf = min(max(xf, 0), wv - 1)
    yf = min(max(yf, 0), hv - 1)

    return xf, yf


def kp_featmap_to_trg(featmap_coords: tuple, trg_img_size: tuple):
    xf, yf = featmap_coords
    xr = (xf + 0.5) * PATCH
    yr = (yf + 0.5) * PATCH

    scale = get_scale_factor(trg_img_size)
    xo = xr / scale
    yo = yr / scale

    return xo, yo

In [45]:
# PCK@T per keypoint

import pandas as pd
import numpy as np

def compute_pckt_keypoints(category_results: dict):
    rows_keypoints = []

    for cat, stats_list in category_results.items():
        tot_keypoints = sum(s["num_keypoints"] for s in stats_list)
        tot_0_05 = sum(s["correct_0_05"] for s in stats_list)
        tot_0_1 = sum(s["correct_0_1"] for s in stats_list)
        tot_0_2 = sum(s["correct_0_2"] for s in stats_list)

        pck_0_05 = tot_0_05 / tot_keypoints if tot_keypoints > 0 else np.nan
        pck_0_1 = tot_0_1 / tot_keypoints if tot_keypoints > 0 else np.nan
        pck_0_2 = tot_0_2 / tot_keypoints if tot_keypoints > 0 else np.nan

        rows_keypoints.append({
            "Category": cat,
            "PCK 0.05": pck_0_05 * 100,
            "PCK 0.1": pck_0_1 * 100,
            "PCK 0.2": pck_0_2 * 100,
        })

    df_keypoints = pd.DataFrame(rows_keypoints).sort_values("Category")

    #  "All" = macro-average on categories
    mean_row_kp = {
        "Category": "All",
        "PCK 0.05": df_keypoints["PCK 0.05"].mean(skipna=True),
        "PCK 0.1": df_keypoints["PCK 0.1"].mean(skipna=True),
        "PCK 0.2": df_keypoints["PCK 0.2"].mean(skipna=True),
    }

    df_keypoints = pd.concat(
        [df_keypoints, pd.DataFrame([mean_row_kp])],
        ignore_index=True
    )

    print("PCK Results per keypoints (%):")
    print(df_keypoints)

In [46]:

def compute_correct_per_category(results: List[Any]) -> dict:
    category_results = {}

    for res in results:
        cat = res["category"]
        if cat not in category_results:
            category_results[cat] = []

        dists_list = res["distances"]
        num_keypoints = len(dists_list)
        dists = torch.tensor(dists_list)

        thr_0_05 = res["pck_threshold_0_05"]
        thr_0_1 = res["pck_threshold_0_1"]
        thr_0_2 = res["pck_threshold_0_2"]

        correct_0_05 = (dists <= thr_0_05).sum().item()
        correct_0_1 = (dists <= thr_0_1).sum().item()
        correct_0_2 = (dists <= thr_0_2).sum().item()

        category_results[cat].append({
            "correct_0_05": correct_0_05,
            "correct_0_1": correct_0_1,
            "correct_0_2": correct_0_2,
            "num_keypoints": num_keypoints
        })
    return category_results

In [47]:
# PCK@T per image
def compute_pckt_images(category_results: dict):
    rows_images = []

    for cat, stats_list in category_results.items():

        pck_imgs_0_05 = []
        pck_imgs_0_1 = []
        pck_imgs_0_2 = []

        for s in stats_list:
            if s["num_keypoints"] == 0:
                continue

            pck_imgs_0_05.append(s["correct_0_05"] / s["num_keypoints"])
            pck_imgs_0_1.append(s["correct_0_1"] / s["num_keypoints"])
            pck_imgs_0_2.append(s["correct_0_2"] / s["num_keypoints"])

        rows_images.append({
            "Category": cat,
            "PCK 0.05": np.mean(pck_imgs_0_05) * 100 if pck_imgs_0_05 else np.nan,
            "PCK 0.1": np.mean(pck_imgs_0_1) * 100 if pck_imgs_0_1 else np.nan,
            "PCK 0.2": np.mean(pck_imgs_0_2) * 100 if pck_imgs_0_2 else np.nan,
        })

    df_image = pd.DataFrame(rows_images).sort_values("Category")

    #  "All" = macro-average on categories
    all_row = {
        "Category": "All",
        "PCK 0.05": df_image["PCK 0.05"].mean(skipna=True),
        "PCK 0.1": df_image["PCK 0.1"].mean(skipna=True),
        "PCK 0.2": df_image["PCK 0.2"].mean(skipna=True),
    }

    df_image = pd.concat([df_image, pd.DataFrame([all_row])], ignore_index=True)

    print("PCK per-image (%):")
    print(df_image)

In [48]:
activation = {}


def get_activation(name):
    def hook(model, input, output):
        if isinstance(output, tuple):
            output = output[0]
        activation[name] = output.detach()

    return hook

# Registrazione hook (fallo UNA volta sola prima del loop)
# Scegliamo alcuni indici. Per ViT-B (base) ci sono 12 blocchi (0-11).
target_layers = [2, 5, 8, 11]
for i in target_layers:
    predictor.model.image_encoder.blocks[i].register_forward_hook(get_activation(f'layer_{i}'))

In [49]:
from tqdm import tqdm
from itertools import islice
import math
import torch
import torch.nn.functional as F

# Dizionario finale: chiave = layer, valore = lista (stessa struttura di results originale)
all_results = {li: [] for li in target_layers}

# (opzionale) qualitativo per layer (solo sul primo batch)
qualitative_by_layer = {
    li: {
        "src_img": None,
        "trg_img": None,
        "src_kps": [],
        "trg_gt_kps": [],
        "trg_pred_kps": []
    }
    for li in target_layers
}

max_images = 1500
data = islice(test_dataloader, max_images)
size = max_images
if using_colab:
    data = test_dataloader
    size = len(test_dataloader)

img_enc = predictor.model.image_encoder  # contiene .neck

with torch.no_grad():
    for iter, batch in enumerate(tqdm(data, total=size, desc=f"Elaborazione con SAM {selected_model}")):

        category = batch["category"]
        src_img = batch["src_img"].to(device).unsqueeze(0)  # [1,3,Hs,Ws]
        trg_img = batch["trg_img"].to(device).unsqueeze(0)  # [1,3,Ht,Wt]
        orig_size_src = batch["src_imsize"]  # (Hs, Ws)
        orig_size_trg = batch["trg_imsize"]  # (Ht, Wt)

        src_resized = predictor.transform.apply_image_torch(src_img)  # [1,3,Hs',Ws']
        trg_resized = predictor.transform.apply_image_torch(trg_img)  # [1,3,Ht',Wt']

        # --- 1) Forward sorgente (riempie activation con output dei blocchi) ---
        predictor.set_torch_image(src_resized, orig_size_src)
        _ = predictor.get_image_embedding()[0]  # trigger hooks
        src_intermediate_emb = {k: v.detach().clone() for k, v in activation.items()}

        # --- 2) Forward target (riempie activation con output dei blocchi) ---
        predictor.set_torch_image(trg_resized, orig_size_trg)
        _ = predictor.get_image_embedding()[0]  # trigger hooks
        trg_intermediate_emb = {k: v.detach().clone() for k, v in activation.items()}

        # --- 3) Keypoints & metadata ---
        src_kps = batch["src_kps"].to(device)  # [N,2]
        trg_kps = batch["trg_kps"].to(device)  # [N,2]
        pck_thr_0_05 = batch["pck_threshold_0_05"]
        pck_thr_0_1 = batch["pck_threshold_0_1"]
        pck_thr_0_2 = batch["pck_threshold_0_2"]

        pair_id = batch["pair_id"]
        filename = batch["filename"]

        # --- regione valida target (per evitare padding) ---
        H_prime, W_prime = trg_resized.shape[-2:]
        hv_t = math.ceil(H_prime / PATCH)
        wv_t = math.ceil(W_prime / PATCH)

        N_kps = src_kps.shape[0]

        # (opzionale) salva immagini una sola volta per layer quando iter==0
        if iter == 0:
            for li in target_layers:
                qualitative_by_layer[li]["src_img"] = batch["src_img"]
                qualitative_by_layer[li]["trg_img"] = batch["trg_img"]

        # ===== Loop sui layer: riuso activation, NON rifaccio forward =====
        for selected_layer in target_layers:

            # Hook output: [1,64,64,768] (NHWC)
            src_hook = src_intermediate_emb[f"layer_{selected_layer}"]
            trg_hook = trg_intermediate_emb[f"layer_{selected_layer}"]

            # NHWC -> NCHW : [1,768,64,64]
            src_feat = src_hook.permute(0, 3, 1, 2).contiguous()
            trg_feat = trg_hook.permute(0, 3, 1, 2).contiguous()

            # ======= QUI la scelta pre-neck / post-neck =======
            if selected_layer <= 4:  # "iniziale" (puoi cambiare soglia: 3/4/5)
                # PRE-NECK: uso direttamente i 768 canali
                # (facoltativo ma consigliato) L2-normalizzazione per cosine
                src_emb = F.normalize(src_feat, dim=1)[0]   # [768,64,64]
                trg_emb = F.normalize(trg_feat, dim=1)[0]   # [768,64,64]
            else:
                # POST-NECK: porto a 256 canali come output standard SAM
                src_emb = img_enc.neck(src_feat)[0]         # [256,64,64]
                trg_emb = img_enc.neck(trg_feat)[0]         # [256,64,64]
                # (facoltativo) normalizza anche qui se vuoi massima coerenza
                # src_emb = F.normalize(src_emb, dim=0)
                # trg_emb = F.normalize(trg_emb, dim=0)
            # ================================================

            C_ft = trg_emb.shape[0]

            trg_valid = trg_emb[:, :hv_t, :wv_t]  # [C,hv,wv]
            trg_flat = trg_valid.permute(1, 2, 0).reshape(-1, C_ft)  # [Pvalid,C]

            distances_this_image = []

            for i in range(N_kps):
                src_keypoint = src_kps[i].unsqueeze(0)  # [1,2] (x,y)
                trg_keypoint = trg_kps[i]              # [2]   (x,y)

                if torch.isnan(src_keypoint).any() or torch.isnan(trg_keypoint).any():
                    continue

                # originale src -> feature src
                x_idx, y_idx = kp_src_to_featmap(src_keypoint, orig_size_src)

                # feature vector sorgente
                src_vec = src_emb[:, y_idx, x_idx]  # [256]

                # cosine similarity con tutte le posizioni valide del target
                sim = torch.cosine_similarity(trg_flat, src_vec.unsqueeze(0), dim=1)  # [P]
                max_idx = torch.argmax(sim).item()
                y_idx_t = max_idx // wv_t
                x_idx_t = max_idx % wv_t

                # feature target -> pixel originali target
                x_pred, y_pred = kp_featmap_to_trg((x_idx_t, y_idx_t), orig_size_trg)

                if iter == 0:
                    qualitative_by_layer[selected_layer]["src_kps"].append(src_keypoint.squeeze(0).tolist())
                    qualitative_by_layer[selected_layer]["trg_gt_kps"].append(trg_keypoint.tolist())
                    qualitative_by_layer[selected_layer]["trg_pred_kps"].append([x_pred, y_pred])

                dist = math.sqrt((x_pred - trg_keypoint[0]) ** 2 + (y_pred - trg_keypoint[1]) ** 2)
                distances_this_image.append(dist)

            # Append con struttura IDENTICA all'originale, ma salvata per-layer
            all_results[selected_layer].append({
                "pair_id": pair_id,
                #"filename": filename,
                "category": category,
                "pck_threshold_0_05": pck_thr_0_05,
                "pck_threshold_0_1": pck_thr_0_1,
                "pck_threshold_0_2": pck_thr_0_2,
                "distances": distances_this_image
            })

            # cleanup per-layer
            del src_hook, trg_hook, src_feat, trg_feat, src_emb, trg_emb, trg_valid, trg_flat, src_vec, sim
            torch.cuda.empty_cache()

        # cleanup per-batch
        predictor.reset_image()
        torch.cuda.empty_cache()

print("Elaborazione completata.")
for li in target_layers:
    print(f"Layer {li}")
    correct = compute_correct_per_category(all_results[li])
    compute_pckt_keypoints(correct)
    compute_pckt_images(correct)
    print("#" * 50)
    print("")


Elaborazione con SAM vit_b: 100%|██████████| 1500/1500 [43:30<00:00,  1.74s/it]


Elaborazione completata.
Layer 2
PCK Results per keypoints (%):
    Category  PCK 0.05    PCK 0.1    PCK 0.2
0  aeroplane  6.779353  14.649218  30.625227
1    bicycle  1.991723   4.966374  14.614589
2       bird  4.057971  11.497585  25.507246
3        All  4.276349  10.371059  23.582354
PCK per-image (%):
    Category  PCK 0.05    PCK 0.1    PCK 0.2
0  aeroplane  5.990497  13.294406  27.212843
1    bicycle  1.809424   4.412904  13.332184
2       bird  3.911007  11.629419  24.729933
3        All  3.903643   9.778910  21.758320
##################################################

Layer 5
PCK Results per keypoints (%):
    Category   PCK 0.05    PCK 0.1    PCK 0.2
0  aeroplane  13.486005  20.556161  35.823337
1    bicycle   6.052768  10.475944  20.563890
2       bird  14.782609  23.768116  36.908213
3        All  11.440461  18.266740  31.098480
PCK per-image (%):
    Category   PCK 0.05    PCK 0.1    PCK 0.2
0  aeroplane  12.016869  18.656085  33.334769
1    bicycle   5.740243  10.118393 

In [50]:
# Visualization and saving of qualitative results
import matplotlib.pyplot as plt


def tensor_to_image(img: torch.Tensor) -> np.ndarray:
    img = img.detach().cpu()  # sicurezza
    img = img.permute(1, 2, 0)  # [H,W,3]
    img = img.numpy()
    img = img.astype(np.uint8)
    return img


def draw_keypoints(ax, image, keypoints, color, label=None, marker='o'):
    ax.imshow(image)
    if len(keypoints) > 0:
        xs = [kp[0] for kp in keypoints]
        ys = [kp[1] for kp in keypoints]
        ax.scatter(xs, ys, c=color, s=40, marker=marker, label=label)
    ax.axis("off")


def plot_keypoints_save(
        src_img_chw: torch.Tensor,
        trg_img_chw: torch.Tensor,
        qualitative_result: dict,
        dpi: int = 200
) -> None:
    # --- cartella output ---
    save_dir = os.path.join(base_dir, "qualitative-results")
    os.makedirs(save_dir, exist_ok=True)

    src_path = os.path.join(save_dir, f"{selected_model}_src.png")
    gt_path = os.path.join(save_dir, f"{selected_model}_trg_gt.png")
    pr_path = os.path.join(save_dir, f"{selected_model}_trg_pred.png")

    # --- immagini ---
    src_img = tensor_to_image(src_img_chw)
    trg_img = tensor_to_image(trg_img_chw)

    src_kps = qualitative_result["src_kps"]  # [[x,y],...]
    trg_gt = qualitative_result["trg_gt_kps"]  # [[x,y],...]
    trg_pr = qualitative_result["trg_pred_kps"]  # [[x,y],...]

    n = min(len(src_kps), len(trg_gt), len(trg_pr))
    cmap = plt.get_cmap("tab10", max(n, 1))

    # ---------- 1) SOURCE ----------
    fig, ax = plt.subplots(figsize=(6, 6))
    ax.imshow(src_img)
    ax.set_title("Source")
    ax.axis("off")
    for i in range(n):
        c = cmap(i)
        xs, ys = src_kps[i]
        ax.scatter(xs, ys, s=60, color=c, marker="o")
        ax.text(xs + 6, ys + 6, str(i), color=c, fontsize=10)
    plt.tight_layout()
    plt.savefig(src_path, dpi=dpi, bbox_inches="tight")
    plt.close(fig)

    # ---------- 2) TARGET GT ----------
    fig, ax = plt.subplots(figsize=(6, 6))
    ax.imshow(trg_img)
    ax.set_title("Target GT")
    ax.axis("off")
    for i in range(n):
        c = cmap(i)
        xg, yg = trg_gt[i]
        ax.scatter(xg, yg, s=60, color=c, marker="o")
        ax.text(xg + 6, yg + 6, str(i), color=c, fontsize=10)
    plt.tight_layout()
    plt.savefig(gt_path, dpi=dpi, bbox_inches="tight")
    plt.close(fig)

    # ---------- 3) TARGET PRED ----------
    fig, ax = plt.subplots(figsize=(6, 6))
    ax.imshow(trg_img)
    ax.set_title("Target Pred")
    ax.axis("off")
    for i in range(n):
        c = cmap(i)
        xg, yg = trg_gt[i]
        xp, yp = trg_pr[i]

        # pred: pallino vuoto
        ax.scatter(xp, yp, s=60, color=c, linewidths=2, marker="o")
        ax.text(xp + 6, yp + 6, str(i), color=c, fontsize=10)

        # (opzionale) linea di errore GT->Pred
        ax.plot([xg, xp], [yg, yp], color=c, linewidth=1, linestyle="--")

    plt.tight_layout()
    plt.savefig(pr_path, dpi=dpi, bbox_inches="tight")
    plt.close(fig)


plot_keypoints_save(
    src_img_chw=batch["src_img"],
    trg_img_chw=batch["trg_img"],
    qualitative_result=qualitative_result
)



