In [None]:
# ============================================================================
# CELLA 1: Setup Progetto (riutilizza da eval.ipynb)
# ============================================================================

from google.colab import drive
import sys
import os
from pathlib import Path

print(" AML Semantic Correspondence - Fine-tuning Stage\n")

# 1. Mount Google Drive
if not Path('/content/drive').exists():
    drive.mount('/content/drive')
    print(" Google Drive mounted\n")
else:
    print(" Google Drive already mounted\n")

# 2. Setup directories
PROJECT_ROOT = '/content/drive/MyDrive/AML'
DATA_DIR = f'{PROJECT_ROOT}/dataset'
CHECKPOINT_DIR = f'{PROJECT_ROOT}/checkpoints'
RESULTS_DIR = f'{PROJECT_ROOT}/results'
FINETUNED_DIR = f'{PROJECT_ROOT}/finetuned_models'

os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)
os.makedirs(FINETUNED_DIR, exist_ok=True)

# 3. Clone repository
GITHUB_REPO_URL = 'https://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@github.com/SamueleCarrea/AML_SemanticCorrespondence'
LOCAL_REPO_NAME = 'AML_SemanticCorrespondence'

if not Path(LOCAL_REPO_NAME).exists():
    print(f"\n Cloning repository...")
    !git clone {GITHUB_REPO_URL} {LOCAL_REPO_NAME}
    print(" Repository cloned")
else:
    print(f"\n Repository {LOCAL_REPO_NAME} already exists.")
    if Path(LOCAL_REPO_NAME, '.git').exists():
        print(" Pulling latest changes...")
        %cd {LOCAL_REPO_NAME}
        !git pull
        %cd ..
        print(" Repository updated")

sys.path.insert(0, LOCAL_REPO_NAME)

