 AML Semantic Correspondence


  GPU: Tesla V100-SXM2-16GB
   VRAM: 16.9 GB

 Setup complete!



In [5]:
# ============================================================================
# CELLA 2: Install Dependencies
# ============================================================================

print(" Installing dependencies...\n")

!pip install -q -r {LOCAL_REPO_NAME}/requirements.txt
!pip install -q tensorboard
!pip install -q wandb

# Performance settings (safe defaults for Colab)
try:
    torch.set_float32_matmul_precision('high')
except Exception:
    pass
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
print(f"\n PyTorch {torch.__version__}")
print(f" CUDA available: {torch.cuda.is_available()}")
print("\n Dependencies installed!\n")

 Installing dependencies...


 PyTorch 2.10.0+cu128
 CUDA available: True

 Dependencies installed!



In [None]:
# ============================================================================
# CELLA 2: Fine-tuning Configuration
# ============================================================================

@dataclass
class FinetuneConfig:
    # -------------------------
    # Model
    # -------------------------
    backbone_name: str = 'sam_vit_b'   # 'dinov2_vitb14', 'dinov3_vitb16', 'sam_vit_b/l/h'
    num_layers_to_finetune: int = 2       # last transformer blocks to unfreeze

    # -------------------------
    # Training
    # -------------------------
    num_epochs: int = 3
    learning_rate: float = 5e-6
    weight_decay: float = 1e-2
    warmup_epochs: int = 1
    grad_clip_norm: float = 1.0

    # Memory optimizations
    use_amp: bool = True
    gradient_accumulation_steps: int = 4
    use_gradient_checkpointing: bool = False   # implemented with best-effort (see model wrapper)

    # DataLoader
    batch_size: int = 1
    num_workers: int = 0                       # keep 0 if dataset is on Drive
    max_train_pairs: Optional[int] = None     # set None for full train
    max_val_pairs: int = 1000                   # keep validation fast

    # Loss
    loss_type: str = 'contrastive'             # 'cosine' | 'l2' | 'contrastive'
    temperature: float = 0.07

    # -------------------------
    # Evaluation modes
    # -------------------------
    val_mode: Literal["slice", "full", "hybrid"] = "hybrid"
    val_slice_max_batches: int = 25            # used in slice/hybrid
    val_full_max_batches: Optional[int] = None # None = full validation
    val_full_every: int = 1                    # run full val every N epochs (hybrid/full)

    # -------------------------
    # Logging / checkpoints
    # -------------------------
    use_tensorboard: bool = True
    tb_logdir: str = "runs/task2_finetune"

    use_wandb: bool = True
    wandb_entity: str = "luffy1"
    wandb_project: str = "AML-project-semantic-correspondence"

    log_checkpoints_per_epoch: int = 4         # intra-epoch logs 
    run_name: str = ""                         # auto if empty

    # Training mode
    training_mode: Literal["fresh", "resume", "continue"] = "fresh"
    resume_checkpoint: str = ""                # used if training_mode == 'resume'

config = FinetuneConfig()

if not config.run_name:
    config.run_name = f"{config.backbone_name}_L{config.num_layers_to_finetune}_ls512_lr{config.learning_rate}"

# Align image size with main.ipynb
def get_img_size(backbone_name: str) -> int:
    if backbone_name.startswith("dinov2"):
        return 518
    return 512

IMG_SIZE = get_img_size(config.backbone_name if not config.backbone_name.startswith("sam") else "sam")

 Fine-tuning config:

  - backbone_name: sam_vit_b
  - num_layers_to_finetune: 2
  - num_epochs: 3
  - learning_rate: 5e-06
  - weight_decay: 0.01
  - warmup_epochs: 1
  - grad_clip_norm: 1.0
  - use_amp: True
  - gradient_accumulation_steps: 4
  - use_gradient_checkpointing: False
  - batch_size: 1
  - num_workers: 0
  - max_train_pairs: 20000
  - max_val_pairs: 500
  - max_test_pairs: None
  - loss_type: contrastive
  - temperature: 0.07
  - val_mode: hybrid
  - val_slice_max_batches: 25
  - val_full_max_batches: None
  - val_full_every: 1
  - do_test_eval: True
  - use_tensorboard: True
  - tb_logdir: runs/task2_finetune
  - use_wandb: False
  - wandb_entity: luffy1
  - wandb_project: AML-project-semantic-correspondence
  - log_checkpoints_per_epoch: 4
  - run_name: sam_vit_b_L2_ls512_lr5e-06
  - training_mode: fresh
  - resume_checkpoint: 

 Image resize target (square padding): 512x512


