In [7]:
import torch
from data.spair import SPairDataset
from torch.utils.data import DataLoader
import os

using_colab = 'google.colab' in str(get_ipython())
base_dir = os.path.abspath(os.path.curdir)

if using_colab:
    base_dir = os.path.join(os.path.abspath(os.path.curdir), 'AML-polito')


def collate_single(batch_list):
    return batch_list[0]


dataset_size = 'large'  # 'small' or 'large'

# Load dataset and construct dataloader

test_dataset = SPairDataset(datatype='test', dataset_size=dataset_size)

test_dataloader = DataLoader(test_dataset, num_workers=2, batch_size=1, collate_fn=collate_single)
print("Dataset loaded")

Dataset loaded


In [2]:
model_size = {
    "vit_b": ("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", "sam_vit_b_01ec64.pth"),
    "vit_l": ("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", "sam_vit_l_0b3195.pth"),
    "vit_h": ("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", "sam_vit_h_4b8939.pth")
}

selected_model = "vit_b"

if using_colab:
    !pip install git+https://github.com/facebookresearch/segment-anything.git -q
    !wget -P ./AML-polito/models/ {model_size[selected_model][0]}
    !clear

from segment_anything import SamPredictor, sam_model_registry

# SAM initialization
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

sam_checkpoint_path = os.path.join(base_dir, 'models', model_size[selected_model][1])
sam = sam_model_registry[selected_model](checkpoint=sam_checkpoint_path)
sam.to(device)
sam.eval()
predictor = SamPredictor(sam)
transform = predictor.transform

# Model parameters
IMG_SIZE = predictor.model.image_encoder.img_size  # 1024
PATCH = int(predictor.model.image_encoder.patch_embed.proj.kernel_size[0])  # 16

print(f"SAM '{selected_model}' loaded. IMG_SIZE={IMG_SIZE}, PATCH={PATCH}")

Device: cuda
SAM 'vit_b' loaded. IMG_SIZE=1024, PATCH=16


In [3]:
import shutil
from pathlib import Path
from utils.utils_correspondence import hard_argmax
from utils.utils_results import *
from tqdm import tqdm
from itertools import islice
import math
import torch
import torch.nn.functional as F

save_root = Path(base_dir) / "data" / "features"
save_root.mkdir(parents=True, exist_ok=True)
print("Saving features to:", save_root)

current_cat = None
output_dict = {}

with torch.no_grad():
    for img_tensor, img_size, img_category, img_name in tqdm(
            test_dataset.iter_images(),
            total=test_dataset.num_images(),
            desc="Generating embeddings"
    ):
        # Se cambia categoria, salva quella precedente
        if current_cat is None:
            current_cat = img_category

        if img_category != current_cat:
            torch.save(output_dict, save_root / f"{current_cat}.pth")
            output_dict = {}
            current_cat = img_category

        # --- la tua parte (corretta) ---
        img_tensor = img_tensor.to(device).unsqueeze(0)  # [1,3,H,W]
        orig_size = tuple(img_size[1:])  # (H,W)
        resized = predictor.transform.apply_image_torch(img_tensor)  # [1,3,H',W']
        predictor.set_torch_image(resized, orig_size)
        img_emb = predictor.get_image_embedding()[0]  # [C,h',w']
        # -------------------------------

        # Salva su CPU (consigliato per non tenere CUDA tensors nel dict)
        output_dict[img_name] = img_emb.detach().cpu()
        # opzionale per risparmiare spazio:
        # output_dict[img_name] = img_emb.detach().cpu().to(torch.float16)

# Salva l’ultima categoria
if current_cat is not None and len(output_dict) > 0:
    torch.save(output_dict, save_root / f"{current_cat}.pth")




Saving features to: /home/pasquale/PycharmProjects/AML-polito/data/features


Generating embeddings: 100%|██████████| 1800/1800 [25:27<00:00,  1.18it/s]


In [4]:
def get_scale_factor(img_size: tuple) -> float:
    img_h, img_w = img_size
    resized_h, resized_w = predictor.transform.get_preprocess_shape(img_h, img_w, IMG_SIZE)

    # lato lungo originale → lato lungo resized
    if img_w >= img_h:
        return resized_w / img_w
    else:
        return resized_h / img_h


