# DINOv2 ViT-B_reg Configuration for Light Fine-tuning

This notebook demonstrates how to load and configure the DINOv3 ViT-Base model with registration tokens (ViT-B_reg) for light fine-tuning of the last layers.

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import sys
import os

# Load DINOv2 ViT-B/14 with registration tokens via torch.hub
dinov2_vitb14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg', pretrained=True)


def setup_light_finetuning(model, num_last_blocks=3):
    """
    Setup light fine-tuning by unfreezing only the last transformer blocks

    Args:
        model: DINOv2 model
        num_last_blocks: Number of last blocks to unfreeze (recommended: 2-4)
    """
    if model is None:
        print("Model not available")
        return None

    # First, freeze all parameters
    for param in model.parameters():
        param.requires_grad = False

    # Get total number of blocks
    total_blocks = len(model.blocks)
    print(f"Total transformer blocks: {total_blocks}")

    # Unfreeze the last num_last_blocks blocks
    blocks_to_unfreeze = list(range(total_blocks - num_last_blocks, total_blocks))

    for block_idx in blocks_to_unfreeze:
        for param in model.blocks[block_idx].parameters():
            param.requires_grad = True
        print(f"‚úì Unfrozen block {block_idx}")

    # Also unfreeze the final layer norm
    for param in model.norm.parameters():
        param.requires_grad = True
    print("‚úì Unfrozen final layer norm")

    # Count trainable parameters after setup
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    print(f"\nAfter light fine-tuning setup:")
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Percentage trainable: {100 * trainable_params / total_params:.2f}%")

    return model

# Apply light fine-tuning configuration
dinov2_vitb14_reg = setup_light_finetuning(dinov2_vitb14_reg, num_last_blocks=3)

# Dataset loading

SHARED_FOLDER_ID = "1O9YEQGv9j-4DGQp6adu5zbo3UH9K4ohJ"

# Download shared data from Google Drive
print("üì• Downloading shared data...")
!pip install -q gdown
!gdown --folder https://drive.google.com/drive/folders/1O9YEQGv9j-4DGQp6adu5zbo3UH9K4ohJ -O /content/shared_data


# 3. Link dataset from shared data
print("\nüîó Setting up dataset from shared data...")
!mkdir -p /content/semantic-correspondence/data

# Extract SPair-71k.tar.gz if it exists, then link
SPAIR_TAR_PATH = '/content/shared_data/SPair-71k.tar.gz'
SPAIR_DIR_PATH = '/content/shared_data/SPair-71k'

if os.path.exists(SPAIR_TAR_PATH):
    print(f"Found {SPAIR_TAR_PATH}, extracting...")
    !tar -xzf {SPAIR_TAR_PATH} -C /content/shared_data
    print("‚úÖ SPair-71k.tar.gz extracted.")

# Link SPair-71k directly from shared data
if os.path.exists(SPAIR_DIR_PATH):
    !ln -sf {SPAIR_DIR_PATH} /content/semantic-correspondence/data/SPair-71k
    print("‚úÖ SPair-71k dataset linked from shared data")
else:
    print("‚ùå SPair-71k not found in shared data")
    print("Available folders:")
    !ls -la /content/shared_data

# Verify dataset setup
if os.path.exists('/content/semantic-correspondence/data/SPair-71k'):
    print("‚úÖ Dataset verification successful")
    !ls -la /content/semantic-correspondence/data/SPair-71k | head -5
else:
    print("‚ùå Dataset link verification failed")

Downloading: "https://github.com/facebookresearch/dinov2/zipball/main" to /root/.cache/torch/hub/main.zip




Downloading: "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth" to /root/.cache/torch/hub/checkpoints/dinov2_vitb14_reg4_pretrain.pth


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 330M/330M [00:00<00:00, 442MB/s]


Total transformer blocks: 12
‚úì Unfrozen block 9
‚úì Unfrozen block 10
‚úì Unfrozen block 11
‚úì Unfrozen final layer norm

