In [None]:
import torch
from data.spair import SPairDataset
from torch.utils.data import DataLoader
import os
from pathlib import Path

from utils.utils_featuremaps import PreComputedFeaturemaps

using_colab = 'google.colab' in str(get_ipython())
base_dir = os.path.abspath(os.path.curdir)

if using_colab:
    base_dir = os.path.join(os.path.abspath(os.path.curdir), 'AML-polito')


def collate_single(batch_list):
    return batch_list[0]

save_dir = Path(base_dir) / "data" / "features"
dataset_size = 'large'  # 'small' or 'large'

# Load dataset and construct dataloader

test_dataset = SPairDataset(datatype='test', dataset_size=dataset_size)

test_dataloader = DataLoader(test_dataset, num_workers=4, batch_size=1, collate_fn=collate_single)
print("Dataset loaded")

In [None]:
model_size = {
    "vit_b": ("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", "sam_vit_b_01ec64.pth"),
    "vit_l": ("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", "sam_vit_l_0b3195.pth"),
    "vit_h": ("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", "sam_vit_h_4b8939.pth")
}

selected_model = "vit_b"

if using_colab:
    !pip install git+https://github.com/facebookresearch/segment-anything.git -q
    !wget -P ./AML-polito/models/ {model_size[selected_model][0]}
    !clear

from segment_anything import SamPredictor, sam_model_registry

# SAM initialization
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

sam_checkpoint_path = os.path.join(base_dir, 'models', model_size[selected_model][1])
sam = sam_model_registry[selected_model](checkpoint=sam_checkpoint_path)
sam.to(device)
sam.eval()
predictor = SamPredictor(sam)
transform = predictor.transform

# Model parameters
IMG_SIZE = predictor.model.image_encoder.img_size  # 1024
PATCH = int(predictor.model.image_encoder.patch_embed.proj.kernel_size[0])  # 16

print(f"SAM '{selected_model}' loaded. IMG_SIZE={IMG_SIZE}, PATCH={PATCH}")

In [None]:
from pathlib import Path
from tqdm import tqdm

save_dir = Path(base_dir) / "data" / "features"
print("Saving features to:", save_dir)

with torch.no_grad():
    with PreComputedFeaturemaps(save_dir, device=device) as pcm:
        for img_tensor, img_size, img_category, img_name in tqdm(
                test_dataset.iter_images(),
                total=test_dataset.num_images(),
                desc="Generating embeddings"
        ):
            img_tensor = img_tensor.to(device).unsqueeze(0)  # [1,3,H,W]
            orig_size = tuple(img_size[1:])  # (H,W)
            resized = predictor.transform.apply_image_torch(img_tensor)  # [1,3,H',W']
            predictor.set_torch_image(resized, orig_size)
            img_emb = predictor.get_image_embedding()[0]  # [C,h',w']

            pcm.save_featuremaps(img_emb, img_category, img_name)



In [None]:
import math

def kps_src_to_featmap(kps_src: torch.Tensor, img_src_size: torch.Size):
    img_h = int(img_src_size[-2])
    img_w = int(img_src_size[-1])

    # (N,2) coords nella resized (no padding)
    coords = predictor.transform.apply_coords_torch(kps_src, (img_h, img_w))  # (N,2)

    img_resized_h, img_resized_w = predictor.transform.get_preprocess_shape(img_h, img_w, IMG_SIZE)

    xf = torch.floor(coords[:, 0] / PATCH).long()
    yf = torch.floor(coords[:, 1] / PATCH).long()

    wv = math.ceil(img_resized_w / PATCH)
    hv = math.ceil(img_resized_h / PATCH)

    xf = xf.clamp(0, wv - 1)
    yf = yf.clamp(0, hv - 1)

    return torch.stack([xf, yf], dim=1)  # (N,2) (x_idx,y_idx)


def kp_featmap_to_trg(y_featmap, x_featmap, trg_img_size: torch.Size):
    img_h = int(trg_img_size[-2])
    img_w = int(trg_img_size[-1])

    resized_h, resized_w = predictor.transform.get_preprocess_shape(img_h, img_w, IMG_SIZE)

    # token center in preprocessed (padded) image coords
    yr = (y_featmap + 0.5) * 16
    xr = (x_featmap + 0.5) * 16

    # discard tokens that fall into padding
    if xr < 0 or yr < 0 or xr >= resized_w or yr >= resized_h:
        return None

    scale = resized_w / img_w  # uniforme
    y_trg = yr / scale
    x_trg = xr / scale

    return y_trg, x_trg

In [12]:
from utils.utils_correspondence import hard_argmax
from utils.utils_results import CorrespondenceResult
from tqdm import tqdm

results = []
torch.cuda.empty_cache()

