## Data Preprocessing

In [1]:
import os, torch
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'  # syncs CUDA so you see the true Python stack
torch.autograd.set_detect_anomaly(True)   # pinpoints the backward op

import urllib.request
import zipfile
import socket
from datasets import load_dataset

# Check if datasets already exist

imagenet_path = os.path.join('.', 'imagenet-100')
ade_path = os.path.join('.', 'ADEChallengeData2016')

print("Loading datasets...")

# Load ImageNet-100 and save to visible folder
if os.path.exists(imagenet_path):
    print("ImageNet-100 already exists. Skipping download.")
else:
    print("Downloading ImageNet-100...")
    try:
        img100 = load_dataset("clane9/imagenet-100")
        img100.save_to_disk(imagenet_path)
        print("ImageNet-100 saved to ./imagenet-100/")
        print(f"Train samples: {len(img100['train'])}")
        print(f"Val samples: {len(img100['validation'])}")
    except Exception as e:
        print(f"ImageNet-100 failed: {e}")

# Download ADE20K manually if not exists
if os.path.exists(ade_path):
    print("ADE20K already exists. Skipping download.")
else:
    print("Downloading ADE20K...")
    try:
        socket.setdefaulttimeout(60)
        url = "http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip"
        zip_path = "ADEChallengeData2016.zip"
        urllib.request.urlretrieve(url, zip_path)
        
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall('.')
        os.remove(zip_path)
        print("ADE20K downloaded")
    except Exception as e:
        print(f"ADE20K download failed: {e}")

print("Ready.")

Loading datasets...
ImageNet-100 already exists. Skipping download.
ADE20K already exists. Skipping download.
Ready.


In [2]:
import os
import glob
import numpy as np
import torch
import random
from PIL import Image
from datasets import load_from_disk
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from math import ceil

# Reproducibility
seed = 1337
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

IMG_SIZE = 128