After light fine-tuning setup:
Total parameters: 86,583,552
Trainable parameters: 21,269,760
Percentage trainable: 24.57%
üì• Downloading shared data...
Retrieving folder contents
Retrieving folder 1TB7W0mx4rhLShdsyGNnT3BvrDAD-vCE6 results
Retrieving folder 18UvZK8jzUhku4YasNUk9c0avpXJZww7I metrics
Retrieving folder 1ftqXNTeL6lWLi_cAakQXLtX-bvtidaVq plots
Retrieving folder 1zWFUZxXVqrLemn-ZWn5jh-frtA4Z1GWM weights
Processing file 1lvmW5fGM_O2DNbb9eEtKvmh69_TSBxqM dinov3_vitb16_pretrain_lvd1689m-73cec8be.pth
Processing file 1j0wXloH_YepR9w2xkXQOcVm4Zy8QVzQR sam_vit_b_01ec64.pth
Processing file 14WscxapP53HYX9JOd9euD5wqJtBqIFmP DINOv2_eval.ipynb
Processing file 1Awnsme0UfwOwevbKPT3SXW-FL-Hjlopg SPair-71k.tar.gz
Processing file 1j32zxdYhEMmEfq9DoTA3gpyVUz3yHZgR Steps.ipynb
Retrieving folder contents completed
Building directory structure
Building directory structur

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset # Added Dataset import
import sys
import os
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as T
import json

# Set device at the beginning of this cell to ensure it's always defined
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Training hyperparameters
LEARNING_RATE = 1e-4  # Lower LR for fine-tuning
BATCH_SIZE = 8        # Adjust based on GPU memory
NUM_EPOCHS = 10
WARMUP_EPOCHS = 1
WEIGHT_DECAY = 1e-4

