In [None]:
# ============================================================================
# CELLA 1: Setup directories and imports
# ============================================================================
from __future__ import annotations
import sys, os, time, torch
import logging
from pathlib import Path
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm
from dataclasses import dataclass
from typing import Optional, Literal
from torch.utils.tensorboard import SummaryWriter
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import LinearLR, SequentialLR, CosineAnnealingLR

PROJECT_ROOT = str(Path.home() / "AML")
LOCAL_REPO_NAME = str(Path.home() / "AML_SemanticCorrespondence")
DATA_DIR = f'{PROJECT_ROOT}/dataset' 
CHECKPOINT_DIR = f'{PROJECT_ROOT}/checkpoints'
LOG_DIR = f'{PROJECT_ROOT}/logs'
SPAIR_ROOT = f'{DATA_DIR}/Spair-71k'

# Add local repository to path for module imports
sys.path.insert(0, LOCAL_REPO_NAME)

# Create required directories if they do not exist
os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)

# Import local modules
from models import CorrespondenceMatcher, UnifiedEvaluator, FinetunableBackbone, CorrespondenceLoss
from dataset import SPairDataset

# Performance optimizations for faster training
try:
    torch.set_float32_matmul_precision('high')
except Exception:
    pass
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True

# Check GPU availability
print(f" CUDA available: {torch.cuda.is_available()}")
print(f"\n  GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'No GPU'}")
if torch.cuda.is_available():
    print(f"   VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
device = "cuda" if torch.cuda.is_available() else "cpu"
print("\n Setup complete!\n")

In [None]:
# ============================================================================
# CELLA 2: Install Dependencies
# ============================================================================

print(" Installing dependencies...\n")
!pip install -q -r {LOCAL_REPO_NAME}/requirements.txt
print("\n Dependencies installed!\n")

In [None]:
# ============================================================================
# CELLA 3: Fine-tuning Configuration
# ============================================================================

@dataclass
class FinetuneConfig:
    # -------------------------
    # Model
    # -------------------------
    backbone_name: str = 'sam_vit_b'        # 'dinov2_vitb14', 'dinov3_vitb16', 'sam_vit_b'
    num_layers_to_finetune: int = 2         # last transformer blocks to unfreeze

    # -------------------------
    # Training
    # -------------------------
    num_epochs: int = 5                     # total epochs to train
    learning_rate: float = 5e-6             # base LR
    weight_decay: float = 1e-2              # weight decay
    warmup_epochs: int = 1                  # number of warmup epochs in which LR increases linearly to base LR 
    grad_clip_norm: float = 1.0             # max norm for gradient clipping (set to 0 to disable)
    tqdm_update_interval: int = 100         # update progress bar every N batches

    # Memory optimizations
    use_amp: bool = True                        # automatic mixed precision
    gradient_accumulation_steps: int = 8        # number of batches to accumulate before stepping optimizer
    use_gradient_checkpointing: bool = False    # save memory by recomputing activations during backward pass

    # DataLoader
    batch_size: int = 1                         # per-GPU batch size (keep 1)
    num_workers: int = 0                        # DataLoader workers (keep 0)
    max_train_pairs: Optional[int] = None       # set None for full training set (53340 pairs)
    max_val_pairs: int = 1000                   # total of 5384 val pairs

    # Loss
    loss_type: str = 'contrastive'              # 'cosine' | 'l2' | 'contrastive'
    temperature: float = 0.07                   # temp for contrastive loss to sharpen similarities

    # -------------------------
    # Evaluation modes
    # -------------------------
    val_mode: Literal["slice", "full", "hybrid"] = "hybrid" # validation mode
    val_slice_max_batches: int = 50                         # used in slice/hybrid
    val_full_max_batches: Optional[int] = None              # None = full validation
    val_full_every: int = 1                                 # run full val every N epochs (hybrid/full)

    # -------------------------
    # Logging / checkpoints
    # -------------------------
    use_tensorboard: bool = True                            # TensorBoard logging
    tb_logdir: str = "runs/task2_finetune"                  # TensorBoard log directory

    use_wandb: bool = True                                          # Weights & Biases logging
    wandb_entity: str = "luffy1"                                    # your W&B username or team name
    wandb_project: str = "AML-project-semantic-correspondence"      # W&B project name

    log_checkpoints_per_epoch: int = 10                     # intra-epoch logs 
    run_name: str = ""                                      # auto if empty

    # Training mode
    training_mode: Literal["fresh", "resume", "continue"] = "fresh"     # 'fresh' | 'resume' | 'continue'
    resume_checkpoint: str = ""                                         # used if training_mode == 'resume'

config = FinetuneConfig()

if not config.run_name:
    config.run_name = f"{config.backbone_name}_L{config.num_layers_to_finetune}_ls512_lr{config.learning_rate}"


IMG_SIZE = 518 if config.backbone_name.startswith("dinov2") else 512

In [None]:
# ============================================================================
# CELL 4: Initialize Weights & Biases logging
# ============================================================================
import wandb
if config.use_wandb:
    # Attempt to authenticate with W&B using cached credentials
    try:
        wandb.login(relogin=False)
    except Exception as e:
        print(f"W&B login required. Run 'wandb login' in terminal if needed.")
    
    wandb.init(
        entity=config.wandb_entity,
        project=config.wandb_project,
        config={
            "learning_rate": config.learning_rate,
            "epochs": config.num_epochs,
            "batch_size": config.batch_size,
            "num_layers_to_unfreeze": config.num_layers_to_finetune,
            "model_to_finetune": ("sam" if config.backbone_name.startswith("sam") else ("dinov2" if config.backbone_name.startswith("dinov2") else "dinov3")),
            "temperature": config.temperature,
            "loss_type": config.loss_type,
            "use_amp": config.use_amp,
            "gradient_accumulation_steps": config.gradient_accumulation_steps,
            "use_gradient_checkpointing": config.use_gradient_checkpointing,
            "val_mode": config.val_mode,
        }
    )
    print("W&B run started.")
else:
    print("W&B disabled by config.use_wandb=False")

In [None]:
# ============================================================================
# CELL 5: Install SAM dependency if needed
# ============================================================================
if 'sam' in config.backbone_name:
    print(f" Installing SAM dependency for {config.backbone_name}...")
    !pip install -q git+https://github.com/facebookresearch/segment-anything.git
    print(" SAM dependency installed.\n")
else:
    print(" SAM not selected â€” skipping.\n")


In [None]:
# ============================================================================
# CELL 6: DataLoaders (train/val)
# ============================================================================

def maybe_subset(ds, max_pairs):                        # subset dataset to max_pairs if not None
    if max_pairs is None:
        return ds
    idx = torch.randperm(len(ds))[:max_pairs].tolist()
    return torch.utils.data.Subset(ds, idx)

train_dataset_full = SPairDataset(
    root=SPAIR_ROOT, split='train', size='large', long_side=IMG_SIZE,
    normalize=True, load_segmentation=False
)
val_dataset_full = SPairDataset(
    root=SPAIR_ROOT, split='val', size='large', long_side=IMG_SIZE,
    normalize=True, load_segmentation=False
)

train_ds = maybe_subset(train_dataset_full, config.max_train_pairs)
val_ds   = maybe_subset(val_dataset_full, config.max_val_pairs)

train_loader = DataLoader(
    train_ds, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers,
    pin_memory=torch.cuda.is_available(), drop_last=False
)
val_loader = DataLoader(
    val_ds, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers,
    pin_memory=torch.cuda.is_available(), drop_last=False
)
print(f" Train pairs: {len(train_ds)} | batches: {len(train_loader)}")
print(f"   Val pairs: {len(val_ds)} | batches: {len(val_loader)}")


In [None]:
# ============================================================================
# CELL 7: Adapter for FinetunableBackbone compatibility with CorrespondenceMatcher
# ============================================================================

class FinetunableBackboneAdapter:
    """Adapter to use FinetunableBackbone with CorrespondenceMatcher interface."""
    def __init__(self, finetunable_backbone):
        self.model = finetunable_backbone
        self.device = finetunable_backbone.device
        self.config = type('obj', (object,), {'patch_size': finetunable_backbone.stride})()
        self.training = True

    def eval(self):
        self.model.eval()
        self.training = False
        return self
    
    def train(self):
        self.model.train()
        self.training = True
        return self
    
    @torch.no_grad()
    def extract_features(self, img):
        """Convert batch features (B, H, W, D) -> single image (H, W, D) for matcher."""
        feat = self.model(img)  # (B, H, W, D)
        return feat.squeeze(0),  # (H, W, D) returned as tuple for compatibility

In [None]:
# ============================================================================
# CELL 8: Training Loop (AMP + Warmup + Grad Clip + TB + Resume + Val modes)
# ============================================================================

# Evaluator setup
evaluator = UnifiedEvaluator(
    dataloader=val_loader,
    device=device,
    thresholds=[0.05, 0.10, 0.15, 0.20]
)

# Model setup
model = FinetunableBackbone(
    backbone_name=config.backbone_name,
    num_layers_to_finetune=config.num_layers_to_finetune,
    device=device,
    use_gradient_checkpointing=config.use_gradient_checkpointing,
).to(device)

# Loss setup
criterion = CorrespondenceLoss(loss_type=config.loss_type, temperature=config.temperature).to(device)
print(" Loss:", config.loss_type)

# Optimizer: AdamW on trainable params only
trainable_params = [p for p in model.parameters() if p.requires_grad]
if len(trainable_params) == 0:
    raise RuntimeError("No trainable parameters found. Check num_layers_to_finetune > 0.")

optimizer = optim.AdamW(trainable_params, lr=config.learning_rate, weight_decay=config.weight_decay)

# Scheduler: warmup -> cosine (step-based)
total_updates = max(1, (len(train_loader) * config.num_epochs) // config.gradient_accumulation_steps)
warmup_updates = max(1, (len(train_loader) * config.warmup_epochs) // config.gradient_accumulation_steps)

warmup = LinearLR(optimizer, start_factor=0.1, total_iters=warmup_updates)
cosine = CosineAnnealingLR(optimizer, T_max=max(1, total_updates - warmup_updates))
scheduler = SequentialLR(optimizer, schedulers=[warmup, cosine], milestones=[warmup_updates])

# AMP
use_amp = bool(config.use_amp and torch.cuda.is_available())
scaler = GradScaler(enabled=use_amp)

print(f" AMP: {use_amp} | Grad accumulation: {config.gradient_accumulation_steps} | Grad checkpointing: {config.use_gradient_checkpointing}")

# TensorBoard
tb_writer = None
if config.use_tensorboard:
    logdir = Path(config.tb_logdir) / config.run_name
    logdir.mkdir(parents=True, exist_ok=True)
    tb_writer = SummaryWriter(log_dir=str(logdir))
    print(f" TensorBoard logdir: {logdir}")

# File Logging setup
log_dir_path = Path(LOG_DIR)
log_dir_path.mkdir(parents=True, exist_ok=True)
log_filename = log_dir_path / f"train_{config.run_name}_{time.strftime('%Y%m%d_%H%M%S')}.log"
logger = logging.getLogger(f"finetune_{config.run_name}")
logger.setLevel(logging.INFO)
if not logger.handlers:
    fh = logging.FileHandler(str(log_filename))
    fh.setLevel(logging.INFO)
    formatter = logging.Formatter("%(asctime)s | %(levelname)s | %(message)s")
    fh.setFormatter(formatter)
    logger.addHandler(fh)
    logger.propagate = False
logger.info(f"Starting training: run_name={config.run_name} backbone={config.backbone_name} epochs={config.num_epochs} lr={config.learning_rate}")

# Checkpointing (complete state for resume)
best_pck = -1.0
start_epoch = 0
global_step = 0
best_path = Path(CHECKPOINT_DIR) / f"finetuned_{config.run_name}.pt"
best_path.parent.mkdir(parents=True, exist_ok=True)

def save_ckpt(path: Path, epoch: int, best_pck_val: float):
    torch.save({
        "epoch": epoch,
        "global_step": global_step,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict(),
        "scaler_state_dict": scaler.state_dict() if use_amp else None,
        "best_pck@0.10": best_pck_val,
        "config": config.__dict__,
    }, str(path))

if config.training_mode == "resume":
    ckpt_path = config.resume_checkpoint or str(best_path)
    if ckpt_path and os.path.exists(ckpt_path):
        print(f"\n Resuming from checkpoint: {ckpt_path}")
        logger.info(f"Resuming from checkpoint: {ckpt_path}")
        ckpt = torch.load(ckpt_path, map_location="cpu")
        model.load_state_dict(ckpt["model_state_dict"], strict=True)
        optimizer.load_state_dict(ckpt["optimizer_state_dict"])
        scheduler.load_state_dict(ckpt["scheduler_state_dict"])
        if use_amp and ckpt.get("scaler_state_dict") is not None:
            scaler.load_state_dict(ckpt["scaler_state_dict"])
        start_epoch = int(ckpt.get("epoch", 0)) + 1
        global_step = int(ckpt.get("global_step", 0))
        best_pck = float(ckpt.get("best_pck@0.10", best_pck))
        print(f"   start_epoch={start_epoch} | global_step={global_step} | best_pck@0.10={best_pck:.2f}")
        logger.info(f"Resume stats: start_epoch={start_epoch} global_step={global_step} best_pck@0.10={best_pck:.2f}")
    else:
        print(f" resume requested but checkpoint not found: {ckpt_path}")
        print("   Continuing from current weights in memory.")
        logger.warning(f"Resume requested but checkpoint not found: {ckpt_path}. Continuing.")
elif config.training_mode == "fresh":
    print("\n Training mode: fresh (starting from pretrained weights loaded by backbone extractor)")
    logger.info("Training mode: fresh")
else:
    print("\n Training mode: continue (keeping current weights in memory)")
    logger.info("Training mode: continue")

# Intra-epoch log checkpoints
def get_log_steps(n_batches: int, k: int):
    if k <= 0:
        return set()
    if n_batches <= 1:
        return {0}
    steps = set(int(round(x)) for x in np.linspace(0, n_batches - 1, k))
    return steps

log_steps = get_log_steps(len(train_loader), config.log_checkpoints_per_epoch)


# Validation function using UnifiedEvaluator
@torch.no_grad()
def run_validation_unified(model_wrapper, num_samples=None):
    """Run validation using UnifiedEvaluator (consistent with eval.ipynb)"""
    matcher = CorrespondenceMatcher(model_wrapper, use_soft_argmax=False)
    
    results = evaluator.evaluate(
        matcher=matcher,
        backbone_name=f"finetuned_{config.backbone_name}",
        num_samples=num_samples,
        show_progress=False
    )
    
    # Extract PCK@0.10 and loss
    pck_010 = results['overall_keypoint']['PCK@0.10']['mean'] * 100.0
    
    return pck_010, results

# One epoch training
def train_one_epoch(epoch: int):
    global global_step
    model.train()
    epoch_losses = []

    optimizer.zero_grad(set_to_none=True)
    pbar = tqdm(total=len(train_loader), desc=f"Train {epoch+1}/{config.num_epochs}", leave=False)

    accumulated = 0

    for step, batch in enumerate(train_loader):
        src = batch["src_img"].to(device, non_blocking=True)
        tgt = batch["tgt_img"].to(device, non_blocking=True)
        src_kps = batch["src_kps"].to(device, non_blocking=True)
        tgt_kps = batch["tgt_kps"].to(device, non_blocking=True)
        vm = batch["valid_mask"].to(device, non_blocking=True)

        with autocast(enabled=use_amp):
            src_feat = model(src)
            tgt_feat = model(tgt)
            loss = criterion(src_feat, tgt_feat, src_kps, tgt_kps, vm, patch_size=model.stride)
            loss = loss / config.gradient_accumulation_steps

        scaler.scale(loss).backward()
        accumulated += 1

        do_update = (accumulated == config.gradient_accumulation_steps) or (step + 1 == len(train_loader))

        if do_update:
            if config.grad_clip_norm is not None and config.grad_clip_norm > 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip_norm)

            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)

            scheduler.step()
            accumulated = 0
            global_step += 1

        epoch_losses.append(float(loss.item() * config.gradient_accumulation_steps))

        should_refresh = ((step + 1) % config.tqdm_update_interval == 0) or ((step + 1) == len(train_loader))
        if should_refresh:
            pbar.update((step + 1) - pbar.n)
            try:
                lr_now = scheduler.get_last_lr()[0]
            except Exception:
                lr_now = optimizer.param_groups[0].get('lr', config.learning_rate)
            pbar.set_postfix({"loss": f"{np.mean(epoch_losses):.4f}", "lr": f"{lr_now:.2e}"})
            logger.info(f"epoch={epoch+1} step={step+1}/{len(train_loader)} loss={np.mean(epoch_losses):.4f} lr={lr_now:.2e} global_step={global_step}")


        if config.use_wandb and (step in log_steps) and (step > 0):
            avg_train_loss = float(np.mean(epoch_losses)) if len(epoch_losses) else 0.0
            
            t_eval0 = time.time()
            model_wrapper = FinetunableBackboneAdapter(model)
            score, _ = run_validation_unified(model_wrapper, num_samples=10) 
            t_eval = time.time() - t_eval0

            pbar.write(
                f"[Intra-epoch eval] epoch={epoch+1} step={step+1}/{len(train_loader)} | "
                f"PCK@0.10={score:.2f} | {t_eval:.1f}s"
            )
            
            wandb.log({
                "training loss": avg_train_loss,
                "validation pck": score,
                "epoch": epoch + 1,
                "batch": step + 1,
                "learning_rate": scheduler.get_last_lr()[0]
            })
            
            model.train()  # ensure model is back in train mode after eval

    pbar.update(len(train_loader) - pbar.n)
    pbar.close()
    return float(np.mean(epoch_losses)) if len(epoch_losses) else 0.0

print("\n Starting fine-tuning...\n")
logger.info("Starting fine-tuning loop")

for epoch in range(start_epoch, config.num_epochs):
    t0 = time.time()
    train_loss = train_one_epoch(epoch)
    dt = time.time() - t0

    model_wrapper = FinetunableBackboneAdapter(model)
    
    if config.val_mode == "slice":
        num_samples = config.val_slice_max_batches
    elif config.val_mode == "full":
        num_samples = config.val_full_max_batches if (epoch + 1) % config.val_full_every == 0 else None
    else:  # hybrid
        num_samples = config.val_slice_max_batches if (epoch + 1) % config.val_full_every != 0 else config.val_full_max_batches

    score, val_results = run_validation_unified(model_wrapper, num_samples=num_samples)

    print(f"\n Epoch {epoch+1}/{config.num_epochs} | train_loss={train_loss:.4f} | PCK@0.10={score:.2f} | {dt:.1f}s")
    logger.info(f"Epoch summary: epoch={epoch+1} train_loss={train_loss:.4f} PCK@0.10={score:.2f} time={dt:.1f}s")

    # TensorBoard
    if tb_writer is not None:
        tb_writer.add_scalar("train/loss", train_loss, epoch)
        tb_writer.add_scalar("val/pck@0.10", score, epoch)

    # wandb end-of-epoch logging
    if config.use_wandb:
        wandb.log({
            "training loss": train_loss,
            "validation pck": score,
            "epoch": epoch + 1,
            "learning_rate": scheduler.get_last_lr()[0]
        })

    # Save best checkpoint
    if score > best_pck:
        best_pck = score
        save_ckpt(best_path, epoch, best_pck)
        print(f"Saved best checkpoint: {best_path} (PCK@0.10={best_pck:.2f})")
        logger.info(f"Saved best checkpoint: path={best_path} PCK@0.10={best_pck:.2f}")

print("\nFine-tuning complete.")
print(f"Best PCK@0.10: {best_pck:.2f}")
logger.info(f"Fine-tuning complete. Best PCK@0.10={best_pck:.2f}")

if tb_writer is not None:
    tb_writer.close()

# Finish W&B run
if config.use_wandb:
    wandb.finish()