In [1]:
# Monta Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
# Percorsi
zip_su_drive = '/content/drive/MyDrive/semantic_correspondence.zip'
zip_locale = '/content/semantic_correspondence.zip'
cartella_destinazione = '/content/'

# Copia lo zip in locale
import shutil
shutil.copy(zip_su_drive, zip_locale)

'/content/semantic_correspondence.zip'

In [3]:
# Estrai lo zip
import zipfile, os
os.makedirs(cartella_destinazione, exist_ok=True)
with zipfile.ZipFile(zip_locale, 'r') as z:
    z.extractall(cartella_destinazione)


In [4]:
# 5. Verify GPU
import torch
print(f"\n✓ GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'No GPU'}")



✓ GPU: Tesla T4


In [8]:

base = '/content/semantic_correspondence/SPair71k'

In [None]:
import json
from collections import defaultdict
import numpy as np
import torch
import torch.nn.functional as F
import time
import os
from datetime import datetime
import pandas as pd
from pathlib import Path
import shutil
import sys
# Add the extracted directory to the Python path
sys.path.insert(0, '/content/semantic_correspondence')
research_path = "/content/semantic_correspondence/models/segment_anything"
if research_path not in sys.path:
    sys.path.insert(0, research_path)

from SPair71k.devkit.SPairDataset import SPairDataset
from helper_functions import extract_dense_features, extract_dense_features_SAM, pixel_to_patch_coord, patch_to_pixel_coord
from matching_strategies import find_best_match_argmax, find_best_match_window_softargmax
from pck import compute_pck_spair71k
from models.dinov3.dinov3.models.vision_transformer import vit_base as dinov3_vit_base
from models.dinov2.dinov2.models.vision_transformer import vit_base as dinov2_vit_base
from models.segment_anything.segment_anything import sam_model_registry

# ==================== CONFIG ====================
IMG_SIZE_DINOV2 = 518
PATCH_SIZE_DINOV2 = 14
IMG_SIZE_DINOV3 = 512
PATCH_SIZE_DINOV3 = 16
IMG_SIZE_SAM = 512
PATCH_SIZE_SAM = 16

BASE = 'semantic_correspondence/SPair71k'
PAIR_ANN_PATH = f'{BASE}/PairAnnotation'
LAYOUT_PATH = f'{BASE}/Layout'
IMAGE_PATH = f'{BASE}/JPEGImages'
DATASET_SIZE = 'large'
PCK_ALPHA = 0.1
THRESHOLDS = [0.05, 0.1, 0.2]

print(os.getcwd())

b = "semantic_correspondence/models"

CHECKPOINT_PATHS = {
    "DINOv2": f"{b}/dinov2/weights/dinov2_vitb14_finetuned_only_model_10temp.pth",
    "DINOv3": f"{b}/dinov3/weights/finetuned/dinov3_vitb16_finetuned_3bl_0.0001lr_15t.pth",
    "SAM": f"{b}/segment_anything/weights/finetuned/SAM_finetuned_4bl_15t_0.0001lr.pth"
}


# ==================== HELPER FUNCTIONS ====================

def load_models(device):
    """Load all three finetuned models."""
    print("Loading models...")

    # DINOv2
    dinov2 = dinov2_vit_base(
        img_size=(IMG_SIZE_DINOV2, IMG_SIZE_DINOV2),
        patch_size=PATCH_SIZE_DINOV2,
        num_register_tokens=0,
        block_chunks=0,
        init_values=1.0,
    )
    ckpt_dinov2 = torch.load(CHECKPOINT_PATHS["DINOv2"], map_location=device)
    dinov2.load_state_dict(ckpt_dinov2, strict=True)
    dinov2.to(device)
    dinov2.eval()

    # DINOv3
    dinov3 = dinov3_vit_base(
        img_size=(IMG_SIZE_DINOV3, IMG_SIZE_DINOV3),
        patch_size=PATCH_SIZE_DINOV3,
        n_storage_tokens=4,
        layerscale_init=1.0,
        mask_k_bias=True,
    )
    ckpt_dinov3 = torch.load(CHECKPOINT_PATHS["DINOv3"], map_location=device)
    dinov3.load_state_dict(ckpt_dinov3["model_state_dict"], strict=True)
    dinov3.to(device)
    dinov3.eval()

    # SAM
    # Initialize the SAM model without loading checkpoint yet
    sam = sam_model_registry["vit_b"](checkpoint=None) # Pass None to initialize without loading
    sam.to(device)

    # Load the custom finetuned checkpoint
    print(f"Loading finetuned SAM checkpoint from {CHECKPOINT_PATHS['SAM']}")
    checkpoint = torch.load(CHECKPOINT_PATHS["SAM"], map_location=device)

    # The finetuned checkpoint likely contains more than just the model state_dict.
    # Extract the actual model_state_dict and load it
    if 'model_state_dict' in checkpoint:
        sam.load_state_dict(checkpoint['model_state_dict'])
        print("Successfully loaded 'model_state_dict' from checkpoint.")
    else:
        # If the checkpoint itself is just the state_dict, try loading it directly
        sam.load_state_dict(checkpoint)
        print("Successfully loaded checkpoint directly as state_dict.")
    sam.eval()

    print("✓ All models loaded successfully")
    return dinov2, dinov3, sam


