In [None]:
# ============================================================================
# CELLA 1: Setup Progetto (riutilizza da eval.ipynb)
# ============================================================================

from google.colab import drive
import sys
import os
from pathlib import Path

print(" AML Semantic Correspondence - Fine-tuning Stage\n")

# 1. Mount Google Drive
if not Path('/content/drive').exists():
    drive.mount('/content/drive')
    print(" Google Drive mounted\n")
else:
    print(" Google Drive already mounted\n")

# 2. Setup directories
PROJECT_ROOT = '/content/drive/MyDrive/AML'
DATA_DIR = f'{PROJECT_ROOT}/dataset'
CHECKPOINT_DIR = f'{PROJECT_ROOT}/checkpoints'
RESULTS_DIR = f'{PROJECT_ROOT}/results'
FINETUNED_DIR = f'{PROJECT_ROOT}/finetuned_models'

os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)
os.makedirs(FINETUNED_DIR, exist_ok=True)

# 3. Clone repository
GITHUB_REPO_URL = 'https://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@github.com/SamueleCarrea/AML_SemanticCorrespondence'
LOCAL_REPO_NAME = 'AML_SemanticCorrespondence'

if not Path(LOCAL_REPO_NAME).exists():
    print(f"\n Cloning repository...")
    !git clone {GITHUB_REPO_URL} {LOCAL_REPO_NAME}
    print(" Repository cloned")
else:
    print(f"\n Repository {LOCAL_REPO_NAME} already exists.")
    if Path(LOCAL_REPO_NAME, '.git').exists():
        print(" Pulling latest changes...")
        %cd {LOCAL_REPO_NAME}
        !git pull
        %cd ..
        print(" Repository updated")

sys.path.insert(0, LOCAL_REPO_NAME)

# 4. GPU info
import torch
print(f"\n  GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'No GPU'}")
if torch.cuda.is_available():
    print(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

print("\n Setup complete!\n")

In [None]:
# ============================================================================
# CELLA 2: Install Dependencies
# ============================================================================

print(" Installing dependencies...\n")

!pip install -q -r {LOCAL_REPO_NAME}/requirements.txt
!pip install -q tensorboard

import torch
print(f"\n PyTorch {torch.__version__}")
print(f" CUDA available: {torch.cuda.is_available()}")
print("\n Dependencies installed!\n")

In [None]:
import torch
import torch.nn.functional as F

def spair_pad_collate(batch):
    """
    Collate function per SPair-71K:
    - pad immagini solo in basso e a destra
    - NON altera le coordinate dei keypoints
    - permette batch_size > 1
    """
    imgsA, imgsB = [], []
    kpsA, kpsB = [], []
    meta = []

    max_hA = max(item[0].shape[1] for item in batch)
    max_wA = max(item[0].shape[2] for item in batch)
    max_hB = max(item[1].shape[1] for item in batch)
    max_wB = max(item[1].shape[2] for item in batch)

    for imgA, imgB, kpA, kpB, m in batch:
        padA = (0, max_wA - imgA.shape[2], 0, max_hA - imgA.shape[1])
        padB = (0, max_wB - imgB.shape[2], 0, max_hB - imgB.shape[1])

        imgsA.append(F.pad(imgA, padA))
        imgsB.append(F.pad(imgB, padB))
        kpsA.append(kpA)
        kpsB.append(kpB)
        meta.append(m)

    return (
        torch.stack(imgsA, dim=0),
        torch.stack(imgsB, dim=0),
        torch.stack(kpsA, dim=0),
        torch.stack(kpsB, dim=0),
        meta,
    )

In [None]:
# ============================================================================
# CELLA 3: Load Datasets (Train, Val, Test)
# ============================================================================

from dataset.spair import SPairDataset
from torch.utils.data import DataLoader

SPAIR_ROOT = f'{DATA_DIR}/Spair-71k'

# Load all splits
train_dataset = SPairDataset(
    root=SPAIR_ROOT,
    split='train',
    size='large',
    long_side=518,
    normalize=True,
    load_segmentation=False
)

val_dataset = SPairDataset(
    root=SPAIR_ROOT,
    split='val',
    size='large',
    long_side=518,
    normalize=True,
    load_segmentation=False
)

test_dataset = SPairDataset(
    root=SPAIR_ROOT,
    split='test',
    size='large',
    long_side=518,
    normalize=True,
    load_segmentation=False
)

# DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    collate_fn=spair_pad_collate
)

