This notebook is for the first task: evaluate models on SPair-71k using PCK as metric.\
This file is intended to be run on Colab, not locally.

In [None]:
# Repositories
!git clone https://github.com/Luffy65/Semantic-Correspondence.git # Clone repo
!git clone https://github.com/facebookresearch/dinov3.git # DINOv3
!pip install git+https://github.com/facebookresearch/segment-anything.git # SAM

# Install requirements
!pip install -r Semantic-Correspondence/requirements.txt
!pip install -r dinov3/requirements.txt

Cloning into 'Semantic-Correspondence'...
remote: Enumerating objects: 87, done.[K
remote: Counting objects: 100% (87/87), done.[K
remote: Compressing objects: 100% (64/64), done.[K
remote: Total 87 (delta 25), reused 73 (delta 13), pack-reused 0 (from 0)[K
Receiving objects: 100% (87/87), 4.66 MiB | 11.06 MiB/s, done.
Resolving deltas: 100% (25/25), done.
Cloning into 'dinov3'...
remote: Enumerating objects: 538, done.[K
remote: Counting objects: 100% (363/363), done.[K
remote: Compressing objects: 100% (264/264), done.[K
remote: Total 538 (delta 201), reused 99 (delta 99), pack-reused 175 (from 1)[K
Receiving objects: 100% (538/538), 9.88 MiB | 19.02 MiB/s, done.
Resolving deltas: 100% (223/223), done.
Collecting git+https://github.com/facebookresearch/segment-anything.git
  Cloning https://github.com/facebookresearch/segment-anything.git to /tmp/pip-req-build-8tptbpqq
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/segment-anything.

In [None]:
# Libraries
import torch
import os
import shutil
import gzip
import cv2
from google.colab import drive
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn.functional as F
import numpy as np
import json
from torchvision import transforms
from tqdm import tqdm # optional

In [None]:
# Connect to Google Drive, load and unzip data
# 1. Mount Drive
drive.mount('/content/drive')

# 2. Define Paths
DRIVE_ROOT = '/content/drive/MyDrive/AML-PROJECT-DATA/'
DATASET_ROOT = os.path.join(DRIVE_ROOT, 'dataset/')
DATASET_ARCHIVE = os.path.join(DATASET_ROOT, 'SPair-71k.tar.gz')

LOCAL_DATA_DIR = '/content/data'

# 3. Copy and Extract
if not os.path.exists(LOCAL_DATA_DIR):
    print(f"Extracting {DATASET_ARCHIVE} to local VM...")
    os.makedirs(LOCAL_DATA_DIR, exist_ok=True)
    shutil.unpack_archive(DATASET_ARCHIVE, LOCAL_DATA_DIR, format='gztar')
    print("Done! Data is ready at:", LOCAL_DATA_DIR)
else:
    print("Data already loaded.")

Mounted at /content/drive
Extracting /content/drive/MyDrive/AML-PROJECT-DATA/dataset/SPair-71k.tar.gz to local VM...
Done! Data is ready at: /content/data


In [None]:
# Instantiate the 3 foundation models
from segment_anything import SamPredictor, sam_model_registry

CHECKPOINTS_ROOT = os.path.join(DRIVE_ROOT, 'checkpoints/')
SAM_WEIGHTS_PATH = os.path.join(CHECKPOINTS_ROOT, 'sam_vit_b_01ec64.pth')
DINOV3_WEIGHTS_PATH = os.path.join(CHECKPOINTS_ROOT, 'dinov3_vitb16_pretrain_lvd1689m-73cec8be.pth')
DINOV3_REPO_DIR = "dinov3"

sam = sam_model_registry["vit_b"](checkpoint=SAM_WEIGHTS_PATH)
sampredictor = SamPredictor(sam)

dinov2_vitb14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
dinov3_vitb16 = torch.hub.load(DINOV3_REPO_DIR, 'dinov3_vitb16', source='local', weights=DINOV3_WEIGHTS_PATH) # DINOv3 ViT model pretrained on web images

Using cache found in /root/.cache/torch/hub/facebookresearch_dinov2_main


In [None]:
# Exploration: see if the models work

# SAM (it works). multimask_output=True outputs 3 masks; multimask_output=False outputs 1 mask. Maybe 3 is better.
"""
# print(sam) # prints the layers

AEROPLANES_DIR = os.path.join(LOCAL_DATA_DIR, 'SPair-71k/JPEGImages/aeroplane')

for image_name in os.listdir(AEROPLANES_DIR):
    image_path = os.path.join(AEROPLANES_DIR, image_name)
    image = cv2.imread(image_path)

    # Convert BGR to RGB for displaying with matplotlib
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    print("Original Image:")
    plt.imshow(image_rgb)
    plt.axis('off')
    plt.show()

    sampredictor.set_image(image_rgb) # Set RGB image for prediction
    masks, scores, logits = sampredictor.predict(
        point_coords=None,
        point_labels=None,
        box=None,
        multimask_output=True,
    )

    # Display all predicted masks
    if masks is not None and len(masks) > 0:
        for i, (mask, score) in enumerate(zip(masks, scores)):
            # Overlay the mask on the original image (optional, for better visualization)
            segmented_image = image_rgb.copy()
            alpha = 0.5
            color = (255, 0, 0) # Red color for the mask
            for c in range(3):
                segmented_image[:, :, c] = segmented_image[:, :, c] * (1 - alpha) + mask * alpha * color[c]

            print(f"Segmented Image (Mask {i}, Score: {score:.4f}):")
            plt.imshow(segmented_image)
            plt.axis('off')
            plt.show()
    else:
        print("No masks found for this image.")

    # Break after the first image for demonstration purposes
    break
"""


In [None]:
# --- Extract features from the foundation models ---

def extract_dino_features(model, img_tensor):
    """
    Extracts dense features from DINO-like models (ViT).
    Returns: (1, Feature_Dim, H_grid, W_grid)
    """
    model.eval()
    with torch.no_grad():
        if hasattr(model, 'forward_features'):
            out = model.forward_features(img_tensor)
            # Handle dictionary output (common in DINOv2/v3)
            if isinstance(out, dict):
                patch_tokens = out.get("x_norm_patchtokens", out.get("x_norm_patch_tokens"))
            else:
                patch_tokens = out

            if patch_tokens is None:
                raise ValueError(f"Could not find patch tokens. Keys: {out.keys() if isinstance(out, dict) else 'N/A'}")

            # Reshape: (B, N, D) -> (B, D, H, W)
            B, N, D = patch_tokens.shape
            grid_size = int(np.sqrt(N))
            feature_map = patch_tokens.permute(0, 2, 1).reshape(B, D, grid_size, grid_size)
            return feature_map
    return None

def extract_sam_features(predictor, image_np):
    """
    Extracts features using SAM Image Encoder.
    """
    predictor.set_image(image_np) # Expects HxWxC uint8
    with torch.no_grad():
        # Get the image embedding (1, 256, 64, 64)
        features = predictor.get_image_embedding()
    return features

In [None]:
# --- Main Evaluation Function ---

def computePCKatT(model, dataset_root, thresholds=[0.05, 0.1, 0.2], img_size=(224, 224)):
    device = "cuda" if torch.cuda.is_available() else "cpu"

    is_sam = False
    if "SamPredictor" in str(type(model)):
        is_sam = True
    else:
        model = model.to(device)

    # Base Paths
    pair_ann_root = os.path.join(dataset_root, 'SPair-71k/PairAnnotation')
    image_dir = os.path.join(dataset_root, 'SPair-71k/JPEGImages')

    # Find Annotation Files (Focus on 'test' split)
    search_paths = [os.path.join(pair_ann_root, 'test')]
    pair_files = []

    print("Searching for annotation files...")
    for path in search_paths:
        if os.path.isdir(path):
            # Recursively find all json files (handles test/aeroplane/*.json structure)
            for root, dirs, files in os.walk(path):
                for f in files:
                    if f.endswith('.json'):
                        pair_files.append(os.path.join(root, f))
            break

    if not pair_files:
        print(f"Error: No annotation files found in {search_paths}")
        return {t: 0.0 for t in thresholds}

    # Filter for 'aeroplane' only (for faster testing)
    pair_files = [f for f in pair_files if 'aeroplane' in f]
    print(f"Evaluating on {len(pair_files)} aeroplane pairs...")

    # Optional: limit to 50 for speed
    pair_files = pair_files[:50]

    # Transform for DINO
    transform = transforms.Compose([
        transforms.Resize(img_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    correct_kps = {t: 0 for t in thresholds}
    total_kps = 0

    for pair_path in tqdm(pair_files):
        with open(pair_path, 'r') as f:
            ann = json.load(f)

        # --- FIX: Use 'category' to build correct path ---
        category = ann['category']
        src_img_path = os.path.join(image_dir, category, ann['src_imname'])
        trg_img_path = os.path.join(image_dir, category, ann['trg_imname'])

        if not os.path.exists(src_img_path) or not os.path.exists(trg_img_path):
            # Debug print only once if missing
            if total_kps == 0:
                print(f"Missing Image: {src_img_path}")
            continue

        # Load Images
        src_pil = Image.open(src_img_path).convert('RGB')
        trg_pil = Image.open(trg_img_path).convert('RGB')
        src_w, src_h = src_pil.size
        trg_w, trg_h = trg_pil.size

        # Extract Features
        if is_sam:
            f_src = extract_sam_features(model, np.array(src_pil))
            f_trg = extract_sam_features(model, np.array(trg_pil))
        else:
            src_tensor = transform(src_pil).unsqueeze(0).to(device)
            trg_tensor = transform(trg_pil).unsqueeze(0).to(device)
            f_src = extract_dino_features(model, src_tensor)
            f_trg = extract_dino_features(model, trg_tensor)

        f_src = F.normalize(f_src, dim=1)
        f_trg = F.normalize(f_trg, dim=1)
        fh, fw = f_src.shape[2], f_src.shape[3]

        # Keypoints & BBox
        src_kps = ann['src_kps']
        trg_kps = ann['trg_kps']
        trg_bbox = ann['trg_bndbox']

        # PCK Normalization Scale
        bbox_w = trg_bbox[2] - trg_bbox[0]
        bbox_h = trg_bbox[3] - trg_bbox[1]
        norm_factor = max(bbox_w, bbox_h)

        kp_indices = range(len(src_kps)) if isinstance(src_kps, list) else src_kps.keys()

        for idx in kp_indices:
            p_src = src_kps[idx]
            p_trg = trg_kps[idx]

            if p_src is None or p_trg is None: continue

            # 1. Map Source Point -> Feature Grid
            if is_sam:
                 # SAM: Simple relative mapping
                 feat_x = int(p_src[0] / src_w * fw)
                 feat_y = int(p_src[1] / src_h * fh)
            else:
                 # DINO: relative mapping
                 feat_x = int(p_src[0] / src_w * fw)
                 feat_y = int(p_src[1] / src_h * fh)

            feat_x = min(max(feat_x, 0), fw - 1)
            feat_y = min(max(feat_y, 0), fh - 1)

            # 2. Get Source Descriptor
            target_feat = f_src[:, :, feat_y, feat_x] # (1, C)

            # 3. Compute Similarity Map
            sim = torch.einsum('nc,nchw->nhw', target_feat, f_trg)

            # 4. Find Best Match
            best_idx = torch.argmax(sim.flatten())
            pred_y_idx = best_idx // fw
            pred_x_idx = best_idx % fw

            # 5. Map back to Target Pixels
            pred_x = (pred_x_idx.item() / fw) * trg_w
            pred_y = (pred_y_idx.item() / fh) * trg_h

            # 6. Evaluate
            dist = np.sqrt((pred_x - p_trg[0])**2 + (pred_y - p_trg[1])**2)

            total_kps += 1
            for t in thresholds:
                if dist <= (t * norm_factor):
                    correct_kps[t] += 1

    return {t: (correct_kps[t] / total_kps) if total_kps > 0 else 0.0 for t in thresholds}


In [38]:
# Move the SAM model to GPU
sam.to(device='cuda')
sampredictor = SamPredictor(sam) # Re-wrap it just to be safe
print("SAM model moved to CUDA.")

SAM model moved to CUDA.


In [None]:
# --- 3. Execute ---
print("Evaluating DINOv2 (Aeroplane)")
dinov2_pck = computePCKatT(dinov2_vitb14, LOCAL_DATA_DIR)
print(f"DINOv2 PCK: {dinov2_pck}")

print("\n=== Evaluating DINOv3 (Aeroplane) ===")
dinov3_pck = computePCKatT(dinov3_vitb16, LOCAL_DATA_DIR)
print(f"DINOv3 PCK: {dinov3_pck}")

print("\n=== Evaluating SAM (Aeroplane) ===")
sam_pck = computePCKatT(sampredictor, LOCAL_DATA_DIR)
print(f"SAM PCK: {sam_pck}")

Evaluating DINOv2 (Aeroplane)
Searching for annotation files...
Evaluating on 600 aeroplane pairs...


100%|██████████| 50/50 [00:02<00:00, 19.14it/s]


DINOv2 PCK: {0.05: 0.21146245059288538, 0.1: 0.5395256916996047, 0.2: 0.7707509881422925}
=== Evaluating DINOv3 (Aeroplane) ===
Searching for annotation files...
Evaluating on 600 aeroplane pairs...


100%|██████████| 50/50 [00:02<00:00, 24.70it/s]


DINOv3 PCK: {0.05: 0.15612648221343872, 0.1: 0.44466403162055335, 0.2: 0.7509881422924901}

=== Evaluating SAM (Aeroplane) ===
Searching for annotation files...
Evaluating on 600 aeroplane pairs...


100%|██████████| 50/50 [00:40<00:00,  1.24it/s]

SAM PCK: {0.05: 0.029644268774703556, 0.1: 0.11462450592885376, 0.2: 0.30632411067193677}



