## Data Preprocessing

In [1]:
import os, torch
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'  # syncs CUDA so you see the true Python stack
torch.autograd.set_detect_anomaly(True)   # pinpoints the backward op

import urllib.request
import zipfile
import socket
from datasets import load_dataset

# Check if datasets already exist

imagenet_path = os.path.join('.', 'imagenet-100')
ade_path = os.path.join('.', 'ADEChallengeData2016')

print("Loading datasets...")

# Load ImageNet-100 and save to visible folder
if os.path.exists(imagenet_path):
    print("ImageNet-100 already exists. Skipping download.")
else:
    print("Downloading ImageNet-100...")
    try:
        img100 = load_dataset("clane9/imagenet-100")
        img100.save_to_disk(imagenet_path)
        print("ImageNet-100 saved to ./imagenet-100/")
        print(f"Train samples: {len(img100['train'])}")
        print(f"Val samples: {len(img100['validation'])}")
    except Exception as e:
        print(f"ImageNet-100 failed: {e}")

# Download ADE20K manually if not exists
if os.path.exists(ade_path):
    print("ADE20K already exists. Skipping download.")
else:
    print("Downloading ADE20K...")
    try:
        socket.setdefaulttimeout(60)
        url = "http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip"
        zip_path = "ADEChallengeData2016.zip"
        urllib.request.urlretrieve(url, zip_path)
        
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall('.')
        os.remove(zip_path)
        print("ADE20K downloaded")
    except Exception as e:
        print(f"ADE20K download failed: {e}")

print("Ready.")

Loading datasets...
ImageNet-100 already exists. Skipping download.
ADE20K already exists. Skipping download.
Ready.


In [2]:
import os
import glob
import numpy as np
import torch
import random
from PIL import Image
from datasets import load_from_disk
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from math import ceil

# Reproducibility
seed = 1337
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

IMG_SIZE = 384