# 4. GPU info
import torch
print(f"\n  GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'No GPU'}")
if torch.cuda.is_available():
    print(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

print("\n Setup complete!\n")

In [None]:
# ============================================================================
# CELLA 2: Install Dependencies
# ============================================================================

print(" Installing dependencies...\n")

!pip install -q -r {LOCAL_REPO_NAME}/requirements.txt
!pip install -q tensorboard

import torch
print(f"\n PyTorch {torch.__version__}")
print(f" CUDA available: {torch.cuda.is_available()}")
print("\n Dependencies installed!\n")

In [None]:
# ============================================================================
# CELLA 3: Load Datasets (Train, Val, Test)
# ============================================================================

from dataset.spair import SPairDataset
from torch.utils.data import DataLoader

SPAIR_ROOT = f'{DATA_DIR}/Spair-71k'

# Load all splits
train_dataset = SPairDataset(
    root=SPAIR_ROOT,
    split='train',
    size='large',
    long_side=518,
    normalize=True,
    load_segmentation=False
)

val_dataset = SPairDataset(
    root=SPAIR_ROOT,
    split='val',
    size='large',
    long_side=518,
    normalize=True,
    load_segmentation=False
)

test_dataset = SPairDataset(
    root=SPAIR_ROOT,
    split='test',
    size='large',
    long_side=518,
    normalize=True,
    load_segmentation=False
)

# DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print(f" Dataset Statistics:")
print(f"   Train: {len(train_dataset)} pairs")
print(f"   Val:   {len(val_dataset)} pairs")
print(f"   Test:  {len(test_dataset)} pairs")
print(f"\n Datasets loaded!")

In [None]:
# ============================================================================
# CELLA 4: Fine-tuning Configuration
# ============================================================================

from dataclasses import dataclass
from typing import List

@dataclass
class FinetuneConfig:
    """Fine-tuning configuration."""

    # Model
    backbone_name: str = 'dinov2_vitb14'
    num_layers_to_finetune: int = 2  # Number of last transformer blocks

    # Training
    num_epochs: int = 10
    learning_rate: float = 1e-5
    weight_decay: float = 0.01
    warmup_epochs: int = 1

    # Loss
    loss_type: str = 'cosine'  # 'cosine', 'l2', or 'combined'
    negative_margin: float = 0.2

    # Optimization
    batch_size: int = 8
    gradient_accumulation_steps: int = 1
    max_grad_norm: float = 1.0

    # Logging
    log_interval: int = 50
    val_interval: int = 500
    save_interval: int = 1000

    # Paths
    checkpoint_dir: str = FINETUNED_DIR
    experiment_name: str = None

    def __post_init__(self):
        if self.experiment_name is None:
            self.experiment_name = f"{self.backbone_name}_ft{self.num_layers_to_finetune}"


# Create configs for different experiments
configs = {
    'dinov2_2layers': FinetuneConfig(
        backbone_name='dinov2_vitb14',
        num_layers_to_finetune=2,
        num_epochs=10,
        learning_rate=1e-5
    ),
    'dinov2_4layers': FinetuneConfig(
        backbone_name='dinov2_vitb14',
        num_layers_to_finetune=4,
        num_epochs=10,
        learning_rate=5e-6
    ),
    'dinov3_2layers': FinetuneConfig(
        backbone_name='dinov3_vitb16',
        num_layers_to_finetune=2,
        num_epochs=10,
        learning_rate=1e-5
    ),
}

# Select config for this run
config = configs['dinov2_2layers']

print(f"   Fine-tuning Configuration:")
print(f"   Backbone: {config.backbone_name}")
print(f"   Layers to finetune: {config.num_layers_to_finetune}")
print(f"   Epochs: {config.num_epochs}")
print(f"   Learning rate: {config.learning_rate}")
print(f"   Batch size: {config.batch_size}")
print(f"\n Config ready!")

In [None]:
# ============================================================================
# CELLA 5: Loss Functions for Correspondence
# ============================================================================

import torch
import torch.nn as nn
import torch.nn.functional as F

class CorrespondenceLoss(nn.Module):
    """Loss function for semantic correspondence fine-tuning."""

    def __init__(
        self,
        loss_type: str = 'cosine',
        negative_margin: float = 0.2,
        temperature: float = 0.1
    ):
        super().__init__()
        self.loss_type = loss_type
        self.negative_margin = negative_margin
        self.temperature = temperature

    def forward(
        self,
        src_features: torch.Tensor,
        tgt_features: torch.Tensor,
        src_kps: torch.Tensor,
        tgt_kps: torch.Tensor,
        valid_mask: torch.Tensor,
        patch_size: int
    ) -> torch.Tensor:
        """
        Compute correspondence loss.

        Args:
            src_features: (B, H_s, W_s, D)
            tgt_features: (B, H_t, W_t, D)
            src_kps: (B, N, 2) source keypoints
            tgt_kps: (B, N, 2) target keypoints
            valid_mask: (B, N) valid keypoint mask
            patch_size: int

        Returns:
            loss: scalar tensor
        """
        B = src_features.shape[0]
        total_loss = 0.0
        n_valid = 0

        for b in range(B):
            src_feat = src_features[b]  # (H_s, W_s, D)
            tgt_feat = tgt_features[b]  # (H_t, W_t, D)
            src_kp = src_kps[b]  # (N, 2)
            tgt_kp = tgt_kps[b]  # (N, 2)
            valid = valid_mask[b]  # (N,)

            if valid.sum() == 0:
                continue

            # Filter valid keypoints
            src_kp_valid = src_kp[valid]
            tgt_kp_valid = tgt_kp[valid]

            # Convert to patch coordinates
            src_kp_patch = (src_kp_valid / patch_size).long()
            tgt_kp_patch = (tgt_kp_valid / patch_size).long()

            H_s, W_s, D = src_feat.shape
            H_t, W_t, _ = tgt_feat.shape

            # Clamp coordinates
            src_kp_patch[:, 0] = src_kp_patch[:, 0].clamp(0, W_s - 1)
            src_kp_patch[:, 1] = src_kp_patch[:, 1].clamp(0, H_s - 1)
            tgt_kp_patch[:, 0] = tgt_kp_patch[:, 0].clamp(0, W_t - 1)
            tgt_kp_patch[:, 1] = tgt_kp_patch[:, 1].clamp(0, H_t - 1)

            # Extract features at keypoint locations
            N = src_kp_valid.shape[0]
            for i in range(N):
                src_x, src_y = src_kp_patch[i]
                tgt_x, tgt_y = tgt_kp_patch[i]

                src_vec = src_feat[src_y, src_x]  # (D,)
                tgt_vec = tgt_feat[tgt_y, tgt_x]  # (D,)

                if self.loss_type == 'cosine':
                    # Maximize cosine similarity between corresponding points
                    similarity = F.cosine_similarity(
                        src_vec.unsqueeze(0), tgt_vec.unsqueeze(0), dim=1
                    )
                    loss = 1.0 - similarity

                elif self.loss_type == 'l2':
                    # Minimize L2 distance
                    loss = F.mse_loss(src_vec, tgt_vec)

                elif self.loss_type == 'contrastive':
                    # Contrastive loss: pull positives, push negatives
                    # Positive: corresponding point
                    pos_sim = F.cosine_similarity(
                        src_vec.unsqueeze(0), tgt_vec.unsqueeze(0), dim=1
                    )

                    # Negatives: sample random points from target
                    neg_indices = torch.randint(0, H_t * W_t, (8,), device=src_vec.device)
                    tgt_flat = tgt_feat.reshape(-1, D)
                    neg_vecs = tgt_flat[neg_indices]  # (8, D)

                    neg_sim = F.cosine_similarity(
                        src_vec.unsqueeze(0).expand(8, -1),
                        neg_vecs,
                        dim=1
                    )

                    # InfoNCE-style loss
                    pos_exp = torch.exp(pos_sim / self.temperature)
                    neg_exp = torch.exp(neg_sim / self.temperature).sum()

                    loss = -torch.log(pos_exp / (pos_exp + neg_exp))

                else:
                    raise ValueError(f"Unknown loss type: {self.loss_type}")

                total_loss += loss
                n_valid += 1

        if n_valid == 0:
            return torch.tensor(0.0, device=src_features.device, requires_grad=True)

        return total_loss / n_valid


# Test loss function
print(" Testing loss function...")

loss_fn = CorrespondenceLoss(loss_type='cosine')

# Dummy data
src_feat = torch.randn(2, 37, 37, 768)
tgt_feat = torch.randn(2, 37, 37, 768)
src_kps = torch.randint(0, 500, (2, 10, 2)).float()
tgt_kps = torch.randint(0, 500, (2, 10, 2)).float()
valid_mask = torch.ones(2, 10).bool()

loss = loss_fn(src_feat, tgt_feat, src_kps, tgt_kps, valid_mask, patch_size=14)
print(f" Loss computed: {loss.item():.4f}")

In [None]:
# ============================================================================
# CELLA 6: Trainable Backbone Wrapper
# ============================================================================

import torch
import torch.nn as nn
from models.backbones import DINOv2Extractor, DINOv3Extractor

class FinetunableBackbone(nn.Module):
    """Wrapper to make backbone partially trainable."""

    def __init__(
        self,
        backbone_name: str,
        num_layers_to_finetune: int = 2,
        device: str = 'cuda'
    ):
        super().__init__()
        self.backbone_name = backbone_name
        self.num_layers_to_finetune = num_layers_to_finetune
        self.device = device

        # Load frozen backbone
        if 'dinov2' in backbone_name:
            self.extractor = DINOv2Extractor(
                variant=backbone_name,
                device=device,
                allow_hub_download=True
            )
        elif 'dinov3' in backbone_name:
            self.extractor = DINOv3Extractor(
                variant=backbone_name,
                device=device
            )
        else:
            raise ValueError(f"Unsupported backbone: {backbone_name}")

        self.stride = self.extractor.stride

        # Freeze all parameters first
        for param in self.extractor.parameters():
            param.requires_grad = False

        # Unfreeze last N transformer blocks
        self._unfreeze_last_layers(num_layers_to_finetune)

        # Count trainable parameters
        n_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        n_total = sum(p.numel() for p in self.parameters())

        print(f"\n Trainable parameters:")
        print(f"   Total: {n_total:,}")
        print(f"   Trainable: {n_trainable:,} ({n_trainable/n_total*100:.2f}%)")

    def _unfreeze_last_layers(self, num_layers: int):
        """Unfreeze last N transformer blocks."""
        model = self.extractor.model

        # Access transformer blocks
        if hasattr(model, 'blocks'):
            blocks = model.blocks
        elif hasattr(model, 'encoder') and hasattr(model.encoder, 'layers'):
            blocks = model.encoder.layers
        else:
            raise AttributeError("Cannot find transformer blocks in model")

        total_blocks = len(blocks)
        start_idx = max(0, total_blocks - num_layers)

        print(f"\n Unfreezing blocks {start_idx} to {total_blocks-1} (total: {total_blocks})")

        for i in range(start_idx, total_blocks):
            for param in blocks[i].parameters():
                param.requires_grad = True

    def extract_features(self, image: torch.Tensor) -> torch.Tensor:
        """Extract features (with gradient if training)."""
        feat_map, stride = self.extractor.extract_feats(image)
        features = feat_map.permute(0, 2, 3, 1)  # (B, H, W, D)
        return features

    def forward(self, image: torch.Tensor) -> torch.Tensor:
        """Forward pass."""
        return self.extract_features(image)


# Test trainable backbone
print(" Testing trainable backbone...")

backbone = FinetunableBackbone(
    backbone_name='dinov2_vitb14',
    num_layers_to_finetune=2,
    device='cuda' if torch.cuda.is_available() else 'cpu'
)

test_img = torch.randn(1, 3, 518, 518).to(backbone.device)
with torch.set_grad_enabled(True):
    features = backbone(test_img)
    print(f"\n Features shape: {features.shape}")
    print(f" Requires grad: {features.requires_grad}")

In [None]:
# ============================================================================
# CELLA 7: Training Loop
# ============================================================================

import torch
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from tqdm import tqdm
import json
from pathlib import Path

class Trainer:
    """Fine-tuning trainer."""

    def __init__(
        self,
        model: FinetunableBackbone,
        train_loader: DataLoader,
        val_loader: DataLoader,
        config: FinetuneConfig,
        device: str = 'cuda'
    ):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.config = config
        self.device = device

        # Loss function
        self.criterion = CorrespondenceLoss(
            loss_type=config.loss_type,
            negative_margin=config.negative_margin
        ).to(device)

        # Optimizer
        self.optimizer = optim.AdamW(
            [p for p in model.parameters() if p.requires_grad],
            lr=config.learning_rate,
            weight_decay=config.weight_decay
        )

        # Learning rate scheduler with warmup
        warmup_steps = config.warmup_epochs * len(train_loader)
        total_steps = config.num_epochs * len(train_loader)

        warmup_scheduler = LinearLR(
            self.optimizer,
            start_factor=0.1,
            total_iters=warmup_steps
        )

        cosine_scheduler = CosineAnnealingLR(
            self.optimizer,
            T_max=total_steps - warmup_steps
        )

        self.scheduler = SequentialLR(
            self.optimizer,
            schedulers=[warmup_scheduler, cosine_scheduler],
            milestones=[warmup_steps]
        )

        # Logging
        self.train_losses = []
        self.val_losses = []
        self.global_step = 0

        # Checkpoint directory
        self.ckpt_dir = Path(config.checkpoint_dir) / config.experiment_name
        self.ckpt_dir.mkdir(parents=True, exist_ok=True)

        print(f"\n  Checkpoints: {self.ckpt_dir}")

    def train_epoch(self, epoch: int):
        """Train for one epoch."""
        self.model.train()
        epoch_loss = 0.0
        n_batches = 0

        pbar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{self.config.num_epochs}")

        for batch_idx, batch in enumerate(pbar):
            # Move to device
            src_img = batch['src_img'].to(self.device)
            tgt_img = batch['tgt_img'].to(self.device)
            src_kps = batch['src_kps'].to(self.device)
            tgt_kps = batch['tgt_kps'].to(self.device)
            valid_mask = batch['valid_mask'].to(self.device)

            # Forward pass
            src_features = self.model(src_img)
            tgt_features = self.model(tgt_img)

            # Compute loss
            loss = self.criterion(
                src_features, tgt_features,
                src_kps, tgt_kps, valid_mask,
                patch_size=self.model.stride
            )

            # Backward pass
            loss = loss / self.config.gradient_accumulation_steps
            loss.backward()

            # Gradient accumulation
            if (batch_idx + 1) % self.config.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(),
                    self.config.max_grad_norm
                )
                self.optimizer.step()
                self.scheduler.step()
                self.optimizer.zero_grad()

            # Logging
            epoch_loss += loss.item() * self.config.gradient_accumulation_steps
            n_batches += 1
            self.global_step += 1

            if self.global_step % self.config.log_interval == 0:
                avg_loss = epoch_loss / n_batches
                lr = self.optimizer.param_groups[0]['lr']
                pbar.set_postfix({
                    'loss': f'{avg_loss:.4f}',
                    'lr': f'{lr:.2e}'
                })
                self.train_losses.append({
                    'step': self.global_step,
                    'loss': avg_loss,
                    'lr': lr
                })

            # Validation
            if self.global_step % self.config.val_interval == 0:
                val_loss = self.validate()
                self.val_losses.append({
                    'step': self.global_step,
                    'loss': val_loss
                })
                self.model.train()

            # Save checkpoint
            if self.global_step % self.config.save_interval == 0:
                self.save_checkpoint(epoch, batch_idx)

        return epoch_loss / n_batches

    @torch.no_grad()
    def validate(self):
        """Validation loop."""
        self.model.eval()
        val_loss = 0.0
        n_batches = 0

        for batch in tqdm(self.val_loader, desc="Validation", leave=False):
            src_img = batch['src_img'].to(self.device)
            tgt_img = batch['tgt_img'].to(self.device)
            src_kps = batch['src_kps'].to(self.device)
            tgt_kps = batch['tgt_kps'].to(self.device)
            valid_mask = batch['valid_mask'].to(self.device)

            src_features = self.model(src_img)
            tgt_features = self.model(tgt_img)

            loss = self.criterion(
                src_features, tgt_features,
                src_kps, tgt_kps, valid_mask,
                patch_size=self.model.stride
            )

            val_loss += loss.item()
            n_batches += 1

        avg_val_loss = val_loss / n_batches
        print(f"\n Validation Loss: {avg_val_loss:.4f}")

        return avg_val_loss

    def save_checkpoint(self, epoch: int, batch_idx: int):
        """Save training checkpoint."""
        checkpoint = {
            'epoch': epoch,
            'global_step': self.global_step,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'config': self.config,
            'train_losses': self.train_losses,
            'val_losses': self.val_losses,
        }

        ckpt_path = self.ckpt_dir / f'checkpoint_step{self.global_step}.pt'
        torch.save(checkpoint, ckpt_path)
        print(f"\n Saved checkpoint: {ckpt_path}")

        # Save config
        config_path = self.ckpt_dir / 'config.json'
        with open(config_path, 'w') as f:
            json.dump(vars(self.config), f, indent=2)

    def train(self):
        """Full training loop."""
        print(f"\n  Starting training: {self.config.experiment_name}")
        print(f"   Epochs: {self.config.num_epochs}")
        print(f"   Batches per epoch: {len(self.train_loader)}")

        for epoch in range(self.config.num_epochs):
            train_loss = self.train_epoch(epoch)
            print(f"\n Epoch {epoch+1} - Train Loss: {train_loss:.4f}")

            # End-of-epoch validation
            val_loss = self.validate()

            # Save epoch checkpoint
            self.save_checkpoint(epoch, len(self.train_loader))

        print(f"\n Training complete!")


