In [None]:
# Monta Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Percorsi
zip_su_drive = '/content/drive/MyDrive/semantic_correspondence.zip'
zip_locale = '/content/semantic_correspondence.zip'
cartella_destinazione = '/content/'

# Copia lo zip in locale
import shutil
shutil.copy(zip_su_drive, zip_locale)

'/content/semantic_correspondence.zip'

In [None]:
# Smonta il Drive
drive.flush_and_unmount()

In [None]:
# Estrai lo zip
import zipfile, os
os.makedirs(cartella_destinazione, exist_ok=True)
with zipfile.ZipFile(zip_locale, 'r') as z:
    z.extractall(cartella_destinazione)


In [None]:
# 5. Verify GPU
import torch
print(f"\n✓ GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'No GPU'}")



✓ GPU: Tesla T4


In [None]:
!nvidia-smi

Sun Dec 21 18:27:31 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   54C    P8             12W /   70W |       2MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
import os
os.chdir('/content/semantic_correspondence')

In [None]:
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import numpy as np
import torch
import glob
import json
import os


class Normalize(object):
    def __init__(self, image_keys):
        self.image_keys = image_keys
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    def __call__(self, image):
        for key in self.image_keys:
            image[key] /= 255.0
            image[key] = self.normalize(image[key])
        return image


def read_img(path):
    img = np.array(Image.open(path).convert('RGB'))

    return torch.tensor(img.transpose(2, 0, 1).astype(np.float32))


class SPairDataset(Dataset):
    def __init__(self, pair_ann_path, layout_path, image_path, dataset_size, pck_alpha, datatype):

        self.datatype = datatype
        self.pck_alpha = pck_alpha
        self.ann_files = open(os.path.join(layout_path, dataset_size, datatype + '.txt'), "r").read().split('\n')
        self.ann_files = self.ann_files[:len(self.ann_files) - 1]
        self.pair_ann_path = pair_ann_path
        self.image_path = image_path
        self.categories = list(map(lambda x: os.path.basename(x), glob.glob('%s/*' % image_path)))
        self.categories.sort()
        self.transform = Normalize(['src_img', 'trg_img'])

    def __len__(self):
        return len(self.ann_files)

    def __getitem__(self, idx):
        # get pre-processed images
        ann_file = self.ann_files[idx] + '.json'
        with open(os.path.join(self.pair_ann_path, self.datatype, ann_file)) as f:
            annotation = json.load(f)

        category = annotation['category']
        src_img = read_img(os.path.join(self.image_path, category, annotation['src_imname']))
        trg_img = read_img(os.path.join(self.image_path, category, annotation['trg_imname']))

        trg_bbox = annotation['trg_bndbox']
        pck_threshold = max(trg_bbox[2] - trg_bbox[0],  trg_bbox[3] - trg_bbox[1]) * self.pck_alpha

        sample = {'pair_id': annotation['pair_id'],
                  'filename': annotation['filename'],
                  'src_imname': annotation['src_imname'],
                  'trg_imname': annotation['trg_imname'],
                  'src_imsize': src_img.size(),
                  'trg_imsize': trg_img.size(),

                  'src_bbox': annotation['src_bndbox'],
                  'trg_bbox': annotation['trg_bndbox'],
                  'category': annotation['category'],

                  'src_pose': annotation['src_pose'],
                  'trg_pose': annotation['trg_pose'],

                  'src_img': src_img,
                  'trg_img': trg_img,
                  'src_kps': torch.tensor(annotation['src_kps']).float(),
                  'trg_kps': torch.tensor(annotation['trg_kps']).float(),
                  'kps_ids': annotation['kps_ids'],

                  'mirror': annotation['mirror'],
                  'vp_var': annotation['viewpoint_variation'],
                  'sc_var': annotation['scale_variation'],
                  'truncn': annotation['truncation'],
                  'occlsn': annotation['occlusion'],

                  'pck_threshold': pck_threshold}

        if self.transform:
            sample = self.transform(sample)

        return sample


In [None]:
import torch


def extract_dense_features(model, img_tensor, training=False):
    """Extract dense features from DINOv2 model given an input image tensor."""
    context = torch.no_grad() if not training else torch.enable_grad()

    with context:
        #get tokens
        features_dict = model.forward_features(img_tensor)
        patch_tokens = features_dict['x_norm_patchtokens']  # [B, N_patches, D]

        #reshaping to dense feature map
        B, N, D = patch_tokens.shape
        H_patches = W_patches = int(N ** 0.5)  # per img 518x518 con patch 14: 37x37
        dense_features = patch_tokens.reshape(B, H_patches, W_patches, D)
    return dense_features