In [None]:
# ============================================================================
# CELLA 3: wandb init 
# ============================================================================

if config.use_wandb:
    wandb.init(
        entity=config.wandb_entity,
        project=config.wandb_project,
        config={
            "learning_rate": config.learning_rate,
            "epochs": config.num_epochs,
            "batch_size": config.batch_size,
            "num_layers_to_unfreeze": config.num_layers_to_finetune,
            "model_to_finetune": ("sam" if config.backbone_name.startswith("sam") else ("dinov2" if config.backbone_name.startswith("dinov2") else "dinov3")),
            "temperature": config.temperature,
            "loss_type": config.loss_type,
            "use_amp": config.use_amp,
            "gradient_accumulation_steps": config.gradient_accumulation_steps,
            "use_gradient_checkpointing": config.use_gradient_checkpointing,
            "val_mode": config.val_mode,
        }
    )
    print("wandb run started.")
else:
    print("wandb disabled by config.use_wandb=False")

wandb disabled by config.use_wandb=False


In [None]:

# ============================================================================
# CELLA 4: Install SAM dependency
# ============================================================================
if 'sam' in config.backbone_name:
    print(f" Installing SAM dependency for {config.backbone_name}...")
    !pip install -q git+https://github.com/facebookresearch/segment-anything.git
    print(" SAM dependency installed.\n")
else:
    print(" SAM not selected ‚Äî skipping.\n")


 Installing SAM dependency for sam_vit_b...
 SAM dependency installed.



In [None]:
# ============================================================================
# CELLA 5: Loss functions for semantic correspondence fine-tuning
# ============================================================================

class CorrespondenceLoss(nn.Module):
    """Loss function for semantic correspondence fine-tuning.

    Supports multiple loss types:
    - 'cosine': Maximize cosine similarity between corresponding features
    - 'l2': Minimize L2 distance between corresponding features
    - 'contrastive': InfoNCE-style contrastive loss with negative sampling

    Args:
        loss_type: Type of loss ('cosine', 'l2', 'contrastive')
        negative_margin: Margin for negative samples
        temperature: Temperature for contrastive loss
    """

    def __init__(
        self,
        loss_type: str = 'cosine',
        negative_margin: float = 0.2,
        temperature: float = 0.1
    ):
        super().__init__()
        self.loss_type = loss_type
        self.negative_margin = negative_margin
        self.temperature = temperature

    def forward(
        self,
        src_features: torch.Tensor,
        tgt_features: torch.Tensor,
        src_kps: torch.Tensor,
        tgt_kps: torch.Tensor,
        valid_mask: torch.Tensor,
        patch_size: int
    ) -> torch.Tensor:
        """Compute correspondence loss.

        Args:
            src_features: (B, H_s, W_s, D)
            tgt_features: (B, H_t, W_t, D)
            src_kps: (B, N, 2) source keypoints in pixel coords
            tgt_kps: (B, N, 2) target keypoints in pixel coords
            valid_mask: (B, N) valid keypoint mask
            patch_size: Patch size of the backbone

        Returns:
            loss: Scalar tensor
        """
        B = src_features.shape[0]
        total_loss = 0.0
        n_valid = 0

        for b in range(B):
            src_feat = src_features[b]  # (H_s, W_s, D)
            tgt_feat = tgt_features[b]  # (H_t, W_t, D)
            src_kp = src_kps[b]  # (N, 2)
            tgt_kp = tgt_kps[b]  # (N, 2)
            valid = valid_mask[b]  # (N,)

            if valid.sum() == 0:
                continue

            # Filter valid keypoints
            src_kp_valid = src_kp[valid]
            tgt_kp_valid = tgt_kp[valid]

            # Convert to patch coordinates
            src_kp_patch = (src_kp_valid / patch_size).long()
            tgt_kp_patch = (tgt_kp_valid / patch_size).long()

            H_s, W_s, D = src_feat.shape
            H_t, W_t, _ = tgt_feat.shape

            # Clamp coordinates
            src_kp_patch[:, 0] = src_kp_patch[:, 0].clamp(0, W_s - 1)
            src_kp_patch[:, 1] = src_kp_patch[:, 1].clamp(0, H_s - 1)
            tgt_kp_patch[:, 0] = tgt_kp_patch[:, 0].clamp(0, W_t - 1)
            tgt_kp_patch[:, 1] = tgt_kp_patch[:, 1].clamp(0, H_t - 1)

            # Extract features at keypoint locations
            N = src_kp_valid.shape[0]
            for i in range(N):
                src_x, src_y = src_kp_patch[i]
                tgt_x, tgt_y = tgt_kp_patch[i]

                src_vec = src_feat[src_y, src_x]  # (D,)
                tgt_vec = tgt_feat[tgt_y, tgt_x]  # (D,)

                if self.loss_type == 'cosine':
                    # Maximize cosine similarity
                    similarity = F.cosine_similarity(
                        src_vec.unsqueeze(0), tgt_vec.unsqueeze(0), dim=1
                    )
                    loss = 1.0 - similarity

                elif self.loss_type == 'l2':
                    # Minimize L2 distance
                    loss = F.mse_loss(src_vec, tgt_vec)

                elif self.loss_type == 'contrastive':
                    # Positive: corresponding point
                    pos_sim = F.cosine_similarity(
                        src_vec.unsqueeze(0), tgt_vec.unsqueeze(0), dim=1
                    )

                    # Negatives: sample random points from target
                    neg_indices = torch.randint(
                        0, H_t * W_t, (8,), device=src_vec.device
                    )
                    tgt_flat = tgt_feat.reshape(-1, D)
                    neg_vecs = tgt_flat[neg_indices]  # (8, D)

                    neg_sim = F.cosine_similarity(
                        src_vec.unsqueeze(0).expand(8, -1),
                        neg_vecs,
                        dim=1
                    )

                    # InfoNCE-style loss
                    pos_exp = torch.exp(pos_sim / self.temperature)
                    neg_exp = torch.exp(neg_sim / self.temperature).sum()

                    loss = -torch.log(pos_exp / (pos_exp + neg_exp))

                else:
                    raise ValueError(f"Unknown loss type: {self.loss_type}")

                total_loss += loss
                n_valid += 1

        if n_valid == 0:
            return torch.tensor(0.0, device=src_features.device, requires_grad=True)

        return total_loss / n_valid