# =============================================================================
# JEPA ImageNet Dataset for Pretraining
# =============================================================================
class JEPADataset(Dataset):
    def __init__(self, root="./imagenet-100", split="train", img_size=384):
        self.dataset = load_from_disk(root)[split]
        self.img_size = img_size
        # Image transforms
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(img_size, scale=(0.2, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        # Load and transform image
        img = self.dataset[idx]["image"].convert("RGB")
        img_tensor = self.transform(img)
        return {
            "image": img_tensor
        }

# =============================================================================
# ADE20K Dataset for Segmentation
# =============================================================================
class ADE20KDataset(Dataset):
    def __init__(self, root="ADEChallengeData2016", split="training", img_size=384):
        self.img_dir = os.path.join(root, "images", split)
        self.ann_dir = os.path.join(root, "annotations", split)
        self.img_size = img_size
        self.items = []
        
        # Find image-mask pairs
        for img_path in glob.glob(os.path.join(self.img_dir, "*.jpg")):
            stem = os.path.splitext(os.path.basename(img_path))[0]
            ann_path = os.path.join(self.ann_dir, stem + ".png")
            if os.path.exists(ann_path):
                self.items.append((img_path, ann_path))
        self.items.sort()
        
        # Image transforms
        self.img_transform = transforms.Compose([
            transforms.Resize((img_size, img_size), interpolation=Image.BILINEAR),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])
    
    def __len__(self):
        return len(self.items)
    
    def __getitem__(self, idx):
        img_path, ann_path = self.items[idx]
        
        # Load image and mask
        img = Image.open(img_path).convert("RGB")
        mask = Image.open(ann_path)
        
        # Apply transforms
        img = self.img_transform(img)
        mask = mask.resize((self.img_size, self.img_size), resample=Image.NEAREST)
        mask = torch.from_numpy(np.array(mask, dtype="int64"))
        
        return img, mask

def jepa_collate(batch):
    images = torch.stack([item["image"] for item in batch])
    return {"images": images}

def ade_collate(batch):
    imgs, masks = zip(*batch)
    return {"images": torch.stack(imgs), "masks": torch.stack(masks)}

## Setup Dataloader

## Helper functions

In [3]:
import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F
from typing import Optional
import numpy as np
import matplotlib.pyplot as plt


def compute_patch_grid(image_shape, patch_size):
    """
    image_shape expected: (C, H, W)
    """
    _, H, W = image_shape
    n_h = H // patch_size
    n_w = W // patch_size
    P = n_h * n_w
    cropped_shape = (n_h * patch_size, n_w * patch_size)
    return n_h, n_w, P, cropped_shape


def extract_patch_embeddings_from_feature_map(feats: torch.Tensor) -> torch.Tensor:
    """
    feats: [N, D, n_h, n_w]  -> returns [N, P, D]  (P = n_h * n_w)
    """
    N, D, n_h, n_w = feats.shape
    # Move D to last, then flatten spatial dims
    return feats.permute(0, 2, 3, 1).reshape(N, -1, D)


def compute_denoising_loss(self, denoised_prediction, original_input):
    # Downsample target to match denoised prediction size
    target_downsampled = F.interpolate(
        original_input, 
        size=denoised_prediction.shape[-2:], 
        mode='bilinear', 
        align_corners=False
    )
    return F.mse_loss(denoised_prediction, target_downsampled)

def compute_reconstruction_loss(preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    """
    Compute reconstruction loss for predicted vs target embeddings
    """
    return F.mse_loss(preds, targets, reduction="mean")

@torch.no_grad()
def update_ema(target_net: nn.Module, online_net: nn.Module, tau: float):
    """
    Update target network parameters using exponential moving average
    """
    for t_param, s_param in zip(target_net.parameters(), online_net.parameters()):
        t_param.data.mul_(tau).add_(s_param.data, alpha=1 - tau)

def unpatchify_embeddings(emb: torch.Tensor, n_h: int, n_w: int) -> torch.Tensor:
    """
    Convert patch embeddings back to 2D feature map
    emb: [N, P, D] -> [N, D, n_h, n_w]
    """
    N, P, D = emb.shape
    emb_4d = emb.view(N, n_h, n_w, D)
    return emb_4d.permute(0, 3, 1, 2).contiguous()

def generate_fi1_mask(fi1_shape: tuple, mask_ratio: float = 0.5, patch_size: int = 8, device='cuda'):
    B, D, H8, W8 = fi1_shape  # e.g., [B, D, 28, 28] for 224x224 images
    
    # Calculate number of patches
    n_patches_h = H8 // patch_size  # 28/8 = 3
    n_patches_w = W8 // patch_size  # 28/8 = 3
    total_patches = n_patches_h * n_patches_w  # 9 patches total
    
    num_masked = int(mask_ratio * total_patches)  # e.g., 4 patches masked
    
    # Generate mask
    fi1_mask = torch.zeros(B, H8 * W8, dtype=torch.bool, device=device)
    
    for b in range(B):
        # Randomly select which patches to mask
        masked_patch_ids = torch.randperm(total_patches, device=device)[:num_masked]
        
        for patch_id in masked_patch_ids:
            # Convert patch_id to patch coordinates
            ph = patch_id // n_patches_w
            pw = patch_id % n_patches_w
            
            # Convert to pixel coordinates in Fi1
            h_start = ph * patch_size
            h_end = min(h_start + patch_size, H8)
            w_start = pw * patch_size  
            w_end = min(w_start + patch_size, W8)
            
            # Mask this patch in flattened Fi1
            for h in range(h_start, h_end):
                for w in range(w_start, w_end):
                    fi1_mask[b, h * W8 + w] = True
    
    return fi1_mask  # [B, H8*W8]

def apply_fi1_mask_tokens(fi1_features: torch.Tensor, fi1_mask: torch.Tensor, mask_token: torch.Tensor):
    """
    Apply masking to Fi1 features using learned mask tokens
    
    Args:
        fi1_features: (B, D, H8, W8) Fi1 feature maps
        fi1_mask: (B, H8*W8) boolean mask
        mask_token: (1, D, 1, 1) learned mask token
    
    Returns:
        masked_fi1: Fi1 with mask tokens at masked positions
    """
    B, D, H8, W8 = fi1_features.shape
    
    # Reshape mask to match feature dimensions
    mask_2d = fi1_mask.reshape(B, H8, W8).unsqueeze(1).expand(-1, D, -1, -1)
    
    # Replace masked positions with mask token
    masked_fi1 = torch.where(mask_2d, mask_token.expand(B, D, H8, W8), fi1_features)
    
    return masked_fi1

def visualize_jepa_patch_quality(
    original: torch.Tensor,
    predicted_features: torch.Tensor,
    target_features: torch.Tensor,
    patch_mask: torch.Tensor,
    epoch: int,
    save_path: str,
    patch_size: int = 16,
):

    # ----- robust image-to-display -----
    def _to_display_img(x: torch.Tensor) -> np.ndarray:
        x = x.detach().cpu()
        if x.ndim == 3 and x.shape[0] in (1, 3):
            xc = x.clone()
            mn, mx = float(xc.min()), float(xc.max())
            if 0.0 <= mn and mx <= 1.0:
                pass  # already [0,1]
            elif -3.5 <= mn <= 3.5 and -3.5 <= mx <= 3.5:
                # assume ImageNet norm
                mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
                std  = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)
                xc = xc * std + mean
            else:
                # min-max to [0,1]
                xc = (xc - mn) / (max(mx - mn, 1e-6))
            img = xc.permute(1, 2, 0).numpy()
        else:
            arr = x.numpy()
            arr = (arr - arr.min()) / (arr.max() - arr.min() + 1e-6)
            img = arr
        return np.clip(img, 0.0, 1.0)

    # ---- prep first image ----
    original_np = _to_display_img(original[0])
    H, W = original_np.shape[:2]
    n_h, n_w = H // patch_size, W // patch_size

    # ---- per-masked-tile losses ----
    pred0 = predicted_features[0].detach().cpu().numpy()   # [M, D]
    tgt0  = target_features[0].detach().cpu().numpy()      # [M, D]
    if pred0.size == 0:
        per_patch_losses = np.zeros((0,), dtype=np.float32)
    else:
        diff = pred0 - tgt0
        per_patch_losses = (diff * diff).mean(axis=-1)     # [M]

    if per_patch_losses.size > 0:
        lo, hi = float(per_patch_losses.min()), float(per_patch_losses.max())
        denom = (hi - lo) if (hi > lo) else 1.0
        normalized_quality = 1.0 - ((per_patch_losses - lo) / denom)
    else:
        normalized_quality = np.zeros((0,), dtype=np.float32)

    # ---- masked indices ----
    mask0 = patch_mask[0].detach().cpu().view(-1)          # [P]
    masked_indices = mask0.nonzero(as_tuple=False).squeeze(1).numpy()  # [K]

    # ---- figure ----
    fig, axs = plt.subplots(1, 3, figsize=(14, 5))

    # Left: original
    axs[0].imshow(original_np, interpolation='nearest')
    axs[0].set_title("Original Image")
    axs[0].axis('off')

    # Center: mask overlay (red) with black grid
    masked_img = original_np.copy()
    overlay = masked_img.copy()
    red = np.array([1.0, 0.0, 0.0], dtype=np.float32)

    for pidx in masked_indices:
        if pidx < 0 or pidx >= n_h * n_w:
            continue
        ih, iw = divmod(int(pidx), n_w)
        h0, h1 = ih * patch_size, (ih + 1) * patch_size
        w0, w1 = iw * patch_size, (iw + 1) * patch_size
        overlay[h0:h1, w0:w1, :] = red

    alpha_center = 0.35
    masked_img = (1 - alpha_center) * masked_img + alpha_center * overlay
    axs[1].imshow(np.clip(masked_img, 0.0, 1.0), interpolation='nearest')
    axs[1].set_title(f"Epoch {epoch} - JEPA Analysis\nMasked Patches (Red)\n{int(mask0.sum())}/{mask0.numel()} masked")
    axs[1].axis('off')

    # Right: reconstruction quality (bold colored squares)
    quality_img = original_np.copy()
    colormap = plt.get_cmap('RdYlGn')  # green=good, red=bad
    alpha_patch = 0.85                 # strong overlay for bold squares
    grid_thick = max(1, patch_size // 16)  # thicker grid lines

    limit = min(len(normalized_quality), len(masked_indices))
    for idx in range(limit):
        patch_idx = int(masked_indices[idx])
        if patch_idx < 0 or patch_idx >= n_h * n_w:
            continue

        ih, iw = divmod(patch_idx, n_w)
        h0, h1 = ih * patch_size, (ih + 1) * patch_size
        w0, w1 = iw * patch_size, (iw + 1) * patch_size

        q = float(np.asarray(normalized_quality[idx]).mean())
        if not np.isfinite(q):
            q = 0.0
        q = float(np.clip(q, 0.0, 1.0))

        color = np.asarray(colormap(q))[:3]  # (3,)
        patch = quality_img[h0:h1, w0:w1, :]
        quality_img[h0:h1, w0:w1, :] = (1 - alpha_patch) * patch + alpha_patch * color[None, None, :]

        # thicker black grid lines
        quality_img[h0:h0+grid_thick, w0:w1, :] = 0.0
        quality_img[h0:h1, w0:w0+grid_thick, :] = 0.0

    axs[2].imshow(np.clip(quality_img, 0.0, 1.0), interpolation='nearest')
    axs[2].set_title("Reconstruction Quality\n(Green=Good, Red=Poor)")
    axs[2].axis('off')

    plt.tight_layout()
    fig.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close(fig)



## Models

#### Patch Embed

In [4]:
# --- PatchEmbed2D: remove undefined pos_embed add (or implement it properly) ---
class PatchEmbed2D(nn.Module):
    def __init__(self, in_chans: int, embed_dim: int, patch_size: int):
        super().__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = self.proj(x)                              # [B, D, n_h, n_w]
        B, D, n_h, n_w = x.shape
        x = x.view(B, D, n_h * n_w).transpose(1, 2)   # [B, P, D]
        # (no pos_embed here)
        return self.norm(x)


#### Context Encoder

In [5]:
# 2D Image → Patches → 2D Feature Map → ViT → 2D Feature Map → Tokens

In [6]:
import math
import torch
import torch.nn as nn
import timm

class ContextEncoder2D(nn.Module):
    """
    Swin Transformer v2 Tiny context encoder configured for 384×384.
    Returns patch tokens (no CLS) and the token grid size (Ht, Wt).
    """
    def __init__(
        self,
        model_name: str = "swinv2_tiny_window8_256",
        pretrained: bool = True,
        img_size: int = 384,
        strict_img_size: bool = False,
        dynamic_img_pad: bool = True,
    ):
        super().__init__()
        # Build timm Swin v2; set img_size=384 and allow dynamic padding
        self.swin = timm.create_model(
            model_name,
            pretrained=pretrained,
            img_size=img_size,          # <- run at 384
            num_classes=0,              # no classifier head
            global_pool='',             # keep token grid
            features_only=False,
            strict_img_size=strict_img_size,
            dynamic_img_pad=dynamic_img_pad,
        )
        self.embed_dim = self.swin.num_features
        self.patch_size = 4  # Swin uses patch4 embed

    def forward(self, x: torch.Tensor):
        """
        x: [B, C, H, W]  (H,W multiples of 32 recommended; 384 works)
        returns:
          tokens: [B, P, D]  (CLS-free)
          (Ht, Wt): token grid size at the final stage (~ H/32, W/32)
        """
        feats = self.swin.forward_features(x)   # [B, L, D] or sometimes [B, Ht, Wt, D] / [B, D, Ht, Wt]

        if feats.dim() == 3:                    # [B, L, D]
            tokens = feats
            P = tokens.shape[1]
            Ht = int(math.sqrt(P))
            Wt = P // Ht
        else:
            # Handle both [B, D, Ht, Wt] and [B, Ht, Wt, D]
            if feats.shape[1] == self.embed_dim:        # [B, D, Ht, Wt]
                B, D, Ht, Wt = feats.shape
                tokens = feats.permute(0, 2, 3, 1).reshape(B, Ht * Wt, D)
            else:                                       # [B, Ht, Wt, D]
                B, Ht, Wt, D = feats.shape
                tokens = feats.reshape(B, Ht * Wt, D)

        return tokens, (Ht, Wt)


#### Pixel Decoder

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Optional

def _simple_pos_embed_2d(x: torch.Tensor) -> torch.Tensor:
    """Simple 2D positional embedding - just scaled coordinates"""
    B, D, H, W = x.shape
    device, dtype = x.device, x.dtype
    
    # Create coordinate grids
    y_coords = torch.linspace(0, 1, H, device=device, dtype=dtype)
    x_coords = torch.linspace(0, 1, W, device=device, dtype=dtype) 
    yy, xx = torch.meshgrid(y_coords, x_coords, indexing='ij')
    
    # Simple embedding: just use x,y coordinates repeated
    pos_embed = torch.stack([xx, yy], dim=0).unsqueeze(0)  # [1, 2, H, W]
    pos_embed = pos_embed.expand(B, -1, -1, -1)            # [B, 2, H, W]
    
    # Repeat to match channel dimension
    pos_embed = pos_embed.repeat(1, D//2, 1, 1)            # [B, D, H, W]
    if pos_embed.shape[1] < D:
        pos_embed = F.pad(pos_embed, (0, 0, 0, 0, 0, D - pos_embed.shape[1]))
    
    return pos_embed * 0.1  # Scale down to not overwhelm features

class PixelDecoder2D(nn.Module):
    """
    Simple FPN-style pixel decoder. Much cleaner than deformable attention.
    Keeps same interface as the complex version.
    """
    def __init__(self,
                 in_channels: int,
                 embed_dim: int,
                 *args, **kwargs):
        super().__init__()
        self.embed_dim = embed_dim

        # Handle per-level channels (C3, C4, C5 might be different)
        in_chs: Optional[Tuple[int,int,int]] = kwargs.get("in_channels_per_level", None)
        if in_chs is None:
            in_chs = (in_channels, in_channels, in_channels)

        # Project each level to common embedding dimension
        self.lateral_convs = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(c, embed_dim, kernel_size=1),
                nn.GroupNorm(32 if embed_dim >= 32 else 1, embed_dim),
                nn.ReLU(inplace=True)
            ) for c in in_chs
        ])

        # FPN fusion layers (reduce aliasing during upsampling)
        self.fpn_convs = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1),
                nn.GroupNorm(32 if embed_dim >= 32 else 1, embed_dim),
                nn.ReLU(inplace=True)
            ) for _ in range(3)
        ])

        # Output heads
        self.fi1_head = nn.Sequential(
            nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1),
            nn.GroupNorm(32 if embed_dim >= 32 else 1, embed_dim),
            nn.ReLU(inplace=True),
            nn.Conv2d(embed_dim, embed_dim, kernel_size=1)
        )
        
        self.flast_head = nn.Sequential(
            nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1),
            nn.GroupNorm(32 if embed_dim >= 32 else 1, embed_dim),
            nn.ReLU(inplace=True),
            nn.Conv2d(embed_dim, embed_dim, kernel_size=1)
        )

    def forward(self, feats_multi: List[torch.Tensor], input_hw: Tuple[int, int]):
        """
        Simple FPN forward pass.
        
        feats_multi: [C3, C4, C5] at strides [1/8, 1/16, 1/32]
        input_hw: (H, W) full resolution size
        
        Returns:
            Fi1: [B, D, H/8, W/8] - feature at stride 8
            F_last: [B, D, H/4, W/4] - feature at stride 4
        """
        H, W = input_hw
        c3, c4, c5 = feats_multi
        
        # 1. Lateral connections - project to common dimension
        p5 = self.lateral_convs[2](c5)  # [B, D, H/32, W/32]
        p4 = self.lateral_convs[1](c4)  # [B, D, H/16, W/16]  
        p3 = self.lateral_convs[0](c3)  # [B, D, H/8, W/8]

        # Add simple positional encoding
        p5 = p5 + _simple_pos_embed_2d(p5)
        p4 = p4 + _simple_pos_embed_2d(p4)
        p3 = p3 + _simple_pos_embed_2d(p3)

        # 2. Top-down pathway (FPN fusion)
        # P5 -> P4
        p5_up = F.interpolate(p5, size=p4.shape[-2:], mode='bilinear', align_corners=False)
        p4 = p4 + p5_up
        p4 = self.fpn_convs[1](p4)

        # P4 -> P3  
        p4_up = F.interpolate(p4, size=p3.shape[-2:], mode='bilinear', align_corners=False)
        p3 = p3 + p4_up
        p3 = self.fpn_convs[0](p3)  # This is our main feature at 1/8 stride

        # 3. Generate outputs
        # Fi1 at stride 8 (same as p3)
        Fi1 = self.fi1_head(p3)  # [B, D, H/8, W/8]

        # F_last at stride 4 (upsample p3)
        p3_upsampled = F.interpolate(p3, size=(H // 4, W // 4), mode='bilinear', align_corners=False)
        F_last = self.flast_head(p3_upsampled)  # [B, D, H/4, W/4]

        return Fi1, F_last

#### PredictorHead

In [10]:
import torch
import torch.nn as nn

class CrossAttentionBlock2D(nn.Module):
    """Cross-attention block for JEPA predictor"""
    
    def __init__(self, embed_dim: int, num_heads: int = 8):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        
        self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        
        # Feedforward
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            nn.Linear(embed_dim * 4, embed_dim)
        )
    
    def forward(self, queries, features):
        # Cross-attention: queries attend to features
        attn_out, _ = self.cross_attn(queries, features, features)
        queries = self.norm1(queries + attn_out)
        
        # Feedforward
        ffn_out = self.ffn(queries)
        queries = self.norm2(queries + ffn_out)
        
        return queries

class SelfAttentionBlock2D(nn.Module):
    """Self-attention block for query refinement"""
    
    def __init__(self, embed_dim: int, num_heads: int = 8):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        
        self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        
        # Feedforward  
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            nn.Linear(embed_dim * 4, embed_dim)
        )
    
    def forward(self, queries):
        # Self-attention: queries attend to themselves
        attn_out, _ = self.self_attn(queries, queries, queries)
        queries = self.norm1(queries + attn_out)
        
        # Feedforward
        ffn_out = self.ffn(queries)
        queries = self.norm2(queries + ffn_out)
        
        return queries