def pixel_to_patch_coord(x, y, original_size, patch_size=14, resized_size=518):
    """convert pixel coordinates to patch coordinates"""
    #scale to resized image
    scale_x = resized_size / original_size[0]
    scale_y = resized_size / original_size[1]
    x_resized = x * scale_x
    y_resized = y * scale_y

    #compute patch coordinates
    patch_x = int(x_resized // patch_size)
    patch_y = int(y_resized // patch_size)

    #clamp to valid range
    max_patch = resized_size // patch_size - 1
    patch_x = min(max(patch_x, 0), max_patch)
    patch_y = min(max(patch_y, 0), max_patch)

    return patch_x, patch_y


def patch_to_pixel_coord(patch_x, patch_y, original_size, patch_size=14, resized_size=518):
    """Convert patch coordinates back to pixel coordinates with a centering strategy"""
    #center of the patch in resized image
    x_resized = patch_x * patch_size + patch_size / 2
    y_resized = patch_y * patch_size + patch_size / 2

    #scale back to original image size
    scale_x = original_size[0] / resized_size
    scale_y = original_size[1] / resized_size
    x = x_resized * scale_x
    y = y_resized * scale_y

    return x, y

In [None]:
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

# References:
#   https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py

from functools import partial
import math
import logging
from typing import Sequence, Tuple, Union, Callable

import numpy as np
import torch
import torch.nn as nn
import torch.utils.checkpoint
from torch.nn.init import trunc_normal_

from models.dinov2.dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block


logger = logging.getLogger("dinov2")


def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
    if not depth_first and include_root:
        fn(module=module, name=name)
    for child_name, child_module in module.named_children():
        child_name = ".".join((name, child_name)) if name else child_name
        named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
    if depth_first and include_root:
        fn(module=module, name=name)
    return module


class BlockChunk(nn.ModuleList):
    def forward(self, x):
        for b in self:
            x = b(x)
        return x


class DinoVisionTransformer(nn.Module):
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_chans=3,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        qkv_bias=True,
        ffn_bias=True,
        proj_bias=True,
        drop_path_rate=0.0,
        drop_path_uniform=False,
        init_values=None,  # for layerscale: None or 0 => no layerscale
        embed_layer=PatchEmbed,
        act_layer=nn.GELU,
        block_fn=Block,
        ffn_layer="mlp",
        block_chunks=1,
        num_register_tokens=0,
        interpolate_antialias=False,
        interpolate_offset=0.1,
    ):
        """
        Args:
            img_size (int, tuple): input image size
            patch_size (int, tuple): patch size
            in_chans (int): number of input channels
            embed_dim (int): embedding dimension
            depth (int): depth of transformer
            num_heads (int): number of attention heads
            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
            qkv_bias (bool): enable bias for qkv if True
            proj_bias (bool): enable bias for proj in attn if True
            ffn_bias (bool): enable bias for ffn if True
            drop_path_rate (float): stochastic depth rate
            drop_path_uniform (bool): apply uniform drop rate across blocks
            weight_init (str): weight init scheme
            init_values (float): layer-scale init values
            embed_layer (nn.Module): patch embedding layer
            act_layer (nn.Module): MLP activation layer
            block_fn (nn.Module): transformer block class
            ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
            block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
            num_register_tokens: (int) number of extra cls tokens (so-called "registers")
            interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
            interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
        """
        super().__init__()
        norm_layer = partial(nn.LayerNorm, eps=1e-6)

        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        self.num_tokens = 1
        self.n_blocks = depth
        self.num_heads = num_heads
        self.patch_size = patch_size
        self.num_register_tokens = num_register_tokens
        self.interpolate_antialias = interpolate_antialias
        self.interpolate_offset = interpolate_offset

        self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
        assert num_register_tokens >= 0
        self.register_tokens = (
            nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
        )

        if drop_path_uniform is True:
            dpr = [drop_path_rate] * depth
        else:
            dpr = np.linspace(0, drop_path_rate, depth).tolist()  # stochastic depth decay rule

        if ffn_layer == "mlp":
            logger.info("using MLP layer as FFN")
            ffn_layer = Mlp
        elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
            logger.info("using SwiGLU layer as FFN")
            ffn_layer = SwiGLUFFNFused
        elif ffn_layer == "identity":
            logger.info("using Identity layer as FFN")

            def f(*args, **kwargs):
                return nn.Identity()

            ffn_layer = f
        else:
            raise NotImplementedError

        blocks_list = [
            block_fn(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                proj_bias=proj_bias,
                ffn_bias=ffn_bias,
                drop_path=dpr[i],
                norm_layer=norm_layer,
                act_layer=act_layer,
                ffn_layer=ffn_layer,
                init_values=init_values,
            )
            for i in range(depth)
        ]
        if block_chunks > 0:
            self.chunked_blocks = True
            chunked_blocks = []
            chunksize = depth // block_chunks
            for i in range(0, depth, chunksize):
                # this is to keep the block index consistent if we chunk the block list
                chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
            self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
        else:
            self.chunked_blocks = False
            self.blocks = nn.ModuleList(blocks_list)

        self.norm = norm_layer(embed_dim)
        self.head = nn.Identity()

        self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))

        self.init_weights()

    def init_weights(self):
        trunc_normal_(self.pos_embed, std=0.02)
        nn.init.normal_(self.cls_token, std=1e-6)
        if self.register_tokens is not None:
            nn.init.normal_(self.register_tokens, std=1e-6)
        named_apply(init_weights_vit_timm, self)

    def interpolate_pos_encoding(self, x, w, h):
        previous_dtype = x.dtype
        npatch = x.shape[1] - 1
        N = self.pos_embed.shape[1] - 1
        if npatch == N and w == h:
            return self.pos_embed
        pos_embed = self.pos_embed.float()
        class_pos_embed = pos_embed[:, 0]
        patch_pos_embed = pos_embed[:, 1:]
        dim = x.shape[-1]
        w0 = w // self.patch_size
        h0 = h // self.patch_size
        M = int(math.sqrt(N))  # Recover the number of patches in each dimension
        assert N == M * M
        kwargs = {}
        if self.interpolate_offset:
            # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
            # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
            sx = float(w0 + self.interpolate_offset) / M
            sy = float(h0 + self.interpolate_offset) / M
            kwargs["scale_factor"] = (sx, sy)
        else:
            # Simply specify an output size instead of a scale factor
            kwargs["size"] = (w0, h0)
        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
            mode="bicubic",
            antialias=self.interpolate_antialias,
            **kwargs,
        )
        assert (w0, h0) == patch_pos_embed.shape[-2:]
        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)

    def prepare_tokens_with_masks(self, x, masks=None):
        B, nc, w, h = x.shape
        x = self.patch_embed(x)
        if masks is not None:
            x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)

        x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
        x = x + self.interpolate_pos_encoding(x, w, h)

        if self.register_tokens is not None:
            x = torch.cat(
                (
                    x[:, :1],
                    self.register_tokens.expand(x.shape[0], -1, -1),
                    x[:, 1:],
                ),
                dim=1,
            )

        return x

    def forward_features_list(self, x_list, masks_list):
        x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
        for blk in self.blocks:
            x = blk(x)

        all_x = x
        output = []
        for x, masks in zip(all_x, masks_list):
            x_norm = self.norm(x)
            output.append(
                {
                    "x_norm_clstoken": x_norm[:, 0],
                    "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
                    "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
                    "x_prenorm": x,
                    "masks": masks,
                }
            )
        return output

    def forward_features(self, x, masks=None):
        if isinstance(x, list):
            return self.forward_features_list(x, masks)

        x = self.prepare_tokens_with_masks(x, masks)

        for blk in self.blocks:
            x = blk(x)

        x_norm = self.norm(x)
        return {
            "x_norm_clstoken": x_norm[:, 0],
            "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
            "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
            "x_prenorm": x,
            "masks": masks,
        }

    def _get_intermediate_layers_not_chunked(self, x, n=1):
        x = self.prepare_tokens_with_masks(x)
        # If n is an int, take the n last blocks. If it's a list, take them
        output, total_block_len = [], len(self.blocks)
        blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
        for i, blk in enumerate(self.blocks):
            x = blk(x)
            if i in blocks_to_take:
                output.append(x)
        assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
        return output

    def _get_intermediate_layers_chunked(self, x, n=1):
        x = self.prepare_tokens_with_masks(x)
        output, i, total_block_len = [], 0, len(self.blocks[-1])
        # If n is an int, take the n last blocks. If it's a list, take them
        blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
        for block_chunk in self.blocks:
            for blk in block_chunk[i:]:  # Passing the nn.Identity()
                x = blk(x)
                if i in blocks_to_take:
                    output.append(x)
                i += 1
        assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
        return output

    def get_intermediate_layers(
        self,
        x: torch.Tensor,
        n: Union[int, Sequence] = 1,  # Layers or n last layers to take
        reshape: bool = False,
        return_class_token: bool = False,
        norm=True,
    ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
        if self.chunked_blocks:
            outputs = self._get_intermediate_layers_chunked(x, n)
        else:
            outputs = self._get_intermediate_layers_not_chunked(x, n)
        if norm:
            outputs = [self.norm(out) for out in outputs]
        class_tokens = [out[:, 0] for out in outputs]
        outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
        if reshape:
            B, _, w, h = x.shape
            outputs = [
                out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
                for out in outputs
            ]
        if return_class_token:
            return tuple(zip(outputs, class_tokens))
        return tuple(outputs)

    def forward(self, *args, is_training=False, **kwargs):
        ret = self.forward_features(*args, **kwargs)
        if is_training:
            return ret
        else:
            return self.head(ret["x_norm_clstoken"])


def init_weights_vit_timm(module: nn.Module, name: str = ""):
    """ViT weight initialization, original timm impl (for reproducibility)"""
    if isinstance(module, nn.Linear):
        trunc_normal_(module.weight, std=0.02)
        if module.bias is not None:
            nn.init.zeros_(module.bias)


def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
    model = DinoVisionTransformer(
        patch_size=patch_size,
        embed_dim=384,
        depth=12,
        num_heads=6,
        mlp_ratio=4,
        block_fn=partial(Block, attn_class=MemEffAttention),
        num_register_tokens=num_register_tokens,
        **kwargs,
    )
    return model


def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
    model = DinoVisionTransformer(
        patch_size=patch_size,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4,
        block_fn=partial(Block, attn_class=MemEffAttention),
        num_register_tokens=num_register_tokens,
        **kwargs,
    )
    return model


def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
    model = DinoVisionTransformer(
        patch_size=patch_size,
        embed_dim=1024,
        depth=24,
        num_heads=16,
        mlp_ratio=4,
        block_fn=partial(Block, attn_class=MemEffAttention),
        num_register_tokens=num_register_tokens,
        **kwargs,
    )
    return model


def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
    """
    Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
    """
    model = DinoVisionTransformer(
        patch_size=patch_size,
        embed_dim=1536,
        depth=40,
        num_heads=24,
        mlp_ratio=4,
        block_fn=partial(Block, attn_class=MemEffAttention),
        num_register_tokens=num_register_tokens,
        **kwargs,
    )
    return model




In [None]:
"""
Quick test script to verify finetuning pipeline works correctly.
Runs only a few iterations to check for errors before full training.
"""

import torch
import torch.nn.functional as F
import torch.optim as optim
#from SPair71k.devkit.SPairDataset import SPairDataset
#from dinov2 import extract_dense_features, pixel_to_patch_coord
#from models.dinov2.dinov2.models.vision_transformer import vit_base


def freeze_model(model):
    """Freeze all model parameters"""
    for param in model.parameters():
        param.requires_grad = False


def unfreeze_last_n_blocks(model, n_blocks):
    """Unfreeze the last n_blocks transformer blocks + final norm layer"""
    total_blocks = len(model.blocks)
    for i in range(total_blocks - n_blocks, total_blocks):
        for param in model.blocks[i].parameters():
            param.requires_grad = True
    for param in model.norm.parameters():
        param.requires_grad = True
    print(f"✓ Unfrozen last {n_blocks} blocks + norm layer")


def compute_cross_entropy_loss(src_features, tgt_features, src_kps, trg_kps,
                               src_original_size, tgt_original_size, temperature=10.0):
    """Compute cross-entropy loss for semantic correspondence"""
    _, H, W, D = tgt_features.shape
    tgt_flat = tgt_features.reshape(H * W, D)

    losses = []

    for i in range(src_kps.shape[0]):
        src_x, src_y = src_kps[i]
        tgt_x, tgt_y = trg_kps[i]

        src_patch_x, src_patch_y = pixel_to_patch_coord(src_x, src_y, src_original_size)
        src_feature = src_features[0, src_patch_y, src_patch_x, :]

        tgt_patch_x, tgt_patch_y = pixel_to_patch_coord(tgt_x, tgt_y, tgt_original_size)

        similarities = F.cosine_similarity(
            src_feature.unsqueeze(0),
            tgt_flat,
            dim=1
        )

        log_probs = F.log_softmax(similarities * temperature, dim=0)
        gt_idx = tgt_patch_y * W + tgt_patch_x
        loss = -log_probs[gt_idx]
        losses.append(loss)

    return torch.stack(losses).mean()


def quick_test():
    """Quick test with just a few samples"""

    print("="*60)
    print("QUICK FINETUNING TEST")
    print("="*60)

    # Configuration
    n_blocks = 2
    n_train_samples = 5  # Only 5 samples for quick test
    n_test_samples = 3   # Only 3 for evaluation
    temperature = 10.0
    learning_rate = 1e-4

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"\n✓ Using device: {device}")

    # Load model
    print("\n[1/5] Loading DINOv2 model...")
    model = vit_base(
        img_size=(518, 518),
        patch_size=14,
        num_register_tokens=0,
        block_chunks=0,
        init_values=1.0,
    )

    ckpt = torch.load("models/dinov2/dinov2_vitb14_pretrain.pth", map_location=device)
    model.load_state_dict(ckpt, strict=True)
    model.to(device)
    print("✓ Model loaded")

    # Freeze and unfreeze
    print(f"\n[2/5] Freezing model and unfreezing last {n_blocks} blocks...")
    freeze_model(model)
    unfreeze_last_n_blocks(model, n_blocks)

    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"✓ Trainable: {trainable_params:,} / {total_params:,} ({100*trainable_params/total_params:.2f}%)")

    # Load dataset (small subset)
    print(f"\n[3/5] Loading {n_train_samples} training samples...")
    base = '/content/semantic_correspondence/SPair71k'
    train_dataset = SPairDataset(
        f'{base}/PairAnnotation',
        f'{base}/Layout',
        f'{base}/JPEGImages',
        'large',
        0.1,
        datatype='trn'
    )
    print(f"✓ Dataset loaded (total: {len(train_dataset)} samples)")

    # Create optimizer
    optimizer = optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=learning_rate,
        weight_decay=0.01
    )

    # Test training loop
    print(f"\n[4/5] Testing training loop with {n_train_samples} samples...")
    model.train()

    for idx in range(n_train_samples):
        sample = train_dataset[idx]

        # Prepare data
        src_tensor = sample['src_img'].unsqueeze(0).to(device)
        tgt_tensor = sample['trg_img'].unsqueeze(0).to(device)

        src_tensor = F.interpolate(src_tensor, size=(518, 518), mode='bilinear', align_corners=False)
        tgt_tensor = F.interpolate(tgt_tensor, size=(518, 518), mode='bilinear', align_corners=False)

        src_original_size = (sample['src_imsize'][2], sample['src_imsize'][1])
        tgt_original_size = (sample['trg_imsize'][2], sample['trg_imsize'][1])

        src_kps = sample['src_kps'].numpy()
        trg_kps = sample['trg_kps'].numpy()

        # Extract features
        src_features = extract_dense_features(model, src_tensor, training=True)
        tgt_features = extract_dense_features(model, tgt_tensor, training=True)

        # Compute loss
        loss = compute_cross_entropy_loss(
            src_features, tgt_features,
            src_kps, trg_kps,
            src_original_size, tgt_original_size,
            temperature=temperature
        )

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"  Sample {idx+1}/{n_train_samples}: Loss = {loss.item():.4f}")

    print("✓ Training loop completed successfully!")

    # Test evaluation
    print(f"\n[5/5] Testing evaluation with {n_test_samples} samples...")
    model.eval()

    test_dataset = SPairDataset(
        f'{base}/PairAnnotation',
        f'{base}/Layout',
        f'{base}/JPEGImages',
        'large',
        0.1,
        datatype='test'
    )

    with torch.no_grad():
        for idx in range(n_test_samples):
            sample = test_dataset[idx]

            src_tensor = sample['src_img'].unsqueeze(0).to(device)
            tgt_tensor = sample['trg_img'].unsqueeze(0).to(device)

            src_tensor = F.interpolate(src_tensor, size=(518, 518), mode='bilinear', align_corners=False)
            tgt_tensor = F.interpolate(tgt_tensor, size=(518, 518), mode='bilinear', align_corners=False)

            src_features = extract_dense_features(model, src_tensor, training=False)
            tgt_features = extract_dense_features(model, tgt_tensor, training=False)

            print(f"  Sample {idx+1}/{n_test_samples}: Features shape = {src_features.shape}")

    print("✓ Evaluation completed successfully!")

    # Verify gradients
    print("\n[VERIFICATION] Checking gradient flow...")
    has_grads = False
    grad_layers = []

    for name, param in model.named_parameters():
        if param.requires_grad and param.grad is not None:
            has_grads = True
            grad_layers.append(name)

    if has_grads:
        print(f"✓ Gradients detected in {len(grad_layers)} layers")
        print("  Sample layers with gradients:")
        for layer in grad_layers[:3]:  # Show first 3
            print(f"    - {layer}")
    else:
        print("✗ WARNING: No gradients detected!")

    # Summary
    print("\n" + "="*60)
    print("TEST SUMMARY")
    print("="*60)
    print("✓ Model loading: OK")
    print("✓ Freeze/Unfreeze: OK")
    print("✓ Training loop: OK")
    print("✓ Backward pass: OK")
    print("✓ Evaluation: OK")
    print(f"✓ Gradient flow: {'OK' if has_grads else 'FAILED'}")
    print("\n✓ All tests passed! Ready for full training.")
    print("="*60)