print(" Trainer ready!")

In [None]:
# ============================================================================
# CELLA 8: Launch Fine-tuning
# ============================================================================

# Initialize model
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = FinetunableBackbone(
    backbone_name=config.backbone_name,
    num_layers_to_finetune=config.num_layers_to_finetune,
    device=device
)

# Initialize trainer
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    config=config,
    device=device
)

# Start training
trainer.train()

In [None]:
# ============================================================================
# CELLA 9: Evaluate Fine-tuned Model
# ============================================================================

from collections import defaultdict
import time
import numpy as np

class FinetunedEvaluator:
    """Evaluator for fine-tuned models."""

    def __init__(
        self,
        model: FinetunableBackbone,
        dataloader: DataLoader,
        device: str = 'cuda',
        thresholds: list = [0.05, 0.10, 0.15, 0.20]
    ):
        self.model = model
        self.dataloader = dataloader
        self.device = device
        self.thresholds = thresholds

    @torch.no_grad()
    def evaluate(self, num_samples: int = None):
        """Evaluate fine-tuned model."""
        self.model.eval()

        all_pck = defaultdict(list)
        per_category = defaultdict(lambda: defaultdict(list))
        inference_times = []

        n_processed = 0
        pbar = tqdm(self.dataloader, desc="Evaluating")

        for batch in pbar:
            if num_samples and n_processed >= num_samples:
                break

            src_img = batch['src_img'].to(self.device)
            tgt_img = batch['tgt_img'].to(self.device)
            src_kps = batch['src_kps'][0]
            tgt_kps = batch['tgt_kps'][0]
            valid_mask = batch['valid_mask'][0]
            category = batch['category'][0]

            src_kps_valid = src_kps[valid_mask]
            tgt_kps_valid = tgt_kps[valid_mask]

            if len(src_kps_valid) == 0:
                continue

            # Predict
            start = time.time()
            tgt_kps_pred = self._predict_keypoints(
                src_img, tgt_img, src_kps_valid
            )
            inference_times.append(time.time() - start)

            # Compute metrics
            from dataset.spair import compute_pck
            H, W = tgt_img.shape[2:]
            pck_scores = compute_pck(
                tgt_kps_pred, tgt_kps_valid, (H, W), thresholds=self.thresholds
            )

            for metric, value in pck_scores.items():
                all_pck[metric].append(value)
                per_category[category][metric].append(value)

            n_processed += 1

            if len(all_pck['PCK@0.10']) > 0:
                avg_pck = np.mean(all_pck['PCK@0.10'])
                pbar.set_postfix({'PCK@0.10': f'{avg_pck:.4f}'})

        # Aggregate results
        results = {
            'overall': {},
            'per_category': {},
            'num_pairs': n_processed,
            'inference_time_ms': np.mean(inference_times) * 1000
        }

        for metric in [f'PCK@{t:.2f}' for t in self.thresholds]:
            values = all_pck[metric]
            results['overall'][metric] = {
                'mean': np.mean(values),
                'std': np.std(values)
            }

        for cat, metrics in per_category.items():
            results['per_category'][cat] = {}
            for metric in [f'PCK@{t:.2f}' for t in self.thresholds]:
                results['per_category'][cat][metric] = np.mean(metrics[metric])

        return results

    def _predict_keypoints(
        self,
        src_img: torch.Tensor,
        tgt_img: torch.Tensor,
        src_kps: torch.Tensor
    ) -> torch.Tensor:
        """Predict target keypoints."""
        src_feat = self.model(src_img)[0]
        tgt_feat = self.model(tgt_img)[0]

        H_s, W_s, D = src_feat.shape
        H_t, W_t, _ = tgt_feat.shape

        patch_size = self.model.stride

        src_kps_patch = (src_kps / patch_size).long()
        src_kps_patch[:, 0] = src_kps_patch[:, 0].clamp(0, W_s - 1)
        src_kps_patch[:, 1] = src_kps_patch[:, 1].clamp(0, H_s - 1)

        N = src_kps.shape[0]
        tgt_kps_pred = torch.zeros(N, 2, device=src_kps.device)

        for i in range(N):
            x, y = src_kps_patch[i]
            src_vec = src_feat[y, x]

            similarity = F.cosine_similarity(
                src_vec.view(1, 1, 1, D),
                tgt_feat.unsqueeze(0),
                dim=-1
            ).squeeze(0)

            max_idx = similarity.flatten().argmax()
            pred_y = max_idx // W_t
            pred_x = max_idx % W_t

            tgt_kps_pred[i, 0] = pred_x * patch_size + patch_size // 2
            tgt_kps_pred[i, 1] = pred_y * patch_size + patch_size // 2

        return tgt_kps_pred