def normalize_features(features):
    """L2 normalize features along the feature dimension."""
    # features: [B, H, W, D] or [H*W, D]
    if len(features.shape) == 4:
        # [B, H, W, D] -> normalize over D
        return F.normalize(features, p=2, dim=-1)
    else:
        # [H*W, D] -> normalize over D
        return F.normalize(features, p=2, dim=1)


def evaluate_ensemble_with_params(
    models_dict,
    dataset,
    device,
    K,
    temperature,
    weights,
    thresholds=None
):
    """
    Evaluate ensemble with weighted_avg fusion.

    Args:
        models_dict: dict with 'dinov2', 'dinov3', 'sam' models
        dataset: evaluation dataset
        device: torch device
        K: window size for softargmax
        temperature: softmax temperature
        weights: [w_dinov2, w_dinov3, w_sam] for weighted fusion (must sum to 1)
        thresholds: PCK thresholds

    Returns:
        per_image_metrics: list of dicts with PCK scores
    """
    if thresholds is None:
        thresholds = THRESHOLDS

    per_image_metrics = []
    dinov2, dinov3, sam = models_dict['dinov2'], models_dict['dinov3'], models_dict['sam']

    with torch.no_grad():
        for idx, sample in enumerate(dataset):
            # Load and resize images
            src_tensor = sample['src_img'].unsqueeze(0).to(device)
            tgt_tensor = sample['trg_img'].unsqueeze(0).to(device)

            # Resize for DINOv2
            src_dinov2 = F.interpolate(src_tensor, size=(IMG_SIZE_DINOV2, IMG_SIZE_DINOV2),
                                       mode='bilinear', align_corners=False)
            tgt_dinov2 = F.interpolate(tgt_tensor, size=(IMG_SIZE_DINOV2, IMG_SIZE_DINOV2),
                                       mode='bilinear', align_corners=False)

            # Resize for DINOv3
            src_dinov3 = F.interpolate(src_tensor, size=(IMG_SIZE_DINOV3, IMG_SIZE_DINOV3),
                                       mode='bilinear', align_corners=False)
            tgt_dinov3 = F.interpolate(tgt_tensor, size=(IMG_SIZE_DINOV3, IMG_SIZE_DINOV3),
                                       mode='bilinear', align_corners=False)

            # Resize for SAM
            src_sam = F.interpolate(src_tensor, size=(IMG_SIZE_SAM, IMG_SIZE_SAM),
                                    mode='bilinear', align_corners=False)
            tgt_sam = F.interpolate(tgt_tensor, size=(IMG_SIZE_SAM, IMG_SIZE_SAM),
                                    mode='bilinear', align_corners=False)

            # Extract features from all models
            src_feat_dinov2 = extract_dense_features(dinov2, src_dinov2)
            tgt_feat_dinov2 = extract_dense_features(dinov2, tgt_dinov2)

            src_feat_dinov3 = extract_dense_features(dinov3, src_dinov3)
            tgt_feat_dinov3 = extract_dense_features(dinov3, tgt_dinov3)

            src_feat_sam = extract_dense_features_SAM(sam, src_sam, image_size=IMG_SIZE_SAM)
            tgt_feat_sam = extract_dense_features_SAM(sam, tgt_sam, image_size=IMG_SIZE_SAM)

            # Get original sizes
            src_original_size = (sample['src_imsize'][2], sample['src_imsize'][1])
            tgt_original_size = (sample['trg_imsize'][2], sample['trg_imsize'][1])

            # Get keypoints
            src_kps = sample['src_kps'].numpy()
            trg_kps = sample['trg_kps'].numpy()
            kps_ids = sample['kps_ids']
            trg_bbox = sample['trg_bbox']
            category = sample['category']

            # Prepare target features for score-level fusion
            tgt_feat_dinov2_squeezed = tgt_feat_dinov2.squeeze(0)  # [H2, W2, D2]
            tgt_feat_dinov3_squeezed = tgt_feat_dinov3.squeeze(0)  # [H3, W3, D3]
            tgt_feat_sam_squeezed    = tgt_feat_sam.squeeze(0)     # [Hs, Ws, Ds]

            # Use SAM grid as the reference grid
            ref_shape = tgt_feat_sam_squeezed.shape  # (Hs, Ws, Ds)
            H_ref, W_ref = ref_shape[0], ref_shape[1]

            # Precompute normalized target flats for score-level fusion
            H2, W2, D2 = tgt_feat_dinov2_squeezed.shape
            H3, W3, D3 = tgt_feat_dinov3_squeezed.shape
            Hs, Ws, Ds = tgt_feat_sam_squeezed.shape

            tgt_v2_flat = F.normalize(tgt_feat_dinov2_squeezed.reshape(H2 * W2, D2), dim=1)
            tgt_v3_flat = F.normalize(tgt_feat_dinov3_squeezed.reshape(H3 * W3, D3), dim=1)
            tgt_s_flat  = F.normalize(tgt_feat_sam_squeezed.reshape(Hs * Ws, Ds),    dim=1)

            pred_matches = []

            # Process each keypoint
            for i in range(src_kps.shape[0]):
                src_x, src_y = src_kps[i]

                # Source features per model
                px2, py2 = pixel_to_patch_coord(src_x, src_y, src_original_size,
                                                patch_size=PATCH_SIZE_DINOV2, resized_size=IMG_SIZE_DINOV2)
                src_v2 = F.normalize(src_feat_dinov2[0, py2, px2, :], dim=0)

                px3, py3 = pixel_to_patch_coord(src_x, src_y, src_original_size,
                                                patch_size=PATCH_SIZE_DINOV3, resized_size=IMG_SIZE_DINOV3)
                src_v3 = F.normalize(src_feat_dinov3[0, py3, px3, :], dim=0)

                pxs, pys = pixel_to_patch_coord(src_x, src_y, src_original_size,
                                                patch_size=PATCH_SIZE_SAM, resized_size=IMG_SIZE_SAM)
                src_vs = F.normalize(src_feat_sam[0, pys, pxs, :], dim=0)

                # Score-level fusion: build per-model sim maps, upsample to ref grid, then weight-sum
                sim2 = F.cosine_similarity(src_v2.unsqueeze(0), tgt_v2_flat, dim=1).view(H2, W2)
                sim3 = F.cosine_similarity(src_v3.unsqueeze(0), tgt_v3_flat, dim=1).view(H3, W3)
                sims = F.cosine_similarity(src_vs.unsqueeze(0),  tgt_s_flat,  dim=1).view(Hs, Ws)

                def resize_map(m, H_t, W_t):
                    if (H_t, W_t) == (H_ref, W_ref):
                        return m
                    return F.interpolate(m.unsqueeze(0).unsqueeze(0), size=(H_ref, W_ref),
                                         mode='bilinear', align_corners=False).squeeze(0).squeeze(0)

                sim2_r = resize_map(sim2, H2, W2)
                sim3_r = resize_map(sim3, H3, W3)
                sims_r = resize_map(sims, Hs, Ws)

                similarities = (weights[0] * sim2_r + weights[1] * sim3_r + weights[2] * sims_r).reshape(-1)

                # Find best match on ensemble similarity map
                match_patch_x, match_patch_y = find_best_match_window_softargmax(
                    similarities, W_ref, H_ref, K=K, temperature=temperature
                )

                # Convert to original image coords (ref grid = SAM)
                match_x, match_y = patch_to_pixel_coord(
                    match_patch_x, match_patch_y, tgt_original_size,
                    patch_size=PATCH_SIZE_SAM, resized_size=IMG_SIZE_SAM
                )
                pred_matches.append([match_x, match_y])

            # Compute PCK
            image_pcks = {}
            for threshold in thresholds:
                pck, _, _ = compute_pck_spair71k(
                    pred_matches,
                    trg_kps.tolist(),
                    trg_bbox,
                    threshold
                )
                image_pcks[threshold] = pck

            per_image_metrics.append({
                'category': category,
                'pck_scores': image_pcks,
            })

            if (idx + 1) % 100 == 0:
                print(f"  Processed {idx + 1}/{len(dataset)} images")

    return per_image_metrics