val_loader = DataLoader(
    val_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print(f" Dataset Statistics:")
print(f"   Train: {len(train_dataset)} pairs")
print(f"   Val:   {len(val_dataset)} pairs")
print(f"   Test:  {len(test_dataset)} pairs")
print(f"\n Datasets loaded!")


In [None]:
# =========================
# Collate function for SPair-71K
# - Pads images in the batch to the max H/W (bottom/right padding only)
# - Pads keypoints to the max number of keypoints in the batch
# - Produces a padded valid_mask so downstream code can ignore padded entries
#   (padding does not shift origin, so keypoints stay correct)
# =========================
import torch
import torch.nn.functional as F

def spair_pad_collate(batch):
    """Batch is a list of dicts returned by SPairDataset."""
    # 1) Pad images
    max_h_src = max(item['src_img'].shape[1] for item in batch)
    max_w_src = max(item['src_img'].shape[2] for item in batch)
    max_h_tgt = max(item['tgt_img'].shape[1] for item in batch)
    max_w_tgt = max(item['tgt_img'].shape[2] for item in batch)

    src_imgs, tgt_imgs = [], []
    for item in batch:
        src = item['src_img']
        tgt = item['tgt_img']
        pad_src = (0, max_w_src - src.shape[2], 0, max_h_src - src.shape[1])
        pad_tgt = (0, max_w_tgt - tgt.shape[2], 0, max_h_tgt - tgt.shape[1])
        src_imgs.append(F.pad(src, pad_src, value=0.0))
        tgt_imgs.append(F.pad(tgt, pad_tgt, value=0.0))

    out = {
        'src_img': torch.stack(src_imgs, dim=0),
        'tgt_img': torch.stack(tgt_imgs, dim=0),
    }

    # 2) Pad keypoints + valid masks
    # SPairDataset returns only valid correspondences, but number of keypoints varies per pair.
    max_kps = max(item['src_kps'].shape[0] for item in batch)

    src_kps_padded = torch.full((len(batch), max_kps, 2), -1.0, dtype=torch.float32)
    tgt_kps_padded = torch.full((len(batch), max_kps, 2), -1.0, dtype=torch.float32)
    valid_mask_padded = torch.zeros((len(batch), max_kps), dtype=torch.bool)

    for bi, item in enumerate(batch):
        n = item['src_kps'].shape[0]
        src_kps_padded[bi, :n] = item['src_kps']
        tgt_kps_padded[bi, :n] = item['tgt_kps']
        # Prefer the dataset-provided valid_mask if present; otherwise assume all n are valid
        if 'valid_mask' in item and torch.is_tensor(item['valid_mask']):
            vm = item['valid_mask']
            valid_mask_padded[bi, :n] = vm[:n]
        else:
            valid_mask_padded[bi, :n] = True

    out['src_kps'] = src_kps_padded
    out['tgt_kps'] = tgt_kps_padded
    out['valid_mask'] = valid_mask_padded

    # 3) Collate the remaining fields
    # Fixed-shape tensors -> stack
    # strings / ids -> keep list
    # variable shape tensors (rare) -> keep list
    keys = list(batch[0].keys())
    for k in keys:
        if k in out:
            continue
        vals = [item[k] for item in batch]
        v0 = vals[0]
        if torch.is_tensor(v0):
            try:
                out[k] = torch.stack(vals, dim=0)
            except RuntimeError:
                out[k] = vals
        else:
            out[k] = vals

    return out


In [None]:
# ============================================================================
# CELLA 4: Fine-tuning Configuration
# ============================================================================

from dataclasses import dataclass
from typing import List, Optional

@dataclass
class FinetuneConfig:
    """Fine-tuning configuration."""

    # Model
    backbone_name: str = 'dinov2_vitb14'  # 'dinov2_*', 'dinov3_*', 'sam_vit_b/l/h'
    num_layers_to_finetune: int = 2      # last transformer blocks to unfreeze (ignored for SAM)

    # Training
    num_epochs: int = 10
    learning_rate: float = 1e-5
    head_lr_mult: float = 10.0   # head LR = learning_rate * head_lr_mult
    proj_out_dim: Optional[int] = None  # None = keep same dim as backbone
    use_projection_head: bool = True
    weight_decay: float = 1e-2
    warmup_epochs: int = 1

    # Scheduler
    use_scheduler: Optional[bool] = None   # if None: enabled only for DINOv2
    step_lr_step_size: int = 5             # used only if scheduler enabled
    step_lr_gamma: float = 0.5             # used only if scheduler enabled

    # Loss
    loss_type: str = 'cosine'  # 'cosine', 'l2', or 'combined'
    negative_margin: float = 0.2

    # Optimization
    batch_size: int = 8
    gradient_accumulation_steps: int = 1
    max_grad_norm: float = 1.0

    # Validation / checkpointing
    validate_every_epochs: int = 1
    val_num_samples: Optional[int] = None   # set an int for faster validation
    pck_thresholds: List[float] = None
    best_pck_threshold: float = 0.10
    resume_from: Optional[str] = None       # path to last.pt

    # Logging
    log_interval: int = 50

    # Paths
    checkpoint_dir: str = FINETUNED_DIR
    experiment_name: str = None

    def __post_init__(self):
        if self.experiment_name is None:
            self.experiment_name = f"{self.backbone_name}_ft{self.num_layers_to_finetune}"
        if self.pck_thresholds is None:
            self.pck_thresholds = [0.05, 0.10, 0.15, 0.20]
        if self.use_scheduler is None:
            self.use_scheduler = ('dinov2' in self.backbone_name)


# Create configs for different experiments
configs = {
    # DINOv2: finetune last N blocks + scheduler
    'dinov2_2layers': FinetuneConfig(
        backbone_name='dinov2_vitb14',
        num_layers_to_finetune=2,
        num_epochs=10,
        learning_rate=1e-5,
        weight_decay=1e-2,
        use_scheduler=True,
        step_lr_step_size=5,
        step_lr_gamma=0.5,
    ),
    'dinov2_4layers': FinetuneConfig(
        backbone_name='dinov2_vitb14',
        num_layers_to_finetune=4,
        num_epochs=10,
        learning_rate=5e-6,
        weight_decay=1e-2,
        use_scheduler=True,
        step_lr_step_size=5,
        step_lr_gamma=0.5,
    ),

    # DINOv3: finetune last N blocks, NO scheduler (keep LR very low)
    'dinov3_2layers': FinetuneConfig(
        backbone_name='dinov3_vitb16',
        num_layers_to_finetune=2,
        num_epochs=10,
        learning_rate=5e-6,
        weight_decay=5e-5,
        use_scheduler=False,
    ),

    # SAM: image encoder frozen (num_layers_to_finetune ignored)
    'sam_frozen': FinetuneConfig(
        backbone_name='sam_vit_b',
        num_layers_to_finetune=0,
        num_epochs=10,
        learning_rate=1e-4,
        weight_decay=1e-4,
        use_scheduler=False,
    ),
}

# Select config for this run
config = configs['dinov2_2layers']

print(f"   Fine-tuning Configuration:")
print(f"   Backbone: {config.backbone_name}")
print(f"   Layers to finetune: {config.num_layers_to_finetune} (SAM ignores this)")
print(f"   Epochs: {config.num_epochs}")
print(f"   Learning rate: {config.learning_rate}")
print(f"   Weight decay: {config.weight_decay}")
print(f"   Scheduler enabled: {config.use_scheduler}")
print(f"   Batch size: {config.batch_size}")
print(f"   Validate every: {config.validate_every_epochs} epoch(s)")
print(f"   Config ready!")

In [None]:
# ============================================================================
# CELLA 5: Loss Functions for Correspondence
# ============================================================================

import torch
import torch.nn as nn
import torch.nn.functional as F

class CorrespondenceLoss(nn.Module):
    """Loss function for semantic correspondence fine-tuning."""

    def __init__(
        self,
        loss_type: str = 'cosine',
        negative_margin: float = 0.2,
        temperature: float = 0.1
    ):
        super().__init__()
        self.loss_type = loss_type
        self.negative_margin = negative_margin
        self.temperature = temperature

    def forward(
        self,
        src_features: torch.Tensor,
        tgt_features: torch.Tensor,
        src_kps: torch.Tensor,
        tgt_kps: torch.Tensor,
        valid_mask: torch.Tensor,
        patch_size: int
    ) -> torch.Tensor:
        """
        Compute correspondence loss.

        Args:
            src_features: (B, H_s, W_s, D)
            tgt_features: (B, H_t, W_t, D)
            src_kps: (B, N, 2) source keypoints
            tgt_kps: (B, N, 2) target keypoints
            valid_mask: (B, N) valid keypoint mask
            patch_size: int

        Returns:
            loss: scalar tensor
        """
        B = src_features.shape[0]
        total_loss = 0.0
        n_valid = 0

        for b in range(B):
            src_feat = src_features[b]  # (H_s, W_s, D)
            tgt_feat = tgt_features[b]  # (H_t, W_t, D)
            src_kp = src_kps[b]  # (N, 2)
            tgt_kp = tgt_kps[b]  # (N, 2)
            valid = valid_mask[b]  # (N,)

            if valid.sum() == 0:
                continue

            # Filter valid keypoints
            src_kp_valid = src_kp[valid]
            tgt_kp_valid = tgt_kp[valid]

            # Convert to patch coordinates
            src_kp_patch = (src_kp_valid / patch_size).long()
            tgt_kp_patch = (tgt_kp_valid / patch_size).long()

            H_s, W_s, D = src_feat.shape
            H_t, W_t, _ = tgt_feat.shape

            # Clamp coordinates
            src_kp_patch[:, 0] = src_kp_patch[:, 0].clamp(0, W_s - 1)
            src_kp_patch[:, 1] = src_kp_patch[:, 1].clamp(0, H_s - 1)
            tgt_kp_patch[:, 0] = tgt_kp_patch[:, 0].clamp(0, W_t - 1)
            tgt_kp_patch[:, 1] = tgt_kp_patch[:, 1].clamp(0, H_t - 1)

            # Extract features at keypoint locations
            N = src_kp_valid.shape[0]
            for i in range(N):
                src_x, src_y = src_kp_patch[i]
                tgt_x, tgt_y = tgt_kp_patch[i]

                src_vec = src_feat[src_y, src_x]  # (D,)
                tgt_vec = tgt_feat[tgt_y, tgt_x]  # (D,)

                if self.loss_type == 'cosine':
                    # Maximize cosine similarity between corresponding points
                    similarity = F.cosine_similarity(
                        src_vec.unsqueeze(0), tgt_vec.unsqueeze(0), dim=1
                    )
                    loss = 1.0 - similarity

                elif self.loss_type == 'l2':
                    # Minimize L2 distance
                    loss = F.mse_loss(src_vec, tgt_vec)

                elif self.loss_type == 'contrastive':
                    # Contrastive loss: pull positives, push negatives
                    # Positive: corresponding point
                    pos_sim = F.cosine_similarity(
                        src_vec.unsqueeze(0), tgt_vec.unsqueeze(0), dim=1
                    )

                    # Negatives: sample random points from target
                    neg_indices = torch.randint(0, H_t * W_t, (8,), device=src_vec.device)
                    tgt_flat = tgt_feat.reshape(-1, D)
                    neg_vecs = tgt_flat[neg_indices]  # (8, D)

                    neg_sim = F.cosine_similarity(
                        src_vec.unsqueeze(0).expand(8, -1),
                        neg_vecs,
                        dim=1
                    )

                    # InfoNCE-style loss
                    pos_exp = torch.exp(pos_sim / self.temperature)
                    neg_exp = torch.exp(neg_sim / self.temperature).sum()

                    loss = -torch.log(pos_exp / (pos_exp + neg_exp))

                else:
                    raise ValueError(f"Unknown loss type: {self.loss_type}")

                total_loss += loss
                n_valid += 1

        if n_valid == 0:
            return torch.tensor(0.0, device=src_features.device, requires_grad=True)

        return total_loss / n_valid


# Test loss function
print(" Testing loss function...")

loss_fn = CorrespondenceLoss(loss_type='cosine')

# Dummy data
src_feat = torch.randn(2, 37, 37, 768)
tgt_feat = torch.randn(2, 37, 37, 768)
src_kps = torch.randint(0, 500, (2, 10, 2)).float()
tgt_kps = torch.randint(0, 500, (2, 10, 2)).float()
valid_mask = torch.ones(2, 10).bool()

loss = loss_fn(src_feat, tgt_feat, src_kps, tgt_kps, valid_mask, patch_size=14)
print(f" Loss computed: {loss.item():.4f}")

In [None]:
# ============================================================================
# CELLA 6: Trainable Backbone Wrapper
# ============================================================================

import torch
import torch.nn as nn
from models.backbones import DINOv2Extractor, DINOv3Extractor, SAMImageEncoder


class ProjectionHead(nn.Module):
    # Simple, trainable projection head for patch features.
    # - If out_dim is None -> identity-like (keeps same dimensionality)
    # - Otherwise projects to out_dim and L2-normalizes (good for cosine similarity)
    def __init__(self, in_dim: int, out_dim: int | None = None):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim if out_dim is not None else in_dim
        self.proj = nn.Linear(self.in_dim, self.out_dim)
        self.norm = nn.LayerNorm(self.out_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, H, W, D)
        y = self.proj(x)
        y = self.norm(y)
        return torch.nn.functional.normalize(y, dim=-1)

class FinetunableBackbone(nn.Module):
    """Wrapper to make backbone partially trainable.

    NOTE: We keep feature extraction from the LAST layer (as in the baseline).
    """

    def __init__(
        self,
        backbone_name: str,
        num_layers_to_finetune: int = 2,
        proj_out_dim: int | None = None,
        use_projection_head: bool = True,
        device: str = 'cuda'
    ):
        super().__init__()
        self.backbone_name = backbone_name
        self.num_layers_to_finetune = num_layers_to_finetune
        self.device = device
        self.proj_out_dim = proj_out_dim
        self.use_projection_head = use_projection_head

        # Load backbone extractor
        if 'dinov2' in backbone_name:
            self.extractor = DINOv2Extractor(
                variant=backbone_name,
                device=device,
                allow_hub_download=True
            )
        elif 'dinov3' in backbone_name:
            self.extractor = DINOv3Extractor(
                variant=backbone_name,
                device=device
            )
        elif 'sam' in backbone_name:
            # backbone_name expected: 'sam_vit_b' / 'sam_vit_l' / 'sam_vit_h'
            sam_variant = backbone_name.replace('sam_', '')
            self.extractor = SAMImageEncoder(
                variant=sam_variant,
                checkpoint_path=None,
                device=device,
                allow_hub_download=True
            )
        else:
            raise ValueError(f"Unsupported backbone: {backbone_name}")

        self.stride = self.extractor.stride

        # Projection head (initialized lazily on first forward, because feature dim depends on backbone variant)
        self.proj_head = None

        # Freeze everything first
        for p in self.extractor.parameters():
            p.requires_grad = False

        # Unfreeze last N blocks ONLY for DINO backbones
        if ('dinov2' in backbone_name) or ('dinov3' in backbone_name):
            self._unfreeze_last_layers(num_layers_to_finetune)
        else:
            print("  SAM selected: image encoder kept frozen (no backbone fine-tuning).")

        # Count trainable parameters
        n_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        n_total = sum(p.numel() for p in self.parameters())

        print(f"   Trainable parameters:")
        print(f"   Total: {n_total:,}")
        print(f"   Trainable: {n_trainable:,} ({(n_trainable/n_total*100 if n_total>0 else 0):.2f}%)")

    def _unfreeze_last_layers(self, num_layers: int):
        """Unfreeze last N transformer blocks."""
        model = self.extractor.model

        # Access transformer blocks
        if hasattr(model, 'blocks'):
            blocks = model.blocks
        elif hasattr(model, 'encoder') and hasattr(model.encoder, 'layers'):
            blocks = model.encoder.layers
        else:
            raise AttributeError("Cannot find transformer blocks in model")

        total_blocks = len(blocks)
        start_idx = max(0, total_blocks - num_layers)

        print(f"  Unfreezing blocks {start_idx} to {total_blocks-1} (total: {total_blocks})")

        for i in range(start_idx, total_blocks):
            for p in blocks[i].parameters():
                p.requires_grad = True

    def extract_features(self, image: torch.Tensor) -> torch.Tensor:
        """Extract features (with gradients if training)."""
        feat_map, stride = self.extractor.extract_feats(image)
        # (B, C, H, W) -> (B, H, W, C)
        features = feat_map.permute(0, 2, 3, 1)

        if self.use_projection_head:
            if self.proj_head is None:
                in_dim = features.shape[-1]
                self.proj_head = ProjectionHead(in_dim, self.proj_out_dim).to(features.device)
            features = self.proj_head(features)

        return features

    def forward(self, image: torch.Tensor) -> torch.Tensor:
        return self.extract_features(image)


print(" Backbone wrapper ready!")


In [None]:
# ============================================================================
# CELLA 7: Training Loop (with PCK validation, checkpoints, TensorBoard)
# ============================================================================

import torch
import torch.optim as optim
from torch.optim.lr_scheduler import LinearLR, SequentialLR, StepLR
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import json
from pathlib import Path
import numpy as np

from dataset.spair import compute_pck

class Trainer:
    """Fine-tuning trainer."""

    def __init__(
        self,
        model: FinetunableBackbone,
        train_loader: DataLoader,
        val_loader: DataLoader,
        config: FinetuneConfig,
        device: str = 'cuda'
    ):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.config = config
        self.device = device

        # Loss function
        self.criterion = CorrespondenceLoss(
            loss_type=config.loss_type,
            negative_margin=config.negative_margin
        ).to(device)

        # Optimizer (AdamW)
        # Use param groups so the projection head can learn faster than the backbone.
        head_lr = self.config.learning_rate * getattr(self.config, 'head_lr_mult', 10.0)

        head_params = []
        backbone_params = []

        # Projection head (if present)
        if hasattr(self.model, 'proj_head') and (self.model.proj_head is not None):
            head_params = [p for p in self.model.proj_head.parameters() if p.requires_grad]

        # Backbone trainable params (unfrozen blocks)
        # Note: extractor params include both backbone and (possibly) other modules, but requires_grad filters correctly.
        backbone_params = [p for p in self.model.extractor.parameters() if p.requires_grad]

        # If proj_head is lazily created, we still set up optimizer for backbone now and will add head params later on first forward.
        param_groups = []
        if backbone_params:
            param_groups.append({'params': backbone_params, 'lr': self.config.learning_rate})
        if head_params:
            param_groups.append({'params': head_params, 'lr': head_lr})

        self.optimizer = optim.AdamW(
            param_groups if param_groups else [p for p in model.parameters() if p.requires_grad],
            weight_decay=config.weight_decay
        )

        # Scheduler
        self.scheduler = None
        if config.use_scheduler:
            # Warmup (linear) + StepLR
            warmup_steps = max(1, config.warmup_epochs * len(train_loader))
            warmup_scheduler = LinearLR(
                self.optimizer,
                start_factor=0.1,
                total_iters=warmup_steps
            )
            step_scheduler = StepLR(
                self.optimizer,
                step_size=config.step_lr_step_size * len(train_loader),
                gamma=config.step_lr_gamma
            )
            # Apply warmup first, then StepLR
            self.scheduler = SequentialLR(
                self.optimizer,
                schedulers=[warmup_scheduler, step_scheduler],
                milestones=[warmup_steps]
            )

        # State
        self.global_step = 0
        self.best_pck = -1.0

        # Directories
        self.ckpt_dir = Path(config.checkpoint_dir) / config.experiment_name
        self.ckpt_dir.mkdir(parents=True, exist_ok=True)
        self.tb_dir = self.ckpt_dir / 'tb'
        self.tb_dir.mkdir(parents=True, exist_ok=True)

        # TensorBoard writer
        self.writer = SummaryWriter(log_dir=str(self.tb_dir))

        print(f"  Checkpoints: {self.ckpt_dir}")
        print(f"  TensorBoard: {self.tb_dir}")

        # Optionally resume
        if config.resume_from is not None:
            self._resume_if_possible(config.resume_from)

    def _resume_if_possible(self, ckpt_path: str):
        ckpt_path = Path(ckpt_path)
        if not ckpt_path.exists():
            print(f"  Resume requested, but checkpoint not found: {ckpt_path}")
            return
        ckpt = torch.load(ckpt_path, map_location='cpu')
        self.model.load_state_dict(ckpt['model_state_dict'])
        self.optimizer.load_state_dict(ckpt['optimizer_state_dict'])
        if self.scheduler is not None and ckpt.get('scheduler_state_dict') is not None:
            self.scheduler.load_state_dict(ckpt['scheduler_state_dict'])
        self.global_step = ckpt.get('global_step', 0)
        self.best_pck = ckpt.get('best_pck', -1.0)
        start_epoch = ckpt.get('epoch', -1) + 1
        print(f"  Resumed from {ckpt_path} at epoch {start_epoch}, step {self.global_step}, best_pck {self.best_pck:.4f}")

    def _maybe_add_head_params(self):
        # If projection head is created lazily during the first forward, add its params to optimizer once.
        if not hasattr(self.model, 'proj_head') or self.model.proj_head is None:
            return
        # Check if already present
        for pg in self.optimizer.param_groups:
            if any(p is list(self.model.proj_head.parameters())[0] for p in pg['params'][:1]):
                return
        head_lr = self.config.learning_rate * getattr(self.config, 'head_lr_mult', 10.0)
        self.optimizer.add_param_group({'params': [p for p in self.model.proj_head.parameters() if p.requires_grad], 'lr': head_lr})

    def _log_lrs(self):
        for i, pg in enumerate(self.optimizer.param_groups):
            self.writer.add_scalar(f'lr/group_{i}', pg['lr'], self.global_step)

    def train_epoch(self, epoch: int):
        self.model.train()
        epoch_loss = 0.0
        n_batches = 0

        pbar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{self.config.num_epochs}")

        for batch_idx, batch in enumerate(pbar):
            src_img = batch['src_img'].to(self.device)
            tgt_img = batch['tgt_img'].to(self.device)
            src_kps = batch['src_kps'].to(self.device)
            tgt_kps = batch['tgt_kps'].to(self.device)
            valid_mask = batch['valid_mask'].to(self.device)

            src_features = self.model(src_img)
            self._maybe_add_head_params()
            tgt_features = self.model(tgt_img)

            loss = self.criterion(
                src_features, tgt_features,
                src_kps, tgt_kps, valid_mask,
                patch_size=self.model.stride
            )

            loss = loss / self.config.gradient_accumulation_steps
            loss.backward()

            if (batch_idx + 1) % self.config.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(),
                    self.config.max_grad_norm
                )
                self.optimizer.step()
                if self.scheduler is not None:
                    self.scheduler.step()
                self.optimizer.zero_grad(set_to_none=True)

            # Stats
            epoch_loss += loss.item() * self.config.gradient_accumulation_steps
            n_batches += 1
            self.global_step += 1

            if self.global_step % self.config.log_interval == 0:
                avg_loss = epoch_loss / n_batches
                lr = self.optimizer.param_groups[0]['lr']
                pbar.set_postfix({'loss': f'{avg_loss:.4f}', 'lr': f'{lr:.2e}'})

                # TensorBoard
                self.writer.add_scalar('train/loss_step', avg_loss, self.global_step)
                self._log_lrs()

        avg_epoch_loss = epoch_loss / max(1, n_batches)
        self.writer.add_scalar('train/loss_epoch', avg_epoch_loss, epoch)
        return avg_epoch_loss

    @torch.no_grad()
    def validate_pck(self, epoch: int):
        """Compute PCK@T on validation split (cosine similarity matching)."""
        self.model.eval()

        all_scores = {f'PCK@{t:.2f}': [] for t in self.config.pck_thresholds}

        n_processed = 0
        pbar = tqdm(self.val_loader, desc='Validation (PCK)', leave=False)

        for batch in pbar:
            if self.config.val_num_samples is not None and n_processed >= self.config.val_num_samples:
                break

            src_img = batch['src_img'].to(self.device)
            tgt_img = batch['tgt_img'].to(self.device)

            # (K,2)
            src_kps = batch['src_kps'][0]
            tgt_kps = batch['tgt_kps'][0]
            valid_mask = batch['valid_mask'][0]

            src_kps_valid = src_kps[valid_mask]
            tgt_kps_valid = tgt_kps[valid_mask]
            if len(src_kps_valid) == 0:
                continue

            # Predict target kps
            tgt_kps_pred = self._predict_keypoints(src_img, tgt_img, src_kps_valid)

            H, W = tgt_img.shape[2:]
            pck = compute_pck(tgt_kps_pred, tgt_kps_valid, (H, W), thresholds=self.config.pck_thresholds)

            for k, v in pck.items():
                all_scores[k].append(float(v))

            n_processed += 1
            if len(all_scores[f'PCK@{self.config.best_pck_threshold:.2f}']) > 0:
                avg = float(np.mean(all_scores[f'PCK@{self.config.best_pck_threshold:.2f}']))
                pbar.set_postfix({f'PCK@{self.config.best_pck_threshold:.2f}': f'{avg:.4f}'})

        results = {k: (float(np.mean(v)) if len(v) else 0.0) for k, v in all_scores.items()}

        # Log to TensorBoard
        for k, v in results.items():
            self.writer.add_scalar(f'val/{k}', v, epoch)

        return results

    @torch.no_grad()
    def _predict_keypoints(self, src_img: torch.Tensor, tgt_img: torch.Tensor, src_kps: torch.Tensor) -> torch.Tensor:
        """Predict target keypoints by cosine similarity argmax over target patches."""
        import torch.nn.functional as F

        src_feat = self.model(src_img)[0]  # (Hs, Ws, D)
        tgt_feat = self.model(tgt_img)[0]  # (Ht, Wt, D)

        H_s, W_s, D = src_feat.shape
        H_t, W_t, _ = tgt_feat.shape
        patch_size = self.model.stride

        src_kps_patch = (src_kps / patch_size).long()
        src_kps_patch[:, 0] = src_kps_patch[:, 0].clamp(0, W_s - 1)
        src_kps_patch[:, 1] = src_kps_patch[:, 1].clamp(0, H_s - 1)

        N = src_kps.shape[0]
        tgt_kps_pred = torch.zeros(N, 2, device=src_kps.device)

        for i in range(N):
            x, y = src_kps_patch[i]
            src_vec = src_feat[y, x]

            sim = F.cosine_similarity(
                src_vec.view(1, 1, 1, D),
                tgt_feat.unsqueeze(0),
                dim=-1
            ).squeeze(0)  # (Ht, Wt)

            max_idx = sim.flatten().argmax()
            pred_y = max_idx // W_t
            pred_x = max_idx % W_t

            tgt_kps_pred[i, 0] = pred_x * patch_size + patch_size // 2
            tgt_kps_pred[i, 1] = pred_y * patch_size + patch_size // 2

        return tgt_kps_pred

    def _save_checkpoint(self, path: Path, epoch: int, extra: dict | None = None):
        payload = {
            'epoch': epoch,
            'global_step': self.global_step,
            'best_pck': self.best_pck,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': (self.scheduler.state_dict() if self.scheduler is not None else None),
            'config': vars(self.config),
        }
        if extra:
            payload.update(extra)
        torch.save(payload, path)

    def train(self):
        print(f"   Starting training: {self.config.experiment_name}")
        print(f"   Epochs: {self.config.num_epochs}")
        print(f"   Batches per epoch: {len(self.train_loader)}")

        for epoch in range(self.config.num_epochs):
            train_loss = self.train_epoch(epoch)
            print(f"  Epoch {epoch+1} - Train Loss: {train_loss:.4f}")

            # Always save LAST checkpoint every epoch (safe resume)
            last_path = self.ckpt_dir / 'last.pt'
            self._save_checkpoint(last_path, epoch)

            # Validation every N epochs (PCK)
            if (epoch + 1) % self.config.validate_every_epochs == 0:
                val_pck = self.validate_pck(epoch)
                key = f"PCK@{self.config.best_pck_threshold:.2f}"
                current = float(val_pck.get(key, 0.0))
                print(f"  Validation {key}: {current:.4f}")

                # Save best
                if current > self.best_pck:
                    self.best_pck = current
                    best_path = self.ckpt_dir / 'best.pt'
                    self._save_checkpoint(best_path, epoch, extra={'val_pck': val_pck})
                    print(f"  New best! Saved: {best_path} (best_pck={self.best_pck:.4f})")

        self.writer.close()
        print("  Training complete!")