if __name__ == "__main__":
    try:
        quick_test()
    except Exception as e:
        print("\n" + "="*60)
        print("✗ TEST FAILED")
        print("="*60)
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()

QUICK FINETUNING TEST

✓ Using device: cuda

[1/5] Loading DINOv2 model...
✓ Model loaded

[2/5] Freezing model and unfreezing last 2 blocks...
✓ Unfrozen last 2 blocks + norm layer
✓ Trainable: 14,180,352 / 86,580,480 (16.38%)

[3/5] Loading 5 training samples...
✓ Dataset loaded (total: 53340 samples)

[4/5] Testing training loop with 5 samples...
  Sample 1/5: Loss = 4.5495
  Sample 2/5: Loss = 3.7546
  Sample 3/5: Loss = 3.0361
  Sample 4/5: Loss = 2.6244
  Sample 5/5: Loss = 1.9461
✓ Training loop completed successfully!

[5/5] Testing evaluation with 3 samples...
  Sample 1/3: Features shape = torch.Size([1, 37, 37, 768])
  Sample 2/3: Features shape = torch.Size([1, 37, 37, 768])
  Sample 3/3: Features shape = torch.Size([1, 37, 37, 768])
✓ Evaluation completed successfully!

[VERIFICATION] Checking gradient flow...
✓ Gradients detected in 30 layers
  Sample layers with gradients:
    - blocks.10.norm1.weight
    - blocks.10.norm1.bias
    - blocks.10.attn.qkv.weight