In [None]:
# ============================================================================
# CELLA 6: Trainable Backbone Wrapper (DINOv2 / DINOv3 / SAM) + Loss
# ============================================================================


class FinetunableBackbone(nn.Module):
    """
    Unified wrapper:
    - Loads extractor (DINOv2 / DINOv3 / SAM)
    - Freezes all params
    - Unfreezes last N transformer blocks
    - Optionally enables gradient checkpointing (best-effort, depending on backbone impl)
    - Returns features as (B, H, W, D)
    """
    def __init__(
        self,
        backbone_name: str,
        num_layers_to_finetune: int,
        device: str,
        use_gradient_checkpointing: bool = False,
    ):
        super().__init__()
        self.backbone_name = backbone_name
        self.num_layers_to_finetune = num_layers_to_finetune
        self.device = device

        # -------------------------
        # Load extractor
        # -------------------------
        if backbone_name.startswith("dinov2"):
            self.extractor = DINOv2Extractor(variant=backbone_name, device=device)
            self._model_for_unfreeze = self.extractor.model
        elif backbone_name.startswith("dinov3"):
            self.extractor = DINOv3Extractor(variant=backbone_name, device=device)
            self._model_for_unfreeze = self.extractor.model
        elif backbone_name.startswith("sam"):
            variant = backbone_name.replace("sam_", "")  # vit_b / vit_l / vit_h
            self.extractor = SAMImageEncoder(variant=variant, device=device, allow_hub_download=True)
            self._model_for_unfreeze = self.extractor.model.image_encoder
        else:
            raise ValueError(f"Unsupported backbone_name: {backbone_name}")

        self.stride = self.extractor.stride

        # Freeze all
        for p in self.extractor.parameters():
            p.requires_grad = False

        # Unfreeze last blocks
        self._unfreeze_last_blocks(num_layers_to_finetune)

        # Enable gradient checkpointing (best-effort)
        self._enable_gradient_checkpointing(use_gradient_checkpointing)

        # Infer feature dim with a tiny forward
        with torch.no_grad():
            dummy = torch.zeros((1, 3, 224, 224), device=self.device)
            feat_map, _ = self.extractor.extract_feats(dummy)
            self.feat_dim = int(feat_map.shape[-1])

        n_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        n_total = sum(p.numel() for p in self.parameters())

        print("\n Model summary:")
        print(f"  Backbone: {backbone_name}")
        print(f"  Stride:   {self.stride}")
        print(f"  Feat dim: {self.feat_dim}")
        print(f"  Trainable params: {n_trainable:,} / {n_total:,} ({(n_trainable/n_total*100 if n_total else 0):.2f}%)\n")

    def _get_blocks(self):
        m = self._model_for_unfreeze
        if hasattr(m, "blocks"):
            return m.blocks
        if hasattr(m, "encoder") and hasattr(m.encoder, "layers"):
            return m.encoder.layers
        return None

    def _unfreeze_last_blocks(self, num_layers: int):
        if num_layers <= 0:
            print("\n Unfreezing: none (frozen backbone)\n")
            return

        blocks = self._get_blocks()
        if blocks is None:
            raise AttributeError("Cannot find transformer blocks to unfreeze in the selected backbone.")

        total = len(blocks)
        start = max(0, total - num_layers)
        print(f"\n Unfreezing last {num_layers} blocks: [{start}..{total-1}] out of {total}\n")

        for i in range(start, total):
            for p in blocks[i].parameters():
                p.requires_grad = True

    def _enable_gradient_checkpointing(self, enabled: bool):
        if not enabled:
            return

        m = self._model_for_unfreeze

        tried = []
        if hasattr(m, "gradient_checkpointing_enable"):
            tried.append("gradient_checkpointing_enable")
            try:
                m.gradient_checkpointing_enable()
                print("‚úì Gradient checkpointing enabled via .gradient_checkpointing_enable()")
                return
            except Exception as e:
                print(f"! gradient_checkpointing_enable failed: {e}")

        if hasattr(m, "set_grad_checkpointing"):
            tried.append("set_grad_checkpointing")
            try:
                m.set_grad_checkpointing(True)
                print("‚úì Gradient checkpointing enabled via .set_grad_checkpointing(True)")
                return
            except Exception as e:
                print(f"! set_grad_checkpointing failed: {e}")

        if hasattr(m, "set_gradient_checkpointing"):
            tried.append("set_gradient_checkpointing")
            try:
                m.set_gradient_checkpointing(True)
                print("‚úì Gradient checkpointing enabled via .set_gradient_checkpointing(True)")
                return
            except Exception as e:
                print(f"! set_gradient_checkpointing failed: {e}")

        if hasattr(m, "use_checkpoint"):
            tried.append("use_checkpoint")
            try:
                m.use_checkpoint = True
                print("‚úì Gradient checkpointing enabled via .use_checkpoint=True")
                return
            except Exception as e:
                print(f"! use_checkpoint flag failed: {e}")

        print("! Gradient checkpointing requested but no supported API was found on this backbone.")
        if tried:
            print("  Tried:", tried)

    def forward(self, image: torch.Tensor) -> torch.Tensor:
        feat_map, _ = self.extractor.extract_feats(image)
        return feat_map