In [11]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class Predictor2D(nn.Module):
    def __init__(self,
                 embed_dim: int,
                 num_queries: int,
                 num_heads: int = None,
                 # accept both old and new arg names:
                 num_cross_blocks: int = None,
                 num_self_blocks: int = None,
                 num_cross_attn: int = None,
                 num_self_attn: int = None):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_queries = num_queries

        # choose a valid num_heads if not provided
        if num_heads is None:
            for h in (16, 12, 8, 6, 4, 3, 2, 1):
                if embed_dim % h == 0:
                    num_heads = h
                    break
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        self.num_heads = num_heads

        # harmonize naming: prefer *attn if provided, else *blocks, else paper defaults (9/2)
        if num_cross_attn is not None:
            L = num_cross_attn
        elif num_cross_blocks is not None:
            L = num_cross_blocks
        else:
            L = 9
        if num_self_attn is not None:
            M = num_self_attn
        elif num_self_blocks is not None:
            M = num_self_blocks
        else:
            M = 2

        # learnable queries
        self.query_embed = nn.Parameter(torch.zeros(1, num_queries, embed_dim))
        nn.init.trunc_normal_(self.query_embed, std=0.02)

        # Cross-attention blocks (Mask2Former-ish: norm -> cross-attn -> add, norm -> FFN -> add)
        self.cross_blocks = nn.ModuleList([
            nn.ModuleDict(dict(
                attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True),
                ffn  = nn.Sequential(
                    nn.Linear(embed_dim, 4*embed_dim),
                    nn.ReLU(inplace=True),
                    nn.Linear(4*embed_dim, embed_dim)
                ),
                norm1 = nn.LayerNorm(embed_dim),
                norm2 = nn.LayerNorm(embed_dim),
            )) for _ in range(L)
        ])

        # Extra self-attention blocks on queries
        self.self_blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True)
            for _ in range(M)
        ])

        # projection head f_L to map query outputs back to Fi1 embedding space
        self.proj = nn.Linear(embed_dim, embed_dim)

        # learnable mask token for masked Fi1 tiles (used to build K/V when Fi1 is masked)
        self.kv_mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        nn.init.trunc_normal_(self.kv_mask_token, std=0.02)

    @staticmethod
    def _add_2d_sincos_pos(feat_2d: torch.Tensor):
        """
        feat_2d: [B, D, H, W] -> returns [B, H*W, D] with fixed 2D sin/cos pos enc added
        """
        import torch.nn.functional as F
        B, D, H, W = feat_2d.shape
        device = feat_2d.device
    
        # build sincos grid
        y = torch.linspace(-1, 1, steps=H, device=device)
        x = torch.linspace(-1, 1, steps=W, device=device)
        yy, xx = torch.meshgrid(y, x, indexing='ij')
        pos = torch.stack([xx, yy], dim=-1).reshape(1, H, W, 2)  # [1,H,W,2]
    
        # project to channel dim D using sin/cos
        half = D // 2
        sin_in = pos[..., 0:1].repeat(1, 1, 1, half)
        cos_in = pos[..., 1:2].repeat(1, 1, 1, D - half)
        pos_embed = torch.cat([torch.sin(sin_in), torch.cos(cos_in)], dim=-1)  # [1,H,W,D]
        if pos_embed.shape[-1] != D:
            pos_embed = F.pad(pos_embed, (0, D - pos_embed.shape[-1]))[:, :, :, :D]
    
        # >>> key fix: match [B, D, H, W] before addition
        pos_embed = pos_embed.permute(0, 3, 1, 2)  # [1, D, H, W]
    
        feat = feat_2d + pos_embed.to(feat_2d.dtype)  # [B, D, H, W]
        feat = feat.flatten(2).transpose(1, 2)        # [B, H*W, D]
        return feat

    def forward(self, Fi1_online: torch.Tensor, Fi1_mask: torch.Tensor):
        """
        Fi1_online: [B, D, H8, W8]   (online pixel-decoder Fi1)
        Fi1_mask:   [B, H8*W8] bool  (True=masked positions in Fi1 to reconstruct)
        Returns:
           pred_masked_feats: [B, M, D] predictions for masked Fi1 positions (M = #masked)
           masked_indices:    [B, M] indices of the masked Fi1 positions (or -1 padded)
           query_feats:       [B, Q, D] final query embeddings
        """
        B, D, H8, W8 = Fi1_online.shape
        Q = self.num_queries

        # Build K/V from Fi1 with 2D sincos pos; replace masked tiles with kv_mask_token (NOT zeros)
        kv = Fi1_online.clone()  # [B,D,H8,W8]
        if Fi1_mask is not None:
            mask_2d = Fi1_mask.reshape(B, H8, W8).unsqueeze(1).expand(-1, D, -1, -1)  # [B,D,H8,W8]
            kv = torch.where(mask_2d, self.kv_mask_token.view(1, D, 1, 1).expand(B, D, H8, W8), kv)

        kv_seq = self._add_2d_sincos_pos(kv)  # [B, H8*W8, D] as K/V

        # Queries
        q = self.query_embed.expand(B, Q, D)  # [B,Q,D]

        # L cross-attention blocks
        for blk in self.cross_blocks:
            qn = blk['norm1'](q)
            attn_out, _ = blk['attn'](qn, kv_seq, kv_seq)  # cross-attn to Fi1+pos
            q = q + attn_out
            q = q + blk['ffn'](blk['norm2'](q))

        # M self-attn blocks on queries
        for sblk in self.self_blocks:
            q = sblk(q)

        # map to Fi1 embedding space
        q_proj = self.proj(q)  # [B,Q,D]

        # Route query embeddings to masked tiles (simple attention routing)
        scores = torch.einsum('bpd,bqd->bpq', kv_seq, q_proj) / (D ** 0.5)  # [B,P,Q]
        probs = scores.softmax(dim=-1)  # over queries

        if Fi1_mask is not None and Fi1_mask.any():
            pred_list, idx_list = [], []
            for b in range(B):
                mask_b = Fi1_mask[b]  # [P]
                if mask_b.any():
                    prob_b = probs[b, mask_b]     # [Mb, Q]
                    q_b    = q_proj[b]            # [Q, D]
                    pred_b = prob_b @ q_b         # [Mb, D]
                    pred_list.append(pred_b)
                    idx_list.append(mask_b.nonzero(as_tuple=False).squeeze(1))
                else:
                    pred_list.append(q_proj.new_zeros((0, D)))
                    idx_list.append(torch.zeros((0,), dtype=torch.long, device=q_proj.device))
            maxM = max([p.size(0) for p in pred_list])
            if maxM == 0:
                pred_masked_feats = q_proj.new_zeros((B, 0, D))
                masked_indices    = q_proj.new_zeros((B, 0), dtype=torch.long)
            else:
                pred_masked_feats, masked_indices = [], []
                for b in range(B):
                    m = pred_list[b].size(0)
                    pad = maxM - m
                    if pad > 0:
                        pred_masked_feats.append(F.pad(pred_list[b], (0,0,0,pad)))
                        masked_indices.append(F.pad(idx_list[b], (0,pad), value=-1))
                    else:
                        pred_masked_feats.append(pred_list[b])
                        masked_indices.append(idx_list[b])
                pred_masked_feats = torch.stack(pred_masked_feats, dim=0)  # [B, M, D]
                masked_indices    = torch.stack(masked_indices, dim=0)     # [B, M]
        else:
            pred_masked_feats = q_proj.new_zeros((B, 0, D))
            masked_indices    = q_proj.new_zeros((B, 0), dtype=torch.long)

        return pred_masked_feats, masked_indices, q_proj