print(" Trainer ready!")


In [None]:
# ============================================================================
# CELLA 8: Launch Fine-tuning
# ============================================================================

# Initialize model
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = FinetunableBackbone(
    backbone_name=config.backbone_name,
    num_layers_to_finetune=config.num_layers_to_finetune,
    proj_out_dim=getattr(config, "proj_out_dim", None),
    use_projection_head=getattr(config, "use_projection_head", True),
    device=device
)

# Initialize trainer
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    config=config,
    device=device
)

print("To monitor TensorBoard in Colab, run:")
print(f"  %load_ext tensorboard %tensorboard --logdir {trainer.tb_dir}")

# Start training
trainer.train()


In [None]:
# ============================================================================
# CELLA 9: Evaluate Fine-tuned Model
# ============================================================================

from collections import defaultdict
import time
import numpy as np

class FinetunedEvaluator:
    """Evaluator for fine-tuned models."""

    def __init__(
        self,
        model: FinetunableBackbone,
        dataloader: DataLoader,
        device: str = 'cuda',
        thresholds: list = [0.05, 0.10, 0.15, 0.20]
    ):
        self.model = model
        self.dataloader = dataloader
        self.device = device
        self.thresholds = thresholds

    @torch.no_grad()
    def evaluate(self, num_samples: int = None):
        """Evaluate fine-tuned model."""
        self.model.eval()

        all_pck = defaultdict(list)
        per_category = defaultdict(lambda: defaultdict(list))
        inference_times = []

        n_processed = 0
        pbar = tqdm(self.dataloader, desc="Evaluating")

        for batch in pbar:
            if num_samples and n_processed >= num_samples:
                break

            src_img = batch['src_img'].to(self.device)
            tgt_img = batch['tgt_img'].to(self.device)
            src_kps = batch['src_kps'][0]
            tgt_kps = batch['tgt_kps'][0]
            valid_mask = batch['valid_mask'][0]
            category = batch['category'][0]

            src_kps_valid = src_kps[valid_mask]
            tgt_kps_valid = tgt_kps[valid_mask]

            if len(src_kps_valid) == 0:
                continue

            # Predict
            start = time.time()
            tgt_kps_pred = self._predict_keypoints(
                src_img, tgt_img, src_kps_valid
            )
            inference_times.append(time.time() - start)

            # Compute metrics
            from dataset.spair import compute_pck
            H, W = tgt_img.shape[2:]
            pck_scores = compute_pck(
                tgt_kps_pred, tgt_kps_valid, (H, W), thresholds=self.thresholds
            )

            for metric, value in pck_scores.items():
                all_pck[metric].append(value)
                per_category[category][metric].append(value)

            n_processed += 1

            if len(all_pck['PCK@0.10']) > 0:
                avg_pck = np.mean(all_pck['PCK@0.10'])
                pbar.set_postfix({'PCK@0.10': f'{avg_pck:.4f}'})

        # Aggregate results
        results = {
            'overall': {},
            'per_category': {},
            'num_pairs': n_processed,
            'inference_time_ms': np.mean(inference_times) * 1000
        }

        for metric in [f'PCK@{t:.2f}' for t in self.thresholds]:
            values = all_pck[metric]
            results['overall'][metric] = {
                'mean': np.mean(values),
                'std': np.std(values)
            }

        for cat, metrics in per_category.items():
            results['per_category'][cat] = {}
            for metric in [f'PCK@{t:.2f}' for t in self.thresholds]:
                results['per_category'][cat][metric] = np.mean(metrics[metric])

        return results

    def _predict_keypoints(
        self,
        src_img: torch.Tensor,
        tgt_img: torch.Tensor,
        src_kps: torch.Tensor
    ) -> torch.Tensor:
        """Predict target keypoints."""
        src_feat = self.model(src_img)[0]
        tgt_feat = self.model(tgt_img)[0]

        H_s, W_s, D = src_feat.shape
        H_t, W_t, _ = tgt_feat.shape

        patch_size = self.model.stride

        src_kps_patch = (src_kps / patch_size).long()
        src_kps_patch[:, 0] = src_kps_patch[:, 0].clamp(0, W_s - 1)
        src_kps_patch[:, 1] = src_kps_patch[:, 1].clamp(0, H_s - 1)

        N = src_kps.shape[0]
        tgt_kps_pred = torch.zeros(N, 2, device=src_kps.device)

        for i in range(N):
            x, y = src_kps_patch[i]
            src_vec = src_feat[y, x]

            similarity = F.cosine_similarity(
                src_vec.view(1, 1, 1, D),
                tgt_feat.unsqueeze(0),
                dim=-1
            ).squeeze(0)

            max_idx = similarity.flatten().argmax()
            pred_y = max_idx // W_t
            pred_x = max_idx % W_t

            tgt_kps_pred[i, 0] = pred_x * patch_size + patch_size // 2
            tgt_kps_pred[i, 1] = pred_y * patch_size + patch_size // 2

        return tgt_kps_pred