class SemanticCorrespondenceLoss(nn.Module):
    """
    Comprehensive loss function for semantic correspondence training
    """
    def __init__(self,
                 lambda_correspondence=1.0,
                 lambda_consistency=0.5,
                 lambda_smooth=0.1,
                 temperature=0.1):
        super().__init__()
        self.lambda_correspondence = lambda_correspondence
        self.lambda_consistency = lambda_consistency
        self.lambda_smooth = lambda_smooth
        self.temperature = temperature

    def correspondence_loss(self, feat1, feat2, kpts1, kpts2, valid_mask):
        """
        Keypoint correspondence loss using feature similarity
        """
        # Normalize features
        feat1 = F.normalize(feat1, dim=-1)
        feat2 = F.normalize(feat2, dim=-1)

        # Compute similarity matrix
        similarity = torch.matmul(feat1, feat2.transpose(-2, -1)) / self.temperature

        # Get ground truth correspondence indices
        batch_size = feat1.size(0)
        losses = []

        for b in range(batch_size):
            if valid_mask[b].sum() == 0:
                continue

            # Get valid keypoints for this batch
            valid_kpts1 = kpts1[b][valid_mask[b]]  # [N_valid, 2]
            valid_kpts2 = kpts2[b][valid_mask[b]]  # [N_valid, 2]

            if len(valid_kpts1) == 0:
                continue

            # Convert keypoints to feature indices (assuming features are spatial)
            H, W = int(feat1.size(1) ** 0.5), int(feat1.size(1) ** 0.5)

            # Get feature similarities for valid correspondences
            sim_batch = similarity[b]  # [H*W, H*W]

            # Create correspondence targets (identity for matched keypoints)
            n_valid = len(valid_kpts1)
            targets = torch.arange(n_valid, device=feat1.device)

            # Select relevant similarities
            if n_valid > 1:
                selected_sim = sim_batch[:n_valid, :n_valid]
                loss_batch = F.cross_entropy(selected_sim, targets)
                losses.append(loss_batch)

        if len(losses) == 0:
            return torch.tensor(0.0, device=feat1.device, requires_grad=True)

        return torch.stack(losses).mean()

    def cycle_consistency_loss(self, feat1, feat2):
        """
        Enforce cycle consistency: feat1 -> feat2 -> feat1
        """
        # Normalize features
        feat1_norm = F.normalize(feat1, dim=-1)
        feat2_norm = F.normalize(feat2, dim=-1)

        # Forward correspondence: feat1 -> feat2
        sim_12 = torch.matmul(feat1_norm, feat2_norm.transpose(-2, -1))
        correspondence_12 = F.softmax(sim_12 / self.temperature, dim=-1)

        # Backward correspondence: feat2 -> feat1
        sim_21 = torch.matmul(feat2_norm, feat1_norm.transpose(-2, -1))
        correspondence_21 = F.softmax(sim_21 / self.temperature, dim=-1)

        # Cycle consistency: should get back to identity
        cycle_consistency = torch.matmul(correspondence_12, correspondence_21)
        identity = torch.eye(feat1.size(1), device=feat1.device).unsqueeze(0).expand(feat1.size(0), -1, -1)

        return F.mse_loss(cycle_consistency, identity)

    def smoothness_loss(self, feat1, feat2):
        """
        Encourage smooth correspondences in spatial neighborhoods
        """
        # Simple spatial smoothness - encourage similar features for nearby patches
        B, N, D = feat1.shape
        H = W = int(N ** 0.5)

        if H * W != N:
            return torch.tensor(0.0, device=feat1.device, requires_grad=True)

        # Reshape to spatial dimensions
        feat1_spatial = feat1.view(B, H, W, D)
        feat2_spatial = feat2.view(B, H, W, D)

        # Compute differences with neighbors (4-connectivity)
        smooth_loss = 0.0
        count = 0

        for dx, dy in [(0, 1), (1, 0), (0, -1), (-1, 0)]:
            # The original condition `0 <= dx < H-1 and 0 <= dy < W-1` might be too restrictive.
            # It should be `0 <= x+dx < H` and `0 <= y+dy < W` when iterating over elements.
            # For simpler tensor slicing, ensure the dimensions are valid for `dx` and `dy`.
            if dx == 0 and dy == 1: # Right neighbor
                diff1 = feat1_spatial[:, :, :-1] - feat1_spatial[:, :, 1:]
                diff2 = feat2_spatial[:, :, :-1] - feat2_spatial[:, :, 1:]
                smooth_loss += F.mse_loss(diff1, diff2)
                count += 1
            elif dx == 1 and dy == 0: # Down neighbor
                diff1 = feat1_spatial[:, :-1, :] - feat1_spatial[:, 1:, :]
                diff2 = feat2_spatial[:, :-1, :] - feat2_spatial[:, 1:, :]
                smooth_loss += F.mse_loss(diff1, diff2)
                count += 1
            # For (0, -1) and (-1, 0), it's more complex to implement with direct slicing like this
            # and usually requires padding or careful boundary handling if using naive subtraction.
            # Given this is 'simple spatial smoothness', sticking to right and down neighbors for now.

        return smooth_loss / max(count, 1)

    def forward(self, feat1, feat2, kpts1, kpts2, valid_mask):
        """
        Compute total loss
        """
        # Main correspondence loss
        corr_loss = self.correspondence_loss(feat1, feat2, kpts1, kpts2, valid_mask)
        #print(f"[DEBUG Loss] correspondence_loss: {corr_loss}")

        # Cycle consistency loss
        consistency_loss = self.cycle_consistency_loss(feat1, feat2)
        #print(f"[DEBUG Loss] consistency_loss: {consistency_loss}")

        # Smoothness loss
        smooth_loss = self.smoothness_loss(feat1, feat2)
        #print(f"[DEBUG Loss] smoothness_loss: {smooth_loss}")

        # Total loss
        total_loss = (self.lambda_correspondence * corr_loss +
                     self.lambda_consistency * consistency_loss +
                     self.lambda_smooth * smooth_loss)
        #print(f"[DEBUG Loss] total_loss: {total_loss}")

        return {
            'total_loss': total_loss,
            'correspondence_loss': corr_loss,
            'consistency_loss': consistency_loss,
            'smoothness_loss': smooth_loss
        }

# Initialize loss function
criterion = SemanticCorrespondenceLoss(
    lambda_correspondence=1.0,
    lambda_consistency=0.5,
    lambda_smooth=0.1,
    temperature=0.1
)
print("‚úì Loss function initialized")