with torch.no_grad():
    with PreComputedFeaturemaps(save_dir, device=device) as pcm:
        for batch in tqdm(
            test_dataloader,
            total=len(test_dataloader),
            desc=f"Elaborazione con SAM {selected_model}"
        ):
            category = batch["category"]

            orig_size_src = batch["src_imsize"]  # torch.Size([C, Hs, Ws]) o simile
            orig_size_trg = batch["trg_imsize"]  # torch.Size([C, Ht, Wt]) o simile

            src_imname = batch["src_imname"]
            trg_imname = batch["trg_imname"]

            # Embeddings SAM (C, h, w) tipicamente (256, 64, 64) o simili
            src_emb = pcm.load_featuremaps(category, src_imname)  # [C,hs,ws]
            trg_emb = pcm.load_featuremaps(category, trg_imname)  # [C,ht,wt]

            # Keypoints (N,2) in pixel originali, ordine (x,y)
            src_kps = batch["src_kps"].to(device)
            trg_kps = batch["trg_kps"].to(device)

            # -------------------------
            # Target: dimensioni originali + dimensioni resize (no padding)
            # -------------------------
            Ht = int(orig_size_trg[-2])
            Wt = int(orig_size_trg[-1])

            H_prime, W_prime = predictor.transform.get_preprocess_shape(
                Ht, Wt, predictor.transform.target_length
            )

            # Regione valida in token (no padding)
            hv_t = (H_prime + PATCH - 1) // PATCH
            wv_t = (W_prime + PATCH - 1) // PATCH

            # -------------------------
            # Prepara target flat sulla regione valida
            # -------------------------
            C_ft = trg_emb.shape[0]
            trg_valid = trg_emb[:, :hv_t, :wv_t]                       # [C, hv, wv]
            trg_flat = trg_valid.permute(1, 2, 0).reshape(-1, C_ft)    # [Pvalid, C]

            # -------------------------
            # Mappa i keypoint SRC -> indici featuremap SRC (token space)
            # (Assumo che kps_src_to_featmap ritorni (N,2) (x_idx, y_idx) long)
            # -------------------------
            src_kps_idx = kps_src_to_featmap(src_kps, orig_size_src)   # (N,2) (x_idx,y_idx)

            N_kps = src_kps_idx.shape[0]
            distances_this_image = []

            # scala SAM (uniforme sul lato lungo)
            if Wt >= Ht:
                scale = W_prime / Wt
            else:
                scale = H_prime / Ht

            # -------------------------
            # Loop keypoints
            # -------------------------
            for i in range(N_kps):
                src_idx = src_kps_idx[i]   # (x_idx, y_idx) su featuremap
                trg_kp = trg_kps[i]        # (x,y) originale

                if torch.isnan(src_idx).any() or torch.isnan(trg_kp).any():
                    continue

                x_idx = int(src_idx[0].item())
                y_idx = int(src_idx[1].item())

                # Feature vector sorgente (C,)
                src_vec = src_emb[:, y_idx, x_idx]  # [C]

                # Cosine similarity su tutte le posizioni target valide
                sim = torch.cosine_similarity(trg_flat, src_vec.unsqueeze(0), dim=1)  # [Pvalid]

                # (hv_t, wv_t) similarity map in token space
                sim2d = sim.view(hv_t, wv_t)

                # ------------------------------------------------------------
                # UPSAMPLE SOLO DELLA SIMILARITY MAP (hard argmax più fine)
                # token space (hv,wv) -> resized pixel space (H_prime,W_prime)
                # ------------------------------------------------------------
                sim_r = torch.nn.functional.interpolate(
                    sim2d[None, None],               # (1,1,hv,wv)
                    size=(H_prime, W_prime),         # resized (no pad)
                    mode="bilinear",
                    align_corners=False
                )[0, 0]                              # (H_prime, W_prime)

                # argmax in resized pixel coords
                x_r, y_r = hard_argmax(sim_r)

                # resized -> originale
                x_pred = x_r / scale
                y_pred = y_r / scale

                # distanza in pixel originali
                dx = x_pred - float(trg_kp[0])
                dy = y_pred - float(trg_kp[1])
                dist = math.sqrt(dx * dx + dy * dy)
                distances_this_image.append(dist)

            # salva risultato
            results.append(
                CorrespondenceResult(
                    category=category,
                    distances=distances_this_image,
                    pck_threshold_0_05=batch["pck_threshold_0_05"],
                    pck_threshold_0_1=batch["pck_threshold_0_1"],
                    pck_threshold_0_2=batch["pck_threshold_0_2"]
                )
            )

Elaborazione con SAM vit_b: 100%|██████████| 12234/12234 [02:02<00:00, 100.21it/s]


In [13]:
from utils.utils_results import compute_pckt_images, compute_correct_per_category, compute_pckt_keypoints

# Compute and print results
correct = compute_correct_per_category(results)
compute_pckt_keypoints(correct)
compute_pckt_images(correct)

PCK Results per keypoints (%):
       Category   PCK 0.05    PCK 0.1    PCK 0.2
0     aeroplane  17.938931  26.045075  39.912759
1       bicycle  10.194049  16.067270  27.347995
2          bird  20.414747  30.967742  45.299539
3          boat  10.366214  18.133491  32.959244
4        bottle  17.274143  26.728972  41.915888
5           bus  13.424168  18.137145  28.121510
6           car  15.033408  20.545657  30.846325
7           cat  28.170378  39.415941  54.969345
8         chair   9.364732  13.910186  22.426068
9           cow  19.186264  27.323628  40.854797
10          dog  12.919358  20.795999  35.507397
11        horse   9.868141  16.333475  27.839217
12    motorbike   8.847737  16.460905  26.807760
13       person  18.179702  30.726257  45.297952
14  pottedplant  12.887511  23.095660  33.392383
15        sheep   9.000000  15.631579  26.842105
16        train  19.630485  30.542725  49.896074
17    tvmonitor  14.326570  24.622089  40.542013
18          All  14.834808  23.082433 