TEST SUMMA

In [None]:
#argmax function to find best matching patch
def find_best_match_argmax(s, width):
    best_match_idx = s.argmax().item()#argmax over the similarities
    y = best_match_idx // width
    x = best_match_idx % width
    return x, y

In [None]:
import numpy as np

def compute_pck(pred_points, gt_points, img_size, threshold):
    """
    Compute PCK@threshold

    Args:
        pred_points: list of [x, y] predictions
        gt_points: list of [x, y] ground truth
        img_size: (width, height) of the image
        threshold: normalized threshold (e.g., 0.05, 0.1, 0.2)

    Returns:
        pck: percentage of correct keypoints
        correct_mask: boolean array indicating which keypoints are correct
    """
    pred_points = np.array(pred_points)
    gt_points = np.array(gt_points)

    #compute Euclidean distance
    distances = np.sqrt(np.sum((pred_points - gt_points) ** 2, axis=1))

    #normalize by image diagonal (standard protocol)
    img_diagonal = np.sqrt(img_size[0] ** 2 + img_size[1] ** 2)
    normalized_distances = distances / img_diagonal

    #check which keypoints are within threshold
    correct_mask = normalized_distances <= threshold
    pck = np.mean(correct_mask) * 100  # percentage

    return pck, correct_mask, normalized_distances

