In [1]:
import os
import torch
import torchvision
from torchvision import datasets
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor, Lambda
import matplotlib.pyplot as plt
import requests
from zipfile import ZipFile
from io import BytesIO
import numpy as np
import zipfile
import os


zip_file_path = r'C:\Users\nicol\Documents\PoliTo\AdvancedML\project\SPair-71k.zip' 
extract_dir = r'C:\Users\nicol\Documents\PoliTo\AdvancedML\project\SPair-71k_extracted'

# Crea la directory di estrazione se non esiste
os.makedirs(extract_dir, exist_ok=True)

# Estrai il file ZIP solo se esiste
if os.path.exists(zip_file_path):
    with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
        zip_ref.extractall(extract_dir)
    print(f"File '{zip_file_path}' estratto con successo nella directory '{extract_dir}'")
    print(f"Contenuti della directory '{extract_dir}':\n{os.listdir(extract_dir)}")
else:
    print(f"File zip '{zip_file_path}' non trovato. Assicurati che il dataset sia estratto in '{extract_dir}'.")



File 'C:\Users\nicol\Documents\PoliTo\AdvancedML\project\SPair-71k.zip' estratto con successo nella directory 'C:\Users\nicol\Documents\PoliTo\AdvancedML\project\SPair-71k_extracted'
Contenuti della directory 'C:\Users\nicol\Documents\PoliTo\AdvancedML\project\SPair-71k_extracted':
['SPair-71k']


In [2]:
import os
import glob
import json
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2

# --- 1. Define the Augmentation Pipeline ---
def get_transforms(mode='train', img_size=518):
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)

    if mode == 'train':
        return A.Compose([
            # Geometric Augmentations (Hard - Moves Keypoints)
            A.Resize(height=img_size, width=img_size), # Force DINOv2 size
            A.HorizontalFlip(p=0.5),
            A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, p=0.5),
            
            # Pixel Augmentations (Safe - Colors only)
            A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.5),
            A.GaussianBlur(p=0.1),
            
            # Normalization & Conversion
            A.Normalize(mean=mean, std=std),
            ToTensorV2(), # Converts to (C, H, W)
        ], keypoint_params=A.KeypointParams(format='xy', remove_invisible=False))
    
    else:
        # Validation/Test: Only Resize & Normalize
        return A.Compose([
            A.Resize(height=img_size, width=img_size),
            A.Normalize(mean=mean, std=std),
            ToTensorV2(),
        ], keypoint_params=A.KeypointParams(format='xy', remove_invisible=False))

# --- 2. Simple Image Reader ---
def read_img(path):
    # Keep as HWC (Standard for Albumentations)
    # Do not transpose or convert to Tensor here yet
    img = np.array(Image.open(path).convert('RGB'))
    return img

