In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# --- 1. Helper: Map (x,y) coordinates to Patch Index ---
def get_patch_indices(keypoints, img_size=518, patch_size=14):
    """
    Converts (x, y) pixel coordinates into the index of the patch (0 to 1368).
    keypoints: [Batch, N_kps, 2]
    """
    grid_w = img_size // patch_size # 37
    
    # Scale coordinates to grid integers (0..36)
    grid_x = (keypoints[:, :, 0] / patch_size).long().clamp(0, grid_w-1)
    grid_y = (keypoints[:, :, 1] / patch_size).long().clamp(0, grid_w-1)
    
    # Calculate flat index (y * width + x)
    flat_indices = grid_y * grid_w + grid_x  # [Batch, N_kps]
    return flat_indices

# --- 2. Helper: Extract Features at those Indices ---
def extract_features_at_indices(features, indices):
    """
    features: [Batch, 1369, Dim]
    indices:  [Batch, N_kps]
    Returns:  [Batch, N_kps, Dim]
    """
    B, N_patches, Dim = features.shape
    B, N_kps = indices.shape
    
    # Expand indices to gather across the Dim dimension
    # [B, N_kps] -> [B, N_kps, Dim]
    indices_expanded = indices.unsqueeze(-1).expand(-1, -1, Dim)
    
    # Gather specific features
    kps_features = torch.gather(features, 1, indices_expanded)
    return kps_features

# --- 3. The Contrastive Loss ---
def contrastive_loss(feat_src_kps, feat_trg_all, trg_kps_indices, mask, temp=0.1):
    """
    feat_src_kps:    [B, N, Dim]   (Query: Feature at Source Nose)
    feat_trg_all:    [B, 1369, Dim](Keys: All patches in Target Image)
    trg_kps_indices: [B, N]        (Label: Index of Target Nose)
    mask:            [B, N]        (Valid points only)
    """
    # Normalize features
    feat_src_kps = F.normalize(feat_src_kps, dim=-1)
    feat_trg_all = F.normalize(feat_trg_all, dim=-1)
    
    # Similarity: [B, N, Dim] @ [B, Dim, 1369] -> [B, N, 1369]
    # We compare every Source Keypoint against ALL Target Patches
    logits = torch.bmm(feat_src_kps, feat_trg_all.transpose(1, 2)) / temp
    
    # Flatten everything to 2D for CrossEntropy
    # We only care about VALID keypoints
    valid = mask.bool()
    
    logits_valid = logits[valid]       # [Total_Valid_Kps, 1369]
    targets_valid = trg_kps_indices[valid] # [Total_Valid_Kps]
    
    loss = F.cross_entropy(logits_valid, targets_valid)
    return loss

# --- 4. Main Execution ---
if __name__ == '__main__':
    print("\n--- Step 2: Overfit Small Sample ---")
    
    # A. Setup Model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Loading Model on {device}...")
    model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14').to(device)
    
    # Freeze Backbone
    for param in model.parameters():
        param.requires_grad = False
    
    # Unfreeze Last 2 Blocks
    for block in model.blocks[-2:]:
        for param in block.parameters():
            param.requires_grad = True

    # B. Setup Optimizer (High LR, No Decay)
    # We filter specifically for parameters that require grad
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), 
                            lr=1e-3, weight_decay=0)
    
    # C. Get ONE Fixed Batch
    print("Grabbing one batch...")
    try:
        batch = next(iter(trn_loader))
    except NameError:
        print("Error: trn_loader not defined. Run dataset code first.")
        exit()

    src = batch['src_img'].to(device)
    trg = batch['trg_img'].to(device)
    src_kps = batch['src_kps'].to(device)
    trg_kps = batch['trg_kps'].to(device)
    mask    = batch['valid_mask'].to(device)
    
    # Pre-calculate INDICES for Source and Target Keypoints
    # We need to know which patch corresponds to the nose/tail/wing
    src_indices = get_patch_indices(src_kps) # [B, N]
    trg_indices = get_patch_indices(trg_kps) # [B, N] (Labels)
    
    print("Starting Training Loop (100 Iterations)...")
    model.train()
    
    # D. The Loop
    for i in range(101):
        optimizer.zero_grad()
        
        # 1. Forward Pass
        dict_A = model.forward_features(src)
        dict_B = model.forward_features(trg)
        
        feat_A_all = dict_A['x_norm_patchtokens'] # [B, 1369, 768]
        feat_B_all = dict_B['x_norm_patchtokens'] # [B, 1369, 768]
        
        # 2. Extract specific features at Source Keypoints
        feat_A_kps = extract_features_at_indices(feat_A_all, src_indices)
        
        # 3. Calculate Loss
        # Query: Src Keypoints | Keys: All Trg Patches | Correct: Trg Keypoints
        loss = contrastive_loss(feat_A_kps, feat_B_all, trg_indices, mask)
        
        # 4. Backward
        loss.backward()
        optimizer.step()
        
        # 5. Monitor
        if i % 10 == 0:
            print(f"Iter {i:3d}: Loss = {loss.item():.5f}")
            
            # Sanity Check for "Accuracy":
            if loss.item() < 0.1:
                print(">>> Converged! Model has memorized the batch.")
                break

    print("\n--- Result Analysis ---")
    if loss.item() > 1.0:
        print("❌ FAILED. Loss is still high. Try unfreezing more blocks (Last 4) or check LR.")
    else:
        print("✅ PASSED. Model is capable of learning.")


--- Step 2: Overfit Small Sample ---
Loading Model on cuda...


Using cache found in C:\Users\nicol/.cache\torch\hub\facebookresearch_dinov2_main


Grabbing one batch...
Starting Training Loop (100 Iterations)...
Iter   0: Loss = 4.52095
Iter  10: Loss = 1.39504
Iter  20: Loss = 0.59009
Iter  30: Loss = 0.37727
Iter  40: Loss = 0.27629
Iter  50: Loss = 0.12469
Iter  60: Loss = 0.62037
Iter  70: Loss = 0.21654
Iter  80: Loss = 0.53351
Iter  90: Loss = 0.24069
Iter 100: Loss = 0.08809
>>> Converged! Model has memorized the batch.

--- Result Analysis ---
✅ PASSED. Model is capable of learning.
