In [1]:
# Monta Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
# Percorsi
zip_su_drive = '/content/drive/MyDrive/semantic_correspondence.zip'
zip_locale = '/content/semantic_correspondence.zip'
cartella_destinazione = '/content/'

# Copia lo zip in locale
import shutil
shutil.copy(zip_su_drive, zip_locale)

'/content/semantic_correspondence.zip'

In [3]:
# Estrai lo zip
import zipfile, os
os.makedirs(cartella_destinazione, exist_ok=True)
with zipfile.ZipFile(zip_locale, 'r') as z:
    z.extractall(cartella_destinazione)


In [4]:
# 5. Verify GPU
import torch
print(f"\n✓ GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'No GPU'}")



✓ GPU: Tesla T4


In [8]:

base = '/content/semantic_correspondence/SPair71k'

In [None]:
import json
import time
import os
import sys
from pathlib import Path

# Add the extracted directory to the Python path
sys.path.insert(0, '/content/semantic_correspondence')

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F

research_path = "/content/semantic_correspondence/models/segment_anything"
if research_path not in sys.path:
    sys.path.insert(0, research_path)

from SPair71k.devkit.SPairDataset import SPairDataset
from helper_functions import extract_dense_features_SAM, pixel_to_patch_coord, patch_to_pixel_coord
from matching_strategies import find_best_match_window_softargmax
# from models.dinov2.dinov2.models.vision_transformer import vit_base
# from models.dinov3.dinov3.models.vision_transformer import vit_base
from models.segment_anything.segment_anything import SamPredictor, sam_model_registry
from pck import compute_pck_spair71k


def evaluate_with_params(model, dataset, device, K, temperature, img_size, patch_size, thresholds=[0.05, 0.1, 0.2]):
    """Evaluate model with specific K and temperature parameters."""
    per_image_metrics = []

    with torch.no_grad():
        for idx, sample in enumerate(dataset):
            src_tensor = sample['src_img'].unsqueeze(0).to(device)
            tgt_tensor = sample['trg_img'].unsqueeze(0).to(device)

            # resize to 518x518
            src_tensor = F.interpolate(src_tensor, size=(img_size, img_size), mode='bilinear', align_corners=False)
            tgt_tensor = F.interpolate(tgt_tensor, size=(img_size, img_size), mode='bilinear', align_corners=False)

            # save original sizes
            src_original_size = (sample['src_imsize'][2], sample['src_imsize'][1])
            tgt_original_size = (sample['trg_imsize'][2], sample['trg_imsize'][1])

            # extract dense features
            src_features = extract_dense_features_SAM(model, src_tensor, image_size=img_size)
            tgt_features = extract_dense_features_SAM(model, tgt_tensor, image_size=img_size)

            # reshape
            _, H, W, D = tgt_features.shape
            tgt_flat = tgt_features.reshape(H * W, D)

            # extract keypoints
            src_kps = sample['src_kps'].numpy()
            trg_kps = sample['trg_kps'].numpy()
            trg_bbox = sample['trg_bbox']

            pred_matches = []

            # iterate over keypoints
            for i in range(src_kps.shape[0]):
                src_x, src_y = src_kps[i]
                patch_x, patch_y = pixel_to_patch_coord(src_x, src_y, src_original_size, patch_size=patch_size, resized_size=img_size)

                # extract source feature
                src_feature = src_features[0, patch_y, patch_x, :]

                # compute cosine similarities
                similarities = F.cosine_similarity(
                    src_feature.unsqueeze(0),
                    tgt_flat,
                    dim=1
                )

                # find best match with windowed softargmax
                match_patch_x, match_patch_y = find_best_match_window_softargmax(
                    similarities, W, H, K=K, temperature=temperature
                )
                match_x, match_y = patch_to_pixel_coord(
                    match_patch_x, match_patch_y, tgt_original_size,
                    patch_size=patch_size, resized_size=img_size
                )

                pred_matches.append([match_x, match_y])

            # compute PCK for each threshold
            image_pcks = {}
            for threshold in thresholds:
                pck, _, _ = compute_pck_spair71k(
                    pred_matches,
                    trg_kps.tolist(),
                    trg_bbox,
                    threshold
                )
                image_pcks[threshold] = pck

            per_image_metrics.append({
                'category': sample['category'],
                'num_keypoints': src_kps.shape[0],
                'pck_scores': image_pcks,
            })
            if idx==100 or idx%1000==0:
                print(f"  Processed {idx+1}/{len(dataset)} images", flush=True)
            # if idx==10:
            #   break  # debug test on 50 images only

    return per_image_metrics