class SPairDataset(Dataset):
    def __init__(self, pair_ann_path, layout_path, image_path, dataset_size, pck_alpha, datatype):
        self.datatype = datatype
        self.pck_alpha = pck_alpha
        self.ann_files = open(os.path.join(layout_path, dataset_size, datatype + '.txt'), "r").read().split('\n')
        self.ann_files = [x for x in self.ann_files if x] # Remove empty strings
        self.pair_ann_path = pair_ann_path
        self.image_path = image_path
        
        # Initialize the Transform Pipeline
        mode = 'train' if datatype == 'trn' else 'test'
        self.transform = get_transforms(mode=mode, img_size=518)

    def __len__(self):
        return len(self.ann_files)

    def __getitem__(self, idx):
        raw_line = self.ann_files[idx]
        ann_filename = raw_line.replace(':', '_')
        ann_file = ann_filename + '.json'
        json_path = os.path.join(self.pair_ann_path, self.datatype, ann_file)

        with open(json_path) as f:
            annotation = json.load(f)

        category = annotation['category']
        src_path = os.path.join(self.image_path, category, annotation['src_imname'])
        trg_path = os.path.join(self.image_path, category, annotation['trg_imname'])

        # 1. Load Images
        src_img_raw = read_img(src_path)
        trg_img_raw = read_img(trg_path)

        # 2. Get Keypoints
        src_kps = np.array(annotation['src_kps']).astype(np.float32)
        trg_kps = np.array(annotation['trg_kps']).astype(np.float32)

        # 3. Apply Augmentations
        src_aug = self.transform(image=src_img_raw, keypoints=src_kps)
        src_tensor = src_aug['image']
        src_kps_aug = np.array(src_aug['keypoints'])
        
        trg_aug = self.transform(image=trg_img_raw, keypoints=trg_kps)
        trg_tensor = trg_aug['image']
        trg_kps_aug = np.array(trg_aug['keypoints'])

        # ==========================================================
        # ‚ö†Ô∏è CRITICAL FIX: PADDING LOGIC (Prevents Stack Error)
        # ==========================================================
        # We enforce a fixed size of 40 points per image.
        MAX_KPS = 40 
        
        # Create empty containers filled with zeros (Shape: [40, 2])
        src_kps_padded = np.zeros((MAX_KPS, 2), dtype=np.float32)
        trg_kps_padded = np.zeros((MAX_KPS, 2), dtype=np.float32)
        
        # Get the actual number of points (limit to 40 just in case)
        n_src = min(len(src_kps_aug), MAX_KPS)
        n_trg = min(len(trg_kps_aug), MAX_KPS)
        
        # Copy the real points into the empty container
        if n_src > 0:
            src_kps_padded[:n_src] = src_kps_aug[:n_src]
        if n_trg > 0:
            trg_kps_padded[:n_trg] = trg_kps_aug[:n_trg]

        # Check which points are inside the image (Visibility)
        src_vis = self._check_visibility(src_kps_padded, 518, 518)
        trg_vis = self._check_visibility(trg_kps_padded, 518, 518)
        
        # Create the Valid Mask
        # A point is valid ONLY if:
        # 1. It existed in the original file (index < n_src)
        # 2. It is still inside the image boundaries (vis=1)
        valid_mask = np.zeros(MAX_KPS, dtype=np.float32)
        
        # We assume points are ordered pairs (1st src matches 1st trg)
        # So we only mark as valid if BOTH exist and are visible
        common_len = min(n_src, n_trg)
        valid_mask[:common_len] = src_vis[:common_len] * trg_vis[:common_len]
        # ==========================================================

        pck_threshold = 518 * self.pck_alpha

        sample = {
            'pair_id': annotation['pair_id'],
            'src_img': src_tensor,
            'trg_img': trg_tensor,
            
            # Now these are ALWAYS [40, 2], so PyTorch won't crash!
            'src_kps': torch.from_numpy(src_kps_padded).float(), 
            'trg_kps': torch.from_numpy(trg_kps_padded).float(), 
            'valid_mask': torch.from_numpy(valid_mask).float(), 
            
            'pck_threshold': pck_threshold,
            'category': category
        }

        return sample

    def _check_visibility(self, kps, h, w):
        """Returns a binary mask (N,) where 1=visible, 0=out of bounds"""
        # kps is shape (N, 2) -> (x, y)
        x = kps[:, 0]
        y = kps[:, 1]
        
        # Check boundaries
        vis_x = (x >= 0) & (x < w)
        vis_y = (y >= 0) & (y < h)
        return (vis_x & vis_y).astype(np.float32)

if __name__ == '__main__':
    # Update this path to your actual path
    base_dir = r"C:\Users\nicol\Documents\PoliTo\AdvancedML\project\SPair-71k_extracted\SPair-71k\SPair-71k"    
    
    pair_ann_path = os.path.join(base_dir, 'PairAnnotation')
    layout_path = os.path.join(base_dir, 'Layout')
    image_path = os.path.join(base_dir, 'JPEGImages')

    # Check paths
    if os.path.exists(base_dir):
        
        # --- 1. Load TRAINING Set ---
        print("Loading Training Set...")
        trn_dataset = SPairDataset(
            pair_ann_path, layout_path, image_path, 
            dataset_size='large', pck_alpha=0.05, 
            datatype='trn'  # <--- Loads from trn.txt
        )
        # SHUFFLE = TRUE for training (important for learning)
        trn_loader = DataLoader(trn_dataset, batch_size=4, shuffle=True)


        # --- 2. Load VALIDATION Set ---
        print("Loading Validation Set...")
        val_dataset = SPairDataset(
            pair_ann_path, layout_path, image_path, 
            dataset_size='large', pck_alpha=0.05, 
            datatype='val'  # <--- Loads from val.txt
        )
        # SHUFFLE = FALSE for validation (keep order stable)
        val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

        
        # --- 3. Test Loading ---
        print("Testing batches...")
        
        # Grab a training batch
        trn_batch = next(iter(trn_loader))
        print(f"Train Batch Images: {trn_batch['src_img'].shape}")
        
        # Grab a validation batch
        val_batch = next(iter(val_loader))
        print(f"Val Batch Images:   {val_batch['src_img'].shape}")
        
        print("Dataset setup complete. Ready for Training Loop.")
        
    else:
        print(f"Path not found: {base_dir}")