# Loss (repo)
criterion = CorrespondenceLoss(loss_type=config.loss_type, temperature=config.temperature).to(device)
print(" Loss:", config.loss_type)

 Loss: contrastive


In [None]:
# ============================================================================
# CELLA 7: DataLoaders (train/val)
# ============================================================================

def maybe_subset(ds, max_pairs):
    if max_pairs is None:
        return ds
    idx = torch.randperm(len(ds))[:max_pairs].tolist()
    return torch.utils.data.Subset(ds, idx)

train_dataset_full = SPairDataset(
    root=SPAIR_ROOT, split='train', size='large', long_side=IMG_SIZE,
    normalize=True, load_segmentation=False
)
val_dataset_full = SPairDataset(
    root=SPAIR_ROOT, split='val', size='large', long_side=IMG_SIZE,
    normalize=True, load_segmentation=False
)

train_ds = maybe_subset(train_dataset_full, config.max_train_pairs)
val_ds   = maybe_subset(val_dataset_full, config.max_val_pairs)

train_loader = DataLoader(
    train_ds, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers,
    pin_memory=torch.cuda.is_available(), drop_last=False
)
val_loader = DataLoader(
    val_ds, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers,
    pin_memory=torch.cuda.is_available(), drop_last=False
)
print(f" Train pairs: {len(train_ds)} | batches: {len(train_loader)}")
print(f"   Val pairs: {len(val_ds)} | batches: {len(val_loader)}")


 SPAIR_ROOT: /home/jupyter/AML/dataset/Spair-71k
 Loaded 53340 pairs from train split (large)
 Loaded 5384 pairs from val split (large)
 Loaded 12234 pairs from test split (large)
 Train pairs: 20000 | batches: 20000
   Val pairs: 500 | batches: 500
  Test pairs: 12234 | batches: 12234