# Create parameter groups with different learning rates
trainable_params = []
for name, param in dinov2_vitb14_reg.named_parameters():
    if param.requires_grad:
        trainable_params.append(param)
        print(f"Trainable parameter: {name}, shape: {param.shape}")

print(f"\nTotal trainable parameters: {sum(p.numel() for p in trainable_params):,}")

# Optimizer setup
optimizer = torch.optim.Adam(
    trainable_params,
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    betas=(0.9, 0.999)
)

print(f"‚úì Optimizer configured: Adam with LR={LEARNING_RATE}")

# Move model to device and ensure proper setup
dinov2_vitb14_reg = dinov2_vitb14_reg.to(device)
dinov2_vitb14_reg.train()
print(f"‚úì Model moved to device: {device}")

def extract_features(model, images, layer_idx=-1):
    """
    Extract features from DINOv2 model

    Args:
        model: DINOv2 model
        images: Input images [B, 3, H, W]
        layer_idx: Which layer to extract features from (-1 for last layer)

    Returns:
        features: Extracted features [B, N, D] where N is number of patches
    """
    # Removed mixed precision autocast for compatibility
    # autocast_dtype = torch.float16 if device.type == 'cuda' else torch.bfloat16

    # with torch.amp.autocast(device_type=device.type, dtype=autocast_dtype):
    if layer_idx == -1:
        # Use get_intermediate_layers to ensure we get the full sequence of tokens from the last block
        # For the last layer, n should be [len(model.blocks) - 1]
        last_block_idx = len(model.blocks) - 1
        intermediate_output = model.get_intermediate_layers(
            images,
            n=[last_block_idx],
            return_class_token=False  # Set to False to get just reg + patch tokens
        )

        features = intermediate_output[0]  # This will be [B, num_reg+N_patch, D]

    else:
        # This branch already uses get_intermediate_layers correctly with return_class_token=False
        intermediate_output = model.get_intermediate_layers(
            images,
            n=[layer_idx],
            return_class_token=False
        )
        features = intermediate_output[0]
        # Remove register tokens if present
        features = features[:, model.num_register_tokens:]

    return features

def prepare_batch_data(batch):
    """
    Prepare batch data for training

    Args:
        batch: Batch from dataloader containing images and keypoints

    Returns:
        Dictionary with prepared data
    """
    # Extract data from batch (adjust according to your dataset format)
    img1 = batch['img1'].to(device)  # [B, 3, H, W]
    img2 = batch['img2'].to(device)  # [B, 3, H, W]
    kpts1 = batch['kpts1'].to(device)  # [B, N, 2]
    kpts2 = batch['kpts2'].to(device)  # [B, N, 2]
    valid_mask = batch['valid'].to(device)  # [B, N] - which keypoints are valid

    return {
        'img1': img1,
        'img2': img2,
        'kpts1': kpts1,
        'kpts2': kpts2,
        'valid_mask': valid_mask
    }