def compute_pck_spair71k(pred_points, gt_points, bbox, threshold):
    """
    Compute PCK@threshold

    Args:
        pred_points: list of [x, y] predictions
        gt_points: list of [x, y] ground truth
        bbox: [xmin, ymin, xmax, ymax]
        threshold: normalized threshold (e.g., 0.05, 0.1, 0.2)

    Returns:
        pck: percentage of correct keypoints
        correct_mask: boolean array indicating which keypoints are correct
    """
    pred_points = np.array(pred_points)
    gt_points = np.array(gt_points)

    #compute Euclidean distance
    distances = np.sqrt(np.sum((pred_points - gt_points) ** 2, axis=1))

    # Normalize by max(bbox_width, bbox_height) - STANDARD SPAIR-71K
    bbox_width = bbox[2] - bbox[0]
    bbox_height = bbox[3] - bbox[1]
    normalization_factor = max(bbox_width, bbox_height)
    normalized_distances = distances / normalization_factor

    #check which keypoints are within threshold
    correct_mask = normalized_distances <= threshold
    pck = np.mean(correct_mask) * 100  # percentage

    return pck, correct_mask, normalized_distances


In [None]:
import json
from collections import defaultdict
import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import os
from datetime import datetime
import torch.nn.functional as F

#from SPair71k.devkit.SPairDataset import SPairDataset
#from dinov2 import extract_dense_features, pixel_to_patch_coord, patch_to_pixel_coord
#from matching_strategies import find_best_match_argmax
#from pck import compute_pck_spair71k
#from models.dinov2.dinov2.models.vision_transformer import vit_base