# =============================================================================
# JEPA ImageNet Dataset for Pretraining
# =============================================================================
class JEPADataset(Dataset):
    def __init__(self, root="./imagenet-100", split="train", img_size=128):
        self.dataset = load_from_disk(root)[split]
        self.img_size = img_size
        # Image transforms
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(img_size, scale=(0.2, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        # Load and transform image
        img = self.dataset[idx]["image"].convert("RGB")
        img_tensor = self.transform(img)
        return {
            "image": img_tensor
        }

# =============================================================================
# ADE20K Dataset for Segmentation
# =============================================================================
class ADE20KDataset(Dataset):
    def __init__(self, root="ADEChallengeData2016", split="training", img_size=128):
        self.img_dir = os.path.join(root, "images", split)
        self.ann_dir = os.path.join(root, "annotations", split)
        self.img_size = img_size
        self.items = []
        
        # Find image-mask pairs
        for img_path in glob.glob(os.path.join(self.img_dir, "*.jpg")):
            stem = os.path.splitext(os.path.basename(img_path))[0]
            ann_path = os.path.join(self.ann_dir, stem + ".png")
            if os.path.exists(ann_path):
                self.items.append((img_path, ann_path))
        self.items.sort()
        
        # Image transforms
        self.img_transform = transforms.Compose([
            transforms.Resize((img_size, img_size), interpolation=Image.BILINEAR),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])
    
    def __len__(self):
        return len(self.items)
    
    def __getitem__(self, idx):
        img_path, ann_path = self.items[idx]
        
        # Load image and mask
        img = Image.open(img_path).convert("RGB")
        mask = Image.open(ann_path)
        
        # Apply transforms
        img = self.img_transform(img)
        mask = mask.resize((self.img_size, self.img_size), resample=Image.NEAREST)
        mask = torch.from_numpy(np.array(mask, dtype="int64"))
        
        return img, mask

def jepa_collate(batch):
    images = torch.stack([item["image"] for item in batch])
    return {"images": images}

def ade_collate(batch):
    imgs, masks = zip(*batch)
    return {"images": torch.stack(imgs), "masks": torch.stack(masks)}

## Setup Dataloader

## Helper functions

In [3]:
import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F
from typing import Optional
import numpy as np
import matplotlib.pyplot as plt


def compute_patch_grid(image_shape, patch_size):
    """
    image_shape expected: (C, H, W)
    """
    _, H, W = image_shape
    n_h = H // patch_size
    n_w = W // patch_size
    P = n_h * n_w
    cropped_shape = (n_h * patch_size, n_w * patch_size)
    return n_h, n_w, P, cropped_shape


def extract_patch_embeddings_from_feature_map(feats: torch.Tensor) -> torch.Tensor:
    """
    feats: [N, D, n_h, n_w]  -> returns [N, P, D]  (P = n_h * n_w)
    """
    N, D, n_h, n_w = feats.shape
    # Move D to last, then flatten spatial dims
    return feats.permute(0, 2, 3, 1).reshape(N, -1, D)


def compute_denoising_loss(self, denoised_prediction, original_input):
    # Downsample target to match denoised prediction size
    target_downsampled = F.interpolate(
        original_input, 
        size=denoised_prediction.shape[-2:], 
        mode='bilinear', 
        align_corners=False
    )
    return F.mse_loss(denoised_prediction, target_downsampled)

def compute_reconstruction_loss(preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    """
    Compute reconstruction loss for predicted vs target embeddings
    """
    return F.mse_loss(preds, targets, reduction="mean")

@torch.no_grad()
def update_ema(target_net: nn.Module, online_net: nn.Module, tau: float):
    """
    Update target network parameters using exponential moving average
    """
    for t_param, s_param in zip(target_net.parameters(), online_net.parameters()):
        t_param.data.mul_(tau).add_(s_param.data, alpha=1 - tau)

def unpatchify_embeddings(emb: torch.Tensor, n_h: int, n_w: int) -> torch.Tensor:
    """
    Convert patch embeddings back to 2D feature map
    emb: [N, P, D] -> [N, D, n_h, n_w]
    """
    N, P, D = emb.shape
    emb_4d = emb.view(N, n_h, n_w, D)
    return emb_4d.permute(0, 3, 1, 2).contiguous()

def generate_fi1_mask(fi1_shape: tuple, mask_ratio: float = 0.5, patch_size: int = 8, device='cuda'):
    B, D, H8, W8 = fi1_shape  # e.g., [B, D, 28, 28] for 224x224 images
    
    # Calculate number of patches
    n_patches_h = H8 // patch_size  # 28/8 = 3
    n_patches_w = W8 // patch_size  # 28/8 = 3
    total_patches = n_patches_h * n_patches_w  # 9 patches total
    
    num_masked = int(mask_ratio * total_patches)  # e.g., 4 patches masked
    
    # Generate mask
    fi1_mask = torch.zeros(B, H8 * W8, dtype=torch.bool, device=device)
    
    for b in range(B):
        # Randomly select which patches to mask
        masked_patch_ids = torch.randperm(total_patches, device=device)[:num_masked]
        
        for patch_id in masked_patch_ids:
            # Convert patch_id to patch coordinates
            ph = patch_id // n_patches_w
            pw = patch_id % n_patches_w
            
            # Convert to pixel coordinates in Fi1
            h_start = ph * patch_size
            h_end = min(h_start + patch_size, H8)
            w_start = pw * patch_size  
            w_end = min(w_start + patch_size, W8)
            
            # Mask this patch in flattened Fi1
            for h in range(h_start, h_end):
                for w in range(w_start, w_end):
                    fi1_mask[b, h * W8 + w] = True
    
    return fi1_mask  # [B, H8*W8]

def apply_fi1_mask_tokens(fi1_features: torch.Tensor, fi1_mask: torch.Tensor, mask_token: torch.Tensor):
    """
    Apply masking to Fi1 features using learned mask tokens
    
    Args:
        fi1_features: (B, D, H8, W8) Fi1 feature maps
        fi1_mask: (B, H8*W8) boolean mask
        mask_token: (1, D, 1, 1) learned mask token
    
    Returns:
        masked_fi1: Fi1 with mask tokens at masked positions
    """
    B, D, H8, W8 = fi1_features.shape
    
    # Reshape mask to match feature dimensions
    mask_2d = fi1_mask.reshape(B, H8, W8).unsqueeze(1).expand(-1, D, -1, -1)
    
    # Replace masked positions with mask token
    masked_fi1 = torch.where(mask_2d, mask_token.expand(B, D, H8, W8), fi1_features)
    
    return masked_fi1

def visualize_jepa_patch_quality(
    original: torch.Tensor,
    predicted_features: torch.Tensor,
    target_features: torch.Tensor,
    patch_mask: torch.Tensor,
    epoch: int,
    save_path: str,
    patch_size: int = 16,
):

    # ----- robust image-to-display -----
    def _to_display_img(x: torch.Tensor) -> np.ndarray:
        x = x.detach().cpu()
        if x.ndim == 3 and x.shape[0] in (1, 3):
            xc = x.clone()
            mn, mx = float(xc.min()), float(xc.max())
            if 0.0 <= mn and mx <= 1.0:
                pass  # already [0,1]
            elif -3.5 <= mn <= 3.5 and -3.5 <= mx <= 3.5:
                # assume ImageNet norm
                mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
                std  = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)
                xc = xc * std + mean
            else:
                # min-max to [0,1]
                xc = (xc - mn) / (max(mx - mn, 1e-6))
            img = xc.permute(1, 2, 0).numpy()
        else:
            arr = x.numpy()
            arr = (arr - arr.min()) / (arr.max() - arr.min() + 1e-6)
            img = arr
        return np.clip(img, 0.0, 1.0)

    # ---- prep first image ----
    original_np = _to_display_img(original[0])
    H, W = original_np.shape[:2]
    n_h, n_w = H // patch_size, W // patch_size

    # ---- per-masked-tile losses ----
    pred0 = predicted_features[0].detach().cpu().numpy()   # [M, D]
    tgt0  = target_features[0].detach().cpu().numpy()      # [M, D]
    if pred0.size == 0:
        per_patch_losses = np.zeros((0,), dtype=np.float32)
    else:
        diff = pred0 - tgt0
        per_patch_losses = (diff * diff).mean(axis=-1)     # [M]

    if per_patch_losses.size > 0:
        lo, hi = float(per_patch_losses.min()), float(per_patch_losses.max())
        denom = (hi - lo) if (hi > lo) else 1.0
        normalized_quality = 1.0 - ((per_patch_losses - lo) / denom)
    else:
        normalized_quality = np.zeros((0,), dtype=np.float32)

    # ---- masked indices ----
    mask0 = patch_mask[0].detach().cpu().view(-1)          # [P]
    masked_indices = mask0.nonzero(as_tuple=False).squeeze(1).numpy()  # [K]

    # ---- figure ----
    fig, axs = plt.subplots(1, 3, figsize=(14, 5))

    # Left: original
    axs[0].imshow(original_np, interpolation='nearest')
    axs[0].set_title("Original Image")
    axs[0].axis('off')

    # Center: mask overlay (red) with black grid
    masked_img = original_np.copy()
    overlay = masked_img.copy()
    red = np.array([1.0, 0.0, 0.0], dtype=np.float32)

    for pidx in masked_indices:
        if pidx < 0 or pidx >= n_h * n_w:
            continue
        ih, iw = divmod(int(pidx), n_w)
        h0, h1 = ih * patch_size, (ih + 1) * patch_size
        w0, w1 = iw * patch_size, (iw + 1) * patch_size
        overlay[h0:h1, w0:w1, :] = red

    alpha_center = 0.35
    masked_img = (1 - alpha_center) * masked_img + alpha_center * overlay
    axs[1].imshow(np.clip(masked_img, 0.0, 1.0), interpolation='nearest')
    axs[1].set_title(f"Epoch {epoch} - JEPA Analysis\nMasked Patches (Red)\n{int(mask0.sum())}/{mask0.numel()} masked")
    axs[1].axis('off')

    # Right: reconstruction quality (bold colored squares)
    quality_img = original_np.copy()
    colormap = plt.get_cmap('RdYlGn')  # green=good, red=bad
    alpha_patch = 0.85                 # strong overlay for bold squares
    grid_thick = max(1, patch_size // 16)  # thicker grid lines

    limit = min(len(normalized_quality), len(masked_indices))
    for idx in range(limit):
        patch_idx = int(masked_indices[idx])
        if patch_idx < 0 or patch_idx >= n_h * n_w:
            continue

        ih, iw = divmod(patch_idx, n_w)
        h0, h1 = ih * patch_size, (ih + 1) * patch_size
        w0, w1 = iw * patch_size, (iw + 1) * patch_size

        q = float(np.asarray(normalized_quality[idx]).mean())
        if not np.isfinite(q):
            q = 0.0
        q = float(np.clip(q, 0.0, 1.0))

        color = np.asarray(colormap(q))[:3]  # (3,)
        patch = quality_img[h0:h1, w0:w1, :]
        quality_img[h0:h1, w0:w1, :] = (1 - alpha_patch) * patch + alpha_patch * color[None, None, :]

        # thicker black grid lines
        quality_img[h0:h0+grid_thick, w0:w1, :] = 0.0
        quality_img[h0:h1, w0:w0+grid_thick, :] = 0.0

    axs[2].imshow(np.clip(quality_img, 0.0, 1.0), interpolation='nearest')
    axs[2].set_title("Reconstruction Quality\n(Green=Good, Red=Poor)")
    axs[2].axis('off')

    plt.tight_layout()
    fig.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close(fig)



## Models

#### Patch Embed

In [4]:
# --- PatchEmbed2D: remove undefined pos_embed add (or implement it properly) ---
class PatchEmbed2D(nn.Module):
    def __init__(self, in_chans: int, embed_dim: int, patch_size: int):
        super().__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = self.proj(x)                              # [B, D, n_h, n_w]
        B, D, n_h, n_w = x.shape
        x = x.view(B, D, n_h * n_w).transpose(1, 2)   # [B, P, D]
        # (no pos_embed here)
        return self.norm(x)


#### Context Encoder

In [5]:
# 2D Image → Patches → 2D Feature Map → ViT → 2D Feature Map → Tokens

In [6]:
import math
import torch
import torch.nn as nn
import timm

class ContextEncoder2D(nn.Module):
    """
    Swin Transformer v2 Tiny context encoder configured for 384×384.
    Returns patch tokens (no CLS) and the token grid size (Ht, Wt).
    """
    def __init__(
        self,
        model_name: str = "swinv2_tiny_window8_256",
        pretrained: bool = True,
        img_size: int = 384,
        strict_img_size: bool = False,
        dynamic_img_pad: bool = True,
    ):
        super().__init__()
        # Build timm Swin v2; set img_size=384 and allow dynamic padding
        self.swin = timm.create_model(
            model_name,
            pretrained=pretrained,
            img_size=img_size,          # <- run at 384
            num_classes=0,              # no classifier head
            global_pool='',             # keep token grid
            features_only=False,
            strict_img_size=strict_img_size,
            dynamic_img_pad=dynamic_img_pad,
        )
        self.embed_dim = self.swin.num_features
        self.patch_size = 4  # Swin uses patch4 embed

    def forward(self, x: torch.Tensor):
        """
        x: [B, C, H, W]  (H,W multiples of 32 recommended; 384 works)
        returns:
          tokens: [B, P, D]  (CLS-free)
          (Ht, Wt): token grid size at the final stage (~ H/32, W/32)
        """
        feats = self.swin.forward_features(x)   # [B, L, D] or sometimes [B, Ht, Wt, D] / [B, D, Ht, Wt]

        if feats.dim() == 3:                    # [B, L, D]
            tokens = feats
            P = tokens.shape[1]
            Ht = int(math.sqrt(P))
            Wt = P // Ht
        else:
            # Handle both [B, D, Ht, Wt] and [B, Ht, Wt, D]
            if feats.shape[1] == self.embed_dim:        # [B, D, Ht, Wt]
                B, D, Ht, Wt = feats.shape
                tokens = feats.permute(0, 2, 3, 1).reshape(B, Ht * Wt, D)
            else:                                       # [B, Ht, Wt, D]
                B, Ht, Wt, D = feats.shape
                tokens = feats.reshape(B, Ht * Wt, D)

        return tokens, (Ht, Wt)


#### Pixel Decoder

In [7]:
# # === CELL A: REPLACE YOUR WHOLE PixelDecoder2D CLASS WITH THIS ===
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from typing import List, Tuple, Optional

# def _build_2d_sincos_pos_like(x: torch.Tensor) -> torch.Tensor:
#     """Fixed 2D sin/cos positional embedding with the same shape as x: [B, D, H, W]."""
#     B, D, H, W = x.shape
#     device, dtype = x.device, x.dtype
#     y = torch.linspace(-1, 1, H, device=device, dtype=dtype)
#     z = torch.linspace(-1, 1, W, device=device, dtype=dtype)
#     yy, zz = torch.meshgrid(y, z, indexing='ij')               # [H,W]
#     pos = torch.stack([yy, zz], dim=0).unsqueeze(0).expand(B, -1, -1, -1)  # [B,2,H,W]
#     half = D // 2
#     sin_bank = pos[:, :1].expand(B, half, H, W)                # y
#     cos_bank = pos[:, 1:2].expand(B, D - half, H, W)           # x
#     pe = torch.cat([torch.sin(sin_bank), torch.cos(cos_bank)], dim=1)      # [B,D,H,W]
#     if pe.shape[1] < D:  # pad if D is odd
#         pe = F.pad(pe, (0,0,0,0,0, D - pe.shape[1]))
#     return pe

# def _make_groupnorm(C: int) -> nn.GroupNorm:
#     """Pick a GroupNorm that divides C (fallback to 1). Safe for D=192,256 etc."""
#     for g in (32, 16, 8, 4, 2, 1):
#         if C % g == 0:
#             return nn.GroupNorm(g, C)
#     return nn.GroupNorm(1, C)

# class _DeformableAttn2D(nn.Module):
#     """
#     MSDeformAttn-lite:
#       - q:               [B, D, Hq, Wq]  (query map; produces offsets & weights)
#       - feats_per_level: list of L tensors, each [B, D, Hi, Wi] (already projected + pos + level id)
#       - returns:         [B, D, Hq, Wq]
#     """
#     def __init__(self, embed_dim: int, num_points: int = 4, num_levels: int = 3, num_heads: int = 8):
#         super().__init__()
#         assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
#         self.D = embed_dim
#         self.H = num_heads
#         self.Dh = embed_dim // num_heads
#         self.K = num_points
#         self.L = num_levels

#         # per-level value projections (like MSDeformAttn’s value proj)
#         self.value_proj = nn.ModuleList([nn.Conv2d(self.D, self.D, 1) for _ in range(self.L)])

#         # sampling offsets & attention weights predicted from q
#         self.sampling_offsets = nn.Conv2d(self.D, self.H * self.L * self.K * 2, kernel_size=3, padding=1)
#         self.attention_weights = nn.Conv2d(self.D, self.H * self.L * self.K,   kernel_size=3, padding=1)

#         self.out_proj = nn.Conv2d(self.D, self.D, 1)
#         self.norm = _make_groupnorm(self.D)

#     @staticmethod
#     def _make_ref_points(B, Hq, Wq, device, dtype):
#         # reference points at the centers of each query location in [0,1] coords
#         ys = (torch.arange(Hq, device=device, dtype=dtype) + 0.5) / Hq
#         xs = (torch.arange(Wq, device=device, dtype=dtype) + 0.5) / Wq
#         yy, xx = torch.meshgrid(ys, xs, indexing='ij')  # [Hq,Wq]
#         ref = torch.stack([xx, yy], dim=-1)             # [Hq,Wq,2], (x,y) in [0,1]
#         ref = ref.unsqueeze(0).expand(B, Hq, Wq, 2)
#         return ref

#     @staticmethod
#     def _to_grid_sample(ref_xy_01):
#         # convert [0,1] → [-1,1] for grid_sample
#         return ref_xy_01 * 2.0 - 1.0

#     def forward(self, q: torch.Tensor, feats_per_level: list[torch.Tensor]) -> torch.Tensor:
#         B, D, Hq, Wq = q.shape
#         device, dtype = q.device, q.dtype

#         # predict offsets & weights
#         offsets = self.sampling_offsets(q)   # [B, H*L*K*2, Hq, Wq]
#         weights = self.attention_weights(q)  # [B, H*L*K,   Hq, Wq]

#         # reshape
#         offsets = offsets.view(B, self.H, self.L, self.K, 2, Hq, Wq)  # (B,H,L,K,2,Hq,Wq)
#         # softmax over levels×points for each head/location
#         weights = weights.view(B, self.H, self.L, self.K, Hq, Wq)     # (B,H,L,K,Hq,Wq)
#         weights = weights.permute(0,1,4,5,2,3)                        # (B,H,Hq,Wq,L,K)
#         weights = torch.softmax(weights.reshape(B, self.H, Hq, Wq, self.L * self.K), dim=-1)
#         weights = weights.view(B, self.H, Hq, Wq, self.L, self.K).permute(0,1,4,5,2,3)  # (B,H,L,K,Hq,Wq)

#         # reference points per query location (shared across heads)
#         ref = self._make_ref_points(B, Hq, Wq, device, dtype)         # (B,Hq,Wq,2) in [0,1]

#         # project values per level, split heads
#         vals = []
#         shapes = []
#         for l, x in enumerate(feats_per_level):
#             v = self.value_proj[l](x)                                 # (B,D,Hi,Wi)
#             Bi, Di, Hi, Wi = v.shape
#             v = v.view(Bi, self.H, self.Dh, Hi, Wi)                   # (B,H,Dh,Hi,Wi)
#             vals.append(v)
#             shapes.append((Hi, Wi))

#         out = torch.zeros(B, self.H, self.Dh, Hq, Wq, device=device, dtype=dtype)

#         # accumulate over levels and points
#         for l in range(self.L):
#             Hi, Wi = shapes[l]
#             # precompute normalization for this level
#             # offsets are predicted in (approx) pixel units; normalize by feature map size
#             norm_x = Wi
#             norm_y = Hi

#             for k in range(self.K):
#                 # offsets for all heads at this (l,k)
#                 off_lk = offsets[:, :, l, k]                           # (B,H,2,Hq,Wq)
#                 # to [0,1]-space: add normalized offsets to ref
#                 dx = off_lk[:, :, 0] / max(1.0, float(norm_x))         # (B,H,Hq,Wq)
#                 dy = off_lk[:, :, 1] / max(1.0, float(norm_y))         # (B,H,Hq,Wq)
#                 ref_xy = torch.stack([ref[..., 0].unsqueeze(1) + dx,
#                                       ref[..., 1].unsqueeze(1) + dy], dim=-1)  # (B,H,Hq,Wq,2) in [0,1]
#                 # to grid_sample coordinates
#                 grid = self._to_grid_sample(ref_xy).clamp(-1.0, 1.0)   # (B,H,Hq,Wq,2)

#                 # sample vals[l] for each head; vectorize by merging (B,H)
#                 v_l = vals[l]                                          # (B,H,Dh,Hi,Wi)
#                 v_l = v_l.reshape(B * self.H, self.Dh, Hi, Wi)
#                 grid_l = grid.reshape(B * self.H, Hq, Wq, 2)

#                 sampled = F.grid_sample(
#                     v_l, grid_l, mode='bilinear', align_corners=False
#                 )  # (B*H, Dh, Hq, Wq)

#                 sampled = sampled.view(B, self.H, self.Dh, Hq, Wq)     # (B,H,Dh,Hq,Wq)

#                 # attention weights for this (l,k)
#                 att_lk = weights[:, :, l, k].unsqueeze(2)              # (B,H,1,Hq,Wq)
#                 out = out + sampled * att_lk

#         out = out.view(B, self.D, Hq, Wq)                              # merge heads
#         out = self.out_proj(self.norm(F.relu(out)))
#         return out

# class PixelDecoder2D(nn.Module):
#     """
#     Mask2Former-style pixel decoder (multi-level + deformable attention).
#     (Name unchanged for drop-in.)

#     __init__ keeps your signature:
#       - in_channels: int (ignored if in_channels_per_level is provided)
#       - embed_dim:   int

#     Extra kwarg:
#       - in_channels_per_level: Tuple[int,int,int] for [C3,C4,C5]. If omitted, use (in_channels,)*3.
#     """
#     def __init__(self,
#                  in_channels: int,
#                  embed_dim: int,
#                  *args, **kwargs):
#         super().__init__()
#         self.embed_dim = embed_dim

#         in_chs: Optional[Tuple[int,int,int]] = kwargs.get("in_channels_per_level", None)
#         if in_chs is None:
#             in_chs = (in_channels, in_channels, in_channels)

#         # per-level projection to shared D
#         self.in_proj = nn.ModuleList([nn.Conv2d(c, embed_dim, 1) for c in in_chs])

#         # learned level embeddings
#         self.level_embed = nn.Parameter(torch.randn(3, embed_dim))

#         # deformable encoder (3 layers)
#         self.enc_layers = nn.ModuleList([
#             nn.ModuleDict({
#                 "attn": _DeformableAttn2D(embed_dim, num_points=4, num_levels=3),
#                 "ffn": nn.Sequential(
#                     nn.Conv2d(embed_dim, 4*embed_dim, 1),
#                     nn.GELU(),
#                     nn.Conv2d(4*embed_dim, embed_dim, 1),
#                 ),
#                 "norm1": _make_groupnorm(embed_dim),
#                 "norm2": _make_groupnorm(embed_dim),
#             }) for _ in range(3)
#         ])

#         # heads for Fi1 (s=8) and F_last (s=4)
#         self.fi1_head   = nn.Conv2d(embed_dim, embed_dim, 1)
#         self.flast_head = nn.Conv2d(embed_dim, embed_dim, 1)

#     def _prepare_levels(self, feats_multi: List[torch.Tensor]) -> List[torch.Tensor]:
#         outs = []
#         for lvl, (x, proj) in enumerate(zip(feats_multi, self.in_proj)):
#             x = proj(x)                                  # [B,D,Hi,Wi]
#             x = x + _build_2d_sincos_pos_like(x)         # pos enc
#             x = x + self.level_embed[lvl].view(1,-1,1,1) # level id
#             outs.append(x)
#         return outs  # [C3@1/8, C4@1/16, C5@1/32], all [B,D,Hi,Wi]

#     def forward(self, feats_multi: List[torch.Tensor], input_hw: Tuple[int,int]):
#         """
#         feats_multi: [C3, C4, C5] at strides [1/8, 1/16, 1/32], each [B,Ci,Hi,Wi]
#         input_hw: (H, W) full-res size
#         returns:
#           Fi1:   [B,D,H/8, W/8]
#           F_last:[B,D,H/4, W/4]
#         """
#         H, W = input_hw

#         # 1) project + pos + level embeddings
#         levels = self._prepare_levels(feats_multi)

#         # 2) use 1/8 map as query grid for encoder
#         q = levels[0]  # [B,D,H/8,W/8]

#         # 3) deformable encoder layers
#         for l in self.enc_layers:
#             q = q + l["attn"](l["norm1"](q), levels)
#             q = q + l["ffn"](l["norm2"](q))

#         # 4) heads at required strides
#         Fi1 = self.fi1_head(q)  # stride 8 already
#         q4 = F.interpolate(q, size=(H // 4, W // 4), mode='bilinear', align_corners=False)
#         F_last = self.flast_head(q4)

#         return Fi1, F_last


In [8]:
## test pixel decoder

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Optional

def _simple_pos_embed_2d(x: torch.Tensor) -> torch.Tensor:
    """Simple 2D positional embedding - just scaled coordinates"""
    B, D, H, W = x.shape
    device, dtype = x.device, x.dtype
    
    # Create coordinate grids
    y_coords = torch.linspace(0, 1, H, device=device, dtype=dtype)
    x_coords = torch.linspace(0, 1, W, device=device, dtype=dtype) 
    yy, xx = torch.meshgrid(y_coords, x_coords, indexing='ij')
    
    # Simple embedding: just use x,y coordinates repeated
    pos_embed = torch.stack([xx, yy], dim=0).unsqueeze(0)  # [1, 2, H, W]
    pos_embed = pos_embed.expand(B, -1, -1, -1)            # [B, 2, H, W]
    
    # Repeat to match channel dimension
    pos_embed = pos_embed.repeat(1, D//2, 1, 1)            # [B, D, H, W]
    if pos_embed.shape[1] < D:
        pos_embed = F.pad(pos_embed, (0, 0, 0, 0, 0, D - pos_embed.shape[1]))
    
    return pos_embed * 0.1  # Scale down to not overwhelm features

class PixelDecoder2D(nn.Module):
    """
    Simple FPN-style pixel decoder. Much cleaner than deformable attention.
    Keeps same interface as the complex version.
    """
    def __init__(self,
                 in_channels: int,
                 embed_dim: int,
                 *args, **kwargs):
        super().__init__()
        self.embed_dim = embed_dim

        # Handle per-level channels (C3, C4, C5 might be different)
        in_chs: Optional[Tuple[int,int,int]] = kwargs.get("in_channels_per_level", None)
        if in_chs is None:
            in_chs = (in_channels, in_channels, in_channels)

        # Project each level to common embedding dimension
        self.lateral_convs = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(c, embed_dim, kernel_size=1),
                nn.GroupNorm(32 if embed_dim >= 32 else 1, embed_dim),
                nn.ReLU(inplace=True)
            ) for c in in_chs
        ])

        # FPN fusion layers (reduce aliasing during upsampling)
        self.fpn_convs = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1),
                nn.GroupNorm(32 if embed_dim >= 32 else 1, embed_dim),
                nn.ReLU(inplace=True)
            ) for _ in range(3)
        ])

        # Output heads
        self.fi1_head = nn.Sequential(
            nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1),
            nn.GroupNorm(32 if embed_dim >= 32 else 1, embed_dim),
            nn.ReLU(inplace=True),
            nn.Conv2d(embed_dim, embed_dim, kernel_size=1)
        )
        
        self.flast_head = nn.Sequential(
            nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1),
            nn.GroupNorm(32 if embed_dim >= 32 else 1, embed_dim),
            nn.ReLU(inplace=True),
            nn.Conv2d(embed_dim, embed_dim, kernel_size=1)
        )

    def forward(self, feats_multi: List[torch.Tensor], input_hw: Tuple[int, int]):
        """
        Simple FPN forward pass.
        
        feats_multi: [C3, C4, C5] at strides [1/8, 1/16, 1/32]
        input_hw: (H, W) full resolution size
        
        Returns:
            Fi1: [B, D, H/8, W/8] - feature at stride 8
            F_last: [B, D, H/4, W/4] - feature at stride 4
        """
        H, W = input_hw
        c3, c4, c5 = feats_multi
        
        # 1. Lateral connections - project to common dimension
        p5 = self.lateral_convs[2](c5)  # [B, D, H/32, W/32]
        p4 = self.lateral_convs[1](c4)  # [B, D, H/16, W/16]  
        p3 = self.lateral_convs[0](c3)  # [B, D, H/8, W/8]

        # Add simple positional encoding
        p5 = p5 + _simple_pos_embed_2d(p5)
        p4 = p4 + _simple_pos_embed_2d(p4)
        p3 = p3 + _simple_pos_embed_2d(p3)

        # 2. Top-down pathway (FPN fusion)
        # P5 -> P4
        p5_up = F.interpolate(p5, size=p4.shape[-2:], mode='bilinear', align_corners=False)
        p4 = p4 + p5_up
        p4 = self.fpn_convs[1](p4)

        # P4 -> P3  
        p4_up = F.interpolate(p4, size=p3.shape[-2:], mode='bilinear', align_corners=False)
        p3 = p3 + p4_up
        p3 = self.fpn_convs[0](p3)  # This is our main feature at 1/8 stride

        # 3. Generate outputs
        # Fi1 at stride 8 (same as p3)
        Fi1 = self.fi1_head(p3)  # [B, D, H/8, W/8]

        # F_last at stride 4 (upsample p3)
        p3_upsampled = F.interpolate(p3, size=(H // 4, W // 4), mode='bilinear', align_corners=False)
        F_last = self.flast_head(p3_upsampled)  # [B, D, H/4, W/4]

        return Fi1, F_last

#### PredictorHead

In [10]:
import torch
import torch.nn as nn

class CrossAttentionBlock2D(nn.Module):
    """Cross-attention block for JEPA predictor"""
    
    def __init__(self, embed_dim: int, num_heads: int = 8):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        
        self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        
        # Feedforward
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            nn.Linear(embed_dim * 4, embed_dim)
        )
    
    def forward(self, queries, features):
        # Cross-attention: queries attend to features
        attn_out, _ = self.cross_attn(queries, features, features)
        queries = self.norm1(queries + attn_out)
        
        # Feedforward
        ffn_out = self.ffn(queries)
        queries = self.norm2(queries + ffn_out)
        
        return queries

class SelfAttentionBlock2D(nn.Module):
    """Self-attention block for query refinement"""
    
    def __init__(self, embed_dim: int, num_heads: int = 8):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        
        self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        
        # Feedforward  
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            nn.Linear(embed_dim * 4, embed_dim)
        )
    
    def forward(self, queries):
        # Self-attention: queries attend to themselves
        attn_out, _ = self.self_attn(queries, queries, queries)
        queries = self.norm1(queries + attn_out)
        
        # Feedforward
        ffn_out = self.ffn(queries)
        queries = self.norm2(queries + ffn_out)
        
        return queries