def kp_src_to_featmap(kp_src_coordinates: torch.Tensor, img_src_size: tuple):
    img_src_h, img_src_w = img_src_size

    # Compute coordinates in resized image (without padding)
    x_prepad, y_prepad = predictor.transform.apply_coords_torch(kp_src_coordinates, img_src_size)[0]

    # dimensioni resized reali
    img_resized_h_prepad, img_resized_w_prepad = predictor.transform.get_preprocess_shape(img_src_h, img_src_w,
                                                                                          IMG_SIZE)

    # resized -> feature map
    xf = int(x_prepad // PATCH)
    yf = int(y_prepad // PATCH)

    # regione valida (no padding)
    wv = math.ceil(img_resized_w_prepad / PATCH)
    hv = math.ceil(img_resized_h_prepad / PATCH)

    xf = min(max(xf, 0), wv - 1)
    yf = min(max(yf, 0), hv - 1)

    return xf, yf


def kp_featmap_to_trg(featmap_coords: tuple, trg_img_size: tuple):
    xf, yf = featmap_coords
    xr = (xf + 0.5) * PATCH
    yr = (yf + 0.5) * PATCH

    scale = get_scale_factor(trg_img_size)
    xo = xr / scale
    yo = yr / scale

    return xo, yo

In [8]:
from collections import OrderedDict

cat_cache = OrderedDict()
MAX_CATS_IN_RAM = 2  # o 4

def get_cat_dict(category):
    if category in cat_cache:
        cat_cache.move_to_end(category)
        return cat_cache[category]

    d = torch.load(save_root / f"{category}.pth", map_location="cpu")
    cat_cache[category] = d
    cat_cache.move_to_end(category)

    while len(cat_cache) > MAX_CATS_IN_RAM:
        cat_cache.popitem(last=False)  # elimina la meno recente
    return d

In [10]:
results = []

with torch.no_grad():
    for batch in tqdm(test_dataloader, total=len(test_dataloader), desc=f"Elaborazione con SAM {selected_model}"):

        category = batch["category"]
        orig_size_src = tuple(batch["src_imsize"][1:])  # (Hs, Ws)
        orig_size_trg = tuple(batch["trg_imsize"][1:])  # (Ht, Wt)

        cat_dict = get_cat_dict(category)

        src_emb = cat_dict[batch['src_imname']]         # [C,hs,ws]
        trg_emb = cat_dict[batch['trg_imname']]         # [C,ht,wt]

        # Keypoints & metadata
        src_kps = batch["src_kps"].to(device)  # [N,2]
        trg_kps = batch["trg_kps"].to(device)  # [N,2]
        pck_thr_0_05 = batch["pck_threshold_0_05"]
        pck_thr_0_1 = batch["pck_threshold_0_1"]
        pck_thr_0_2 = batch["pck_threshold_0_2"]

        # Target valid region
        Ht, Wt = orig_size_trg
        H_prime, W_prime = predictor.transform.get_preprocess_shape(
            Ht, Wt, predictor.transform.target_length
        )

        hv_t = (H_prime + PATCH - 1) // PATCH
        wv_t = (W_prime + PATCH - 1) // PATCH

        N_kps = src_kps.shape[0]

        C_ft = trg_emb.shape[0]

        trg_valid = trg_emb[:, :hv_t, :wv_t]  # [C,hv,wv]
        trg_flat = trg_valid.permute(1, 2, 0).reshape(-1, C_ft)  # [Pvalid,C]

        distances_this_image = []

        for i in range(N_kps):
            src_keypoint = src_kps[i].unsqueeze(0)  # [1,2] (x,y)
            trg_keypoint = trg_kps[i]  # [2]   (x,y)

            if torch.isnan(src_keypoint).any() or torch.isnan(trg_keypoint).any():
                continue

            # originale src -> feature src
            x_idx, y_idx = kp_src_to_featmap(src_keypoint, orig_size_src)

            # feature vector sorgente
            src_vec = src_emb[:, y_idx, x_idx]  # [256]

            # cosine similarity con tutte le posizioni valide del target
            sim = torch.cosine_similarity(trg_flat, src_vec.unsqueeze(0), dim=1)  # [P]

            sim2d = sim.view(hv_t, wv_t)
            x_idx_t, y_idx_t = hard_argmax(sim2d)

            # feature target -> pixel originali target
            x_pred, y_pred = kp_featmap_to_trg((x_idx_t, y_idx_t), orig_size_trg)

            dist = math.sqrt((x_pred - trg_keypoint[0]) ** 2 + (y_pred - trg_keypoint[1]) ** 2)
            distances_this_image.append(dist)

        results.append(
            CorrespondenceResult(
                category=category,
                distances=distances_this_image,
                pck_threshold_0_05=pck_thr_0_05,
                pck_threshold_0_1=pck_thr_0_1,
                pck_threshold_0_2=pck_thr_0_2
            )
        )

        # cleanup
        del src_emb, trg_emb, trg_valid, trg_flat, src_vec, sim
        torch.cuda.empty_cache()

        # cleanup per-batch
        predictor.reset_image()
        torch.cuda.empty_cache()

Elaborazione con SAM vit_b: 100%|██████████| 12234/12234 [06:58<00:00, 29.22it/s]


In [11]:
# Compute and print results
correct = compute_correct_per_category(results)
compute_pckt_keypoints(correct)
compute_pckt_images(correct)

PCK Results per keypoints (%):
       Category   PCK 0.05    PCK 0.1    PCK 0.2
0     aeroplane  17.938931  26.045075  39.912759
1       bicycle  10.191412  16.063114  27.340921
2          bird  20.377185  30.910764  45.216191
3          boat  10.336680  18.133491  32.959244
4        bottle  17.284335  26.720648  41.902834
5           bus  13.421170  18.133095  28.115230
6           car  15.033408  20.545657  30.846325
7           cat  28.170378  39.415941  54.969345
8         chair   9.364732  13.910186  22.426068
9           cow  19.186264  27.323628  40.854797
10          dog  12.911287  20.741358  35.485214
11        horse   9.885204  16.347789  27.848639
12    motorbike   8.877131  16.460905  26.807760
13       person  18.179702  30.726257  45.297952
14  pottedplant  12.887511  23.095660  33.392383
15        sheep   9.000000  15.631579  26.842105
16        train  19.630485  30.542725  49.907621
17    tvmonitor  14.320719  24.612034  40.525456
18          All  14.833141  23.075550 