def freeze_model(model):
    """Freeze all model parameters"""
    for param in model.parameters():
        param.requires_grad = False


def unfreeze_last_n_blocks(model, n_blocks):
    """
    Unfreeze the last n_blocks transformer blocks + final norm layer

    Args:
        model: DINOv2 model
        n_blocks: number of blocks to unfreeze (counting from the end)
    """
    total_blocks = len(model.blocks)

    # Unfreeze last n blocks
    for i in range(total_blocks - n_blocks, total_blocks):
        for param in model.blocks[i].parameters():
            param.requires_grad = True

    # Also unfreeze the final normalization layer
    for param in model.norm.parameters():
        param.requires_grad = True

    print(f"Unfrozen last {n_blocks} blocks + norm layer")


def compute_cross_entropy_loss(src_features, tgt_features, src_kps, trg_kps,
                               src_original_size, tgt_original_size, temperature=10.0):
    """
    Compute cross-entropy loss for semantic correspondence.
    Treats correspondence as a classification problem where each target patch is a class.

    Args:
        src_features: [1, H, W, D] source dense features
        tgt_features: [1, H, W, D] target dense features
        src_kps: [N, 2] source keypoints in pixel coordinates
        trg_kps: [N, 2] target keypoints in pixel coordinates
        src_original_size: (width, height) of original source image
        tgt_original_size: (width, height) of original target image
        temperature: softmax temperature (higher = more peaked distribution)

    Returns:
        loss: mean cross-entropy loss across all keypoints
    """
    _, H, W, D = tgt_features.shape
    tgt_flat = tgt_features.reshape(H * W, D)  # [H*W, D]

    losses = []

    for i in range(src_kps.shape[0]):
        src_x, src_y = src_kps[i]
        tgt_x, tgt_y = trg_kps[i]

        # Get source feature at keypoint location
        src_patch_x, src_patch_y = pixel_to_patch_coord(src_x, src_y, src_original_size)
        src_feature = src_features[0, src_patch_y, src_patch_x, :]  # [D]

        # Get ground truth target patch coordinates
        tgt_patch_x, tgt_patch_y = pixel_to_patch_coord(tgt_x, tgt_y, tgt_original_size)

        # Compute cosine similarities with all target patches
        similarities = F.cosine_similarity(
            src_feature.unsqueeze(0),  # [1, D]
            tgt_flat,  # [H*W, D]
            dim=1
        )  # [H*W]

        # Convert similarities to log-probabilities
        log_probs = F.log_softmax(similarities * temperature, dim=0)

        # Ground truth index (flatten 2D coordinates to 1D)
        gt_idx = tgt_patch_y * W + tgt_patch_x

        # Negative log-likelihood loss
        loss = -log_probs[gt_idx]
        losses.append(loss)

    return torch.stack(losses).mean()

def evaluate(model, dataset, device, thresholds=[0.05, 0.1, 0.2]):
    """
    Evaluate model on test set using PCK metric

    Args:
        model: DINOv2 model
        dataset: test dataset
        device: 'cuda' or 'cpu'
        thresholds: list of PCK thresholds to evaluate

    Returns:
        results_SPair71K: dictionary with PCK scores at different thresholds
    """
    model.eval()
    per_image_metrics = []

    print(f"Evaluating on {len(dataset)} image pairs...")

    with torch.no_grad():
        for idx, sample in enumerate(dataset):
            # Prepare images
            src_tensor = sample['src_img'].unsqueeze(0).to(device)
            tgt_tensor = sample['trg_img'].unsqueeze(0).to(device)

            src_tensor = F.interpolate(src_tensor, size=(518, 518), mode='bilinear', align_corners=False)
            tgt_tensor = F.interpolate(tgt_tensor, size=(518, 518), mode='bilinear', align_corners=False)

            src_original_size = (sample['src_imsize'][2], sample['src_imsize'][1])
            tgt_original_size = (sample['trg_imsize'][2], sample['trg_imsize'][1])

            # Extract features
            src_features = extract_dense_features(model, src_tensor)
            tgt_features = extract_dense_features(model, tgt_tensor)

            _, H, W, D = tgt_features.shape
            tgt_flat = tgt_features.reshape(H * W, D)

            # Get keypoints and bbox
            src_kps = sample['src_kps'].numpy()
            trg_kps = sample['trg_kps'].numpy()
            trg_bbox = sample['trg_bbox']

            # Predict matches for all keypoints
            pred_matches = []

            for i in range(src_kps.shape[0]):
                src_x, src_y = src_kps[i]

                # Get source feature
                patch_x, patch_y = pixel_to_patch_coord(src_x, src_y, src_original_size)
                src_feature = src_features[0, patch_y, patch_x, :]

                # Compute similarities
                similarities = F.cosine_similarity(
                    src_feature.unsqueeze(0),
                    tgt_flat,
                    dim=1
                )

                # Find best match using argmax
                match_patch_x, match_patch_y = find_best_match_argmax(similarities, W)
                match_x, match_y = patch_to_pixel_coord(
                    match_patch_x, match_patch_y, tgt_original_size
                )

                pred_matches.append([match_x, match_y])

            # Compute PCK for different thresholds
            image_pcks = {}
            for threshold in thresholds:
                pck, _, _ = compute_pck_spair71k(
                    pred_matches,
                    trg_kps.tolist(),
                    trg_bbox,
                    threshold
                )
                image_pcks[threshold] = pck

            per_image_metrics.append({
                'category': sample['category'],
                'pck_scores': image_pcks,
            })

            if (idx + 1) % 100 == 0:
                print(f"Evaluated {idx + 1}/{len(dataset)} pairs...")

    # Compute overall statistics
    results = {}
    for threshold in thresholds:
        all_pcks = [img['pck_scores'][threshold] for img in per_image_metrics]
        results[f'pck@{threshold:.2f}'] = {
            'mean': float(np.mean(all_pcks)),
            'std': float(np.std(all_pcks)),
            'median': float(np.median(all_pcks)),
        }

    return results, per_image_metrics