In [11]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class Predictor2D(nn.Module):
    def __init__(self,
                 embed_dim: int,
                 num_queries: int,
                 num_heads: int = None,
                 # accept both old and new arg names:
                 num_cross_blocks: int = None,
                 num_self_blocks: int = None,
                 num_cross_attn: int = None,
                 num_self_attn: int = None):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_queries = num_queries

        # choose a valid num_heads if not provided
        if num_heads is None:
            for h in (16, 12, 8, 6, 4, 3, 2, 1):
                if embed_dim % h == 0:
                    num_heads = h
                    break
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        self.num_heads = num_heads

        # harmonize naming: prefer *attn if provided, else *blocks, else paper defaults (9/2)
        if num_cross_attn is not None:
            L = num_cross_attn
        elif num_cross_blocks is not None:
            L = num_cross_blocks
        else:
            L = 9
        if num_self_attn is not None:
            M = num_self_attn
        elif num_self_blocks is not None:
            M = num_self_blocks
        else:
            M = 2

        # learnable queries
        self.query_embed = nn.Parameter(torch.zeros(1, num_queries, embed_dim))
        nn.init.trunc_normal_(self.query_embed, std=0.02)

        # Cross-attention blocks (Mask2Former-ish: norm -> cross-attn -> add, norm -> FFN -> add)
        self.cross_blocks = nn.ModuleList([
            nn.ModuleDict(dict(
                attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True),
                ffn  = nn.Sequential(
                    nn.Linear(embed_dim, 4*embed_dim),
                    nn.ReLU(inplace=True),
                    nn.Linear(4*embed_dim, embed_dim)
                ),
                norm1 = nn.LayerNorm(embed_dim),
                norm2 = nn.LayerNorm(embed_dim),
            )) for _ in range(L)
        ])

        # Extra self-attention blocks on queries
        self.self_blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True)
            for _ in range(M)
        ])

        # projection head f_L to map query outputs back to Fi1 embedding space
        self.proj = nn.Linear(embed_dim, embed_dim)

        # learnable mask token for masked Fi1 tiles (used to build K/V when Fi1 is masked)
        self.kv_mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        nn.init.trunc_normal_(self.kv_mask_token, std=0.02)

    @staticmethod
    def _add_2d_sincos_pos(feat_2d: torch.Tensor):
        """
        feat_2d: [B, D, H, W] -> returns [B, H*W, D] with fixed 2D sin/cos pos enc added
        """
        import torch.nn.functional as F
        B, D, H, W = feat_2d.shape
        device = feat_2d.device
    
        # build sincos grid
        y = torch.linspace(-1, 1, steps=H, device=device)
        x = torch.linspace(-1, 1, steps=W, device=device)
        yy, xx = torch.meshgrid(y, x, indexing='ij')
        pos = torch.stack([xx, yy], dim=-1).reshape(1, H, W, 2)  # [1,H,W,2]
    
        # project to channel dim D using sin/cos
        half = D // 2
        sin_in = pos[..., 0:1].repeat(1, 1, 1, half)
        cos_in = pos[..., 1:2].repeat(1, 1, 1, D - half)
        pos_embed = torch.cat([torch.sin(sin_in), torch.cos(cos_in)], dim=-1)  # [1,H,W,D]
        if pos_embed.shape[-1] != D:
            pos_embed = F.pad(pos_embed, (0, D - pos_embed.shape[-1]))[:, :, :, :D]
    
        # >>> key fix: match [B, D, H, W] before addition
        pos_embed = pos_embed.permute(0, 3, 1, 2)  # [1, D, H, W]
    
        feat = feat_2d + pos_embed.to(feat_2d.dtype)  # [B, D, H, W]
        feat = feat.flatten(2).transpose(1, 2)        # [B, H*W, D]
        return feat

    def forward(self, Fi1_online: torch.Tensor, Fi1_mask: torch.Tensor):
        """
        Fi1_online: [B, D, H8, W8]   (online pixel-decoder Fi1)
        Fi1_mask:   [B, H8*W8] bool  (True=masked positions in Fi1 to reconstruct)
        Returns:
           pred_masked_feats: [B, M, D] predictions for masked Fi1 positions (M = #masked)
           masked_indices:    [B, M] indices of the masked Fi1 positions (or -1 padded)
           query_feats:       [B, Q, D] final query embeddings
        """
        B, D, H8, W8 = Fi1_online.shape
        Q = self.num_queries

        # Build K/V from Fi1 with 2D sincos pos; replace masked tiles with kv_mask_token (NOT zeros)
        kv = Fi1_online.clone()  # [B,D,H8,W8]
        if Fi1_mask is not None:
            mask_2d = Fi1_mask.reshape(B, H8, W8).unsqueeze(1).expand(-1, D, -1, -1)  # [B,D,H8,W8]
            kv = torch.where(mask_2d, self.kv_mask_token.view(1, D, 1, 1).expand(B, D, H8, W8), kv)

        kv_seq = self._add_2d_sincos_pos(kv)  # [B, H8*W8, D] as K/V

        # Queries
        q = self.query_embed.expand(B, Q, D)  # [B,Q,D]

        # L cross-attention blocks
        for blk in self.cross_blocks:
            qn = blk['norm1'](q)
            attn_out, _ = blk['attn'](qn, kv_seq, kv_seq)  # cross-attn to Fi1+pos
            q = q + attn_out
            q = q + blk['ffn'](blk['norm2'](q))

        # M self-attn blocks on queries
        for sblk in self.self_blocks:
            q = sblk(q)

        # map to Fi1 embedding space
        q_proj = self.proj(q)  # [B,Q,D]

        # Route query embeddings to masked tiles (simple attention routing)
        scores = torch.einsum('bpd,bqd->bpq', kv_seq, q_proj) / (D ** 0.5)  # [B,P,Q]
        probs = scores.softmax(dim=-1)  # over queries

        if Fi1_mask is not None and Fi1_mask.any():
            pred_list, idx_list = [], []
            for b in range(B):
                mask_b = Fi1_mask[b]  # [P]
                if mask_b.any():
                    prob_b = probs[b, mask_b]     # [Mb, Q]
                    q_b    = q_proj[b]            # [Q, D]
                    pred_b = prob_b @ q_b         # [Mb, D]
                    pred_list.append(pred_b)
                    idx_list.append(mask_b.nonzero(as_tuple=False).squeeze(1))
                else:
                    pred_list.append(q_proj.new_zeros((0, D)))
                    idx_list.append(torch.zeros((0,), dtype=torch.long, device=q_proj.device))
            maxM = max([p.size(0) for p in pred_list])
            if maxM == 0:
                pred_masked_feats = q_proj.new_zeros((B, 0, D))
                masked_indices    = q_proj.new_zeros((B, 0), dtype=torch.long)
            else:
                pred_masked_feats, masked_indices = [], []
                for b in range(B):
                    m = pred_list[b].size(0)
                    pad = maxM - m
                    if pad > 0:
                        pred_masked_feats.append(F.pad(pred_list[b], (0,0,0,pad)))
                        masked_indices.append(F.pad(idx_list[b], (0,pad), value=-1))
                    else:
                        pred_masked_feats.append(pred_list[b])
                        masked_indices.append(idx_list[b])
                pred_masked_feats = torch.stack(pred_masked_feats, dim=0)  # [B, M, D]
                masked_indices    = torch.stack(masked_indices, dim=0)     # [B, M]
        else:
            pred_masked_feats = q_proj.new_zeros((B, 0, D))
            masked_indices    = q_proj.new_zeros((B, 0), dtype=torch.long)

        return pred_masked_feats, masked_indices, q_proj