Loading Training Set...
Loading Validation Set...
Testing batches...
Train Batch Images: torch.Size([4, 3, 518, 518])


  original_init(self, **validated_kwargs)


Val Batch Images:   torch.Size([4, 3, 518, 518])
Dataset setup complete. Ready for Training Loop.


In [None]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import numpy as np

# ==========================================
# 1. HELPER FUNCTIONS (Need these for Validation)
# ==========================================
def get_patch_indices(keypoints, img_size=518, patch_size=14):
    grid_w = img_size // patch_size
    grid_x = (keypoints[:, :, 0] / patch_size).long().clamp(0, grid_w-1)
    grid_y = (keypoints[:, :, 1] / patch_size).long().clamp(0, grid_w-1)
    flat_indices = grid_y * grid_w + grid_x 
    return flat_indices

def extract_features_at_indices(features, indices):
    B, N_patches, Dim = features.shape
    B, N_kps = indices.shape
    indices_expanded = indices.unsqueeze(-1).expand(-1, -1, Dim)
    kps_features = torch.gather(features, 1, indices_expanded)
    return kps_features

def contrastive_loss(feat_src_kps, feat_trg_all, trg_kps_indices, mask, temp=0.1):
    feat_src_kps = F.normalize(feat_src_kps, dim=-1)
    feat_trg_all = F.normalize(feat_trg_all, dim=-1)
    logits = torch.bmm(feat_src_kps, feat_trg_all.transpose(1, 2)) / temp
    valid = mask.bool()
    logits_valid = logits[valid]
    targets_valid = trg_kps_indices[valid]
    loss = F.cross_entropy(logits_valid, targets_valid)
    return loss

def validate_model(model, val_loader, device):
    model.eval()
    correct_points = 0
    total_points = 0
    
    with torch.no_grad():
        for i, batch in enumerate(val_loader):
            # Limit validation to 20 batches to be fast
            if i > 20: break 
            
            src = batch['src_img'].to(device)
            trg = batch['trg_img'].to(device)
            src_kps = batch['src_kps'].to(device)
            trg_kps = batch['trg_kps'].to(device)
            mask    = batch['valid_mask'].to(device)
            thresholds = batch['pck_threshold'].to(device)

            dict_A = model.forward_features(src)
            dict_B = model.forward_features(trg)
            feat_A_all = dict_A['x_norm_patchtokens']
            feat_B_all = dict_B['x_norm_patchtokens']
            
            src_indices = get_patch_indices(src_kps)
            feat_A_kps = extract_features_at_indices(feat_A_all, src_indices)
            
            sim = torch.bmm(F.normalize(feat_A_kps, dim=-1), 
                            F.normalize(feat_B_all, dim=-1).transpose(1, 2))
            
            best_match_idx = torch.argmax(sim, dim=2)
            
            grid_w = 37
            pred_y = best_match_idx // grid_w
            pred_x = best_match_idx % grid_w
            pred_x = pred_x * 14 + 7
            pred_y = pred_y * 14 + 7
            
            dist = torch.sqrt((pred_x - trg_kps[:, :, 0])**2 + (pred_y - trg_kps[:, :, 1])**2)
            thresh_expanded = thresholds.unsqueeze(1).expand(-1, 40)
            is_correct = (dist < thresh_expanded) & (mask.bool())
            
            correct_points += is_correct.sum().item()
            total_points += mask.sum().item()
            
    return correct_points / (total_points + 1e-6)