# Evaluate
evaluator = FinetunedEvaluator(
    model=model,
    dataloader=test_loader,
    device=device
)

results = evaluator.evaluate()

# Print results
print("\n" + "="*70)
print("FINE-TUNED MODEL RESULTS")
print("="*70)
for metric, vals in results['overall'].items():
    print(f"   {metric}: {vals['mean']:.4f} ± {vals['std']:.4f}")
print(f"\n⏱  Inference: {results['inference_time_ms']:.2f} ms/pair")
print(f" Evaluated: {results['num_pairs']} pairs")
print("="*70)

# Save results
results_path = f"{RESULTS_DIR}/{config.experiment_name}_results.json"
with open(results_path, 'w') as f:
    json.dump(results, f, indent=2)
print(f"\n Results saved: {results_path}")

In [None]:
# ============================================================================
# CELLA 10: Compare Baseline vs Fine-tuned
# ============================================================================

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

def plot_comparison(baseline_results, finetuned_results):
    """Plot baseline vs fine-tuned comparison."""

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    thresholds = [0.05, 0.10, 0.15, 0.20]

    # Extract PCK values
    baseline_pck = [baseline_results['overall'][f'PCK@{t:.2f}']['mean']
                    for t in thresholds]
    finetuned_pck = [finetuned_results['overall'][f'PCK@{t:.2f}']['mean']
                     for t in thresholds]

    # Plot 1: PCK Curves
    ax1 = axes[0]
    ax1.plot(thresholds, baseline_pck, marker='o', linewidth=2,
             label='Baseline (Frozen)', markersize=8, color='#E74C3C')
    ax1.plot(thresholds, finetuned_pck, marker='s', linewidth=2,
             label=f'Fine-tuned ({config.num_layers_to_finetune} layers)',
             markersize=8, color='#27AE60')

    ax1.set_xlabel('Threshold', fontsize=12, fontweight='bold')
    ax1.set_ylabel('PCK', fontsize=12, fontweight='bold')
    ax1.set_title('PCK Comparison', fontsize=14, fontweight='bold')
    ax1.legend(loc='lower right')
    ax1.grid(True, alpha=0.3)
    ax1.set_ylim([0, 1])

    # Plot 2: Improvement Bar Chart
    ax2 = axes[1]
    improvements = [(ft - bl) * 100 for bl, ft in zip(baseline_pck, finetuned_pck)]
    colors = ['#27AE60' if imp > 0 else '#E74C3C' for imp in improvements]

    bars = ax2.bar(range(len(thresholds)), improvements, color=colors, alpha=0.8)
    ax2.set_xticks(range(len(thresholds)))
    ax2.set_xticklabels([f'PCK@{t:.2f}' for t in thresholds])
    ax2.set_ylabel('Improvement (%)', fontsize=12, fontweight='bold')
    ax2.set_title('Fine-tuning Improvement', fontsize=14, fontweight='bold')
    ax2.grid(axis='y', alpha=0.3)
    ax2.axhline(y=0, color='black', linestyle='-', linewidth=0.8)

    for bar, val in zip(bars, improvements):
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2, height,
                f'{val:+.2f}%', ha='center',
                va='bottom' if height > 0 else 'top',
                fontsize=10, fontweight='bold')

    plt.tight_layout()

    save_path = f'{RESULTS_DIR}/{config.experiment_name}_comparison.png'
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"\n Saved: {save_path}")

    plt.show()


# Load baseline results
baseline_path = f'{RESULTS_DIR}/dinov2_vitb14_results.json'
with open(baseline_path, 'r') as f:
    baseline_results = json.load(f)

# Plot comparison
plot_comparison(baseline_results, results)

# Print summary table
print("\n" + "="*70)
print("SUMMARY COMPARISON")
print("="*70)

comparison_data = []
for t in [0.05, 0.10, 0.15, 0.20]:
    metric = f'PCK@{t:.2f}'
    bl = baseline_results['overall'][metric]['mean']
    ft = results['overall'][metric]['mean']
    imp = (ft - bl) * 100

    comparison_data.append({
        'Metric': metric,
        'Baseline': f'{bl:.4f}',
        'Fine-tuned': f'{ft:.4f}',
        'Improvement': f'{imp:+.2f}%'
    })

df = pd.DataFrame(comparison_data)
print(df.to_string(index=False))
print("="*70)