## Setup

#### Dataset & Dataloader

In [12]:
# =============================================================================
# Create Datasets and DataLoaders
# =============================================================================

# Training configuration
batch_size_pretrain = 424
batch_size_downstream = 48

# Create dataset instances
jepa_dataset = JEPADataset()
ade_train_dataset = ADE20KDataset(split="training")
ade_val_dataset = ADE20KDataset(split="validation")

print(f"Dataset sizes:")
print(f"JEPA (ImageNet): {len(jepa_dataset):,} samples")
print(f"ADE20K train: {len(ade_train_dataset):,} samples")
print(f"ADE20K val: {len(ade_val_dataset):,} samples")

# JEPA pretraining loader
pretrain_loader = DataLoader(
    jepa_dataset,
    batch_size=batch_size_pretrain,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    collate_fn=jepa_collate
)

# Downstream fine-tuning loaders
downstream_train_loader = DataLoader(
    ade_train_dataset,
    batch_size=batch_size_downstream,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    collate_fn=ade_collate
)

downstream_val_loader = DataLoader(
    ade_val_dataset,
    batch_size=batch_size_downstream,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    collate_fn=ade_collate
)

print(f"\nDataLoader info:")
print(f"Pretrain: {len(pretrain_loader)} batches")
print(f"Train: {len(downstream_train_loader)} batches")
print(f"Val: {len(downstream_val_loader)} batches")

Loading dataset from disk:   0%|          | 0/17 [00:00<?, ?it/s]

Dataset sizes:
JEPA (ImageNet): 126,689 samples
ADE20K train: 20,210 samples
ADE20K val: 2,000 samples

DataLoader info:
Pretrain: 299 batches
Train: 422 batches
Val: 42 batches


In [13]:
# Test batch shapes for 2D JEPA
batch = next(iter(pretrain_loader))
imgs = batch["images"] # (B, C, H, W)

print(imgs.shape)

torch.Size([424, 3, 128, 128])


#### Model Setup

#### Mask-JEPA setup

In [14]:
import copy
import math
from typing import Tuple, Optional
import torch
import time
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint

class MaskJEPA2D(nn.Module):
    def __init__(self,
                 in_chans: int,
                 num_queries: int = 32,
                 num_cross_attn: int = 2,
                 num_self_attn: int = 1,
                 tau: float = 0.996,
                 fi1_mask_ratio: float = 0.5,
                 patch_size: int = 8,
                 model_name: str = "swin_tiny_patch4_window7_224",
                 pretrained: bool = True):
        super().__init__()

        self.patch_size = patch_size
        
        # === Context encoder (timm ViT with pos_embed) ===
        self.context_encoder = ContextEncoder2D(model_name=model_name, pretrained=pretrained)

        # pull embed_dim & patch_size from the encoder backbone
        self.embed_dim = self.context_encoder.embed_dim
        self.in_chans = in_chans  # needed for denoising head output

        # === Target encoder (EMA, frozen) ===
        self.target_encoder = copy.deepcopy(self.context_encoder)
        for p in self.target_encoder.parameters():
            p.requires_grad = False

        # === Pixel decoders ===
        self.pixel_decoder = PixelDecoder2D(
            in_channels=self.embed_dim,
            embed_dim=self.embed_dim
        )
        self.pixel_decoder_ema = copy.deepcopy(self.pixel_decoder)
        for p in self.pixel_decoder_ema.parameters():
            p.requires_grad = False

        # === Downsampling convs for C4 / C5 ===
        self.ds16 = nn.Sequential(
            nn.Conv2d(self.embed_dim, self.embed_dim, 3, stride=2, padding=1), nn.GELU(),
            nn.Conv2d(self.embed_dim, self.embed_dim, 3, padding=1),           nn.GELU(),
        )
        self.ds32 = nn.Sequential(
            nn.Conv2d(self.embed_dim, self.embed_dim, 3, stride=2, padding=1), nn.GELU(),
            nn.Conv2d(self.embed_dim, self.embed_dim, 3, padding=1),           nn.GELU(),
        )

        # === JEPA Predictor ===
        self.predictor = Predictor2D(
            embed_dim=self.embed_dim,
            num_queries=num_queries,
            num_cross_attn=num_cross_attn,
            num_self_attn=num_self_attn
        )

        # === Denoising head ===
        self.denoising_head = nn.Conv2d(self.embed_dim, in_chans, kernel_size=1)

        # === EMA tau config ===
        self.tau = tau
        self.tau_base  = tau
        self.tau_final = 1.0

        self.fi1_mask_ratio = fi1_mask_ratio

    def forward(self, x: torch.Tensor):
        """
        Paper-correct noise path:
          - Add Gaussian noise at s_last=4 and expand
          - Online branch sees noisy input
          - Target branch sees clean input
        """
        B, C, H, W = x.shape
        device = x.device
    
        # ---------- (A) add noise ----------
        s_last = 4
        H4, W4 = H // s_last, W // s_last
        sigma = 0.4
    
        eps_lr = torch.randn(B, C, H4, W4, device=device, dtype=x.dtype) * sigma
        eps_full = eps_lr.repeat_interleave(s_last, dim=2).repeat_interleave(s_last, dim=3)
        x_noisy = x + eps_full
    
        # ---------- (B) ONLINE BRANCH ----------
        tokens_online, (enc_h, enc_w) = self.context_encoder(x_noisy)  # [B,P,D], (Ht,Wt)
        feat_online = tokens_online.transpose(1, 2).reshape(
            B, self.embed_dim, enc_h, enc_w
        )
        
        C3 = F.interpolate(feat_online, size=(H//8, W//8), mode='bilinear', align_corners=False)
        x16 = self.ds16(C3)
        x32 = self.ds32(x16)
        C4  = F.interpolate(x16, size=(H//16, W//16), mode='bilinear', align_corners=False)
        C5  = F.interpolate(x32, size=(H//32, W//32), mode='bilinear', align_corners=False)
        
        f_i1_online, f_last_online = self.pixel_decoder([C3, C4, C5], (H, W))

        # ---------- (C) TARGET BRANCH ----------
        with torch.no_grad():
            tokens_target, _ = self.target_encoder(x)
            feat_target = tokens_target.transpose(1, 2).reshape(
                B, self.embed_dim, enc_h, enc_w
            )
        
            C3t = F.interpolate(feat_target, size=(H//8, W//8), mode='bilinear', align_corners=False)
            x16t = self.ds16(C3t)
            x32t = self.ds32(x16t)
            C4t  = F.interpolate(x16t, size=(H//16, W//16), mode='bilinear', align_corners=False)
            C5t  = F.interpolate(x32t, size=(H//32, W//32), mode='bilinear', align_corners=False)
        
            f_i1_target, _ = self.pixel_decoder_ema([C3t, C4t, C5t], (H, W))

        # ---------- (D) Fi1 MASK ----------
        fi1_mask = generate_fi1_mask(
            fi1_shape=f_i1_online.shape,
            mask_ratio=self.fi1_mask_ratio,
            patch_size = self.patch_size,
            device=device
        )  # [B, H8*W8] bool
    
        # ---------- (E) PREDICTOR ----------
        predicted_features, masked_indices, q_proj = self.predictor(f_i1_online, fi1_mask)
    
        # ---------- (F) TARGET GATHER + LN ----------
        D = f_i1_target.shape[1]
        fi1_h, fi1_w = f_i1_target.shape[-2:]
        target_seq = f_i1_target.permute(0, 2, 3, 1).reshape(B, fi1_h * fi1_w, D)
        target_seq = F.layer_norm(target_seq, (D,))
    
        if masked_indices.numel() > 0:
            pad_mask = (masked_indices >= 0)
            safe_idx = masked_indices.clamp_min(0)
            b_idx = torch.arange(B, device=device).unsqueeze(-1).expand_as(safe_idx)
            target_masked_full = target_seq[b_idx, safe_idx]
            target_masked = target_masked_full * pad_mask.unsqueeze(-1).to(target_masked_full.dtype)
        else:
            target_masked = target_seq.new_zeros((B, 0, D))
    
        # ---------- (G) DENOISING ----------
        denoised_prediction = self.denoising_head(f_last_online)
    
        return {
            'predicted_features': predicted_features,
            'target_masked':      target_masked,
            'mask_info':          (q_proj, masked_indices),
            'denoised_prediction': denoised_prediction,
            'original_input':     x,
            'fi1_mask':           fi1_mask,
            'mask_indices':       masked_indices,
            'eps_target':         eps_lr
        }

    def set_ema_tau(self, tau: float):
        self.tau = float(tau)

    @torch.no_grad()
    def update_ema(self):
        update_ema(self.target_encoder, self.context_encoder, tau=self.tau)
        update_ema(self.pixel_decoder_ema, self.pixel_decoder, tau=self.tau)


In [15]:
# import copy
# import math
# import torch
# import torch.nn.functional as F
# from torch.optim import AdamW
# from torch.optim.lr_scheduler import LambdaLR

# # Learning rate schedule (linear warmup + cosine decay)
# def lr_lambda(epoch: int) -> float:
#     if epoch < warmup_epochs:
#         # Linear warmup: scale from 0 → 1 over warmup_epochs
#         return float(epoch) / float(max(1, warmup_epochs))
#     else:
#         # Cosine decay: 1 → 0 after warmup
#         decay_epoch = epoch - warmup_epochs
#         decay_total = num_epochs - warmup_epochs
#         cosine = 0.5 * (1 + math.cos(math.pi * decay_epoch / decay_total))
#         return cosine


In [16]:
def lr_lambda(epoch: int) -> float:
    if epoch < warmup_epochs:
        # Linear warmup: scale from 0 → 1 over warmup_epochs
        return float(epoch) / float(max(1, warmup_epochs))
    elif epoch < warmup_epochs + 3:
        # Aggressive phase: stay at full rate for 3 epochs after warmup
        return 1.0
    else:
        # Exponential decay phase
        decay_start = warmup_epochs + 3
        decay_factor = 0.7 ** (epoch - decay_start)
        return max(decay_factor, 0.3)  # Don't decay below 30% of base_lr

In [17]:
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from torch.amp import GradScaler, autocast
import os
import matplotlib.pyplot as plt
import numpy as np
import gc
import math
import time   # for epoch timing
use_bf16 = torch.cuda.is_bf16_supported()  # True on A100/RTX 40xx/etc


device = "cuda" if torch.cuda.is_available() else "cpu"

# Training configuration
num_epochs = 20
warmup_epochs = 0
base_lr = 1e-4
weight_decay = 0.05

# Early stopping config
early_stop_patience = 5
best_total_sc = float('inf')
epochs_no_improve = 0

# Enable memory optimizations
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True


def enable_gradient_checkpointing(model):
    """Enable gradient checkpointing for timm ViT blocks"""
    if hasattr(model.context_encoder, 'vit') and hasattr(model.context_encoder.vit, 'blocks'):
        for block in model.context_encoder.vit.blocks:
            if hasattr(block, 'set_grad_checkpointing'):
                block.set_grad_checkpointing(True)
    if hasattr(model.target_encoder, 'vit') and hasattr(model.target_encoder.vit, 'blocks'):
        for block in model.target_encoder.vit.blocks:
            if hasattr(block, 'set_grad_checkpointing'):
                block.set_grad_checkpointing(True)

# Create model
print("Creating model...")
model = MaskJEPA2D(
    in_chans=3,
    tau=0.996,
    fi1_mask_ratio=0.5,
    num_queries=8,
    num_cross_attn=5,
    num_self_attn=1,
    patch_size=8
).to(device)

D = model.embed_dim

# Enable gradient checkpointing
enable_gradient_checkpointing(model)
print("Gradient checkpointing enabled")

# Mixed precision scaler
scaler = GradScaler('cuda')

# Optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=base_lr, weight_decay=weight_decay)
scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)

# Create save directory
save_dir = "./jepa_training_output"
os.makedirs(save_dir, exist_ok=True)
best_ckpt_path = os.path.join(save_dir, "best_jepa_model.pt")

def clear_memory():
    """Memory cleanup"""
    torch.cuda.empty_cache()
    gc.collect()

# Print model info
print(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
print(f"Batch size: {batch_size_pretrain}")

ps = model.patch_size
mr = model.fi1_mask_ratio
nq = model.predictor.num_queries
nca = len(model.predictor.cross_blocks)
nsa = len(model.predictor.self_blocks)

print(f"Model config: patch_size={ps}, mask_ratio={mr}, queries={nq}, cross_attn={nca}, self_attn={nsa}")

# EMA ramp setup
planned_updates_per_epoch = len(pretrain_loader)
max_updates = num_epochs * planned_updates_per_epoch
global_update = 0

# Training loop
train_losses = []
best_snapshot = None  # will hold best weights for final save

for epoch in range(num_epochs):
    epoch_start_time = time.time()   # start timer
    print(f"Epoch {epoch+1}/{num_epochs}")
    
    model.train()
    epoch_loss = 0.0
    epoch_recon_loss = 0.0
    epoch_denoise_loss = 0.0
    
    clear_memory()
    optimizer.zero_grad()
    
    for batch_idx, batch in enumerate(pretrain_loader):
        images = batch["images"].to(device, non_blocking=True)
        
        # Forward pass with mixed precision (bf16 if available, else fp16)
        with autocast(device_type='cuda', enabled=False):
            outputs = model(images)
            
            pred = outputs['predicted_features']
            tgt = outputs['target_masked']
            idx = outputs['mask_indices']
            valid = (idx >= 0).unsqueeze(-1)
            
            # Reconstruction loss
            if pred.numel() == 0 or valid.sum() == 0:
                recon_loss = pred.new_zeros(())
            else:
                diff = (pred - tgt) * valid
                recon_loss = diff.pow(2).sum() / valid.sum().clamp_min(1)

            # Option 2: Predict clean image x (current default)
            x4 = F.interpolate(
                outputs['original_input'],
                size=outputs['denoised_prediction'].shape[-2:],
                mode='bilinear', 
                align_corners=False
            )
            denoise_loss = F.mse_loss(outputs['denoised_prediction'], x4)
            
            # Debug info (first batch only)
            if batch_idx == 0:
                td = F.interpolate(images, size=outputs['denoised_prediction'].shape[-2:],
                                   mode='bilinear', align_corners=False)
                pd = outputs['denoised_prediction'].detach()
                print(f"[probe] target: mean={td.mean().item():.3f} std={td.std().item():.3f}")
                print(f"[probe] pred  : mean={pd.mean().item():.3f} std={pd.std().item():.3f}")

            total_loss = recon_loss + denoise_loss
        
        # Backward pass
        scaler.scale(total_loss).backward()
        
        # Gradient clipping
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        # Optimizer step
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
        
        # EMA update with ramp
        progress = global_update / max(1, max_updates - 1)
        tau_now = model.tau_base + (model.tau_final - model.tau_base) * progress
        model.set_ema_tau(tau_now)
        model.update_ema()
        global_update += 1
        
        # Track losses
        epoch_loss += total_loss.item()
        epoch_recon_loss += recon_loss.item()
        epoch_denoise_loss += denoise_loss.item()
        
        # Cleanup
        del outputs, recon_loss, denoise_loss, total_loss, images
        
        # Progress logging
        if batch_idx % 50 == 0:
            recon_sc = (epoch_recon_loss / (batch_idx + 1)) / D
            denoise_sc = (epoch_denoise_loss / (batch_idx + 1))
            total_sc = recon_sc + denoise_sc
            
            print(f"  Batch {batch_idx}/{len(pretrain_loader)} - "
                  f"TotalSc: {total_sc:.4f}, ReconSc: {recon_sc:.4f}, DenoiseSc: {denoise_sc:.4f}")
    
    # Average losses
    epoch_loss /= len(pretrain_loader)
    epoch_recon_loss /= len(pretrain_loader)
    epoch_denoise_loss /= len(pretrain_loader)
    
    train_losses.append(epoch_loss)
    scheduler.step()
    
    # Epoch summary
    recon_sc_epoch = epoch_recon_loss / D
    denoise_sc_epoch = epoch_denoise_loss
    total_sc_epoch = recon_sc_epoch + denoise_sc_epoch
    
    epoch_end_time = time.time()   # end timer
    epoch_duration = epoch_end_time - epoch_start_time
    print(f"  Avg losses - TotalSc: {total_sc_epoch:.4f}, "
          f"ReconSc: {recon_sc_epoch:.4f}, DenoiseSc: {denoise_sc_epoch:.4f}")
    print(f"  LR: {scheduler.get_last_lr()[0]:.2e}")
    print(f"  GPU Memory: {torch.cuda.memory_allocated()/1e9:.2f}GB")
    print(f"  Time for epoch {epoch+1}: {epoch_duration/60:.2f} minutes")
    
    # Periodic evaluation (every 1 epoch here)
    if (epoch + 1) % 1 == 0:
        print("  Running evaluation...")
        model.eval()
        
        with torch.no_grad():
            eval_batch = next(iter(pretrain_loader))
            eval_images = eval_batch["images"][:4].to(device)
            
            with autocast(device_type='cuda', enabled=False):
                eval_outputs = model(eval_images)
            
            # Visualization
            H, W = eval_images.shape[-2:]
            fi1_tile = max(H // (H // 8), 1)
            
            vis_path = os.path.join(save_dir, f"reconstruction_epoch_{epoch+1:03d}.png")
            
            visualize_jepa_patch_quality(
                original=eval_images,
                predicted_features=eval_outputs['predicted_features'].float(),
                target_features=eval_outputs['target_masked'].float(),
                patch_mask=eval_outputs['fi1_mask'],
                epoch=epoch+1,
                save_path=vis_path,
                patch_size=fi1_tile
            )
            
            del eval_batch, eval_images, eval_outputs
        
        print(f"    Saved visualization: {vis_path}")
        model.train()
        clear_memory()
    
    # ---- Save ONLY when we have a new best TotalSc ----
    if total_sc_epoch < best_total_sc:
        best_total_sc = total_sc_epoch
        epochs_no_improve = 0
        best_snapshot = {
            'backbone_state_dict': model.context_encoder.state_dict(),
            'pixel_decoder_state_dict': model.pixel_decoder.state_dict(),
            'transformer_decoder_cross_blocks_state_dict': model.predictor.cross_blocks.state_dict(),
        }
        torch.save(best_snapshot, best_ckpt_path)
        print(f"    [best] New best TotalSc={best_total_sc:.4f}. Saved: {best_ckpt_path}")
    else:
        epochs_no_improve += 1
        print(f"  [early-stop] No improvement ({epochs_no_improve}/{early_stop_patience}).")
        if epochs_no_improve >= early_stop_patience:
            print(f"  [early-stop] Patience exceeded. Stopping training early.")
            break
    # --------------------------------------

print("Training completed!")

# Save final pretrained weights for downstream use (best-only)
final_weights_path = os.path.join(save_dir, 'mask_jepa_pretrained_weights.pt')
if best_snapshot is not None:
    torch.save(best_snapshot, final_weights_path)
else:
    # Fallback: save current (shouldn't happen unless no batches ran)
    torch.save({
        'backbone_state_dict': model.context_encoder.state_dict(),
        'pixel_decoder_state_dict': model.pixel_decoder.state_dict(),
        'transformer_decoder_cross_blocks_state_dict': model.predictor.cross_blocks.state_dict(),
    }, final_weights_path)
print(f"Final pretrained weights saved (best-only): {final_weights_path}")


Creating model...
Gradient checkpointing enabled
Model parameters: 119,820,413
Batch size: 424
Model config: patch_size=8, mask_ratio=0.5, queries=8, cross_attn=5, self_attn=1
Epoch 1/20
[probe] target: mean=-0.041 std=1.135
[probe] pred  : mean=-0.009 std=0.267
  Batch 0/299 - TotalSc: 2.6922, ReconSc: 1.3513, DenoiseSc: 1.3409
  Batch 50/299 - TotalSc: 1.2750, ReconSc: 0.2401, DenoiseSc: 1.0349
  Batch 100/299 - TotalSc: 0.9345, ReconSc: 0.1465, DenoiseSc: 0.7880
  Batch 150/299 - TotalSc: 0.7808, ReconSc: 0.1086, DenoiseSc: 0.6722
  Batch 200/299 - TotalSc: 0.6919, ReconSc: 0.0913, DenoiseSc: 0.6006
  Batch 250/299 - TotalSc: 0.6364, ReconSc: 0.0858, DenoiseSc: 0.5506
  Avg losses - TotalSc: 0.6069, ReconSc: 0.0897, DenoiseSc: 0.5172
  LR: 1.00e-04
  GPU Memory: 2.06GB
  Time for epoch 1: 25.47 minutes
  Running evaluation...
    Saved visualization: ./jepa_training_output/reconstruction_epoch_001.png
    [best] New best TotalSc=0.6069. Saved: ./jepa_training_output/best_jepa_model.

KeyboardInterrupt: 

## Fine Tuning

In [None]:
# # === NaN-Fixed Mask-JEPA fine-tuning @ 224x224 ===
# import os, re, math, numpy as np, torch
# import torch.nn as nn
# import torch.nn.functional as F
# from collections import OrderedDict
# from torch.optim import AdamW
# from torch.optim.lr_scheduler import LambdaLR
# from torch.amp import GradScaler, autocast
# import matplotlib.pyplot as plt
# from tqdm import tqdm

# # Enable anomaly detection for debugging
# torch.autograd.set_detect_anomaly(False)  # Turn off for performance, enable if debugging

# # -------------------------
# # CONFIG
# # -------------------------
# device = "cuda" if torch.cuda.is_available() else "cpu"

# finetune_epochs = 40
# finetune_lr = 1e-3            # Reduced LR for stability
# finetune_weight_decay = 0.05
# num_classes = 150
# BACKBONE_IM_SIZE = 224
# IGNORE_INDEX = 255

# print(f"Fine-tuning for {finetune_epochs} epochs")
# vis_dir = "./jepa_finetuning_visualizations"
# os.makedirs(vis_dir, exist_ok=True)

# # -------------------------
# # NUMERICALLY STABLE SEGMENTATION MODEL
# # -------------------------
# class SegmentationModel(nn.Module):
#     def __init__(self, backbone, pixel_decoder, transformer_cross_blocks, num_classes, num_queries=32):
#         super().__init__()
#         self.backbone = backbone
#         self.pixel_decoder = pixel_decoder
#         self.transformer_decoder = nn.ModuleDict({'cross_blocks': transformer_cross_blocks})

#         # Ensure embed_dim is properly set
#         if not hasattr(self.pixel_decoder, "embed_dim"):
#             self.pixel_decoder.embed_dim = getattr(self.backbone, "embed_dim")
#         embed_dim = self.pixel_decoder.embed_dim

#         self.num_queries = num_queries
        
#         # STABLE: Much smaller initialization to prevent NaN
#         self.query_embed = nn.Parameter(torch.zeros(1, num_queries, embed_dim))
#         nn.init.normal_(self.query_embed, std=0.01)  # Very small std
        
#         # Add learnable positional encodings for queries
#         self.query_pos = nn.Parameter(torch.zeros(1, num_queries, embed_dim))
#         nn.init.normal_(self.query_pos, std=0.01)

#         # STABLE: Simpler, more robust heads
#         self.class_head = nn.Sequential(
#             nn.LayerNorm(embed_dim),
#             nn.Linear(embed_dim, num_classes)
#         )
        
#         self.mask_head = nn.Sequential(
#             nn.LayerNorm(embed_dim),
#             nn.Linear(embed_dim, embed_dim)
#         )

#         # Pyramid feature processing
#         self.ds16 = nn.Sequential(
#             nn.Conv2d(embed_dim, embed_dim, 3, stride=2, padding=1), 
#             nn.GroupNorm(min(32, embed_dim), embed_dim),
#             nn.GELU(),
#         )
#         self.ds32 = nn.Sequential(
#             nn.Conv2d(embed_dim, embed_dim, 3, stride=2, padding=1), 
#             nn.GroupNorm(min(32, embed_dim), embed_dim),
#             nn.GELU(),
#         )

#         # STABLE: Proper normalization for stability
#         self.kv_norm = nn.LayerNorm(embed_dim)
#         self.query_norm = nn.LayerNorm(embed_dim)

#         # STABLE: Fixed temperature (no learnable parameter to avoid instability)
#         self.register_buffer('mask_temp', torch.tensor(10.0))

#         # STABLE: Conservative initialization
#         self._init_weights()

#     def _init_weights(self):
#         # Very conservative initialization to prevent NaN
#         for m in [self.class_head, self.mask_head]:
#             for layer in m.modules():
#                 if isinstance(layer, nn.Linear):
#                     nn.init.xavier_uniform_(layer.weight, gain=0.01)  # Very small gain
#                     if layer.bias is not None:
#                         nn.init.constant_(layer.bias, 0)
        
#         # Set class head bias for better convergence  
#         if hasattr(self.class_head[-1], 'bias') and self.class_head[-1].bias is not None:
#             nn.init.constant_(self.class_head[-1].bias, -math.log(num_classes - 1))

#     @staticmethod
#     def _to_tokens(x: torch.Tensor) -> torch.Tensor:
#         """[B,D,H,W] -> [B,HW,D]"""
#         B, D, H, W = x.shape
#         return x.permute(0,2,3,1).reshape(B, H*W, D)

#     def _check_for_nan_inf(self, tensor, name):
#         """Debug helper to catch NaN/Inf"""
#         if torch.isnan(tensor).any():
#             print(f"NaN detected in {name}")
#             return False
#         if torch.isinf(tensor).any():
#             print(f"Inf detected in {name}")
#             return False
#         return True

#     def forward(self, x):
#         B, C, H_raw, W_raw = x.shape

#         # 1) Backbone
#         x_in = x if (H_raw == BACKBONE_IM_SIZE and W_raw == BACKBONE_IM_SIZE) else \
#                F.interpolate(x, size=(BACKBONE_IM_SIZE, BACKBONE_IM_SIZE), mode='bilinear', align_corners=False)

#         tokens, (enc_h, enc_w) = self.backbone(x_in)
#         feat = tokens.transpose(1, 2).reshape(B, self.backbone.embed_dim, enc_h, enc_w)
        
#         if not self._check_for_nan_inf(feat, "backbone_feat"):
#             raise ValueError("NaN/Inf in backbone features")

#         # 2) Build pyramid
#         C3 = F.interpolate(feat, size=(H_raw // 8,  W_raw // 8),  mode='bilinear', align_corners=False)
#         x16 = self.ds16(C3)
#         x32 = self.ds32(x16)
#         C4 = F.interpolate(x16, size=(H_raw // 16, W_raw // 16), mode='bilinear', align_corners=False)
#         C5 = F.interpolate(x32, size=(H_raw // 32, W_raw // 32), mode='bilinear', align_corners=False)

#         # 3) Pixel decoder
#         Fi1, _ = self.pixel_decoder([C3, C4, C5], (H_raw, W_raw))
        
#         if not self._check_for_nan_inf(Fi1, "Fi1"):
#             raise ValueError("NaN/Inf in pixel decoder output")

#         # 4) STABLE: More robust transformer decoder
#         # Simple positional encoding (no complex sin/cos to avoid numerical issues)
#         def _simple_pos_encoding(x):
#             B, D, H, W = x.shape
#             pos_h = torch.linspace(-1, 1, H, device=x.device, dtype=x.dtype).view(1, 1, H, 1).expand(B, D//2, H, W)
#             pos_w = torch.linspace(-1, 1, W, device=x.device, dtype=x.dtype).view(1, 1, 1, W).expand(B, D-D//2, H, W)
#             return torch.cat([pos_h, pos_w], dim=1) * 0.1  # Small scale factor

#         Fi1_pos = Fi1 + _simple_pos_encoding(Fi1)
#         kv_seq = self._to_tokens(Fi1_pos).contiguous()
#         kv_seq = self.kv_norm(kv_seq)
        
#         if not self._check_for_nan_inf(kv_seq, "kv_seq"):
#             raise ValueError("NaN/Inf in KV sequence")

#         # Initialize queries
#         queries = self.query_embed.expand(B, -1, -1) + self.query_pos.expand(B, -1, -1)
        
#         # STABLE: Add small noise to break symmetry without causing instability
#         if self.training:
#             queries = queries + torch.randn_like(queries) * 0.001
        
#         if not self._check_for_nan_inf(queries, "initial_queries"):
#             raise ValueError("NaN/Inf in initial queries")
        
#         # Transformer layers with careful gradient monitoring
#         for i, blk in enumerate(self.transformer_decoder['cross_blocks']):
#             queries_norm = self.query_norm(queries)
#             if not self._check_for_nan_inf(queries_norm, f"queries_norm_{i}"):
#                 raise ValueError(f"NaN/Inf in queries_norm layer {i}")
            
#             residual = queries
#             queries = blk(queries_norm, kv_seq)
            
#             if not self._check_for_nan_inf(queries, f"queries_after_block_{i}"):
#                 raise ValueError(f"NaN/Inf in queries after block {i}")
            
#             # Add residual connection with gradient clipping
#             queries = residual + queries
            
#             # Gradient clipping within forward pass to prevent explosion
#             if queries.requires_grad:
#                 queries.register_hook(lambda grad: torch.clamp(grad, -10.0, 10.0))

#         # 5) STABLE: Robust heads with numerical safeguards
#         class_logits = self.class_head(queries)
#         mask_features = self.mask_head(queries)
        
#         if not self._check_for_nan_inf(class_logits, "class_logits"):
#             raise ValueError("NaN/Inf in class_logits")
#         if not self._check_for_nan_inf(mask_features, "mask_features"):
#             raise ValueError("NaN/Inf in mask_features")

#         # STABLE: Very careful mask computation to avoid NaN in einsum
#         Fi1_for_mask = Fi1_pos
#         Fi1_flat = Fi1_for_mask.flatten(2)  # [B,D,HW']
        
#         # STABLE: L2 normalize with epsilon for numerical stability
#         eps = 1e-8
#         mask_features_norm = F.normalize(mask_features + eps, dim=-1, p=2, eps=eps)  # [B,Q,D]
#         Fi1_flat_norm = F.normalize(Fi1_flat + eps, dim=1, p=2, eps=eps)  # [B,D,HW']
        
#         if not self._check_for_nan_inf(mask_features_norm, "mask_features_norm"):
#             raise ValueError("NaN/Inf in mask_features_norm")
#         if not self._check_for_nan_inf(Fi1_flat_norm, "Fi1_flat_norm"):
#             raise ValueError("NaN/Inf in Fi1_flat_norm")
        
#         # STABLE: Use bmm instead of einsum for better numerical control
#         # einsum('bqd,bdl->bql') equivalent to bmm
#         mask_logits = torch.bmm(mask_features_norm, Fi1_flat_norm)  # [B,Q,HW']
        
#         if not self._check_for_nan_inf(mask_logits, "mask_logits_raw"):
#             raise ValueError("NaN/Inf in raw mask_logits")
        
#         # Apply temperature with clamping
#         mask_logits = mask_logits * self.mask_temp.clamp(min=0.1, max=50.0)
        
#         # STABLE: Clamp logits to prevent extreme values
#         mask_logits = torch.clamp(mask_logits, min=-50.0, max=50.0)

#         # Reshape and upsample
#         H_fi1, W_fi1 = Fi1.shape[-2], Fi1.shape[-1]
#         mask_logits = mask_logits.reshape(B, self.num_queries, H_fi1, W_fi1)
#         mask_logits = F.interpolate(mask_logits, size=(H_raw, W_raw), mode='bilinear', align_corners=False)
        
#         if not self._check_for_nan_inf(mask_logits, "final_mask_logits"):
#             raise ValueError("NaN/Inf in final mask_logits")

#         return {'class_logits': class_logits, 'mask_logits': mask_logits}

# # -------------------------
# # STABLE LOSS FUNCTION
# # -------------------------
# def stable_semantic_ce(outputs, targets, *, ignore_index=IGNORE_INDEX, num_classes=None, label_smoothing=0.0):
#     """
#     STABLE: More numerically robust mixture-of-queries loss
#     """
#     class_logits = outputs['class_logits'].float()   # [B,Q,C]
#     mask_logits  = outputs['mask_logits'].float()    # [B,Q,H,W]

#     B, Q, C = class_logits.shape
#     _, _, H, W = mask_logits.shape

#     # STABLE: Clamp inputs to prevent extreme values
#     class_logits = torch.clamp(class_logits, min=-50, max=50)
#     mask_logits = torch.clamp(mask_logits, min=-50, max=50)
    
#     # STABLE: Use softmax with temperature for better numerical properties
#     temp = 2.0  # Temperature to soften distributions
#     class_probs = F.softmax(class_logits / temp, dim=-1)  # [B,Q,C]
#     mask_probs = F.softmax(mask_logits / temp, dim=1)     # [B,Q,H,W]
    
#     # Check for NaN in probabilities
#     if torch.isnan(class_probs).any() or torch.isnan(mask_probs).any():
#         print("NaN detected in probabilities - using fallback loss")
#         # Fallback: just use class logits from first query
#         first_query_logits = class_logits[:, 0]  # [B,C]
#         first_query_logits = first_query_logits.unsqueeze(-1).unsqueeze(-1).expand(B, C, H, W)
#         return F.cross_entropy(first_query_logits, targets, ignore_index=ignore_index)
    
#     # STABLE: Mixture computation with numerical safeguards
#     class_probs_exp = class_probs.permute(0,2,1).unsqueeze(-1).unsqueeze(-1)  # [B,C,Q,1,1]
    
#     # Element-wise multiplication and sum
#     pixel_probs = (class_probs_exp * mask_probs.unsqueeze(1)).sum(dim=2)  # [B,C,H,W]
    
#     # STABLE: Add small epsilon and renormalize
#     eps = 1e-8
#     pixel_probs = pixel_probs + eps
#     pixel_probs = pixel_probs / pixel_probs.sum(dim=1, keepdim=True)
#     pixel_probs = torch.clamp(pixel_probs, min=eps, max=1.0 - eps)
    
#     # Convert to logits
#     pixel_logits = torch.log(pixel_probs / (1 - pixel_probs + eps))
#     pixel_logits = torch.clamp(pixel_logits, min=-50, max=50)
    
#     if torch.isnan(pixel_logits).any():
#         print("NaN in pixel_logits - using fallback")
#         # Another fallback
#         return F.cross_entropy(class_logits.mean(1).unsqueeze(-1).unsqueeze(-1).expand(B, C, H, W), 
#                               targets, ignore_index=ignore_index)

#     # Standard cross entropy
#     targets = targets.long()
#     if num_classes is not None:
#         ok = (targets != ignore_index)
#         if ok.any():
#             targets = torch.where(ok, targets.clamp(0, num_classes - 1), targets)

#     loss = F.cross_entropy(
#         pixel_logits, targets,
#         ignore_index=ignore_index,
#         label_smoothing=label_smoothing
#     )
    
#     return loss

# # -------------------------
# # UTILITIES (SIMPLIFIED TO AVOID ISSUES)
# # -------------------------

# def remap_and_sanitize_targets(t, num_classes=num_classes, ignore_index=IGNORE_INDEX):
#     """Clean up ADE-style labels"""
#     t = t.long()
#     has_numclass = (t == num_classes).any()
#     valid = (t != ignore_index)
#     max_lab = int(t[valid].max()) if valid.any() else -1
#     one_based = has_numclass or (max_lab == num_classes)

#     if one_based:
#         zero_mask = (t == 0)
#         t = torch.where((t >= 1) & (t <= num_classes), t - 1, t)
#         t = torch.where(zero_mask, torch.as_tensor(ignore_index, device=t.device, dtype=t.dtype), t)

#     bad_lo = (t < 0) & (t != ignore_index)
#     bad_hi = (t >= num_classes) & (t != ignore_index)
#     if bad_lo.any() or bad_hi.any():
#         t = torch.where(bad_lo | bad_hi, torch.as_tensor(ignore_index, device=t.device, dtype=t.dtype), t)
#     return t

# # -------------------------
# # VISUALIZATION FUNCTIONS
# # -------------------------
# def visualize_segmentation(images, targets, outputs, epoch, save_path):
#     """Create segmentation visualization: Original | Ground Truth | Prediction"""
#     B = min(2, images.shape[0])
    
#     # Denormalize images for visualization
#     mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(images.device)
#     std  = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(images.device)
#     images_vis = torch.clamp(images * std + mean, 0, 1)

#     # Get predictions from model outputs
#     class_logits = outputs['class_logits']  # [B,Q,C]
#     mask_logits  = outputs['mask_logits']   # [B,Q,H,W]
    
#     # For each pixel, find the query with highest mask probability
#     mask_probs = torch.sigmoid(mask_logits)  # [B,Q,H,W]
#     best_queries = mask_probs.argmax(dim=1)  # [B,H,W] - which query wins each pixel
    
#     # Get class predictions for each query
#     class_preds = class_logits.argmax(dim=-1)  # [B,Q] - class per query
    
#     # Create pixel-wise predictions
#     H, W = best_queries.shape[-2:]
#     pixel_preds = torch.zeros(B, H, W, device=targets.device, dtype=torch.long)
#     for b in range(B):
#         pixel_preds[b] = class_preds[b][best_queries[b]]

#     # Create visualization
#     fig, axes = plt.subplots(B, 3, figsize=(12, 4*B))
#     if B == 1: 
#         axes = axes.reshape(1, -1)
    
#     for i in range(B):
#         # Original image
#         img = images_vis[i].permute(1, 2, 0).detach().cpu().numpy()
#         axes[i, 0].imshow(img)
#         axes[i, 0].set_title('Original')
#         axes[i, 0].axis('off')

#         # Ground truth
#         gt = targets[i].detach().cpu().numpy()
#         gt_vis = np.where(gt == IGNORE_INDEX, 0, gt % 20)  # Mod 20 for colormap
#         im1 = axes[i, 1].imshow(gt_vis, cmap='tab20', vmin=0, vmax=19)
#         axes[i, 1].set_title('Ground Truth')
#         axes[i, 1].axis('off')

#         # Prediction
#         pred = pixel_preds[i].detach().cpu().numpy()
#         pred_vis = pred % 20  # Mod 20 for colormap consistency
#         im2 = axes[i, 2].imshow(pred_vis, cmap='tab20', vmin=0, vmax=19)
#         axes[i, 2].set_title('Prediction')
#         axes[i, 2].axis('off')

#     plt.suptitle(f'Segmentation Results - Epoch {epoch}')
#     plt.tight_layout()
#     plt.savefig(save_path, dpi=100, bbox_inches='tight')
#     plt.close()

# def compute_miou(outputs, targets, num_classes):
#     """Compute mean IoU for evaluation"""
#     class_logits = outputs['class_logits']   # [B,Q,C]
#     mask_logits  = outputs['mask_logits']    # [B,Q,H,W]

#     mask_probs = torch.sigmoid(mask_logits)
#     best_queries = mask_probs.argmax(dim=1)  # [B,H,W]

#     B, H, W = best_queries.shape
#     pixel_preds = torch.zeros(B, H, W, device=targets.device, dtype=torch.long)
#     for b in range(B):
#         class_preds = class_logits[b].argmax(dim=-1)  # [Q]
#         pixel_preds[b] = class_preds[best_queries[b]]

#     ious = []
#     for cls in range(min(20, num_classes)):  # Limit to 20 classes for computational efficiency
#         pred_mask = (pixel_preds == cls)
#         target_mask = (targets == cls)
#         inter = (pred_mask & target_mask).sum().float()
#         union = (pred_mask | target_mask).sum().float()
#         if union > 0:
#             ious.append((inter / union).item())
#     return float(np.mean(ious)) if ious else 0.0

# @torch.no_grad()
# def evaluate_model(model, val_loader, num_classes):
#     """Evaluate model on validation set"""
#     model.eval()
#     total_miou, num_batches = 0.0, 0
#     max_eval_batches = min(30, len(val_loader))
    
#     for batch_idx, batch in enumerate(val_loader):
#         if batch_idx >= max_eval_batches:
#             break
#         try:
#             images = batch['images'].to(device)
#             masks = remap_and_sanitize_targets(batch['masks'].to(device), num_classes)
            
#             outputs = model(images)
#             miou = compute_miou(outputs, masks, num_classes)
#             total_miou += miou
#             num_batches += 1
#         except Exception as e:
#             print(f"Error in evaluation batch {batch_idx}: {e}")
#             continue
    
#     return total_miou / num_batches if num_batches > 0 else 0.0

# # Placeholder for load_pretrained_weights - use your existing implementation
# def load_pretrained_weights(checkpoint_path, prefer_branch="online"):
#     print(f"Loading pretrained weights from {checkpoint_path}")
#     ckpt = torch.load(checkpoint_path, map_location='cpu')
    
#     # Create modules (you need to ensure these classes are defined)
#     backbone = ContextEncoder2D(model_name="vit_tiny_patch16_224", pretrained=False)
#     pixel_decoder = PixelDecoder2D(in_channels=backbone.embed_dim, embed_dim=backbone.embed_dim)
#     embed_dim = backbone.embed_dim
    
#     # Load cross blocks
#     saved_x_sd = ckpt.get("transformer_decoder_cross_blocks_state_dict", None)
#     n_x = 9  # or infer from saved state
#     transformer_cross_blocks = nn.ModuleList([CrossAttentionBlock2D(embed_dim, num_heads=8) for _ in range(n_x)])
    
#     # Load weights (simplified - you can use your detailed version)
#     if "backbone_state_dict" in ckpt:
#         backbone.load_state_dict(ckpt["backbone_state_dict"], strict=True)
#     if "pixel_decoder_state_dict" in ckpt:
#         pixel_decoder.load_state_dict(ckpt["pixel_decoder_state_dict"], strict=False)
#     if saved_x_sd:
#         transformer_cross_blocks.load_state_dict(saved_x_sd, strict=False)
    
#     return backbone, pixel_decoder, transformer_cross_blocks

# # -------------------------
# # MAIN TRAINING LOOP
# # -------------------------
# print("="*60)
# print("LOADING MASK-JEPA PRETRAINED WEIGHTS")
# print("="*60)

# pretrained_path = "/home/sks6nv/Projects/RL-JEPA/jepa_training_output//mask_jepa_pretrained_weights.pt"
# backbone, pixel_decoder, transformer_cross_blocks = load_pretrained_weights(pretrained_path)

# # Use fewer queries initially for stability
# model = SegmentationModel(backbone, pixel_decoder, transformer_cross_blocks, num_classes, num_queries=32).to(device)

# # STABLE: More conservative parameter groups
# head_params = []
# head_params.extend(list(model.class_head.parameters()))
# head_params.extend(list(model.mask_head.parameters()))
# head_params.extend([model.query_embed, model.query_pos])

# pretrained_params = []
# pretrained_params.extend(list(model.backbone.parameters()))
# pretrained_params.extend(list(model.pixel_decoder.parameters()))
# pretrained_params.extend(list(model.transformer_decoder['cross_blocks'].parameters()))

# # STABLE: Much more conservative learning rates
# optimizer = AdamW([
#     {"params": head_params, "lr": finetune_lr * 2, "weight_decay": finetune_weight_decay},
#     {"params": pretrained_params, "lr": finetune_lr * 0.1, "weight_decay": finetune_weight_decay * 0.1},
# ], betas=(0.9, 0.999), eps=1e-8)

# scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: max(0.1, (1 - epoch / finetune_epochs) ** 0.5))
# scaler = GradScaler(enabled=torch.cuda.is_available())

# print(f"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
# print("Starting training...")

# best_miou = 0.0

# for epoch in range(finetune_epochs):
#     model.train()
#     epoch_loss, num_batches = 0.0, 0
#     max_batches = min(150, len(downstream_train_loader))
    
#     pbar = tqdm(downstream_train_loader, desc=f"Epoch {epoch+1}/{finetune_epochs}", total=max_batches)

#     for batch_idx, batch in enumerate(pbar):
#         if batch_idx >= max_batches: 
#             break
            
#         try:
#             images = batch['images'].to(device)
#             masks = remap_and_sanitize_targets(batch['masks'].to(device), num_classes)

#             optimizer.zero_grad(set_to_none=True)
            
#             with autocast('cuda', enabled=torch.cuda.is_available()):
#                 outputs = model(images)
                
#                 # STABLE: Simple loss without complex regularization initially
#                 loss = stable_semantic_ce(outputs, masks, ignore_index=IGNORE_INDEX,
#                                         num_classes=num_classes, label_smoothing=0.0)

#             # Check for finite loss
#             if not torch.isfinite(loss):
#                 print(f"Non-finite loss at batch {batch_idx}: {loss.item()}")
#                 continue

#             scaler.scale(loss).backward()
#             scaler.unscale_(optimizer)
            
#             # STABLE: More aggressive gradient clipping
#             torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
            
#             scaler.step(optimizer)
#             scaler.update()

#             epoch_loss += float(loss.item())
#             num_batches += 1
#             pbar.set_postfix({'loss': f"{loss.item():.4f}"})

#         except Exception as e:
#             print(f"Error in batch {batch_idx}: {e}")
#             # Clear any corrupted gradients
#             optimizer.zero_grad(set_to_none=True)
#             continue

#     scheduler.step()
#     avg_loss = epoch_loss / num_batches if num_batches > 0 else 0
#     cur_lr = optimizer.param_groups[0]['lr']
#     print(f"\nEpoch {epoch+1}/{finetune_epochs} - Loss: {avg_loss:.4f}, LR: {cur_lr:.2e}")

#     # Visualization every epoch
#     if (epoch + 1) % 1 == 0:
#         print("  Creating visualizations...")
#         model.eval()
#         try:
#             with torch.no_grad():
#                 vis_batch = next(iter(downstream_val_loader))
#                 vis_images = vis_batch['images'][:2].to(device)
#                 vis_masks = remap_and_sanitize_targets(vis_batch['masks'][:2].to(device), num_classes)
#                 vis_outputs = model(vis_images)
                
#                 vis_path = os.path.join(vis_dir, f"seg_epoch_{epoch+1:03d}.png")
#                 visualize_segmentation(vis_images, vis_masks, vis_outputs, epoch+1, vis_path)
#                 print(f"    Saved visualization: {vis_path}")
#         except Exception as e:
#             print(f"    Visualization failed: {e}")
#         model.train()

#     # Evaluation every 5 epochs
#     if (epoch + 1) % 5 == 0:
#         print("  Running evaluation...")
#         try:
#             miou = evaluate_model(model, downstream_val_loader, num_classes)
#             print(f"  Validation mIoU: {miou:.4f}")
#             if miou > best_miou:
#                 best_miou = miou
#                 best_checkpoint_path = f"./jepa_training_output/mask_jepa_finetuned_best.pt"
#                 torch.save({
#                     'model_state_dict': model.state_dict(),
#                     'best_miou': best_miou,
#                     'epoch': epoch + 1
#                 }, best_checkpoint_path)
#                 print(f"  ✓ New best model saved! mIoU: {best_miou:.4f}")
#         except Exception as e:
#             print(f"  Evaluation failed: {e}")

#     # Checkpoint saving
#     if (epoch + 1) % 10 == 0:
#         try:
#             checkpoint_path = f"./jepa_training_output/mask_jepa_finetuned_epoch_{epoch+1}.pt"
#             torch.save({
#                 'model_state_dict': model.state_dict(),
#                 'epoch': epoch + 1,
#                 'loss': avg_loss
#             }, checkpoint_path)
#             print(f"  ✓ Checkpoint saved: {checkpoint_path}")
#         except Exception as e:
#             print(f"  Failed to save checkpoint: {e}")

# print("Fine-tuning completed!")

In [18]:
# ==== Fine-tuning: Fusion head (Fi1 + F_last), CE-only, with fast GPU metrics & batch prints ====
import os, gc, math, numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from torch.amp import GradScaler, autocast
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"
IGNORE_INDEX = 255
NUM_CLASSES  = 150

# -----------------------
# Helpers
# -----------------------
def _gn_groups(C):
    for g in (32,16,8,4,2,1):
        if C % g == 0: return g
    return 1

class DWSepResBlock(nn.Module):
    def __init__(self, channels: int, dilation: int = 1):
        super().__init__()
        self.dw   = nn.Conv2d(channels, channels, 3, padding=dilation, dilation=dilation,
                              groups=channels, bias=False)
        self.dw_g = nn.GroupNorm(_gn_groups(channels), channels)
        self.pw   = nn.Conv2d(channels, channels, 1, bias=False)
        self.pw_g = nn.GroupNorm(_gn_groups(channels), channels)
        self.act  = nn.ReLU(inplace=True)
    def forward(self, x):
        y = self.act(self.dw_g(self.dw(x)))
        y = self.pw_g(self.pw(y))
        return self.act(x + y)

class SpatialGate(nn.Module):
    def __init__(self, channels: int):
        super().__init__()
        # upgraded: depthwise 3x3 then pointwise 1x1 before sigmoid
        self.dw = nn.Conv2d(channels, channels, 3, padding=1, groups=channels, bias=False)
        self.pw = nn.Conv2d(channels, 1, 1, bias=True)
    def forward(self, x):
        a = self.pw(self.dw(x))
        return x * torch.sigmoid(a)

class ChannelSE(nn.Module):
    """Lightweight channel attention (squeeze-excite)."""
    def __init__(self, channels: int, r: int = 8):
        super().__init__()
        m = max(1, channels // r)
        self.fc1 = nn.Conv2d(channels, m, 1, bias=True)
        self.fc2 = nn.Conv2d(m, channels, 1, bias=True)
        self.act = nn.ReLU(inplace=True)
    def forward(self, x):
        s = x.mean(dim=(2,3), keepdim=True)
        s = self.act(self.fc1(s))
        s = torch.sigmoid(self.fc2(s))
        return x * s

class FusionSegHead(nn.Module):
    """
    Fi1 (s/8) ↑ to s/4 + F_last (s/4) -> DW+PW fusion -> 3x DW-sep residual (dil=1,2,4)
    -> spatial gate -> channel SE -> 1x1 classes
    """
    def __init__(self, in_channels: int, mid_channels: int, num_classes: int):
        super().__init__()
        self.fi1_reduce   = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, 1, bias=False),
            nn.GroupNorm(_gn_groups(mid_channels), mid_channels),
            nn.ReLU(inplace=True),
        )
        self.flast_reduce = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, 1, bias=False),
            nn.GroupNorm(_gn_groups(mid_channels), mid_channels),
            nn.ReLU(inplace=True),
        )
        # upgraded fusion: depthwise 3x3 then pointwise 1x1
        self.fuse_dw = nn.Conv2d(2*mid_channels, 2*mid_channels, 3, padding=1,
                                 groups=2*mid_channels, bias=False)
        self.fuse_pw = nn.Sequential(
            nn.Conv2d(2*mid_channels, mid_channels, 1, bias=False),
            nn.GroupNorm(_gn_groups(mid_channels), mid_channels),
            nn.ReLU(inplace=True),
        )
        # strengthened refinement: add a third block with dil=4
        self.refine1 = DWSepResBlock(mid_channels, dilation=1)
        self.refine2 = DWSepResBlock(mid_channels, dilation=2)
        self.refine3 = DWSepResBlock(mid_channels, dilation=4)
        # dual lightweight attention
        self.spatial = SpatialGate(mid_channels)
        self.channel = ChannelSE(mid_channels, r=8)
        self.cls     = nn.Conv2d(mid_channels, num_classes, 1)

    def forward(self, fi1, flast, out_hw):
        H, W = out_hw
        fi1_up = F.interpolate(fi1, size=flast.shape[-2:], mode='bilinear', align_corners=False)
        z = torch.cat([self.fi1_reduce(fi1_up), self.flast_reduce(flast)], dim=1)
        z = self.fuse_pw(self.fuse_dw(z))
        z = self.refine1(z)
        z = self.refine2(z)
        z = self.refine3(z)
        z = self.spatial(z)
        z = self.channel(z)
        logits_s4 = self.cls(z)
        return F.interpolate(logits_s4, size=(H, W), mode='bilinear', align_corners=False)

class JEPASegmentationModel(nn.Module):
    """
    JEPA backbone + pixel decoder -> FusionSegHead (no ASPP, no aux).
    """
    def __init__(self, backbone_model, num_classes=150, mid_channels=128):
        super().__init__()
        self.backbone = backbone_model.context_encoder
        self.pixel_decoder = backbone_model.pixel_decoder
        self.embed_dim = backbone_model.embed_dim
        self.ds16 = backbone_model.ds16
        self.ds32 = backbone_model.ds32
        self.head = FusionSegHead(self.embed_dim, mid_channels, num_classes)

    def forward(self, x):
        B, C, H, W = x.shape
        tokens, (enc_h, enc_w) = self.backbone(x)                 # [B, P, D]
        feat = tokens.transpose(1,2).reshape(B, self.embed_dim, enc_h, enc_w)

        # pyramid as in pretrain
        C3  = F.interpolate(feat, size=(H//8,  W//8),  mode='bilinear', align_corners=False)
        x16 = self.ds16(C3)
        x32 = self.ds32(x16)
        C4  = F.interpolate(x16, size=(H//16, W//16), mode='bilinear', align_corners=False)
        C5  = F.interpolate(x32, size=(H//32, W//32), mode='bilinear', align_corners=False)

        Fi1, F_last = self.pixel_decoder([C3, C4, C5], (H, W))    # Fi1 ~ s/8, F_last ~ s/4
        return self.head(Fi1, F_last, (H, W))                     # [B, K, H, W]

def fix_masks(m):
    return torch.where((m < 0) | (m >= NUM_CLASSES),
                       torch.full_like(m, IGNORE_INDEX),
                       m).long()

class StreamingSegMetrics:
    """GPU-side streaming confusion matrix for mIoU/Dice (no CPU stalls)."""
    def __init__(self, num_classes, ignore_index=255, device=None):
        self.C = num_classes
        self.ignore = ignore_index
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.conf = torch.zeros((self.C, self.C), dtype=torch.float64, device=self.device)
    @torch.no_grad()
    def update(self, logits, target):
        pred = logits.argmax(1)
        tgt  = target
        valid = (tgt != self.ignore)
        if valid.any():
            pred = pred[valid]
            tgt  = tgt[valid]
            idx = tgt * self.C + pred
            bins = torch.bincount(idx, minlength=self.C*self.C).reshape(self.C, self.C).to(self.conf.dtype)
            self.conf += bins
    @torch.no_grad()
    def get(self):
        h = self.conf
        diag = torch.diag(h)
        sum_row = h.sum(1)
        sum_col = h.sum(0)
        denom_iou = sum_row + sum_col - diag
        iou  = torch.where(denom_iou > 0, diag / denom_iou, torch.nan)
        dice = torch.where((sum_row + sum_col) > 0, (2*diag) / (sum_row + sum_col), torch.nan)
        miou  = torch.nanmean(iou).item()
        mdice = torch.nanmean(dice).item()
        return miou, mdice
    def reset(self):
        self.conf.zero_()

def visualize_segmentation(images, true_masks, logits, epoch, save_path, num_samples=4):
    pred = logits.argmax(1)
    fig, axes = plt.subplots(3, num_samples, figsize=(16, 12))
    mean = torch.tensor([0.485,0.456,0.406], device=images.device).view(3,1,1)
    std  = torch.tensor([0.229,0.224,0.225], device=images.device).view(3,1,1)
    for i in range(min(num_samples, images.size(0))):
        img = torch.clamp(images[i]*std + mean, 0, 1).permute(1,2,0).cpu().numpy()
        axes[0,i].imshow(img); axes[0,i].set_title(f"Original {i+1}"); axes[0,i].axis('off')
        gt = true_masks[i].cpu().numpy(); gt_rgb = np.zeros((*gt.shape,3))
        pr = pred[i].cpu().numpy();      pr_rgb = np.zeros((*pr.shape,3))
        for cls in range(NUM_CLASSES):
            m1 = (gt==cls); m2 = (pr==cls)
            if m1.any(): gt_rgb[m1] = plt.cm.tab20(cls%20)[:3]
            if m2.any(): pr_rgb[m2] = plt.cm.tab20(cls%20)[:3]
        axes[1,i].imshow(gt_rgb); axes[1,i].set_title(f"Ground Truth {i+1}"); axes[1,i].axis('off')
        axes[2,i].imshow(pr_rgb); axes[2,i].set_title(f"Prediction {i+1}"); axes[2,i].axis('off')
    plt.suptitle(f"Epoch {epoch} - Segmentation Results", fontsize=16)
    plt.tight_layout(); plt.savefig(save_path, dpi=150, bbox_inches='tight'); plt.close()

# -----------------------
# One-time ADE sanity peek
# -----------------------
try:
    peek = next(iter(downstream_train_loader))
    print("ADE20K uniques (peek):", torch.unique(peek["masks"])[:20].cpu())
except Exception as e:
    print("ADE peek skipped:", e)

# -----------------------
# Load JEPA parts & build model
# -----------------------
print("Loading pretrained JEPA model...")
jepa_model = MaskJEPA2D(
    in_chans=3, tau=0.996, fi1_mask_ratio=0.5,
    num_queries=50, num_cross_attn=5, num_self_attn=1, patch_size=8
).to(device)

weights_path = "./jepa_training_output/jepa_model_epoch_1.pt"
if os.path.exists(weights_path):
    ckpt = torch.load(weights_path, map_location=device)
    jepa_model.context_encoder.load_state_dict(ckpt['backbone_state_dict'])
    jepa_model.pixel_decoder.load_state_dict(ckpt['pixel_decoder_state_dict'])
    if 'ds16_state_dict' in ckpt: jepa_model.ds16.load_state_dict(ckpt['ds16_state_dict'])
    if 'ds32_state_dict' in ckpt: jepa_model.ds32.load_state_dict(ckpt['ds32_state_dict'])
    print("Loaded: context_encoder, pixel_decoder",
          "+ ds16" if 'ds16_state_dict' in ckpt else "",
          "+ ds32" if 'ds32_state_dict' in ckpt else "")
else:
    print("WARNING: no JEPA weights found; FT from scratch")

model = JEPASegmentationModel(jepa_model, num_classes=NUM_CLASSES, mid_channels=128).to(device)
model = model.to(memory_format=torch.channels_last)

# -----------------------
# Train config
# -----------------------
num_epochs_ft = 40
base_lr_ft    = 2e-4
weight_decay  = 0.01
print_every   = 50     # batch print interval (CE, mIoU, Dice)
GRAD_ACCUM    = 1      # set >1 if you need micro-batching for memory

def cosine(epoch):
    return 0.5*(1 + np.cos(np.pi*epoch/num_epochs_ft))

optimizer = AdamW(model.parameters(), lr=base_lr_ft, weight_decay=weight_decay)
scheduler = LambdaLR(optimizer, lr_lambda=cosine)
criterion = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)
scaler = GradScaler('cuda')

ft_save_dir = "./jepa_finetuning_output"
os.makedirs(ft_save_dir, exist_ok=True)

print("Starting fine-tuning (Fusion head, CE only)...")
print(f"Train batches: {len(downstream_train_loader)}, Val batches: {len(downstream_val_loader)}")
print(f"Model trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

best_miou = 0.0
train_ce_hist, val_ce_hist, val_miou_hist, val_dice_hist = [], [], [], []

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True

for epoch in range(num_epochs_ft):
    model.train()
    epoch_ce_accum = 0.0

    # per-epoch meters
    epoch_meter = StreamingSegMetrics(NUM_CLASSES, IGNORE_INDEX, device=device)
    print_meter = StreamingSegMetrics(NUM_CLASSES, IGNORE_INDEX, device=device)

    for bidx, batch in enumerate(downstream_train_loader):
        images = batch["images"].to(device, non_blocking=True).to(memory_format=torch.channels_last)
        masks  = batch["masks"].to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        # micro-batching if GRAD_ACCUM > 1
        num_micro = max(1, GRAD_ACCUM)
        chunk = math.ceil(images.size(0) / num_micro)

        loss_sum = 0.0
        for i in range(0, images.size(0), chunk):
            imgs = images[i:i+chunk]
            msks = masks[i:i+chunk]
            with autocast('cuda'):
                logits = model(imgs)
                loss   = criterion(logits, fix_masks(msks)) / num_micro

            # update GPU metrics before backward (keeps VRAM usage similar)
            with torch.no_grad():
                epoch_meter.update(logits, fix_masks(msks))
                print_meter.update(logits, fix_masks(msks))

            scaler.scale(loss).backward()
            loss_sum += float(loss.detach())

        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer); scaler.update()

        epoch_ce_accum += loss_sum

        if (bidx % print_every) == 0:
            miou_b, dice_b = print_meter.get()  # only 2 scalar syncs
            print(f"  Batch {bidx:4d}/{len(downstream_train_loader)} | CE: {loss_sum:.4f} | mIoU: {miou_b:.4f} | Dice: {dice_b:.4f}")
            print_meter.reset()

        del images, masks, logits, loss

    epoch_ce = epoch_ce_accum / max(1, len(downstream_train_loader))
    train_ce_hist.append(epoch_ce)
    scheduler.step()

    # ---- Validation ----
    model.eval()
    val_ce = 0.0
    val_meter = StreamingSegMetrics(NUM_CLASSES, IGNORE_INDEX, device=device)
    with torch.no_grad():
        for batch in downstream_val_loader:
            images = batch["images"].to(device, non_blocking=True).to(memory_format=torch.channels_last)
            masks  = batch["masks"].to(device, non_blocking=True)
            with autocast('cuda'):
                logits = model(images)
                ce = criterion(logits, fix_masks(masks))
            val_ce += ce.item()
            val_meter.update(logits, fix_masks(masks))
            del images, masks, logits, ce

    val_ce /= max(1, len(downstream_val_loader))
    mean_miou, mean_dice = val_meter.get()
    val_ce_hist.append(val_ce); val_miou_hist.append(mean_miou); val_dice_hist.append(mean_dice)

    print(f"Epoch {epoch+1:02d}/{num_epochs_ft} | Train CE: {epoch_ce:.4f} | Val CE: {val_ce:.4f} | mIoU: {mean_miou:.4f} | Dice: {mean_dice:.4f} | LR: {scheduler.get_last_lr()[0]:.2e}")

    # Save best by mIoU
    if mean_miou > best_miou:
        best_miou = mean_miou
        best_path = os.path.join(ft_save_dir, "best_segmentation_model.pt")
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'epoch': epoch,
            'miou': mean_miou,
            'dice': mean_dice
        }, best_path)
        print(f"  New best mIoU! Saved -> {best_path}")

    # Visualizations every 2 epochs
    if (epoch + 1) % 1 == 0:
        with torch.no_grad():
            vis_batch = next(iter(downstream_val_loader))
            vis_images = vis_batch["images"].to(device).to(memory_format=torch.channels_last)
            vis_masks  = vis_batch["masks"].to(device)
            with autocast('cuda'):
                vis_logits = model(vis_images)
            vis_path = os.path.join(ft_save_dir, f"seg_epoch_{epoch+1:03d}.png")
            visualize_segmentation(vis_images, fix_masks(vis_masks), vis_logits, epoch+1, vis_path)
            print(f"  Saved visualization: {vis_path}")
            del vis_batch, vis_images, vis_masks, vis_logits

    torch.cuda.empty_cache(); gc.collect()

print("Fine-tuning completed.")
print(f"Best mIoU: {best_miou:.4f}")

# Save final
final_path = os.path.join(ft_save_dir, "final_segmentation_model.pt")
torch.save({
    'model_state_dict': model.state_dict(),
    'train_ce_losses': train_ce_hist,
    'val_ce_losses': val_ce_hist,
    'val_mious': val_miou_hist,
    'val_dices': val_dice_hist,
    'best_miou': best_miou
}, final_path)
print(f"Final model saved: {final_path}")


ADE20K uniques (peek): tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
Loading pretrained JEPA model...
Starting fine-tuning (Fusion head, CE only)...
Train batches: 422, Val batches: 42
Model trainable params: 78,578,721
  Batch    0/422 | CE: 5.0341 | mIoU: 0.0003 | Dice: 0.0007
  Batch   50/422 | CE: 3.2219 | mIoU: 0.0078 | Dice: 0.0128
  Batch  100/422 | CE: 2.7778 | mIoU: 0.0150 | Dice: 0.0220
  Batch  150/422 | CE: 2.5199 | mIoU: 0.0230 | Dice: 0.0327
  Batch  200/422 | CE: 2.2715 | mIoU: 0.0281 | Dice: 0.0387
  Batch  250/422 | CE: 2.1032 | mIoU: 0.0324 | Dice: 0.0453
  Batch  300/422 | CE: 2.2173 | mIoU: 0.0363 | Dice: 0.0511
  Batch  350/422 | CE: 2.0572 | mIoU: 0.0426 | Dice: 0.0597
  Batch  400/422 | CE: 2.0698 | mIoU: 0.0459 | Dice: 0.0646
Epoch 01/40 | Train CE: 2.5257 | Val CE: 1.9439 | mIoU: 0.0478 | Dice: 0.0665 | LR: 2.00e-04
  New best mIoU! Saved -> ./jepa_finetuning_output/best_segmentation_model.pt
  Saved visualizat

KeyboardInterrupt: 