def train_epoch(model, dataloader, optimizer, device, epoch, temperature=10.0):
    """
    Train for one epoch

    Args:
        model: DINOv2 model
        dataloader: training data loader
        optimizer: optimizer
        device: 'cuda' or 'cpu'
        epoch: current epoch number
        temperature: softmax temperature for loss

    Returns:
        avg_loss: average loss over the epoch
    """
    model.train()
    total_loss = 0
    num_batches = 0

    for idx, sample in enumerate(dataloader):
        # Prepare data
        src_tensor = sample['src_img'].to(device)  # [1, 3, H, W]
        tgt_tensor = sample['trg_img'].to(device)  # [1, 3, H, W]

        # Resize to 518x518 (DINOv2 expects this size)
        src_tensor = F.interpolate(src_tensor, size=(518, 518), mode='bilinear', align_corners=False)
        tgt_tensor = F.interpolate(tgt_tensor, size=(518, 518), mode='bilinear', align_corners=False)

        # Store original sizes for coordinate conversion
        src_original_size = (sample['src_imsize'][2], sample['src_imsize'][1])
        tgt_original_size = (sample['trg_imsize'][2], sample['trg_imsize'][1])

        # Get keypoints
        src_kps = sample['src_kps'].numpy()[0]  # [N, 2]
        trg_kps = sample['trg_kps'].numpy()[0]  # [N, 2]


        # Extract dense features
        src_features = extract_dense_features(model, src_tensor, training=True)
        tgt_features = extract_dense_features(model, tgt_tensor, training=True)

        # Compute loss
        loss = compute_cross_entropy_loss(
            src_features, tgt_features,
            src_kps, trg_kps,
            src_original_size, tgt_original_size,
            temperature=temperature
        )

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        num_batches += 1

        # Print progress
        if (idx + 1) % 50 == 0:
            print(f"Epoch {epoch}, Batch {idx + 1}/{len(dataloader)}, Loss: {loss.item():.4f}")

    avg_loss = total_loss / num_batches
    return avg_loss