In [None]:
# ============================================================================
# CELLA 8: Evaluation helpers ‚Äî Using UnifiedEvaluator + CorrespondenceMatcher
# ============================================================================

class FinetunableBackboneAdapter:
    """Adapter per usare FinetunableBackbone con CorrespondenceMatcher."""
    def __init__(self, finetunable_backbone):
        self.model = finetunable_backbone
        self.device = finetunable_backbone.device
        self.config = type('obj', (object,), {'patch_size': finetunable_backbone.stride})()
    
    @torch.no_grad()
    def extract_features(self, img):
        """Adatta (B, H, W, D) ‚Üí (H, W, D) per compatibilit√† con CorrespondenceMatcher."""
        feat = self.model(img)  # (B, H, W, D)
        return feat.squeeze(0),  # (H, W, D), restituisce tupla come expected

evaluator = UnifiedEvaluator(
    dataloader=val_loader,
    device=device,
    thresholds=[0.05, 0.10, 0.15, 0.20]
)

print(" Eval helpers ready (using UnifiedEvaluator + CorrespondenceMatcher).")

 Eval helpers ready (using UnifiedEvaluator + CorrespondenceMatcher).


In [None]:
# ============================================================================
# CELLA 10: Training Loop (AMP + Warmup + Grad Clip + TB + Resume + Val modes)
# ============================================================================
import os
import time
from pathlib import Path
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import LinearLR, SequentialLR, CosineAnnealingLR

# -------------------------
# Build finetunable model
# -------------------------
model = FinetunableBackbone(
    backbone_name=config.backbone_name,
    num_layers_to_finetune=config.num_layers_to_finetune,
    device=device,
    use_gradient_checkpointing=config.use_gradient_checkpointing,
).to(device)

# -------------------------
# Optimizer
# -------------------------
trainable_params = [p for p in model.parameters() if p.requires_grad]
if len(trainable_params) == 0:
    raise RuntimeError("No trainable parameters found. Check num_layers_to_finetune > 0.")

optimizer = optim.AdamW(trainable_params, lr=config.learning_rate, weight_decay=config.weight_decay)