class SPair71kDataset(Dataset):
    """
    SPair-71k Dataset for semantic correspondence

    Expected data structure:
    SPair-71k/
    ‚îú‚îÄ‚îÄ JPEGImages/
    ‚îÇ   ‚îú‚îÄ‚îÄ cat/
    ‚îÇ   ‚îú‚îÄ‚îÄ dog/
    ‚îÇ   ‚îî‚îÄ‚îÄ ...
    ‚îú‚îÄ‚îÄ PairAnnotation/
    ‚îÇ   ‚îú‚îÄ‚îÄ trn/  # <--- Note 'trn', 'test', 'val' instead of 'train', 'val'
    ‚îÇ   ‚îÇ   ‚îú‚îÄ‚îÄ 000001-img1-img2:category.json
    ‚îÇ   ‚îÇ   ‚îî‚îÄ‚îÄ ...
    ‚îÇ   ‚îú‚îÄ‚îÄ val/
    ‚îÇ   ‚îÇ   ‚îú‚îÄ‚îÄ 000002-img3-img4:category.json
    ‚îÇ   ‚îÇ   ‚îî‚îÄ‚îÄ ...
    ‚îÇ   ‚îî‚îÄ‚îÄ test/
    ‚îÇ       ‚îú‚îÄ‚îÄ 000003-img5-img4:category.json
    ‚îÇ       ‚îî‚îÄ‚îÄ ...
    ‚îî‚îÄ‚îÄ ImageAnnotation/
    """

    def __init__(self,
                 data_root,
                 split='train', # 'train', 'val', 'test'
                 image_size=224,
                 category=None):
        self.data_root = data_root
        # Map 'train' to 'trn' for directory names as per dataset structure
        self.split_dir = 'trn' if split == 'train' else split
        self.image_size = image_size
        self.category = category

        # Image transforms
        self.transform = T.Compose([
            T.Resize((image_size, image_size)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
        ])

        # Load annotations
        self.pairs = self._load_pairs()
        print(f"Loaded {len(self.pairs)} pairs for {split} split")

    def _load_pairs(self):
        """Load pair annotations from SPair-71k format, with error handling for missing fields."""
        pairs = []

        # Construct the path to the split-specific annotation directory
        # e.g., /content/semantic-correspondence/data/SPair-71k/PairAnnotation/trn
        split_annotation_root = os.path.join(self.data_root, 'PairAnnotation', self.split_dir)

        if not os.path.exists(split_annotation_root) or not os.path.isdir(split_annotation_root):
            print(f"Warning: Split annotation directory not found: {split_annotation_root}. Returning empty dataset.")
            return []

        # JSON files are directly within the split_annotation_root
        for pair_file_name in os.listdir(split_annotation_root):
            if not pair_file_name.endswith('.json'):
                continue

            pair_path = os.path.join(split_annotation_root, pair_file_name)
            try:
                with open(pair_path, 'r') as f:
                    pair_data = json.load(f)

                src_imname = pair_data.get('src_imname')
                trg_imname = pair_data.get('trg_imname')
                src_kps = pair_data.get('src_kps')
                trg_kps = pair_data.get('trg_kps')
                actual_category_from_json = pair_data.get('category')

                # Handle 'valid_kpts' - not explicitly present in sample JSON, assume all true
                valid_kpts = pair_data.get('valid_kpts')
                if valid_kpts is None:
                    valid_kpts = [True] * len(src_kps) if src_kps else []
                elif not isinstance(valid_kpts, list):
                    print(f"Warning: 'valid_kpts' in {pair_file_name} is not a list. Defaulting to all True.")
                    valid_kpts = [True] * len(src_kps) if src_kps else []

                # Essential check: Ensure img paths and keypoints are present and are not None
                if not (isinstance(src_imname, str) and isinstance(trg_imname, str) and
                        isinstance(src_kps, list) and isinstance(trg_kps, list) and
                        isinstance(actual_category_from_json, str)):
                    print(f"Warning: Skipping malformed pair file {pair_file_name} due to missing or invalid essential data (src_imname:{type(src_imname)}, trg_imname:{type(trg_imname)}, src_kps:{type(src_kps)}, trg_kps:{type(trg_kps)}, category:{type(actual_category_from_json)}).")
                    continue

                # Check for consistent lengths of keypoint lists
                if not (len(src_kps) == len(trg_kps) and len(src_kps) == len(valid_kpts)):
                    print(f"Warning: Skipping pair {pair_file_name} due to mismatch in keypoint/valid_mask lengths. "
                          f"src_kps:{len(src_kps)}, trg_kps:{len(trg_kps)}, valid_kpts:{len(valid_kpts)}.")
                    continue

                # Filter by category if specified
                if self.category and actual_category_from_json != self.category:
                    continue

                pairs.append({
                    'category': actual_category_from_json, # Use the category from JSON
                    'src_img': src_imname,
                    'trg_img': trg_imname,
                    'src_kpts': src_kps,
                    'trg_kpts': trg_kps,
                    'valid': valid_kpts
                })
            except json.JSONDecodeError:
                print(f"Warning: Skipping malformed JSON file {pair_file_name}.")
            except Exception as e:
                print(f"Warning: An unexpected error occurred while processing {pair_file_name}: {e}")

        return pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        pair = self.pairs[idx]

        # Load images
        img1_path = os.path.join(self.data_root, 'JPEGImages', pair['category'], pair['src_img'])
        img2_path = os.path.join(self.data_root, 'JPEGImages', pair['category'], pair['trg_img'])

        img1 = Image.open(img1_path).convert('RGB')
        img2 = Image.open(img2_path).convert('RGB')

        # Get original dimensions for keypoint scaling
        orig_w1, orig_h1 = img1.size
        orig_w2, orig_h2 = img2.size

        # Transform images
        img1_tensor = self.transform(img1)
        img2_tensor = self.transform(img2)

        # Scale keypoints to match resized image
        kpts1 = torch.tensor(pair['src_kpts'], dtype=torch.float32)
        kpts2 = torch.tensor(pair['trg_kpts'], dtype=torch.float32)

        # Scale keypoints from original size to self.image_size
        kpts1[:, 0] = kpts1[:, 0] * self.image_size / orig_w1
        kpts1[:, 1] = kpts1[:, 1] * self.image_size / orig_h1
        kpts2[:, 0] = kpts2[:, 0] * self.image_size / orig_w2
        kpts2[:, 1] = kpts2[:, 1] * self.image_size / orig_h2

        # Valid mask
        valid_mask = torch.tensor(pair['valid'], dtype=torch.bool)

        return {
            'img1': img1_tensor,
            'img2': img2_tensor,
            'kpts1': kpts1,
            'kpts2': kpts2,
            'valid': valid_mask,
            'category': pair['category']
        }

# Custom collate_fn for handling varying number of keypoints
def custom_collate_fn(batch):
    imgs1 = torch.stack([item['img1'] for item in batch])
    imgs2 = torch.stack([item['img2'] for item in batch])

    # Find the maximum number of keypoints in the current batch
    max_kpts = max([item['kpts1'].shape[0] for item in batch])

    padded_kpts1 = []
    padded_kpts2 = []
    padded_valid_mask = []

    for item in batch:
        num_kpts = item['kpts1'].shape[0]
        # Pad kpts1 to max_kpts
        pad_kpts1 = F.pad(item['kpts1'], (0, 0, 0, max_kpts - num_kpts), 'constant', 0)
        padded_kpts1.append(pad_kpts1)

        # Pad kpts2 to max_kpts
        pad_kpts2 = F.pad(item['kpts2'], (0, 0, 0, max_kpts - num_kpts), 'constant', 0)
        padded_kpts2.append(pad_kpts2)

        # Pad valid mask to max_kpts (fill with False for padded values)
        pad_valid_mask = F.pad(item['valid'], (0, max_kpts - num_kpts), 'constant', False)
        padded_valid_mask.append(pad_valid_mask)

    kpts1_batch = torch.stack(padded_kpts1)
    kpts2_batch = torch.stack(padded_kpts2)
    valid_mask_batch = torch.stack(padded_valid_mask)

    categories = [item['category'] for item in batch]

    return {
        'img1': imgs1,
        'img2': imgs2,
        'kpts1': kpts1_batch,
        'kpts2': kpts2_batch,
        'valid': valid_mask_batch,
        'category': categories
    }

# Define training and validation functions
def train_one_epoch(model, dataloader, criterion, optimizer, scheduler, scaler, epoch):
    """
    Train for one epoch
    """
    model.train()
    total_loss = 0
    total_corr_loss = 0
    total_consistency_loss = 0
    total_smooth_loss = 0
    num_batches = 0

    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")

    # Disable mixed precision to fix type mismatch issues
    # autocast_dtype = torch.float16 if device.type == 'cuda' else torch.bfloat16

    for batch_idx, batch in enumerate(progress_bar):
        try:
            # Prepare batch data
            data = prepare_batch_data(batch)

            # Zero gradients
            optimizer.zero_grad()

            # Standard forward pass without autocast to avoid type mismatches
            # with torch.amp.autocast(device_type=device.type, dtype=autocast_dtype):
            # Extract features
            feat1 = extract_features(model, data['img1'])
            feat2 = extract_features(model, data['img2'])

            # Compute loss
            loss_dict = criterion(
                feat1, feat2,
                data['kpts1'], data['kpts2'],
                data['valid_mask']
            )

            loss = loss_dict['total_loss']

            # Standard backward pass without gradient scaling
            loss.backward()
            optimizer.step()
            scheduler.step()

            # Update metrics
            total_loss += loss.item()
            total_corr_loss += loss_dict['correspondence_loss'].item()
            total_consistency_loss += loss_dict['consistency_loss'].item()
            total_smooth_loss += loss_dict['smoothness_loss'].item()
            num_batches += 1

            # Update progress bar
            current_lr = optimizer.param_groups[0]['lr']
            progress_bar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'LR': f'{current_lr:.2e}'
            })

        except Exception as e:
            print(f"Error in batch {batch_idx}: {e}")
            continue

    # Calculate average losses
    avg_loss = total_loss / max(num_batches, 1)
    avg_corr_loss = total_corr_loss / max(num_batches, 1)
    avg_consistency_loss = total_consistency_loss / max(num_batches, 1)
    avg_smooth_loss = total_smooth_loss / max(num_batches, 1)

    return {
        'total_loss': avg_loss,
        'correspondence_loss': avg_corr_loss,
        'consistency_loss': avg_consistency_loss,
        'smoothness_loss': avg_smooth_loss
    }

