In [2]:
import torch
from data.spair import SPairDataset
from torch.utils.data import DataLoader
import os

using_colab = True
base_dir = os.path.abspath(os.path.curdir)

if using_colab:
    base_dir = os.path.join(os.path.abspath(os.path.curdir), 'AML-polito')


def collate_single(batch_list):
    return batch_list[0]


dataset_dir = os.path.join(base_dir, 'dataset')
dataset_size = 'large' # 'small' or 'large'

# Load dataset and construct dataloader
if os.path.exists(dataset_dir):
    test_dataset = SPairDataset(datatype='test', dataset_size=dataset_size)

    test_dataloader = DataLoader(test_dataset, num_workers=0, batch_size=1, collate_fn=collate_single)
    print("Dataset loaded")
else:
    raise IOError(f"Cannot find dataset files in '{dataset_dir}'.")

Merged annotation file already exists: /home/pasquale/PycharmProjects/AML-polito/dataset/ap-10k/annotations/ap10k-test-merged.jsonl
[INTRA-SPECIES] Total pairs: 4802
Intra-species pairs generated.
[CROSS-SPECIES] Total pairs: 4250
Cross-species pairs generated.
[CROSS-FAMILY] Total pairs: 4200
Cross-family pairs generated.
Processed annotation file created: /home/pasquale/PycharmProjects/AML-polito/dataset/ap-10k/annotations/ap10k-test-processed.jsonl
Dataset loaded


In [3]:
len(test_dataloader)

13252

In [4]:
model_size = {
    "vit_b": ("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", "sam_vit_b_01ec64.pth"),
    "vit_l": ("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", "sam_vit_l_0b3195.pth"),
    "vit_h": ("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", "sam_vit_h_4b8939.pth")
}

selected_model = "vit_b"

if using_colab:
    !pip install git+https://github.com/facebookresearch/segment-anything.git -q
    !wget -P ./AML-polito/models/ {model_size[selected_model][0]}
    !clear

from segment_anything import SamPredictor, sam_model_registry

# SAM initialization
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

sam_checkpoint_path = os.path.join(base_dir, 'models', model_size[selected_model][1])
sam = sam_model_registry[selected_model](checkpoint=sam_checkpoint_path)
sam.to(device)
sam.eval()
predictor = SamPredictor(sam)
transform = predictor.transform

# Model parameters
IMG_SIZE = predictor.model.image_encoder.img_size  # 1024
PATCH = int(predictor.model.image_encoder.patch_embed.proj.kernel_size[0])  # 16

print(f"SAM '{selected_model}' loaded. IMG_SIZE={IMG_SIZE}, PATCH={PATCH}")

Device: cuda
SAM 'vit_b' loaded. IMG_SIZE=1024, PATCH=16


In [5]:
def get_scale_factor(img_size: tuple) -> float:
    img_h, img_w = img_size
    resized_h, resized_w = predictor.transform.get_preprocess_shape(img_h, img_w, IMG_SIZE)

    # lato lungo originale â†’ lato lungo resized
    if img_w >= img_h:
        return resized_w / img_w
    else:
        return resized_h / img_h