# -------------------------
# Scheduler: warmup -> cosine (step-based, matches main.ipynb style)
# -------------------------
total_updates = max(1, (len(train_loader) * config.num_epochs) // config.gradient_accumulation_steps)
warmup_updates = max(1, (len(train_loader) * config.warmup_epochs) // config.gradient_accumulation_steps)

warmup = LinearLR(optimizer, start_factor=0.1, total_iters=warmup_updates)
cosine = CosineAnnealingLR(optimizer, T_max=max(1, total_updates - warmup_updates))
scheduler = SequentialLR(optimizer, schedulers=[warmup, cosine], milestones=[warmup_updates])

# AMP
use_amp = bool(config.use_amp and torch.cuda.is_available())
scaler = GradScaler(enabled=use_amp)

print(f" AMP: {use_amp} | Grad accumulation: {config.gradient_accumulation_steps} | Grad checkpointing: {config.use_gradient_checkpointing}")

# TensorBoard
tb_writer = None
if config.use_tensorboard:
    from torch.utils.tensorboard import SummaryWriter
    logdir = Path(config.tb_logdir) / config.run_name
    logdir.mkdir(parents=True, exist_ok=True)
    tb_writer = SummaryWriter(log_dir=str(logdir))
    print(f" TensorBoard logdir: {logdir}")

# -------------------------
# Checkpointing (complete state for resume)
# -------------------------
best_pck = -1.0
start_epoch = 0
global_step = 0
best_path = Path(config.save_dir) / f"finetuned_{config.run_name}.pt"
best_path.parent.mkdir(parents=True, exist_ok=True)

def save_ckpt(path: Path, epoch: int, best_pck_val: float):
    torch.save({
        "epoch": epoch,
        "global_step": global_step,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict(),
        "scaler_state_dict": scaler.state_dict() if use_amp else None,
        "best_pck@0.10": best_pck_val,
        "config": config.__dict__,
        "baseline_val_slice_pck": baseline_pck,
        "baseline_val_slice_loss": baseline_vloss,
    }, str(path))

if config.training_mode == "resume":
    ckpt_path = config.resume_checkpoint or str(best_path)
    if ckpt_path and os.path.exists(ckpt_path):
        print(f"\nüîß Resuming from checkpoint: {ckpt_path}")
        ckpt = torch.load(ckpt_path, map_location="cpu")
        model.load_state_dict(ckpt["model_state_dict"], strict=True)
        optimizer.load_state_dict(ckpt["optimizer_state_dict"])
        scheduler.load_state_dict(ckpt["scheduler_state_dict"])
        if use_amp and ckpt.get("scaler_state_dict") is not None:
            scaler.load_state_dict(ckpt["scaler_state_dict"])
        start_epoch = int(ckpt.get("epoch", 0)) + 1
        global_step = int(ckpt.get("global_step", 0))
        best_pck = float(ckpt.get("best_pck@0.10", best_pck))
        print(f"   ‚úì start_epoch={start_epoch} | global_step={global_step} | best_pck@0.10={best_pck:.2f}")
    else:
        print(f"‚ö†Ô∏è resume requested but checkpoint not found: {ckpt_path}")
        print("   Continuing from current weights in memory.")
elif config.training_mode == "fresh":
    print("\nüîß Training mode: fresh (starting from pretrained weights loaded by backbone extractor)")
else:
    print("\nüîß Training mode: continue (keeping current weights in memory)")

# -------------------------
# Intra-epoch log checkpoints
# -------------------------
def get_log_steps(n_batches: int, k: int):
    if k <= 0:
        return set()
    if n_batches <= 1:
        return {0}
    steps = set(int(round(x)) for x in np.linspace(0, n_batches - 1, k))
    return steps

log_steps = get_log_steps(len(train_loader), config.log_checkpoints_per_epoch)

# -------------------------
# Helper: evaluation with visible progress + final print
# -------------------------
@torch.no_grad()
def eval_metrics_visible(model, loader, alphas=(0.05, 0.1, 0.2), max_batches=None, compute_loss=True, desc="Eval"):
    """
    Same as eval_metrics, but tqdm total matches max_batches so it doesn't look like it "stops early".
    """
    model.eval()
    agg = {a: {"correct": 0, "valid": 0} for a in alphas}
    losses = []

    total = len(loader) if max_batches is None else min(len(loader), max_batches)
    for bi, batch in enumerate(tqdm(loader, total=total, desc=desc, leave=False)):
        if max_batches is not None and bi >= max_batches:
            break

        src = batch["src_img"].to(device)
        tgt = batch["tgt_img"].to(device)
        src_kps = batch["src_kps"].to(device)
        tgt_kps = batch["tgt_kps"].to(device)
        vm = batch["valid_mask"].to(device)

        src_feat = model(src)
        tgt_feat = model(tgt)

        if compute_loss:
            loss = criterion(src_feat, tgt_feat, src_kps, tgt_kps, vm, patch_size=model.stride)
            losses.append(float(loss.item()))

        pred = predict_argmax(src_feat, tgt_feat, src_kps, vm, model.stride)

        pred_np = pred.cpu().numpy()
        tgt_np = tgt_kps.cpu().numpy()
        vm_np = vm.cpu().numpy()

        for b in range(pred_np.shape[0]):
            for a in alphas:
                pck, n_valid = compute_pck(pred_np[b], tgt_np[b], vm_np[b], alpha=a)
                correct = int(round(pck * n_valid / 100.0))
                agg[a]["correct"] += correct
                agg[a]["valid"] += n_valid

    out = {}
    for a in alphas:
        v = agg[a]["valid"]
        out[a] = 0.0 if v == 0 else (agg[a]["correct"] / v) * 100.0

    avg_loss = None if (not compute_loss or len(losses) == 0) else float(np.mean(losses))
    return out, avg_loss

# -------------------------
# One epoch
# -------------------------
def train_one_epoch(epoch: int):
    global global_step
    model.train()
    epoch_losses = []

    optimizer.zero_grad(set_to_none=True)
    pbar = tqdm(train_loader, desc=f"Train {epoch+1}/{config.num_epochs}", leave=False)

    accumulated = 0

    for step, batch in enumerate(pbar):
        src = batch["src_img"].to(device, non_blocking=True)
        tgt = batch["tgt_img"].to(device, non_blocking=True)
        src_kps = batch["src_kps"].to(device, non_blocking=True)
        tgt_kps = batch["tgt_kps"].to(device, non_blocking=True)
        vm = batch["valid_mask"].to(device, non_blocking=True)

        with autocast(enabled=use_amp):
            src_feat = model(src)
            tgt_feat = model(tgt)
            loss = criterion(src_feat, tgt_feat, src_kps, tgt_kps, vm, patch_size=model.stride)
            loss = loss / config.gradient_accumulation_steps

        scaler.scale(loss).backward()
        accumulated += 1

        do_update = (accumulated == config.gradient_accumulation_steps) or (step + 1 == len(train_loader))

        if do_update:
            if config.grad_clip_norm is not None and config.grad_clip_norm > 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip_norm)

            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)

            scheduler.step()
            accumulated = 0
            global_step += 1

        epoch_losses.append(float(loss.item() * config.gradient_accumulation_steps))
        pbar.set_postfix({"loss": f"{np.mean(epoch_losses):.4f}", "lr": f"{scheduler.get_last_lr()[0]:.2e}"})

        # ---- Intra-epoch logging + VISIBLE eval summary
        if config.use_wandb and (step in log_steps) and (step > 0):
            avg_train_loss = float(np.mean(epoch_losses)) if len(epoch_losses) else 0.0

            t_eval0 = time.time()
            vpck, vloss = eval_metrics_visible(
                model,
                val_loader,
                alphas=(0.05, 0.10, 0.20),
                max_batches=config.val_slice_max_batches,
                compute_loss=True,
                desc="Eval (val-slice)"
            )
            t_eval = time.time() - t_eval0

            # Print a compact progress line so you see the result even if tqdm rendering is weird
            print(
                f"[Intra-epoch eval] epoch={epoch+1} step={step+1}/{len(train_loader)} | "
                f"val_loss={vloss:.4f} | PCK@0.10={vpck.get(0.10, 0.0):.2f} | {t_eval:.1f}s"
            )

            wandb.log({
                "training loss": avg_train_loss,
                "validation loss": vloss if vloss is not None else 0.0,
                "validation pck": vpck.get(0.10, 0.0),
                "epoch": epoch + 1,
                "batch": step + 1,
                "learning_rate": scheduler.get_last_lr()[0]
            })

            model.train()

    return float(np.mean(epoch_losses)) if len(epoch_losses) else 0.0