## Setup

#### Dataset & Dataloader

In [12]:
# =============================================================================
# Create Datasets and DataLoaders
# =============================================================================

# Training configuration
batch_size_pretrain = 72
batch_size_downstream = 48

# Create dataset instances
jepa_dataset = JEPADataset()
ade_train_dataset = ADE20KDataset(split="training")
ade_val_dataset = ADE20KDataset(split="validation")

print(f"Dataset sizes:")
print(f"JEPA (ImageNet): {len(jepa_dataset):,} samples")
print(f"ADE20K train: {len(ade_train_dataset):,} samples")
print(f"ADE20K val: {len(ade_val_dataset):,} samples")

# JEPA pretraining loader
pretrain_loader = DataLoader(
    jepa_dataset,
    batch_size=batch_size_pretrain,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    collate_fn=jepa_collate
)

# Downstream fine-tuning loaders
downstream_train_loader = DataLoader(
    ade_train_dataset,
    batch_size=batch_size_downstream,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    collate_fn=ade_collate
)

downstream_val_loader = DataLoader(
    ade_val_dataset,
    batch_size=batch_size_downstream,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    collate_fn=ade_collate
)

print(f"\nDataLoader info:")
print(f"Pretrain: {len(pretrain_loader)} batches")
print(f"Train: {len(downstream_train_loader)} batches")
print(f"Val: {len(downstream_val_loader)} batches")

Loading dataset from disk:   0%|          | 0/17 [00:00<?, ?it/s]

Dataset sizes:
JEPA (ImageNet): 126,689 samples
ADE20K train: 20,210 samples
ADE20K val: 2,000 samples

DataLoader info:
Pretrain: 1760 batches
Train: 422 batches
Val: 42 batches


In [13]:
# Test batch shapes for 2D JEPA
batch = next(iter(pretrain_loader))
imgs = batch["images"] # (B, C, H, W)

print(imgs.shape)

torch.Size([72, 3, 384, 384])


#### Model Setup

#### Mask-JEPA setup

In [14]:
import copy
import math
from typing import Tuple, Optional
import torch
import time
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint

class MaskJEPA2D(nn.Module):
    def __init__(self,
                 in_chans: int,
                 num_queries: int = 32,
                 num_cross_attn: int = 2,
                 num_self_attn: int = 1,
                 tau: float = 0.996,
                 fi1_mask_ratio: float = 0.5,
                 patch_size: int = 8,
                 model_name: str = "swin_tiny_patch4_window7_224",
                 pretrained: bool = True):
        super().__init__()

        self.patch_size = patch_size
        
        # === Context encoder (timm ViT with pos_embed) ===
        self.context_encoder = ContextEncoder2D(model_name=model_name, pretrained=pretrained)

        # pull embed_dim & patch_size from the encoder backbone
        self.embed_dim = self.context_encoder.embed_dim
        self.in_chans = in_chans  # needed for denoising head output

        # === Target encoder (EMA, frozen) ===
        self.target_encoder = copy.deepcopy(self.context_encoder)
        for p in self.target_encoder.parameters():
            p.requires_grad = False

        # === Pixel decoders ===
        self.pixel_decoder = PixelDecoder2D(
            in_channels=self.embed_dim,
            embed_dim=self.embed_dim
        )
        self.pixel_decoder_ema = copy.deepcopy(self.pixel_decoder)
        for p in self.pixel_decoder_ema.parameters():
            p.requires_grad = False

        # === Downsampling convs for C4 / C5 ===
        self.ds16 = nn.Sequential(
            nn.Conv2d(self.embed_dim, self.embed_dim, 3, stride=2, padding=1), nn.GELU(),
            nn.Conv2d(self.embed_dim, self.embed_dim, 3, padding=1),           nn.GELU(),
        )
        self.ds32 = nn.Sequential(
            nn.Conv2d(self.embed_dim, self.embed_dim, 3, stride=2, padding=1), nn.GELU(),
            nn.Conv2d(self.embed_dim, self.embed_dim, 3, padding=1),           nn.GELU(),
        )

        # === JEPA Predictor ===
        self.predictor = Predictor2D(
            embed_dim=self.embed_dim,
            num_queries=num_queries,
            num_cross_attn=num_cross_attn,
            num_self_attn=num_self_attn
        )

        # === Denoising head ===
        self.denoising_head = nn.Conv2d(self.embed_dim, in_chans, kernel_size=1)

        # === EMA tau config ===
        self.tau = tau
        self.tau_base  = tau
        self.tau_final = 1.0

        self.fi1_mask_ratio = fi1_mask_ratio

    def forward(self, x: torch.Tensor):
        """
        Paper-correct noise path:
          - Add Gaussian noise at s_last=4 and expand
          - Online branch sees noisy input
          - Target branch sees clean input
        """
        B, C, H, W = x.shape
        device = x.device
    
        # ---------- (A) add noise ----------
        s_last = 4
        H4, W4 = H // s_last, W // s_last
        sigma = 0.4
    
        eps_lr = torch.randn(B, C, H4, W4, device=device, dtype=x.dtype) * sigma
        eps_full = eps_lr.repeat_interleave(s_last, dim=2).repeat_interleave(s_last, dim=3)
        x_noisy = x + eps_full
    
        # ---------- (B) ONLINE BRANCH ----------
        tokens_online, (enc_h, enc_w) = self.context_encoder(x_noisy)  # [B,P,D], (Ht,Wt)
        feat_online = tokens_online.transpose(1, 2).reshape(
            B, self.embed_dim, enc_h, enc_w
        )
        
        C3 = F.interpolate(feat_online, size=(H//8, W//8), mode='bilinear', align_corners=False)
        x16 = self.ds16(C3)
        x32 = self.ds32(x16)
        C4  = F.interpolate(x16, size=(H//16, W//16), mode='bilinear', align_corners=False)
        C5  = F.interpolate(x32, size=(H//32, W//32), mode='bilinear', align_corners=False)
        
        f_i1_online, f_last_online = self.pixel_decoder([C3, C4, C5], (H, W))

        # ---------- (C) TARGET BRANCH ----------
        with torch.no_grad():
            tokens_target, _ = self.target_encoder(x)
            feat_target = tokens_target.transpose(1, 2).reshape(
                B, self.embed_dim, enc_h, enc_w
            )
        
            C3t = F.interpolate(feat_target, size=(H//8, W//8), mode='bilinear', align_corners=False)
            x16t = self.ds16(C3t)
            x32t = self.ds32(x16t)
            C4t  = F.interpolate(x16t, size=(H//16, W//16), mode='bilinear', align_corners=False)
            C5t  = F.interpolate(x32t, size=(H//32, W//32), mode='bilinear', align_corners=False)
        
            f_i1_target, _ = self.pixel_decoder_ema([C3t, C4t, C5t], (H, W))

        # ---------- (D) Fi1 MASK ----------
        fi1_mask = generate_fi1_mask(
            fi1_shape=f_i1_online.shape,
            mask_ratio=self.fi1_mask_ratio,
            patch_size = self.patch_size,
            device=device
        )  # [B, H8*W8] bool
    
        # ---------- (E) PREDICTOR ----------
        predicted_features, masked_indices, q_proj = self.predictor(f_i1_online, fi1_mask)
    
        # ---------- (F) TARGET GATHER + LN ----------
        D = f_i1_target.shape[1]
        fi1_h, fi1_w = f_i1_target.shape[-2:]
        target_seq = f_i1_target.permute(0, 2, 3, 1).reshape(B, fi1_h * fi1_w, D)
        target_seq = F.layer_norm(target_seq, (D,))
    
        if masked_indices.numel() > 0:
            pad_mask = (masked_indices >= 0)
            safe_idx = masked_indices.clamp_min(0)
            b_idx = torch.arange(B, device=device).unsqueeze(-1).expand_as(safe_idx)
            target_masked_full = target_seq[b_idx, safe_idx]
            target_masked = target_masked_full * pad_mask.unsqueeze(-1).to(target_masked_full.dtype)
        else:
            target_masked = target_seq.new_zeros((B, 0, D))
    
        # ---------- (G) DENOISING ----------
        denoised_prediction = self.denoising_head(f_last_online)
    
        return {
            'predicted_features': predicted_features,
            'target_masked':      target_masked,
            'mask_info':          (q_proj, masked_indices),
            'denoised_prediction': denoised_prediction,
            'original_input':     x,
            'fi1_mask':           fi1_mask,
            'mask_indices':       masked_indices,
            'eps_target':         eps_lr
        }

    def set_ema_tau(self, tau: float):
        self.tau = float(tau)

    @torch.no_grad()
    def update_ema(self):
        update_ema(self.target_encoder, self.context_encoder, tau=self.tau)
        update_ema(self.pixel_decoder_ema, self.pixel_decoder, tau=self.tau)


In [16]:
def lr_lambda(epoch: int) -> float:
    if epoch < warmup_epochs:
        # Linear warmup: scale from 0 → 1 over warmup_epochs
        return float(epoch) / float(max(1, warmup_epochs))
    elif epoch < warmup_epochs + 3:
        # Aggressive phase: stay at full rate for 3 epochs after warmup
        return 1.0
    else:
        # Exponential decay phase
        decay_start = warmup_epochs + 3
        decay_factor = 0.7 ** (epoch - decay_start)
        return max(decay_factor, 0.3)  # Don't decay below 30% of base_lr

In [17]:
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from torch.amp import GradScaler, autocast
import os
import matplotlib.pyplot as plt
import numpy as np
import gc
import math
import time   # for epoch timing
use_bf16 = torch.cuda.is_bf16_supported()  # True on A100/RTX 40xx/etc


device = "cuda" if torch.cuda.is_available() else "cpu"

# Training configuration
num_epochs = 20
warmup_epochs = 0
base_lr = 1e-4
weight_decay = 0.05

# Early stopping config
early_stop_patience = 5
best_total_sc = float('inf')
epochs_no_improve = 0

# Enable memory optimizations
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True


def enable_gradient_checkpointing(model):
    """Enable gradient checkpointing for timm ViT blocks"""
    if hasattr(model.context_encoder, 'vit') and hasattr(model.context_encoder.vit, 'blocks'):
        for block in model.context_encoder.vit.blocks:
            if hasattr(block, 'set_grad_checkpointing'):
                block.set_grad_checkpointing(True)
    if hasattr(model.target_encoder, 'vit') and hasattr(model.target_encoder.vit, 'blocks'):
        for block in model.target_encoder.vit.blocks:
            if hasattr(block, 'set_grad_checkpointing'):
                block.set_grad_checkpointing(True)

# Create model
print("Creating model...")
model = MaskJEPA2D(
    in_chans=3,
    tau=0.996,
    fi1_mask_ratio=0.5,
    num_queries=8,
    num_cross_attn=5,
    num_self_attn=1,
    patch_size=8
).to(device)

D = model.embed_dim

# Enable gradient checkpointing
enable_gradient_checkpointing(model)
print("Gradient checkpointing enabled")

# Mixed precision scaler
scaler = GradScaler('cuda')

# Optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=base_lr, weight_decay=weight_decay)
scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)

# Create save directory
save_dir = "./jepa_training_output"
os.makedirs(save_dir, exist_ok=True)
best_ckpt_path = os.path.join(save_dir, "best_jepa_model.pt")

def clear_memory():
    """Memory cleanup"""
    torch.cuda.empty_cache()
    gc.collect()

# Print model info
print(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
print(f"Batch size: {batch_size_pretrain}")

ps = model.patch_size
mr = model.fi1_mask_ratio
nq = model.predictor.num_queries
nca = len(model.predictor.cross_blocks)
nsa = len(model.predictor.self_blocks)

print(f"Model config: patch_size={ps}, mask_ratio={mr}, queries={nq}, cross_attn={nca}, self_attn={nsa}")

# EMA ramp setup
planned_updates_per_epoch = len(pretrain_loader)
max_updates = num_epochs * planned_updates_per_epoch
global_update = 0

# Training loop
train_losses = []
best_snapshot = None  # will hold best weights for final save

for epoch in range(num_epochs):
    epoch_start_time = time.time()   # start timer
    print(f"Epoch {epoch+1}/{num_epochs}")
    
    model.train()
    epoch_loss = 0.0
    epoch_recon_loss = 0.0
    epoch_denoise_loss = 0.0
    
    clear_memory()
    optimizer.zero_grad()
    
    for batch_idx, batch in enumerate(pretrain_loader):
        images = batch["images"].to(device, non_blocking=True)
        
        # Forward pass with mixed precision (bf16 if available, else fp16)
        with autocast(device_type='cuda', dtype=(torch.bfloat16 if use_bf16 else torch.float16)):
            outputs = model(images)
            
            pred = outputs['predicted_features']
            tgt = outputs['target_masked']
            idx = outputs['mask_indices']
            valid = (idx >= 0).unsqueeze(-1)
            
            # Reconstruction loss
            if pred.numel() == 0 or valid.sum() == 0:
                recon_loss = pred.new_zeros(())
            else:
                diff = (pred - tgt) * valid
                recon_loss = diff.pow(2).sum() / valid.sum().clamp_min(1)

            # Option 2: Predict clean image x (current default)
            x4 = F.interpolate(
                outputs['original_input'],
                size=outputs['denoised_prediction'].shape[-2:],
                mode='bilinear', 
                align_corners=False
            )
            denoise_loss = F.mse_loss(outputs['denoised_prediction'], x4)
            
            # Debug info (first batch only)
            if batch_idx == 0:
                td = F.interpolate(images, size=outputs['denoised_prediction'].shape[-2:],
                                   mode='bilinear', align_corners=False)
                pd = outputs['denoised_prediction'].detach()
                print(f"[probe] target: mean={td.mean().item():.3f} std={td.std().item():.3f}")
                print(f"[probe] pred  : mean={pd.mean().item():.3f} std={pd.std().item():.3f}")

            total_loss = recon_loss + denoise_loss
        
        # Backward pass
        scaler.scale(total_loss).backward()
        
        # Gradient clipping
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        # Optimizer step
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
        
        # EMA update with ramp
        progress = global_update / max(1, max_updates - 1)
        tau_now = model.tau_base + (model.tau_final - model.tau_base) * progress
        model.set_ema_tau(tau_now)
        model.update_ema()
        global_update += 1
        
        # Track losses
        epoch_loss += total_loss.item()
        epoch_recon_loss += recon_loss.item()
        epoch_denoise_loss += denoise_loss.item()
        
        # Cleanup
        del outputs, recon_loss, denoise_loss, total_loss, images
        
        # Progress logging
        if batch_idx % 50 == 0:
            recon_sc = (epoch_recon_loss / (batch_idx + 1)) / D
            denoise_sc = (epoch_denoise_loss / (batch_idx + 1))
            total_sc = recon_sc + denoise_sc
            
            print(f"  Batch {batch_idx}/{len(pretrain_loader)} - "
                  f"TotalSc: {total_sc:.4f}, ReconSc: {recon_sc:.4f}, DenoiseSc: {denoise_sc:.4f}")
    
    # Average losses
    epoch_loss /= len(pretrain_loader)
    epoch_recon_loss /= len(pretrain_loader)
    epoch_denoise_loss /= len(pretrain_loader)
    
    train_losses.append(epoch_loss)
    scheduler.step()
    
    # Epoch summary
    recon_sc_epoch = epoch_recon_loss / D
    denoise_sc_epoch = epoch_denoise_loss
    total_sc_epoch = recon_sc_epoch + denoise_sc_epoch
    
    epoch_end_time = time.time()   # end timer
    epoch_duration = epoch_end_time - epoch_start_time
    print(f"  Avg losses - TotalSc: {total_sc_epoch:.4f}, "
          f"ReconSc: {recon_sc_epoch:.4f}, DenoiseSc: {denoise_sc_epoch:.4f}")
    print(f"  LR: {scheduler.get_last_lr()[0]:.2e}")
    print(f"  GPU Memory: {torch.cuda.memory_allocated()/1e9:.2f}GB")
    print(f"  Time for epoch {epoch+1}: {epoch_duration/60:.2f} minutes")
    
    # Periodic evaluation (every 1 epoch here)
    if (epoch + 1) % 1 == 0:
        print("  Running evaluation...")
        model.eval()
        
        with torch.no_grad():
            eval_batch = next(iter(pretrain_loader))
            eval_images = eval_batch["images"][:4].to(device)
            
            with autocast(device_type='cuda', dtype=(torch.bfloat16 if use_bf16 else torch.float16)):
                eval_outputs = model(eval_images)
            
            # Visualization
            H, W = eval_images.shape[-2:]
            fi1_tile = max(H // (H // 8), 1)
            
            vis_path = os.path.join(save_dir, f"reconstruction_epoch_{epoch+1:03d}.png")
            
            visualize_jepa_patch_quality(
                original=eval_images,
                predicted_features=eval_outputs['predicted_features'].float(),
                target_features=eval_outputs['target_masked'].float(),
                patch_mask=eval_outputs['fi1_mask'],
                epoch=epoch+1,
                save_path=vis_path,
                patch_size=fi1_tile
            )
            
            del eval_batch, eval_images, eval_outputs
        
        print(f"    Saved visualization: {vis_path}")
        model.train()
        clear_memory()
    
    # ---- Save ONLY when we have a new best TotalSc ----
    if total_sc_epoch < best_total_sc:
        best_total_sc = total_sc_epoch
        epochs_no_improve = 0
        best_snapshot = {
            'backbone_state_dict': model.context_encoder.state_dict(),
            'pixel_decoder_state_dict': model.pixel_decoder.state_dict(),
            'transformer_decoder_cross_blocks_state_dict': model.predictor.cross_blocks.state_dict(),
        }
        torch.save(best_snapshot, best_ckpt_path)
        print(f"    [best] New best TotalSc={best_total_sc:.4f}. Saved: {best_ckpt_path}")
    else:
        epochs_no_improve += 1
        print(f"  [early-stop] No improvement ({epochs_no_improve}/{early_stop_patience}).")
        if epochs_no_improve >= early_stop_patience:
            print(f"  [early-stop] Patience exceeded. Stopping training early.")
            break
    # --------------------------------------

print("Training completed!")

# Save final pretrained weights for downstream use (best-only)
final_weights_path = os.path.join(save_dir, 'mask_jepa_pretrained_weights.pt')
if best_snapshot is not None:
    torch.save(best_snapshot, final_weights_path)
else:
    # Fallback: save current (shouldn't happen unless no batches ran)
    torch.save({
        'backbone_state_dict': model.context_encoder.state_dict(),
        'pixel_decoder_state_dict': model.pixel_decoder.state_dict(),
        'transformer_decoder_cross_blocks_state_dict': model.predictor.cross_blocks.state_dict(),
    }, final_weights_path)
print(f"Final pretrained weights saved (best-only): {final_weights_path}")

Creating model...
Gradient checkpointing enabled
Model parameters: 119,820,413
Batch size: 72
Model config: patch_size=8, mask_ratio=0.5, queries=8, cross_attn=5, self_attn=1
Epoch 1/20
[probe] target: mean=-0.018 std=1.196
[probe] pred  : mean=-0.007 std=0.271
  Batch 0/1760 - TotalSc: 2.8302, ReconSc: 1.3562, DenoiseSc: 1.4740


KeyboardInterrupt: 

## Fine Tuning

In [None]:
# ==== Fine-tuning: Fusion head (Fi1 + F_last), CE-only, with fast GPU metrics & batch prints ====
import os, gc, math, numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from torch.amp import GradScaler, autocast
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"
IGNORE_INDEX = 255
NUM_CLASSES  = 150

# -----------------------
# Helpers
# -----------------------
def _gn_groups(C):
    for g in (32,16,8,4,2,1):
        if C % g == 0: return g
    return 1

class DWSepResBlock(nn.Module):
    def __init__(self, channels: int, dilation: int = 1):
        super().__init__()
        self.dw   = nn.Conv2d(channels, channels, 3, padding=dilation, dilation=dilation,
                              groups=channels, bias=False)
        self.dw_g = nn.GroupNorm(_gn_groups(channels), channels)
        self.pw   = nn.Conv2d(channels, channels, 1, bias=False)
        self.pw_g = nn.GroupNorm(_gn_groups(channels), channels)
        self.act  = nn.ReLU(inplace=True)
    def forward(self, x):
        y = self.act(self.dw_g(self.dw(x)))
        y = self.pw_g(self.pw(y))
        return self.act(x + y)

class SpatialGate(nn.Module):
    def __init__(self, channels: int):
        super().__init__()
        self.conv = nn.Conv2d(channels, 1, 3, padding=1)
    def forward(self, x):
        return x * torch.sigmoid(self.conv(x))

class FusionSegHead(nn.Module):
    """
    Fi1 (s/8) ↑ to s/4 + F_last (s/4) -> fuse -> 2x DW-sep residual (dil=1,2) -> spatial gate -> 1x1 classes
    """
    def __init__(self, in_channels: int, mid_channels: int, num_classes: int):
        super().__init__()
        self.fi1_reduce   = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, 1, bias=False),
            nn.GroupNorm(_gn_groups(mid_channels), mid_channels),
            nn.ReLU(inplace=True),
        )
        self.flast_reduce = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, 1, bias=False),
            nn.GroupNorm(_gn_groups(mid_channels), mid_channels),
            nn.ReLU(inplace=True),
        )
        self.fuse   = nn.Sequential(
            nn.Conv2d(2*mid_channels, mid_channels, 1, bias=False),
            nn.GroupNorm(_gn_groups(mid_channels), mid_channels),
            nn.ReLU(inplace=True),
        )
        self.refine1 = DWSepResBlock(mid_channels, dilation=1)
        self.refine2 = DWSepResBlock(mid_channels, dilation=2)
        self.spatial = SpatialGate(mid_channels)
        self.cls     = nn.Conv2d(mid_channels, num_classes, 1)

    def forward(self, fi1, flast, out_hw):
        H, W = out_hw
        fi1_up = F.interpolate(fi1, size=flast.shape[-2:], mode='bilinear', align_corners=False)
        z = torch.cat([self.fi1_reduce(fi1_up), self.flast_reduce(flast)], dim=1)
        z = self.fuse(z)
        z = self.refine1(z)
        z = self.refine2(z)
        z = self.spatial(z)
        logits_s4 = self.cls(z)
        return F.interpolate(logits_s4, size=(H, W), mode='bilinear', align_corners=False)

class JEPASegmentationModel(nn.Module):
    """
    JEPA backbone + pixel decoder -> FusionSegHead (no ASPP, no aux).
    """
    def __init__(self, backbone_model, num_classes=150, mid_channels=128):
        super().__init__()
        self.backbone = backbone_model.context_encoder
        self.pixel_decoder = backbone_model.pixel_decoder
        self.embed_dim = backbone_model.embed_dim
        self.ds16 = backbone_model.ds16
        self.ds32 = backbone_model.ds32
        self.head = FusionSegHead(self.embed_dim, mid_channels, num_classes)

    def forward(self, x):
        B, C, H, W = x.shape
        tokens, (enc_h, enc_w) = self.backbone(x)                 # [B, P, D]
        feat = tokens.transpose(1,2).reshape(B, self.embed_dim, enc_h, enc_w)

        # pyramid as in pretrain
        C3  = F.interpolate(feat, size=(H//8,  W//8),  mode='bilinear', align_corners=False)
        x16 = self.ds16(C3)
        x32 = self.ds32(x16)
        C4  = F.interpolate(x16, size=(H//16, W//16), mode='bilinear', align_corners=False)
        C5  = F.interpolate(x32, size=(H//32, W//32), mode='bilinear', align_corners=False)

        Fi1, F_last = self.pixel_decoder([C3, C4, C5], (H, W))    # Fi1 ~ s/8, F_last ~ s/4
        return self.head(Fi1, F_last, (H, W))                     # [B, K, H, W]

def fix_masks(m):
    return torch.where((m < 0) | (m >= NUM_CLASSES),
                       torch.full_like(m, IGNORE_INDEX),
                       m).long()

class StreamingSegMetrics:
    """GPU-side streaming confusion matrix for mIoU/Dice (no CPU stalls)."""
    def __init__(self, num_classes, ignore_index=255, device=None):
        self.C = num_classes
        self.ignore = ignore_index
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.conf = torch.zeros((self.C, self.C), dtype=torch.float64, device=self.device)
    @torch.no_grad()
    def update(self, logits, target):
        pred = logits.argmax(1)
        tgt  = target
        valid = (tgt != self.ignore)
        if valid.any():
            pred = pred[valid]
            tgt  = tgt[valid]
            idx = tgt * self.C + pred
            bins = torch.bincount(idx, minlength=self.C*self.C).reshape(self.C, self.C).to(self.conf.dtype)
            self.conf += bins
    @torch.no_grad()
    def get(self):
        h = self.conf
        diag = torch.diag(h)
        sum_row = h.sum(1)
        sum_col = h.sum(0)
        denom_iou = sum_row + sum_col - diag
        iou  = torch.where(denom_iou > 0, diag / denom_iou, torch.nan)
        dice = torch.where((sum_row + sum_col) > 0, (2*diag) / (sum_row + sum_col), torch.nan)
        miou  = torch.nanmean(iou).item()
        mdice = torch.nanmean(dice).item()
        return miou, mdice
    def reset(self):
        self.conf.zero_()

def visualize_segmentation(images, true_masks, logits, epoch, save_path, num_samples=4):
    pred = logits.argmax(1)
    fig, axes = plt.subplots(3, num_samples, figsize=(16, 12))
    mean = torch.tensor([0.485,0.456,0.406], device=images.device).view(3,1,1)
    std  = torch.tensor([0.229,0.224,0.225], device=images.device).view(3,1,1)
    for i in range(min(num_samples, images.size(0))):
        img = torch.clamp(images[i]*std + mean, 0, 1).permute(1,2,0).cpu().numpy()
        axes[0,i].imshow(img); axes[0,i].set_title(f"Original {i+1}"); axes[0,i].axis('off')
        gt = true_masks[i].cpu().numpy(); gt_rgb = np.zeros((*gt.shape,3))
        pr = pred[i].cpu().numpy();      pr_rgb = np.zeros((*pr.shape,3))
        for cls in range(NUM_CLASSES):
            m1 = (gt==cls); m2 = (pr==cls)
            if m1.any(): gt_rgb[m1] = plt.cm.tab20(cls%20)[:3]
            if m2.any(): pr_rgb[m2] = plt.cm.tab20(cls%20)[:3]
        axes[1,i].imshow(gt_rgb); axes[1,i].set_title(f"Ground Truth {i+1}"); axes[1,i].axis('off')
        axes[2,i].imshow(pr_rgb); axes[2,i].set_title(f"Prediction {i+1}"); axes[2,i].axis('off')
    plt.suptitle(f"Epoch {epoch} - Segmentation Results", fontsize=16)
    plt.tight_layout(); plt.savefig(save_path, dpi=150, bbox_inches='tight'); plt.close()

# -----------------------
# One-time ADE sanity peek
# -----------------------
try:
    peek = next(iter(downstream_train_loader))
    print("ADE20K uniques (peek):", torch.unique(peek["masks"])[:20].cpu())
except Exception as e:
    print("ADE peek skipped:", e)

# -----------------------
# Load JEPA parts & build model
# -----------------------
print("Loading pretrained JEPA model...")
jepa_model = MaskJEPA2D(
    in_chans=3, tau=0.996, fi1_mask_ratio=0.5,
    num_queries=50, num_cross_attn=5, num_self_attn=1, patch_size=8
).to(device)

weights_path = "./jepa_training_output/jepa_model_epoch_1.pt"
if os.path.exists(weights_path):
    ckpt = torch.load(weights_path, map_location=device)
    jepa_model.context_encoder.load_state_dict(ckpt['backbone_state_dict'])
    jepa_model.pixel_decoder.load_state_dict(ckpt['pixel_decoder_state_dict'])
    if 'ds16_state_dict' in ckpt: jepa_model.ds16.load_state_dict(ckpt['ds16_state_dict'])
    if 'ds32_state_dict' in ckpt: jepa_model.ds32.load_state_dict(ckpt['ds32_state_dict'])
    print("Loaded: context_encoder, pixel_decoder",
          "+ ds16" if 'ds16_state_dict' in ckpt else "",
          "+ ds32" if 'ds32_state_dict' in ckpt else "")
else:
    print("WARNING: no JEPA weights found; FT from scratch")

model = JEPASegmentationModel(jepa_model, num_classes=NUM_CLASSES, mid_channels=128).to(device)
model = model.to(memory_format=torch.channels_last)

# -----------------------
# Train config
# -----------------------
num_epochs_ft = 40
base_lr_ft    = 2e-4
weight_decay  = 0.01
print_every   = 50     # batch print interval (CE, mIoU, Dice)
GRAD_ACCUM    = 1      # set >1 if you need micro-batching for memory

def cosine(epoch):
    return 0.5*(1 + np.cos(np.pi*epoch/num_epochs_ft))

optimizer = AdamW(model.parameters(), lr=base_lr_ft, weight_decay=weight_decay)
scheduler = LambdaLR(optimizer, lr_lambda=cosine)
criterion = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)
scaler = GradScaler('cuda')

ft_save_dir = "./jepa_finetuning_output"
os.makedirs(ft_save_dir, exist_ok=True)

print("Starting fine-tuning (Fusion head, CE only)...")
print(f"Train batches: {len(downstream_train_loader)}, Val batches: {len(downstream_val_loader)}")
print(f"Model trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

best_miou = 0.0
train_ce_hist, val_ce_hist, val_miou_hist, val_dice_hist = [], [], [], []

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True

for epoch in range(num_epochs_ft):
    model.train()
    epoch_ce_accum = 0.0

    # per-epoch meters
    epoch_meter = StreamingSegMetrics(NUM_CLASSES, IGNORE_INDEX, device=device)
    print_meter = StreamingSegMetrics(NUM_CLASSES, IGNORE_INDEX, device=device)

    for bidx, batch in enumerate(downstream_train_loader):
        images = batch["images"].to(device, non_blocking=True).to(memory_format=torch.channels_last)
        masks  = batch["masks"].to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        # micro-batching if GRAD_ACCUM > 1
        num_micro = max(1, GRAD_ACCUM)
        chunk = math.ceil(images.size(0) / num_micro)

        loss_sum = 0.0
        for i in range(0, images.size(0), chunk):
            imgs = images[i:i+chunk]
            msks = masks[i:i+chunk]
            with autocast('cuda'):
                logits = model(imgs)
                loss   = criterion(logits, fix_masks(msks)) / num_micro

            # update GPU metrics before backward (keeps VRAM usage similar)
            with torch.no_grad():
                epoch_meter.update(logits, fix_masks(msks))
                print_meter.update(logits, fix_masks(msks))

            scaler.scale(loss).backward()
            loss_sum += float(loss.detach())

        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer); scaler.update()

        epoch_ce_accum += loss_sum

        if (bidx % print_every) == 0:
            miou_b, dice_b = print_meter.get()  # only 2 scalar syncs
            print(f"  Batch {bidx:4d}/{len(downstream_train_loader)} | CE: {loss_sum:.4f} | mIoU: {miou_b:.4f} | Dice: {dice_b:.4f}")
            print_meter.reset()

        del images, masks, logits, loss

    epoch_ce = epoch_ce_accum / max(1, len(downstream_train_loader))
    train_ce_hist.append(epoch_ce)
    scheduler.step()

    # ---- Validation ----
    model.eval()
    val_ce = 0.0
    val_meter = StreamingSegMetrics(NUM_CLASSES, IGNORE_INDEX, device=device)
    with torch.no_grad():
        for batch in downstream_val_loader:
            images = batch["images"].to(device, non_blocking=True).to(memory_format=torch.channels_last)
            masks  = batch["masks"].to(device, non_blocking=True)
            with autocast('cuda'):
                logits = model(images)
                ce = criterion(logits, fix_masks(masks))
            val_ce += ce.item()
            val_meter.update(logits, fix_masks(masks))
            del images, masks, logits, ce

    val_ce /= max(1, len(downstream_val_loader))
    mean_miou, mean_dice = val_meter.get()
    val_ce_hist.append(val_ce); val_miou_hist.append(mean_miou); val_dice_hist.append(mean_dice)

    print(f"Epoch {epoch+1:02d}/{num_epochs_ft} | Train CE: {epoch_ce:.4f} | Val CE: {val_ce:.4f} | mIoU: {mean_miou:.4f} | Dice: {mean_dice:.4f} | LR: {scheduler.get_last_lr()[0]:.2e}")

    # Save best by mIoU
    if mean_miou > best_miou:
        best_miou = mean_miou
        best_path = os.path.join(ft_save_dir, "best_segmentation_model.pt")
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'epoch': epoch,
            'miou': mean_miou,
            'dice': mean_dice
        }, best_path)
        print(f"  New best mIoU! Saved -> {best_path}")

    # Visualizations every 2 epochs
    if (epoch + 1) % 1 == 0:
        with torch.no_grad():
            vis_batch = next(iter(downstream_val_loader))
            vis_images = vis_batch["images"].to(device).to(memory_format=torch.channels_last)
            vis_masks  = vis_batch["masks"].to(device)
            with autocast('cuda'):
                vis_logits = model(vis_images)
            vis_path = os.path.join(ft_save_dir, f"seg_epoch_{epoch+1:03d}.png")
            visualize_segmentation(vis_images, fix_masks(vis_masks), vis_logits, epoch+1, vis_path)
            print(f"  Saved visualization: {vis_path}")
            del vis_batch, vis_images, vis_masks, vis_logits

    torch.cuda.empty_cache(); gc.collect()

print("Fine-tuning completed.")
print(f"Best mIoU: {best_miou:.4f}")

# Save final
final_path = os.path.join(ft_save_dir, "final_segmentation_model.pt")
torch.save({
    'model_state_dict': model.state_dict(),
    'train_ce_losses': train_ce_hist,
    'val_ce_losses': val_ce_hist,
    'val_mious': val_miou_hist,
    'val_dices': val_dice_hist,
    'best_miou': best_miou
}, final_path)
print(f"Final model saved: {final_path}")