def validate_model(model, dataloader, criterion):
    """
    Validate the model
    """
    model.eval()
    total_loss = 0
    num_batches = 0

    # Disable mixed precision for validation as well
    # autocast_dtype = torch.float16 if device.type == 'cuda' else torch.bfloat16

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validation"):
            try:
                data = prepare_batch_data(batch)

                # Standard forward pass without autocast
                # with torch.amp.autocast(device_type=device.type, dtype=autocast_dtype):
                # Extract features
                feat1 = extract_features(model, data['img1'])
                feat2 = extract_features(model, data['img2'])

                loss_dict = criterion(
                    feat1, feat2,
                    data['kpts1'], data['kpts2'],
                    data['valid_mask']
                )

                loss = loss_dict['total_loss']

                total_loss += loss.item()
                num_batches += 1

            except Exception as e:
                print(f"Validation error: {e}")
                continue

    return total_loss / max(num_batches, 1)

# Initialize mixed precision scaler using torch.amp.GradScaler (disabled for compatibility)
# scaler = torch.amp.GradScaler()

# Training history
train_losses = []
val_losses = []

# Create datasets using downloaded data
train_dataset = SPair71kDataset(
    data_root='/content/semantic-correspondence/data/SPair-71k',
    split='train',
    image_size=224
)