print("\n Starting fine-tuning...\n")

for epoch in range(start_epoch, config.num_epochs):
    t0 = time.time()
    train_loss = train_one_epoch(epoch)
    dt = time.time() - t0

    # Validation per config.val_mode
    val_out = run_validation(model, config.val_mode, epoch)

    # Choose checkpoint selection metric: prefer FULL when available, otherwise slice
    if config.val_mode in ("full", "slice"):
        score = val_out["pck"][0.10]
        val_loss = val_out["loss"]
    else:  # hybrid
        if "pck_full" in val_out:
            score = val_out["pck_full"][0.10]
            val_loss = val_out["loss_full"]
        else:
            score = val_out["pck"][0.10]
            val_loss = val_out["loss"]

    print(f"\n Epoch {epoch+1}/{config.num_epochs} | train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | PCK@0.10={score:.2f} | {dt:.1f}s")

    # TensorBoard
    if tb_writer is not None:
        tb_writer.add_scalar("train/loss", train_loss, epoch)
        tb_writer.add_scalar("val/pck@0.10", score, epoch)
        if val_loss is not None:
            tb_writer.add_scalar("val/loss", val_loss, epoch)

    # wandb end-of-epoch logging (main.ipynb style)
    if config.use_wandb:
        wandb.log({
            "training loss": train_loss,
            "validation loss": val_loss if val_loss is not None else 0.0,
            "validation pck": score,
            "epoch": epoch + 1,
            "batch": len(train_loader),
            "learning_rate": scheduler.get_last_lr()[0]
        })

    # Save best checkpoint
    if score > best_pck:
        best_pck = score
        save_ckpt(best_path, epoch, best_pck)
        print(f"  ‚úî Saved best checkpoint: {best_path} (PCK@0.10={best_pck:.2f})")

print("\n Fine-tuning done.")
print(" Best PCK@0.10:", best_pck)

if tb_writer is not None:
    tb_writer.close()

# finish wandb run (if needed)
if config.use_wandb:
    wandb.finish()