def kp_src_to_featmap(kp_src_coordinates: torch.Tensor, img_src_size: tuple):
    img_src_h, img_src_w = img_src_size

    # Compute coordinates in resized image (without padding)
    x_prepad, y_prepad = predictor.transform.apply_coords_torch(kp_src_coordinates, img_src_size)[0]

    # dimensioni resized reali
    img_resized_h_prepad, img_resized_w_prepad = predictor.transform.get_preprocess_shape(img_src_h, img_src_w,
                                                                                          IMG_SIZE)

    # resized -> feature map
    xf = int(x_prepad // PATCH)
    yf = int(y_prepad // PATCH)

    # regione valida (no padding)
    wv = math.ceil(img_resized_w_prepad / PATCH)
    hv = math.ceil(img_resized_h_prepad / PATCH)

    xf = min(max(xf, 0), wv - 1)
    yf = min(max(yf, 0), hv - 1)

    return xf, yf


def kp_featmap_to_trg(featmap_coords: tuple, trg_img_size: tuple):
    xf, yf = featmap_coords
    xr = (xf + 0.5) * PATCH
    yr = (yf + 0.5) * PATCH

    scale = get_scale_factor(trg_img_size)
    xo = xr / scale
    yo = yr / scale

    return xo, yo

In [6]:
activation = {}


def get_activation(name):
    def hook(model, input, output):
        if isinstance(output, tuple):
            output = output[0]
        activation[name] = output.detach()

    return hook


# Hook registration
target_layers = [7, 9, 10, 11]  # Layer to extract (SAM has 12 layers)
for i in target_layers:
    predictor.model.image_encoder.blocks[i].register_forward_hook(get_activation(f'layer_{i}'))

In [7]:
from utils.utils_correspondence import hard_argmax
from utils.utils_results import *
from tqdm import tqdm
from itertools import islice
import math
import torch
import torch.nn.functional as F

# Dizionario finale: chiave = layer, valore = lista (stessa struttura di results originale)
all_results = {li: [] for li in target_layers}

# (opzionale) qualitativo per layer (solo sul primo batch)
qualitative_by_layer = {
    li: {
        "src_img": None,
        "trg_img": None,
        "src_kps": [],
        "trg_gt_kps": [],
        "trg_pred_kps": []
    }
    for li in target_layers
}

max_images = 300
data = islice(test_dataloader, max_images)
size = max_images
if using_colab:
    data = test_dataloader
    size = len(test_dataloader)

img_enc = predictor.model.image_encoder  # contiene .neck

with torch.no_grad():
    for iter, batch in enumerate(tqdm(data, total=size, desc=f"Elaborazione con SAM {selected_model}")):

        category = batch["category"]
        src_img = batch["src_img"].to(device).unsqueeze(0)  # [1,3,Hs,Ws]
        trg_img = batch["trg_img"].to(device).unsqueeze(0)  # [1,3,Ht,Wt]
        orig_size_src = tuple(batch["src_imsize"][1:])  # (Hs, Ws)
        orig_size_trg = tuple(batch["trg_imsize"][1:])  # (Ht, Wt)

        src_resized = predictor.transform.apply_image_torch(src_img)  # [1,3,Hs',Ws']
        trg_resized = predictor.transform.apply_image_torch(trg_img)  # [1,3,Ht',Wt']

        # --- 1) Forward sorgente (riempie activation con output dei blocchi) ---
        predictor.set_torch_image(src_resized, orig_size_src)
        _ = predictor.get_image_embedding()[0]  # trigger hooks
        src_intermediate_emb = {k: v.detach().clone() for k, v in activation.items()}

        # --- 2) Forward target (riempie activation con output dei blocchi) ---
        predictor.set_torch_image(trg_resized, orig_size_trg)
        _ = predictor.get_image_embedding()[0]  # trigger hooks
        trg_intermediate_emb = {k: v.detach().clone() for k, v in activation.items()}

        # --- 3) Keypoints & metadata ---
        src_kps = batch["src_kps"].to(device)  # [N,2]
        trg_kps = batch["trg_kps"].to(device)  # [N,2]
        pck_thr_0_05 = batch["pck_threshold_0_05"]
        pck_thr_0_1 = batch["pck_threshold_0_1"]
        pck_thr_0_2 = batch["pck_threshold_0_2"]

        # --- regione valida target (per evitare padding) ---
        H_prime, W_prime = trg_resized.shape[-2:]
        hv_t = math.ceil(H_prime / PATCH)
        wv_t = math.ceil(W_prime / PATCH)

        N_kps = src_kps.shape[0]

        # (opzionale) salva immagini una sola volta per layer quando iter==0
        if iter == 0:
            for li in target_layers:
                qualitative_by_layer[li]["src_img"] = batch["src_img"]
                qualitative_by_layer[li]["trg_img"] = batch["trg_img"]

        # ===== Loop sui layer: riuso activation, NON rifaccio forward =====
        for selected_layer in target_layers:

            # Hook output: [1,64,64,768] (NHWC)
            src_hook = src_intermediate_emb[f"layer_{selected_layer}"]
            trg_hook = trg_intermediate_emb[f"layer_{selected_layer}"]

            # NHWC -> NCHW : [1,768,64,64]
            src_feat = src_hook.permute(0, 3, 1, 2).contiguous()
            trg_feat = trg_hook.permute(0, 3, 1, 2).contiguous()

            src_emb = img_enc.neck(src_feat)[0]  # [256,64,64]
            trg_emb = img_enc.neck(trg_feat)[0]  # [256,64,64]

            C_ft = trg_emb.shape[0]

            trg_valid = trg_emb[:, :hv_t, :wv_t]  # [C,hv,wv]
            trg_flat = trg_valid.permute(1, 2, 0).reshape(-1, C_ft)  # [Pvalid,C]

            distances_this_image = []

            for i in range(N_kps):
                src_keypoint = src_kps[i].unsqueeze(0)  # [1,2] (x,y)
                trg_keypoint = trg_kps[i]  # [2]   (x,y)

                if torch.isnan(src_keypoint).any() or torch.isnan(trg_keypoint).any():
                    continue

                # originale src -> feature src
                x_idx, y_idx = kp_src_to_featmap(src_keypoint, orig_size_src)

                # feature vector sorgente
                src_vec = src_emb[:, y_idx, x_idx]  # [256]

                # cosine similarity con tutte le posizioni valide del target
                sim = torch.cosine_similarity(trg_flat, src_vec.unsqueeze(0), dim=1)  # [P]
                #max_idx = torch.argmax(sim).item()
                #y_idx_t = max_idx // wv_t
                #x_idx_t = max_idx % wv_t

                sim2d = sim.view(hv_t, wv_t)
                x_idx_t, y_idx_t = hard_argmax(sim2d)

                # feature target -> pixel originali target
                x_pred, y_pred = kp_featmap_to_trg((x_idx_t, y_idx_t), orig_size_trg)

                if iter == 0:
                    qualitative_by_layer[selected_layer]["src_kps"].append(src_keypoint.squeeze(0).tolist())
                    qualitative_by_layer[selected_layer]["trg_gt_kps"].append(trg_keypoint.tolist())
                    qualitative_by_layer[selected_layer]["trg_pred_kps"].append([x_pred, y_pred])

                dist = math.sqrt((x_pred - trg_keypoint[0]) ** 2 + (y_pred - trg_keypoint[1]) ** 2)
                distances_this_image.append(dist)

            # Append con struttura IDENTICA all'originale, ma salvata per-layer
            all_results[selected_layer].append(
                CorrespondenceResult(
                    category=category,
                    distances=distances_this_image,
                    pck_threshold_0_05=pck_thr_0_05,
                    pck_threshold_0_1=pck_thr_0_1,
                    pck_threshold_0_2=pck_thr_0_2
                )
            )

            # cleanup per-layer
            del src_hook, trg_hook, src_feat, trg_feat, src_emb, trg_emb, trg_valid, trg_flat, src_vec, sim
            torch.cuda.empty_cache()

        # cleanup per-batch
        predictor.reset_image()
        torch.cuda.empty_cache()

Elaborazione con SAM vit_b:   0%|          | 10/13252 [00:17<6:29:13,  1.76s/it]


KeyboardInterrupt: 

In [None]:
# Compute and print results per layer
for li in target_layers:
    print(f"Layer {li}")
    correct = compute_correct_per_category(all_results[li])
    compute_pckt_keypoints(correct)
    compute_pckt_images(correct)
    print("<" + "=" * 50 + ">")
    print("")

In [None]:
# Visualization and saving of qualitative results
import matplotlib.pyplot as plt


def tensor_to_image(img: torch.Tensor) -> np.ndarray:
    img = img.detach().cpu()  # sicurezza
    img = img.permute(1, 2, 0)  # [H,W,3]
    img = img.numpy()
    img = img.astype(np.uint8)
    return img


def draw_keypoints(ax, image, keypoints, color, label=None, marker='o'):
    ax.imshow(image)
    if len(keypoints) > 0:
        xs = [kp[0] for kp in keypoints]
        ys = [kp[1] for kp in keypoints]
        ax.scatter(xs, ys, c=color, s=40, marker=marker, label=label)
    ax.axis("off")


def plot_keypoints_save(
        src_img_chw: torch.Tensor,
        trg_img_chw: torch.Tensor,
        qualitative_result: dict,
        dpi: int = 200
) -> None:
    # --- cartella output ---
    save_dir = os.path.join(base_dir, "qualitative-results")
    os.makedirs(save_dir, exist_ok=True)

    src_path = os.path.join(save_dir, f"{selected_model}_src.png")
    gt_path = os.path.join(save_dir, f"{selected_model}_trg_gt.png")
    pr_path = os.path.join(save_dir, f"{selected_model}_trg_pred.png")

    # --- immagini ---
    src_img = tensor_to_image(src_img_chw)
    trg_img = tensor_to_image(trg_img_chw)

    src_kps = qualitative_result["src_kps"]  # [[x,y],...]
    trg_gt = qualitative_result["trg_gt_kps"]  # [[x,y],...]
    trg_pr = qualitative_result["trg_pred_kps"]  # [[x,y],...]

    n = min(len(src_kps), len(trg_gt), len(trg_pr))
    cmap = plt.get_cmap("tab10", max(n, 1))

    # ---------- 1) SOURCE ----------
    fig, ax = plt.subplots(figsize=(6, 6))
    ax.imshow(src_img)
    ax.set_title("Source")
    ax.axis("off")
    for i in range(n):
        c = cmap(i)
        xs, ys = src_kps[i]
        ax.scatter(xs, ys, s=60, color=c, marker="o")
        ax.text(xs + 6, ys + 6, str(i), color=c, fontsize=10)
    plt.tight_layout()
    plt.savefig(src_path, dpi=dpi, bbox_inches="tight")
    plt.close(fig)

    # ---------- 2) TARGET GT ----------
    fig, ax = plt.subplots(figsize=(6, 6))
    ax.imshow(trg_img)
    ax.set_title("Target GT")
    ax.axis("off")
    for i in range(n):
        c = cmap(i)
        xg, yg = trg_gt[i]
        ax.scatter(xg, yg, s=60, color=c, marker="o")
        ax.text(xg + 6, yg + 6, str(i), color=c, fontsize=10)
    plt.tight_layout()
    plt.savefig(gt_path, dpi=dpi, bbox_inches="tight")
    plt.close(fig)

    # ---------- 3) TARGET PRED ----------
    fig, ax = plt.subplots(figsize=(6, 6))
    ax.imshow(trg_img)
    ax.set_title("Target Pred")
    ax.axis("off")
    for i in range(n):
        c = cmap(i)
        xg, yg = trg_gt[i]
        xp, yp = trg_pr[i]

        # pred: pallino vuoto
        ax.scatter(xp, yp, s=60, color=c, linewidths=2, marker="o")
        ax.text(xp + 6, yp + 6, str(i), color=c, fontsize=10)

        # (opzionale) linea di errore GT->Pred
        ax.plot([xg, xp], [yg, yp], color=c, linewidth=1, linestyle="--")

    plt.tight_layout()
    plt.savefig(pr_path, dpi=dpi, bbox_inches="tight")
    plt.close(fig)


plot_keypoints_save(
    src_img_chw=batch["src_img"],
    trg_img_chw=batch["trg_img"],
    qualitative_result=qualitative_result
)