val_dataset = SPair71kDataset(
    data_root='/content/semantic-correspondence/data/SPair-71k',
    split='val',
    image_size=224
)

# Create dataloaders, now using the custom_collate_fn
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, collate_fn=custom_collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, collate_fn=custom_collate_fn)

# Learning rate scheduler - now we can calculate correct total_steps
total_steps = NUM_EPOCHS * len(train_dataloader)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=LEARNING_RATE,
    total_steps=total_steps,
    pct_start=0.1,  # 10% warmup
    anneal_strategy='cos'
)

print(f"‚úì Scheduler configured: OneCycleLR with {total_steps} total steps ({len(train_dataloader)} steps per epoch)")
print("Starting training...")
best_val_loss = float('inf')

for epoch in range(NUM_EPOCHS):
    print(f"\n{'='*50}")
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
    print(f"{'='*50}")

    # Training
    train_metrics = train_one_epoch(
        model=dinov2_vitb14_reg,
        dataloader=train_dataloader,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        scaler=None,  # Disabled mixed precision
        epoch=epoch
    )

    # Validation
    val_loss = validate_model(
        model=dinov2_vitb14_reg,
        dataloader=val_dataloader,
        criterion=criterion
    )

    # Log metrics
    train_losses.append(train_metrics['total_loss'])
    val_losses.append(val_loss)

    print(f"Train Loss: {train_metrics['total_loss']:.4f}")
    print(f"  - Correspondence: {train_metrics['correspondence_loss']:.4f}")
    print(f"  - Consistency: {train_metrics['consistency_loss']:.4f}")
    print(f"  - Smoothness: {train_metrics['smoothness_loss']:.4f}")
    print(f"Val Loss: {val_loss:.4f}")

    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': dinov2_vitb14_reg.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_loss': val_loss,
            'train_metrics': train_metrics
        }, 'best_dinov2_light_finetune.pth')
        print(f"üíæ New best model saved! (val_loss: {val_loss:.4f})")