def run_grid_search(model, val_dataset, device, results_dir):
    """Run grid search over K and temperature parameters."""

    #hyperparameter ranges
    # K_values = [3, 5, 7, 9]
    K_values = [5]
    # K = 5
    temperature_values = [0.01, 0.05, 0.1, 0.2, 0.5, 1.0, 2.0]
    thresholds = [0.05, 0.1, 0.2]

    Path(results_dir).mkdir(parents=True, exist_ok=True)

    print("=" * 80)
    print("GRID SEARCH FOR WINDOWED SOFTARGMAX HYPERPARAMETERS")
    print("=" * 80)
    print(f"K values: {K_values}")
    print(f"Temperature values: {temperature_values}")
    print(f"Total combinations: {len(K_values) * len(temperature_values)}")
    print(f"Validation set size: {len(val_dataset)}")
    print("=" * 80)

    all_results = []
    total_combinations = len(K_values) * len(temperature_values)
    current_combo = 0

    for K in K_values:
        for temp in temperature_values:
            current_combo += 1
            print(f"\n[{current_combo}/{total_combinations}] Testing K={K}, temperature={temp}")

            start_time = time.time()
            per_image_metrics = evaluate_with_params(
                model, val_dataset, device, K, temp, img_size, patch_size, thresholds
            )
            inference_time = time.time() - start_time

            result = {
                'K': K,
                'temperature': temp,
                'inference_time_sec': inference_time,
            }

            for threshold in thresholds:
                all_pcks = np.array([img['pck_scores'][threshold] for img in per_image_metrics])

                result[f'pck@{threshold:.2f}_mean'] = float(np.mean(all_pcks))
                result[f'pck@{threshold:.2f}_std'] = float(np.std(all_pcks))
                result[f'pck@{threshold:.2f}_median'] = float(np.median(all_pcks))
                result[f'pck@{threshold:.2f}_p25'] = float(np.percentile(all_pcks, 25))
                result[f'pck@{threshold:.2f}_p75'] = float(np.percentile(all_pcks, 75))

                print(f"  PCK@{threshold:.2f}: mean={result[f'pck@{threshold:.2f}_mean']:.2f}%, "
                      f"median={result[f'pck@{threshold:.2f}_median']:.2f}%")

            all_results.append(result)
            print(f"  Time: {inference_time:.2f}s")

    #save all results_SPair71K to CSV
    df_results = pd.DataFrame(all_results)
    csv_path = f'{results_dir}/grid_search_results.csv'
    df_results.to_csv(csv_path, index=False)
    print(f"\n{'=' * 80}")
    print(f"Saved grid search results to '{csv_path}'")

    #find best parameters for each threshold
    print(f"\n{'=' * 80}")
    print("BEST PARAMETERS FOR EACH THRESHOLD")
    print("=" * 80)

    best_params_summary = []
    for threshold in thresholds:
        metric_col = f'pck@{threshold:.2f}_mean'
        best_idx = df_results[metric_col].idxmax()
        best_row = df_results.loc[best_idx]

        best_params = {
            'threshold': threshold,
            'best_K': int(best_row['K']),
            'best_temperature': float(best_row['temperature']),
            'best_pck_mean': float(best_row[metric_col]),
            'best_pck_median': float(best_row[f'pck@{threshold:.2f}_median']),
            'best_pck_std': float(best_row[f'pck@{threshold:.2f}_std']),
        }
        best_params_summary.append(best_params)

        print(f"\nPCK@{threshold:.2f}:")
        print(f"  Best K: {best_params['best_K']}")
        print(f"  Best temperature: {best_params['best_temperature']}")
        print(f"  Mean PCK: {best_params['best_pck_mean']:.2f}%")
        print(f"  Median PCK: {best_params['best_pck_median']:.2f}%")
        print(f"  Std PCK: {best_params['best_pck_std']:.2f}%")


    df_best = pd.DataFrame(best_params_summary)
    best_csv_path = f'{results_dir}/best_parameters.csv'
    df_best.to_csv(best_csv_path, index=False)
    print(f"\nSaved best parameters to '{best_csv_path}'")


    best_json_path = f'{results_dir}/best_parameters.json'
    with open(best_json_path, 'w') as f:
        json.dump(best_params_summary, f, indent=2)
    print(f"Saved best parameters to '{best_json_path}'")

    print("=" * 80)

    drive_results_base_path = '/content/drive/MyDrive/Colab_SAM_finetuned_grid_search_results/'
    drive_destination_path = os.path.join(drive_results_base_path, os.path.basename(results_dir))

    try:
        if not os.path.exists(drive_results_base_path):
            os.makedirs(drive_results_base_path, exist_ok=True)
        shutil.copytree(results_dir, drive_destination_path)
        print(f"\n✓ Successfully copied results to Google Drive: {drive_destination_path}")
    except Exception as e:
        print(f"\n✗ Error copying results to Google Drive: {e}")

    return df_results, best_params_summary