# ==========================================
# 2. Step 4: Coarse Grid Search
# ==========================================
if __name__ == '__main__':
    print("\n--- Step 4: Coarse Grid Search ---")
    
    # Grid Settings: Combining Best LRs with Prof's Weight Decays
    grid_configs = [
        # Set A: Using your Best LR (5e-5)
        {'lr': 5e-5, 'wd': 1e-4}, 
        {'lr': 5e-5, 'wd': 1e-5}, # <--- Added
        {'lr': 5e-5, 'wd': 0.0},  
        
        # Set B: Using the Safer LR (1e-5)
        {'lr': 1e-5, 'wd': 1e-4}, 
        {'lr': 1e-5, 'wd': 1e-5}, # <--- Added
        {'lr': 1e-5, 'wd': 0.0},  
    ]
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Ensure Loaders exist
    try:
        _ = len(trn_loader)
        _ = len(val_loader)
    except NameError:
        print("Error: Loaders not defined. Run Dataset setup first.")
        exit()

    best_acc = 0
    best_config = None

    for config in grid_configs:
        lr = config['lr']
        wd = config['wd']
        print(f"\n>>> Training Config: LR={lr}, WD={wd}")
        
        # 1. Reset Model
        model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14').to(device)
        for param in model.parameters(): param.requires_grad = False
        for block in model.blocks[-2:]: 
            for param in block.parameters(): param.requires_grad = True
            
        optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), 
                                lr=lr, weight_decay=wd)
        
        # 2. Train (Short run: 50 batches)
        model.train()
        for i, batch in enumerate(trn_loader):
            if i >= 50: break # Stop early to save time
            
            src = batch['src_img'].to(device)
            trg = batch['trg_img'].to(device)
            src_kps = batch['src_kps'].to(device)
            trg_kps = batch['trg_kps'].to(device)
            mask    = batch['valid_mask'].to(device)
            
            src_indices = get_patch_indices(src_kps)
            trg_indices = get_patch_indices(trg_kps)

            optimizer.zero_grad()
            dict_A = model.forward_features(src)
            dict_B = model.forward_features(trg)
            feat_A_kps = extract_features_at_indices(dict_A['x_norm_patchtokens'], src_indices)
            feat_B_all = dict_B['x_norm_patchtokens']
            
            loss = contrastive_loss(feat_A_kps, feat_B_all, trg_indices, mask)
            loss.backward()
            optimizer.step()
            
            if i % 25 == 0: print(f"  Iter {i}: Loss {loss.item():.4f}")

        # 3. Validate
        print("  Validating...")
        acc = validate_model(model, val_loader, device)
        print(f"  >> Validation PCK: {acc*100:.2f}%")
        
        if acc > best_acc:
            best_acc = acc
            best_config = config

    print("\n===========================================")
    print(f"üèÜ Best Config: LR={best_config['lr']}, WD={best_config['wd']}")
    print(f"üèÜ Best Validation Accuracy: {best_acc*100:.2f}%")
    print("===========================================")
    print("Now perform Step 5 (Long Training) using these exact values!")


--- Step 4: Coarse Grid Search (Including Prof's Suggestions) ---

>>> Training Config: LR=5e-05, WD=0.0001


Using cache found in C:\Users\nicol/.cache\torch\hub\facebookresearch_dinov2_main


  Iter 0: Loss 3.5205
  Iter 25: Loss 2.7998
  Validating...
  >> Validation PCK: 59.30%

>>> Training Config: LR=5e-05, WD=1e-05


Using cache found in C:\Users\nicol/.cache\torch\hub\facebookresearch_dinov2_main


  Iter 0: Loss 3.7596
  Iter 25: Loss 3.2685
  Validating...
  >> Validation PCK: 63.87%

>>> Training Config: LR=5e-05, WD=0.0


Using cache found in C:\Users\nicol/.cache\torch\hub\facebookresearch_dinov2_main


  Iter 0: Loss 4.2249
  Iter 25: Loss 3.6745
  Validating...
  >> Validation PCK: 63.26%

>>> Training Config: LR=1e-05, WD=0.0001


Using cache found in C:\Users\nicol/.cache\torch\hub\facebookresearch_dinov2_main


  Iter 0: Loss 4.9820
  Iter 25: Loss 2.9578
  Validating...
  >> Validation PCK: 57.01%

>>> Training Config: LR=1e-05, WD=1e-05


Using cache found in C:\Users\nicol/.cache\torch\hub\facebookresearch_dinov2_main


  Iter 0: Loss 3.3437
  Iter 25: Loss 3.8029
  Validating...
  >> Validation PCK: 56.71%

>>> Training Config: LR=1e-05, WD=0.0


Using cache found in C:\Users\nicol/.cache\torch\hub\facebookresearch_dinov2_main


  Iter 0: Loss 4.4358
  Iter 25: Loss 3.6930
  Validating...
  >> Validation PCK: 55.95%

üèÜ Best Config: LR=5e-05, WD=1e-05
üèÜ Best Validation Accuracy: 63.87%
Now perform Step 5 (Long Training) using these exact values!