# Evaluate
evaluator = FinetunedEvaluator(
    model=model,
    dataloader=test_loader,
    device=device
)

results = evaluator.evaluate()

# Print results
print("\n" + "="*70)
print("FINE-TUNED MODEL RESULTS")
print("="*70)
for metric, vals in results['overall'].items():
    print(f"   {metric}: {vals['mean']:.4f} ± {vals['std']:.4f}")
print(f"\n⏱  Inference: {results['inference_time_ms']:.2f} ms/pair")
print(f" Evaluated: {results['num_pairs']} pairs")
print("="*70)

# Save results
results_path = f"{RESULTS_DIR}/{config.experiment_name}_results.json"
with open(results_path, 'w') as f:
    json.dump(results, f, indent=2)
print(f"\n Results saved: {results_path}")

In [None]:
# ============================================================================
# CELLA 10: Compare Baseline vs Fine-tuned
# ============================================================================

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

def plot_comparison(baseline_results, finetuned_results):
    """Plot baseline vs fine-tuned comparison."""

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    thresholds = [0.05, 0.10, 0.15, 0.20]

    # Extract PCK values
    baseline_pck = [baseline_results['overall'][f'PCK@{t:.2f}']['mean']
                    for t in thresholds]
    finetuned_pck = [finetuned_results['overall'][f'PCK@{t:.2f}']['mean']
                     for t in thresholds]

    # Plot 1: PCK Curves
    ax1 = axes[0]
    ax1.plot(thresholds, baseline_pck, marker='o', linewidth=2,
             label='Baseline (Frozen)', markersize=8, color='#E74C3C')
    ax1.plot(thresholds, finetuned_pck, marker='s', linewidth=2,
             label=f'Fine-tuned ({config.num_layers_to_finetune} layers)',
             markersize=8, color='#27AE60')

    ax1.set_xlabel('Threshold', fontsize=12, fontweight='bold')
    ax1.set_ylabel('PCK', fontsize=12, fontweight='bold')
    ax1.set_title('PCK Comparison', fontsize=14, fontweight='bold')
    ax1.legend(loc='lower right')
    ax1.grid(True, alpha=0.3)
    ax1.set_ylim([0, 1])

    # Plot 2: Improvement Bar Chart
    ax2 = axes[1]
    improvements = [(ft - bl) * 100 for bl, ft in zip(baseline_pck, finetuned_pck)]
    colors = ['#27AE60' if imp > 0 else '#E74C3C' for imp in improvements]

    bars = ax2.bar(range(len(thresholds)), improvements, color=colors, alpha=0.8)
    ax2.set_xticks(range(len(thresholds)))
    ax2.set_xticklabels([f'PCK@{t:.2f}' for t in thresholds])
    ax2.set_ylabel('Improvement (%)', fontsize=12, fontweight='bold')
    ax2.set_title('Fine-tuning Improvement', fontsize=14, fontweight='bold')
    ax2.grid(axis='y', alpha=0.3)
    ax2.axhline(y=0, color='black', linestyle='-', linewidth=0.8)

    for bar, val in zip(bars, improvements):
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2, height,
                f'{val:+.2f}%', ha='center',
                va='bottom' if height > 0 else 'top',
                fontsize=10, fontweight='bold')

    plt.tight_layout()

    save_path = f'{RESULTS_DIR}/{config.experiment_name}_comparison.png'
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"\n Saved: {save_path}")

    plt.show()


# Load baseline results
baseline_path = f'{RESULTS_DIR}/dinov2_vitb14_results.json'
with open(baseline_path, 'r') as f:
    baseline_results = json.load(f)

# Plot comparison
plot_comparison(baseline_results, results)

# Print summary table
print("\n" + "="*70)
print("SUMMARY COMPARISON")
print("="*70)

comparison_data = []
for t in [0.05, 0.10, 0.15, 0.20]:
    metric = f'PCK@{t:.2f}'
    bl = baseline_results['overall'][metric]['mean']
    ft = results['overall'][metric]['mean']
    imp = (ft - bl) * 100

    comparison_data.append({
        'Metric': metric,
        'Baseline': f'{bl:.4f}',
        'Fine-tuned': f'{ft:.4f}',
        'Improvement': f'{imp:+.2f}%'
    })

df = pd.DataFrame(comparison_data)
print(df.to_string(index=False))
print("="*70)