def main():
    """Main training and evaluation pipeline"""

    # ========== CONFIGURATION ==========
    n_blocks_to_unfreeze = 2  # Try: 1, 2, 3, 4, 6, 12
    num_epochs = 3
    learning_rate = 1e-4
    batch_size = 1  # SPair-71k has variable-sized images
    temperature = 10.0  # Softmax temperature for cross-entropy loss

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    # Create results_SPair71K directory with timestamp
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_dir = f'results_SPair71K/finetuned_{n_blocks_to_unfreeze}blocks_{timestamp}'
    os.makedirs(results_dir, exist_ok=True)
    print(f"Results will be saved to: {results_dir}")

    # ========== LOAD MODEL ==========
    print("\nLoading DINOv2-base model...")
    model = vit_base(
        img_size=(518, 518),
        patch_size=14,
        num_register_tokens=0,
        block_chunks=0,
        init_values=1.0,
    )

    # Load pretrained weights
    ckpt = torch.load("models/dinov2/dinov2_vitb14_pretrain.pth", map_location=device)
    model.load_state_dict(ckpt, strict=True)
    model.to(device)

    # Freeze entire model, then unfreeze last N blocks
    freeze_model(model)
    unfreeze_last_n_blocks(model, n_blocks_to_unfreeze)

    # Count trainable parameters
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"\nTrainable parameters: {trainable_params:,} / {total_params:,} "
          f"({100 * trainable_params / total_params:.2f}%)")

    # ========== LOAD DATASETS ==========
    print("\nLoading SPair-71k dataset...")
    base = '/content/semantic_correspondence/SPair71k'

    train_dataset = SPairDataset(
        f'{base}/PairAnnotation',
        f'{base}/Layout',
        f'{base}/JPEGImages',
        'large',
        0.1,  # dummy pck_alpha, not used during training
        datatype='trn'  # training split
    )

    val_dataset = SPairDataset(
        f'{base}/PairAnnotation',
        f'{base}/Layout',
        f'{base}/JPEGImages',
        'large',
        0.1,
        datatype='val'
    )

    print(f"Training samples: {len(train_dataset)}")
    print(f"Test samples: {len(val_dataset)}")

    # Create data loader
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True if device == 'cuda' else False
    )

    # ========== OPTIMIZER ==========
    optimizer = optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=learning_rate,
        weight_decay=0.01
    )

    # Optional: Learning rate scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    # ========== TRAINING LOOP ==========
    print("\n" + "=" * 60)
    print("STARTING TRAINING")
    print("=" * 60)

    best_pck = 0
    best_epoch = 0
    training_history = []

    for epoch in range(num_epochs):
        print(f"\n{'=' * 60}")
        print(f"Epoch {epoch + 1}/{num_epochs}")
        print('=' * 60)

        # Train for one epoch
        train_loss = train_epoch(
            model, train_loader, optimizer, device, epoch + 1, temperature=temperature
        )
        print(f"\nAverage training loss: {train_loss:.4f}")

        # Update learning rate
        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']
        print(f"Learning rate: {current_lr:.6f}")

        # Evaluate on val set
        print("\nEvaluating on val set...")
        results, per_image_metrics = evaluate(model, val_dataset, device)

        print("\nTest Results:")
        for key, value in results.items():
            print(f"  {key}: {value['mean']:.2f}% ± {value['std']:.2f}% "
                  f"(median: {value['median']:.2f}%)")

        # Save best model
        current_pck = results['pck@0.10']['mean']
        if current_pck > best_pck:
            best_pck = current_pck
            best_epoch = epoch + 1

            # Save model checkpoint
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'pck': best_pck,
                'n_blocks': n_blocks_to_unfreeze,
            }, f'{results_dir}/best_model.pth')

            print(f"\n✓ New best model saved! PCK@0.1: {best_pck:.2f}%")

        # Store training history
        training_history.append({
            'epoch': epoch + 1,
            'train_loss': train_loss,
            'learning_rate': current_lr,
            'test_results': results
        })

        # Save intermediate results_SPair71K
        with open(f'{results_dir}/training_history.json', 'w') as f:
            json.dump(training_history, f, indent=2)

    # ========== FINAL RESULTS ==========
    print("\n" + "=" * 60)
    print("TRAINING COMPLETED")
    print("=" * 60)
    print(f"Best PCK@0.1: {best_pck:.2f}% (Epoch {best_epoch})")
    print(f"Results saved to: {results_dir}")

    # Load best model and evaluate on full test set
    print("\nLoading best model for final evaluation...")
    checkpoint = torch.load(f'{results_dir}/best_model.pth')
    model.load_state_dict(checkpoint['model_state_dict'])

    final_results, final_per_image = evaluate(model, test_dataset, device)

    # Save final detailed results_SPair71K
    with open(f'{results_dir}/final_results.json', 'w') as f:
        json.dump({
            'best_epoch': best_epoch,
            'n_blocks_unfrozen': n_blocks_to_unfreeze,
            'temperature': temperature,
            'learning_rate': learning_rate,
            'num_epochs': num_epochs,
            'results_SPair71K': final_results
        }, f, indent=2)

    # Save per-category analysis
    category_results = defaultdict(lambda: defaultdict(list))
    for img_metric in final_per_image:
        category = img_metric['category']
        for threshold, pck in img_metric['pck_scores'].items():
            category_results[category][threshold].append(pck)

    # Compute per-category statistics
    category_stats = {}
    for category, thresholds_dict in category_results.items():
        category_stats[category] = {}
        for threshold, pcks in thresholds_dict.items():
            category_stats[category][f'pck@{threshold:.2f}'] = {
                'mean': float(np.mean(pcks)),
                'std': float(np.std(pcks)),
                'n_samples': len(pcks)
            }

    with open(f'{results_dir}/per_category_results.json', 'w') as f:
        json.dump(category_stats, f, indent=2)

    print("\nPer-category results_SPair71K:")
    for category, stats in sorted(category_stats.items()):
        pck_01 = stats['pck@0.10']['mean']
        n_samples = stats['pck@0.10']['n_samples']
        print(f"  {category:20s}: {pck_01:.2f}% (n={n_samples})")

    print("\n" + "=" * 60)


if __name__ == "__main__":
    main()

Using device: cuda
Results will be saved to: results/finetuned_2blocks_20251221_182948

Loading DINOv2-base model...
Unfrozen last 2 blocks + norm layer

Trainable parameters: 14,180,352 / 86,580,480 (16.38%)

Loading SPair-71k dataset...
Training samples: 53340
Test samples: 5384

STARTING TRAINING

Epoch 1/3
Epoch 1, Batch 50/53340, Loss: 5.2169
Epoch 1, Batch 100/53340, Loss: 2.2020
Epoch 1, Batch 150/53340, Loss: 2.9084
Epoch 1, Batch 200/53340, Loss: 3.6872
Epoch 1, Batch 250/53340, Loss: 4.0350
Epoch 1, Batch 300/53340, Loss: 2.9882
Epoch 1, Batch 350/53340, Loss: 4.0329
Epoch 1, Batch 400/53340, Loss: 2.6079
Epoch 1, Batch 450/53340, Loss: 3.6571
Epoch 1, Batch 500/53340, Loss: 1.9698
Epoch 1, Batch 550/53340, Loss: 2.6591
Epoch 1, Batch 600/53340, Loss: 3.5188
Epoch 1, Batch 650/53340, Loss: 2.7611
Epoch 1, Batch 700/53340, Loss: 1.4283
Epoch 1, Batch 750/53340, Loss: 1.5851
Epoch 1, Batch 800/53340, Loss: 3.8043
Epoch 1, Batch 850/53340, Loss: 3.7350
Epoch 1, Batch 900/53340, 