print("Training completed!")

from google.colab import drive
import shutil

# Mount Google Drive
print("Mounting Google Drive...")
drive.mount('/content/drive')

# Define the path to save the checkpoint in Google Drive
gdrive_save_path = '/content/drive/MyDrive/DINOv3_checkpoints'
checkpoint_filename = 'best_dinov2_light_finetune.pth'
local_checkpoint_path = checkpoint_filename

# Create the target directory in Google Drive if it doesn't exist
os.makedirs(gdrive_save_path, exist_ok=True)

# Check if the checkpoint file exists locally
if os.path.exists(local_checkpoint_path):
    # Copy the checkpoint to Google Drive
    destination_path = os.path.join(gdrive_save_path, checkpoint_filename)
    shutil.copy(local_checkpoint_path, destination_path)
    print(f"‚úÖ Checkpoint '{checkpoint_filename}' successfully uploaded to Google Drive at '{destination_path}'")
else:
    print(f"‚ùå Checkpoint file '{checkpoint_filename}' not found locally. Skipping upload.")


# Utility functions for when you're ready to train
def save_checkpoint(model, optimizer, scheduler, epoch, loss, filepath):
    """
    Save training checkpoint
    """
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'loss': loss,
    }, filepath)

def load_checkpoint(filepath, model, optimizer=None, scheduler=None):
    """
    Load training checkpoint
    """
    checkpoint = torch.load(filepath)
    model.load_state_dict(checkpoint['model_state_dict'])

    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    if scheduler is not None:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

    return checkpoint['epoch'], checkpoint['loss']

def plot_training_curves(train_losses, val_losses):
    """
    Plot training and validation loss curves
    """
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title('Loss Curves')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    plt.subplot(1, 2, 2)
    if len(train_losses) > 1:
        plt.plot(np.gradient(train_losses), label='Train Loss Gradient')
        plt.plot(np.gradient(val_losses), label='Val Loss Gradient')
        plt.title('Loss Gradients (Learning Progress)')
        plt.xlabel('Epoch')
        plt.ylabel('Loss Gradient')
        plt.legend()
        plt.grid(True)

    plt.tight_layout()
    plt.show()

Using device: cuda
‚úì Loss function initialized
Trainable parameter: blocks.9.norm1.weight, shape: torch.Size([768])
Trainable parameter: blocks.9.norm1.bias, shape: torch.Size([768])
Trainable parameter: blocks.9.attn.qkv.weight, shape: torch.Size([2304, 768])
Trainable parameter: blocks.9.attn.qkv.bias, shape: torch.Size([2304])
Trainable parameter: blocks.9.attn.proj.weight, shape: torch.Size([768, 768])
Trainable parameter: blocks.9.attn.proj.bias, shape: torch.Size([768])
Trainable parameter: blocks.9.ls1.gamma, shape: torch.Size([768])
Trainable parameter: blocks.9.norm2.weight, shape: torch.Size([768])
Trainable parameter: blocks.9.norm2.bias, shape: torch.Size([768])
Trainable parameter: blocks.9.mlp.fc1.weight, shape: torch.Size([3072, 768])
Trainable parameter: blocks.9.mlp.fc1.bias, shape: torch.Size([3072])
Trainable parameter: blocks.9.mlp.fc2.weight, shape: torch.Size([768, 3072])
Trainable parameter: blocks.9.mlp.fc2.bias, shape: torch.Size([768])
Trainable parameter: b

Epoch 1/10:  11%|‚ñà         | 723/6668 [04:37<37:59,  2.61it/s, Loss=0.1287, LR=2.15e-05]


KeyboardInterrupt: 