# ==================== MAIN ====================

if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    # Load models
    dinov2, dinov3, sam = load_models(device)
    models_dict = {'dinov2': dinov2, 'dinov3': dinov3, 'sam': sam}

    # Fixed window soft-argmax params
    K = 5
    temperature = 0.2

    # Weighted-average fusion weights: [DINOv2, DINOv3, SAM]
    weights = [0.25, 0.65, 0.10]

    # Load test dataset
    print("\nLoading test dataset...")
    val_dataset = SPairDataset(PAIR_ANN_PATH, LAYOUT_PATH, IMAGE_PATH, DATASET_SIZE,
                                PCK_ALPHA, datatype='val')
    print(f"✓ Test set loaded: {len(val_dataset)} pairs")

    # Results dir
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    wtag = f"{weights[0]:.2f}-{weights[1]:.2f}-{weights[2]:.2f}"
    results_dir = f'results_SPair71K/ensemble/weighted_avg/K{K}_T{temperature}_w{wtag}_{timestamp}'
    os.makedirs(results_dir, exist_ok=True)
    print(f"Results will be saved to: {results_dir}")

    # Evaluate with weighted_avg fusion
    start = time.time()
    per_image_metrics = evaluate_ensemble_with_params(
        models_dict=models_dict,
        dataset=val_dataset,
        device=device,
        K=K,
        temperature=temperature,
        weights=weights,
        thresholds=THRESHOLDS
    )
    elapsed = time.time() - start
    print(f"Total inference time: {elapsed:.2f} seconds")

    # Aggregate overall stats
    overall_stats = {"inference_time_sec": elapsed}
    for threshold in THRESHOLDS:
        all_pcks = np.array([img['pck_scores'][threshold] for img in per_image_metrics])
        overall_stats[f"pck@{threshold:.2f}"] = {
            "mean": float(np.mean(all_pcks)),
            "std": float(np.std(all_pcks)),
            "median": float(np.median(all_pcks)),
            "p25": float(np.percentile(all_pcks, 25)),
            "p75": float(np.percentile(all_pcks, 75)),
        }
        print(f"PCK@{threshold:.2f}: mean={overall_stats[f'pck@{threshold:.2f}']['mean']:.2f}% "
              f"std={overall_stats[f'pck@{threshold:.2f}']['std']:.2f}% "
              f"median={overall_stats[f'pck@{threshold:.2f}']['median']:.2f}% "
              f"p25={overall_stats[f'pck@{threshold:.2f}']['p25']:.2f}% "
              f"p75={overall_stats[f'pck@{threshold:.2f}']['p75']:.2f}%")

    # Save outputs
    with open(f'{results_dir}/overall_stats.json', 'w') as f:
        json.dump(overall_stats, f, indent=2)
    df_all = pd.DataFrame([
        {"category": m["category"], **{f"pck@{t:.2f}": m["pck_scores"][t] for t in THRESHOLDS}}
        for m in per_image_metrics
    ])
    df_all.to_csv(f'{results_dir}/per_image_metrics.csv', index=False)
    print(f"Saved overall_stats.json and per_image_metrics.csv to {results_dir}")

    drive_results_base_path = '/content/drive/MyDrive/Colab_ensmble_validation_results/'
    drive_destination_path = os.path.join(drive_results_base_path, os.path.basename(results_dir))

    try:
        if not os.path.exists(drive_results_base_path):
            os.makedirs(drive_results_base_path, exist_ok=True)
        shutil.copytree(results_dir, drive_destination_path)
        print(f"\n✓ Successfully copied results to Google Drive: {drive_destination_path}")
    except Exception as e:
        print(f"\n✗ Error copying results to Google Drive: {e}")

GRID SEARCH FOR WINDOWED SOFTARGMAX HYPERPARAMETERS
Temperature values: [0.05, 0.1, 0.2, 0.5, 1.0, 2.0]
Total combinations: 30
Validation set size: 5384

[1/30] Testing K=3, temperature=0.05
  Processed 1/5384 images
  Processed 101/5384 images
  Processed 1001/5384 images
  Processed 2001/5384 images
  Processed 3001/5384 images
  Processed 4001/5384 images
  Processed 5001/5384 images
  PCK@0.05: mean=56.75%, median=60.00%
  PCK@0.10: mean=68.98%, median=75.00%
  PCK@0.20: mean=78.86%, median=87.50%
  Time: 927.29s

[2/30] Testing K=3, temperature=0.1
  Processed 1/5384 images


KeyboardInterrupt: 

In [None]:
# Smonta il Drive
drive.flush_and_unmount()