if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    img_size = 512
    patch_size = 16
    print("\nLoading SAM model...")
    model_type = "vit_b"
    checkpoint_path = "/content/semantic_correspondence/models/segment_anything/weights/finetuned/SAM_finetuned_4bl_15t_0.0001lr.pth"
    # Initialize the SAM model without loading checkpoint yet
    sam_model = sam_model_registry[model_type](checkpoint=None) # Pass None to initialize without loading
    sam_model.to(device)

    # Load the custom finetuned checkpoint
    print(f"Loading finetuned SAM checkpoint from {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)

    # The finetuned checkpoint likely contains more than just the model state_dict.
    # Extract the actual model_state_dict and load it
    if 'model_state_dict' in checkpoint:
        sam_model.load_state_dict(checkpoint['model_state_dict'])
        print("Successfully loaded 'model_state_dict' from checkpoint.")
    else:
        # If the checkpoint itself is just the state_dict, try loading it directly
        sam_model.load_state_dict(checkpoint)
        print("Successfully loaded checkpoint directly as state_dict.")


    base = '/content/semantic_correspondence/SPair71k'
    pair_ann_path = f'{base}/PairAnnotation'
    layout_path = f'{base}/Layout'
    image_path = f'{base}/JPEGImages'
    dataset_size = 'large'
    pck_alpha = 0.1  # mock, it's not used in evaluation
    val_dataset = SPairDataset(pair_ann_path, layout_path, image_path, dataset_size, pck_alpha, datatype='val')

    df_results, best_params = run_grid_search(sam_model, val_dataset, device, 'grid_search_results/SAM/SAM_finetuned')

GRID SEARCH FOR WINDOWED SOFTARGMAX HYPERPARAMETERS
Temperature values: [0.05, 0.1, 0.2, 0.5, 1.0, 2.0]
Total combinations: 30
Validation set size: 5384

[1/30] Testing K=3, temperature=0.05
  Processed 1/5384 images
  Processed 101/5384 images
  Processed 1001/5384 images
  Processed 2001/5384 images
  Processed 3001/5384 images
  Processed 4001/5384 images
  Processed 5001/5384 images
  PCK@0.05: mean=56.75%, median=60.00%
  PCK@0.10: mean=68.98%, median=75.00%
  PCK@0.20: mean=78.86%, median=87.50%
  Time: 927.29s

[2/30] Testing K=3, temperature=0.1
  Processed 1/5384 images


KeyboardInterrupt: 

In [None]:
# Smonta il Drive
drive.flush_and_unmount()