# Vision Transformer (ViT) Technical Deep Dive
## Interactive Onboarding Guide for DINOv3 + ViT-Adapter + Mask2Former

---

Welcome! This notebook serves as a comprehensive technical guide to our Vision Transformer-based semantic segmentation pipeline. You'll learn:

1. **ViT Fundamentals & Attention Visualization** - Understanding how ViT processes images and visualizing attention
2. **Multi-Scale Input Handling** - How positional encodings are interpolated for different resolutions
3. **ViT-Adapter Architecture** - Bridging ViT to dense prediction tasks
4. **MS-Deformable Attention** - Efficient multi-scale attention mechanism
5. **Mask2Former Framework** - Universal segmentation architecture

**Prerequisites:** Familiarity with CNNs, attention mechanisms, and PyTorch.

---
# Setup & Imports

In [None]:
import sys
import os

# Add project root to path
PROJECT_ROOT = os.path.dirname(os.getcwd())
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import warnings
warnings.filterwarnings('ignore')

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

---
# Part 1: Vision Transformer (ViT) Fundamentals & Attention Visualization

## 1.1 How ViT Works: From Pixels to Patches to Predictions

Unlike CNNs that use local convolution kernels, ViT processes images as sequences of patches using self-attention:

```
Input Image (H x W x 3)
      |
      v
+------------------+
| Patch Embedding  |  Split image into P×P patches, flatten, project to D dimensions
+------------------+
      |
      v
+------------------+
| + Positional     |  Add learnable position embeddings to each patch
|   Embeddings     |
+------------------+
      |
      v
+------------------+
| [CLS] + Patches  |  Prepend a learnable [CLS] token
+------------------+
      |
      v
+------------------+
| Transformer      |  L layers of Multi-Head Self-Attention + MLP
| Encoder Blocks   |
+------------------+
      |
      v
Output Features (N+1 tokens × D)
```

### Key Equations:

**Patch Embedding:**
$$z_0 = [x_{\text{cls}}; x_1^{\text{patch}}E; x_2^{\text{patch}}E; ...; x_N^{\text{patch}}E] + E_{\text{pos}}$$

where:
- $E \in \mathbb{R}^{(P^2 \cdot C) \times D}$ is the patch projection matrix
- $E_{\text{pos}} \in \mathbb{R}^{(N+1) \times D}$ are learnable position embeddings
- $N = \frac{H \times W}{P^2}$ is the number of patches

**Self-Attention:**
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

## 1.2 Loading a Pre-trained ViT Model

We'll use `timm` (PyTorch Image Models) to load a pre-trained ViT and explore its architecture.

In [None]:
import timm

# Load a pre-trained ViT model
model_name = 'vit_base_patch16_224'
vit_model = timm.create_model(model_name, pretrained=True)
vit_model.eval()
vit_model.to(device)

# Inspect architecture
print(f"Model: {model_name}")
print(f"Patch size: {vit_model.patch_embed.patch_size}")
print(f"Embed dimension: {vit_model.embed_dim}")
print(f"Number of transformer blocks: {len(vit_model.blocks)}")
print(f"Number of attention heads: {vit_model.blocks[0].attn.num_heads}")

In [None]:
# Download and prepare a sample image
from urllib.request import urlretrieve

# Using a sample satellite image URL (or use your own local image)
sample_image_path = "sample_image.jpg"

# Try to use an existing image from the dataset, or download a sample
try:
    # Check if we have local test images
    local_paths = [
        os.path.join(PROJECT_ROOT, 'data', 'loveDA', 'Test', 'Rural', 'images_png', '0.png'),
        os.path.join(PROJECT_ROOT, 'data', 'loveDA', 'Val', 'Rural', 'images_png', '0.png'),
    ]
    for path in local_paths:
        if os.path.exists(path):
            sample_image_path = path
            break
    else:
        # Download a sample image
        url = "https://upload.wikimedia.org/wikipedia/commons/thumb/4/47/PNG_transparency_demonstration_1.png/280px-PNG_transparency_demonstration_1.png"
        urlretrieve(url, sample_image_path)
except Exception as e:
    print(f"Using placeholder image: {e}")
    # Create a synthetic image
    img = Image.new('RGB', (224, 224), color='blue')
    img.save(sample_image_path)

# Load and display the image
img = Image.open(sample_image_path).convert('RGB')
plt.figure(figsize=(6, 6))
plt.imshow(img)
plt.title('Sample Input Image')
plt.axis('off')
plt.show()

print(f"Image size: {img.size}")

In [None]:
# Prepare image for ViT
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

img_tensor = transform(img).unsqueeze(0).to(device)
print(f"Input tensor shape: {img_tensor.shape}")  # (1, 3, 224, 224)

## 1.3 Understanding Patch Tokenization

Let's visualize how the image is divided into patches:

In [None]:
def visualize_patches(image, patch_size=16):
    """Visualize how an image is divided into patches."""
    img_np = np.array(image.resize((224, 224)))
    H, W = img_np.shape[:2]
    
    fig, ax = plt.subplots(1, 1, figsize=(8, 8))
    ax.imshow(img_np)
    
    # Draw grid
    for i in range(0, H + 1, patch_size):
        ax.axhline(i, color='red', linewidth=0.5)
    for j in range(0, W + 1, patch_size):
        ax.axvline(j, color='red', linewidth=0.5)
    
    # Number some patches
    patch_idx = 0
    for i in range(0, H, patch_size):
        for j in range(0, W, patch_size):
            if patch_idx < 10:  # Only label first 10
                ax.text(j + patch_size//2, i + patch_size//2, str(patch_idx), 
                       ha='center', va='center', color='white', fontsize=8,
                       bbox=dict(boxstyle='round', facecolor='red', alpha=0.7))
            patch_idx += 1
    
    ax.set_title(f'Image divided into {patch_idx} patches ({H//patch_size}x{W//patch_size} grid)\n'
                 f'Each patch: {patch_size}x{patch_size} pixels')
    ax.axis('off')
    plt.tight_layout()
    plt.show()
    
    return patch_idx

num_patches = visualize_patches(img, patch_size=16)
print(f"\nTotal patches: {num_patches}")
print(f"With [CLS] token: {num_patches + 1} tokens")

## 1.4 Attention Map Visualization

### 1.4.1 Extracting Raw Attention Weights

We'll hook into the attention layers to capture attention weights during forward pass.

In [None]:
class AttentionExtractor:
    """Hook-based attention extraction for ViT models."""
    
    def __init__(self, model):
        self.model = model
        self.attentions = []
        self.hooks = []
        
    def _attention_hook(self, module, input, output):
        """Hook to capture attention weights."""
        # For timm ViT, attention output is (attn_output, attn_weights)
        # but we need to modify the forward to return weights
        pass
    
    def get_attention_maps(self, x):
        """Extract attention maps from all layers."""
        self.attentions = []
        
        # We'll manually compute attention for visualization
        # First, get patch embeddings
        B = x.shape[0]
        
        # Patch embedding
        x = self.model.patch_embed(x)
        
        # Add CLS token
        cls_tokens = self.model.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        
        # Add position embedding
        x = x + self.model.pos_embed
        x = self.model.pos_drop(x)
        
        # Process through transformer blocks and capture attention
        for block in self.model.blocks:
            # Get attention weights by computing Q, K, V manually
            B, N, C = x.shape
            qkv = block.attn.qkv(block.norm1(x)).reshape(
                B, N, 3, block.attn.num_heads, C // block.attn.num_heads
            ).permute(2, 0, 3, 1, 4)
            q, k, v = qkv[0], qkv[1], qkv[2]
            
            # Compute attention weights
            attn = (q @ k.transpose(-2, -1)) * block.attn.scale
            attn = attn.softmax(dim=-1)
            self.attentions.append(attn.detach().cpu())
            
            # Continue forward pass normally
            x = x + block.drop_path1(block.ls1(block.attn(block.norm1(x))))
            x = x + block.drop_path2(block.ls2(block.mlp(block.norm2(x))))
        
        return self.attentions

# Alternative simpler approach using forward hooks
def get_attention_maps_simple(model, x):
    """Simpler attention extraction using forward with attention output."""
    attentions = []
    
    def hook_fn(module, input, output):
        # Some models return attention weights as second output
        if isinstance(output, tuple) and len(output) > 1:
            attentions.append(output[1])
    
    hooks = []
    for block in model.blocks:
        hook = block.attn.register_forward_hook(hook_fn)
        hooks.append(hook)
    
    with torch.no_grad():
        _ = model(x)
    
    for hook in hooks:
        hook.remove()
    
    return attentions

# Extract attention maps
extractor = AttentionExtractor(vit_model)
with torch.no_grad():
    attention_maps = extractor.get_attention_maps(img_tensor)

print(f"Extracted attention from {len(attention_maps)} layers")
print(f"Each attention map shape: {attention_maps[0].shape}")
print("  -> (batch, num_heads, num_tokens, num_tokens)")

### 1.4.2 Attention Rollout

**Attention Rollout** recursively combines attention maps from all layers to trace how information flows from input patches to output tokens.

**Algorithm:**
1. For each layer, add identity matrix (residual connection): $\tilde{A}^{(l)} = 0.5 \cdot A^{(l)} + 0.5 \cdot I$
2. Normalize rows to sum to 1
3. Multiply attention matrices: $R = \tilde{A}^{(1)} \cdot \tilde{A}^{(2)} \cdot ... \cdot \tilde{A}^{(L)}$
4. Extract the [CLS] token's attention to all patches

In [None]:
def attention_rollout(attention_maps, discard_ratio=0.1, head_fusion='mean'):
    """
    Compute attention rollout from a list of attention maps.
    
    Args:
        attention_maps: List of attention tensors (B, num_heads, N, N)
        discard_ratio: Fraction of lowest attention values to discard
        head_fusion: How to combine attention heads ('mean', 'max', 'min')
    
    Returns:
        Attention map from [CLS] to all patches
    """
    result = torch.eye(attention_maps[0].shape[-1])
    
    for attention in attention_maps:
        # Fuse attention heads
        if head_fusion == 'mean':
            attention_fused = attention.mean(dim=1)[0]  # (N, N)
        elif head_fusion == 'max':
            attention_fused = attention.max(dim=1)[0][0]
        elif head_fusion == 'min':
            attention_fused = attention.min(dim=1)[0][0]
        
        # Discard lowest attention values
        flat = attention_fused.flatten()
        threshold = flat.kthvalue(int(flat.numel() * discard_ratio))[0]
        attention_fused = attention_fused * (attention_fused > threshold).float()
        
        # Add identity for residual connection
        I = torch.eye(attention_fused.shape[-1])
        attention_fused = 0.5 * attention_fused + 0.5 * I
        
        # Normalize rows
        attention_fused = attention_fused / attention_fused.sum(dim=-1, keepdim=True)
        
        # Accumulate
        result = torch.matmul(attention_fused, result)
    
    # Get attention from [CLS] token (index 0) to all patches
    mask = result[0, 1:]  # Exclude [CLS] token itself
    
    return mask

# Compute attention rollout
rollout_mask = attention_rollout(attention_maps)

# Reshape to 2D grid
num_patches_per_side = int(np.sqrt(len(rollout_mask)))
rollout_2d = rollout_mask.reshape(num_patches_per_side, num_patches_per_side)

print(f"Rollout mask shape: {rollout_2d.shape}")

In [None]:
def visualize_attention(image, attention_map, title="Attention Map"):
    """
    Overlay attention map on original image.
    """
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Original image
    img_resized = image.resize((224, 224))
    axes[0].imshow(img_resized)
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    # Attention heatmap
    attention_np = attention_map.numpy()
    # Normalize to [0, 1]
    attention_np = (attention_np - attention_np.min()) / (attention_np.max() - attention_np.min() + 1e-8)
    
    axes[1].imshow(attention_np, cmap='hot')
    axes[1].set_title('Attention Heatmap')
    axes[1].axis('off')
    
    # Overlay
    # Resize attention to image size
    attention_resized = Image.fromarray((attention_np * 255).astype(np.uint8))
    attention_resized = attention_resized.resize((224, 224), resample=Image.BILINEAR)
    attention_resized = np.array(attention_resized) / 255.0
    
    # Create colormap overlay
    cmap = plt.cm.jet
    heatmap = cmap(attention_resized)[:, :, :3]  # RGB
    
    # Blend with original
    img_np = np.array(img_resized) / 255.0
    overlay = 0.6 * img_np + 0.4 * heatmap
    overlay = np.clip(overlay, 0, 1)
    
    axes[2].imshow(overlay)
    axes[2].set_title(title)
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()

visualize_attention(img, rollout_2d, "Attention Rollout Overlay")

### 1.4.3 Per-Layer Attention Analysis

Let's visualize how attention patterns evolve through the network layers:

In [None]:
def visualize_layer_attentions(attention_maps, layers_to_show=[0, 3, 6, 9, 11]):
    """
    Visualize attention patterns at different layers.
    Shows [CLS] token's attention to all patches.
    """
    num_layers = len(layers_to_show)
    fig, axes = plt.subplots(1, num_layers, figsize=(4*num_layers, 4))
    
    for idx, layer_idx in enumerate(layers_to_show):
        if layer_idx >= len(attention_maps):
            continue
            
        # Get attention for this layer, average across heads
        attn = attention_maps[layer_idx].mean(dim=1)[0]  # (N, N)
        
        # Get [CLS] token's attention to patches (row 0, columns 1:)
        cls_attn = attn[0, 1:].numpy()  # Exclude [CLS] attending to itself
        
        # Reshape to 2D
        size = int(np.sqrt(len(cls_attn)))
        cls_attn_2d = cls_attn.reshape(size, size)
        
        axes[idx].imshow(cls_attn_2d, cmap='viridis')
        axes[idx].set_title(f'Layer {layer_idx + 1}')
        axes[idx].axis('off')
    
    plt.suptitle('[CLS] Token Attention to Patches Across Layers', fontsize=14)
    plt.tight_layout()
    plt.show()

visualize_layer_attentions(attention_maps)

### 1.4.4 Multi-Head Attention Visualization

Different attention heads learn to focus on different aspects of the image:

In [None]:
def visualize_attention_heads(attention_maps, layer_idx=-1, num_heads_to_show=6):
    """
    Visualize individual attention heads from a specific layer.
    """
    attn = attention_maps[layer_idx][0]  # (num_heads, N, N)
    num_heads = min(attn.shape[0], num_heads_to_show)
    
    fig, axes = plt.subplots(2, num_heads // 2, figsize=(3 * (num_heads // 2), 6))
    axes = axes.flatten()
    
    for head_idx in range(num_heads):
        # Get [CLS] attention to patches for this head
        head_attn = attn[head_idx, 0, 1:].numpy()
        size = int(np.sqrt(len(head_attn)))
        head_attn_2d = head_attn.reshape(size, size)
        
        axes[head_idx].imshow(head_attn_2d, cmap='magma')
        axes[head_idx].set_title(f'Head {head_idx + 1}')
        axes[head_idx].axis('off')
    
    plt.suptitle(f'Attention Heads from Layer {layer_idx if layer_idx >= 0 else len(attention_maps) + layer_idx + 1}', fontsize=14)
    plt.tight_layout()
    plt.show()

visualize_attention_heads(attention_maps, layer_idx=-1)  # Last layer

---
# Part 2: Handling Multi-Scale Inputs via Positional Encoding Interpolation

## 2.1 The Fixed Position Encoding Problem

### Why Standard ViT Fails on Different Resolutions

ViT learns positional embeddings for a **fixed number of patches**. For example:
- Training size: 224×224 with 16×16 patches → 14×14 = 196 patches
- Position embedding shape: (1, 197, D) including [CLS]

**Problem:** What happens if we want to use 512×512 images?
- 512×512 with 16×16 patches → 32×32 = 1024 patches
- We only have 196 position embeddings!

```
Training Resolution:              Inference at Higher Resolution:
+-+-+-+-+                         +-+-+-+-+-+-+-+-+
|0|1|2|3|   14×14 = 196 patches  |0|1|2|3|4|5|6|7|   32×32 = 1024 patches
+-+-+-+-+   with known positions +-+-+-+-+-+-+-+-+   MISSING positions!
|4|5|...|                        |8|9|...
```

## 2.2 Solution: 2D Bicubic Interpolation of Position Embeddings

The key insight is that position embeddings have **2D spatial structure**:

1. **Reshape** the flat position embeddings to a 2D grid
2. **Interpolate** to the target resolution using bicubic interpolation
3. **Flatten** back to 1D sequence

This works because position embeddings encode **relative spatial relationships**, which should be preserved under smooth interpolation.

In [None]:
def interpolate_pos_encoding(
    pos_embed: torch.Tensor,
    src_size: tuple,  # (H_src, W_src) in patches
    tgt_size: tuple,  # (H_tgt, W_tgt) in patches
    num_extra_tokens: int = 1,  # Usually 1 for [CLS], DINOv3 uses 5 (1 CLS + 4 registers)
) -> torch.Tensor:
    """
    Interpolate positional encodings from source to target resolution.
    
    This is the core function used in our codebase at:
    models/backbone/dinov3_adapter.py:387-396 (_get_pos_embed method)
    
    Args:
        pos_embed: Position embeddings of shape (1, N_src + extra, D)
        src_size: Original grid size in patches (H_src, W_src)
        tgt_size: Target grid size in patches (H_tgt, W_tgt)
        num_extra_tokens: Number of non-patch tokens ([CLS], registers)
    
    Returns:
        Interpolated position embeddings of shape (1, N_tgt + extra, D)
    """
    D = pos_embed.shape[-1]
    H_src, W_src = src_size
    H_tgt, W_tgt = tgt_size
    
    # Separate extra tokens (CLS, registers) from patch embeddings
    extra_tokens = pos_embed[:, :num_extra_tokens, :]  # (1, extra, D)
    patch_pos_embed = pos_embed[:, num_extra_tokens:, :]  # (1, H*W, D)
    
    # Reshape to 2D spatial grid: (1, H, W, D) -> (1, D, H, W)
    patch_pos_embed = patch_pos_embed.reshape(1, H_src, W_src, D).permute(0, 3, 1, 2)
    
    # Bicubic interpolation to target size
    patch_pos_embed = F.interpolate(
        patch_pos_embed,
        size=(H_tgt, W_tgt),
        mode='bicubic',
        align_corners=False
    )
    
    # Reshape back: (1, D, H, W) -> (1, H*W, D)
    patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).reshape(1, H_tgt * W_tgt, D)
    
    # Concatenate extra tokens back
    interpolated = torch.cat([extra_tokens, patch_pos_embed], dim=1)
    
    return interpolated

# Example: Interpolate from 14×14 (224px) to 32×32 (512px)
print("Position Embedding Interpolation Demo:")
print("="*50)

# Original position embeddings (simulated)
src_H, src_W = 14, 14  # 224×224 image with 16×16 patches
embed_dim = 768

# Create dummy position embeddings
pos_embed_original = torch.randn(1, 1 + src_H * src_W, embed_dim)
print(f"Original pos_embed shape: {pos_embed_original.shape}")
print(f"  - 1 CLS token + {src_H}×{src_W} = {src_H*src_W} patch tokens")

# Interpolate to larger resolution
tgt_H, tgt_W = 32, 32  # 512×512 image
pos_embed_interpolated = interpolate_pos_encoding(
    pos_embed_original,
    src_size=(src_H, src_W),
    tgt_size=(tgt_H, tgt_W),
    num_extra_tokens=1
)
print(f"\nInterpolated pos_embed shape: {pos_embed_interpolated.shape}")
print(f"  - 1 CLS token + {tgt_H}×{tgt_W} = {tgt_H*tgt_W} patch tokens")

In [None]:
def visualize_pos_encoding_interpolation():
    """
    Visualize the effect of position encoding interpolation.
    """
    # Create position embeddings that encode position (for visualization)
    src_H, src_W = 14, 14
    D = 3  # Use 3 dims for RGB visualization
    
    # Create embeddings that encode (x, y, x*y) for visualization
    pos = torch.zeros(1, 1 + src_H * src_W, D)
    for i in range(src_H):
        for j in range(src_W):
            idx = 1 + i * src_W + j  # Skip CLS
            pos[0, idx, 0] = i / (src_H - 1)  # Normalized row position
            pos[0, idx, 1] = j / (src_W - 1)  # Normalized col position  
            pos[0, idx, 2] = (i + j) / (src_H + src_W - 2)  # Diagonal
    
    # Interpolate to higher resolution
    tgt_H, tgt_W = 45, 45  # 720×720 image (our training size!)
    pos_interp = interpolate_pos_encoding(
        pos, (src_H, src_W), (tgt_H, tgt_W), num_extra_tokens=1
    )
    
    # Visualize
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    # Original
    orig_vis = pos[0, 1:].reshape(src_H, src_W, 3).numpy()
    axes[0].imshow(orig_vis)
    axes[0].set_title(f'Original Position Encoding\n{src_H}×{src_W} = {src_H*src_W} patches\n(224×224 image)')
    axes[0].axis('off')
    
    # Interpolated
    interp_vis = pos_interp[0, 1:].reshape(tgt_H, tgt_W, 3).numpy()
    axes[1].imshow(interp_vis)
    axes[1].set_title(f'Interpolated Position Encoding\n{tgt_H}×{tgt_W} = {tgt_H*tgt_W} patches\n(720×720 image)')
    axes[1].axis('off')
    
    plt.suptitle('Position Encoding Interpolation via Bicubic Upsampling', fontsize=14)
    plt.tight_layout()
    plt.show()

visualize_pos_encoding_interpolation()

## 2.3 DINOv3 Position Encoding in Our Codebase

DINOv3 uses a slightly different token structure with **register tokens**:

```
Token Sequence: [CLS, patch_1, patch_2, ..., patch_N, reg_1, reg_2, reg_3, reg_4]
```

Our `DINOv3CompatibilityWrapper` handles this in `dinov3_mask2former_integration.py:85-155`:

In [None]:
# Code from our codebase showing how DINOv3 token structure is handled
dinov3_token_handling = """
# From dinov3_mask2former_integration.py

class DINOv3CompatibilityWrapper(nn.Module):
    '''
    Wraps HuggingFace DINOv3 to be compatible with adapter expectations.
    
    DINOv3 output structure:
    - Token 0: [CLS] token
    - Tokens 1 to N: Patch tokens (N = H*W / patch_size^2)
    - Tokens N+1 to N+4: Register tokens (4 registers)
    '''
    
    def get_intermediate_layers(self, x, n, return_class_token=True):
        # Get features at specific layers
        outputs = self.model(
            x, output_hidden_states=True, return_dict=True
        )
        hidden_states = outputs.hidden_states[1:]  # Skip embedding
        
        results = []
        for idx in n:  # n = interaction_indexes like [4, 11, 17, 23]
            h = hidden_states[idx]
            
            # Extract components from token sequence
            cls_token = h[:, 0:1, :]           # First token is CLS
            patch_tokens = h[:, 1:-4, :]       # Middle tokens are patches
            # register_tokens = h[:, -4:, :]  # Last 4 are registers (unused)
            
            results.append((patch_tokens, cls_token))
        
        return results
"""
print(dinov3_token_handling)

---
# Part 3: The ViT-Adapter Architecture

## 3.1 Why Do We Need an Adapter?

### The Gap Between ViT and Dense Prediction

**Plain ViT outputs:**
- Single-scale features at patch resolution (H/16, W/16)
- Global receptive field from layer 1 (no local inductive bias)
- 1D sequence without explicit spatial structure

**Dense prediction (segmentation) needs:**
- Multi-scale features (H/4, H/8, H/16, H/32) for FPN
- Local features for fine boundaries
- Spatial relationships preserved

**ViT-Adapter bridges this gap!**

## 3.2 ViT-Adapter Architecture Overview

```
                              ViT Transformer (Frozen)
                              ========================
Input Image                   Layer 4    Layer 11   Layer 17   Layer 23
    |                            |           |          |          |
    |                            |           |          |          |
    v                            v           v          v          v
+--------+                   [Patch Tokens + CLS from each layer]
| Spatial|                       |           |          |          |
| Prior  |--c2,c3,c4-->[Interaction Blocks with Deformable Attention]
| Module |                       |           |          |          |
+--------+                       v           v          v          v
    |                        [Refined Multi-Scale Features]
    |                            |           |          |          |
    |                            v           v          v          v
    +------------------------> f1(H/4)    f2(H/8)   f3(H/16)  f4(H/32)
                                 |           |          |          |
                                 +-----------+----------+----------+
                                             |
                                             v
                                    [FPN / Pixel Decoder]
                                             |
                                             v
                                    [Mask2Former Head]
```

## 3.3 Spatial Prior Module (SPM)

The SPM is a lightweight CNN that provides **local inductive bias** and **multi-scale spatial features**.

**Architecture:**
```
Input: (B, 3, H, W)
    |
    v
Stem: Conv(3→64, k=3) → BN → ReLU → Conv(64→64, k=3) → BN → ReLU → MaxPool
    |                                                              |
    |                                                              v
    |                                                    c1: (B, 64, H/4, W/4)
    |                                                              |
    v                                                              |
Conv2: Conv(64→128, stride=2) → BN → ReLU                        |
    |                                                              |
    v                                                              |
fc2: Conv(128→1024, k=1) → c2: (B, 1024, H/8, W/8) ←--------------+
    |                                                              
    v                                                              
Conv3: Conv(128→256, stride=2) → BN → ReLU                        
    |                                                              
    v                                                              
fc3: Conv(256→1024, k=1) → c3: (B, 1024, H/16, W/16)              
    |                                                              
    v                                                              
Conv4: Conv(256→256, stride=2) → BN → ReLU                        
    |                                                              
    v                                                              
fc4: Conv(256→1024, k=1) → c4: (B, 1024, H/32, W/32)
```

In [None]:
# Simplified implementation of Spatial Prior Module
class SpatialPriorModule(nn.Module):
    """
    Generates multi-scale spatial priors using convolutional stem.
    
    This provides the local inductive bias that ViT lacks.
    Located at: models/backbone/dinov3_adapter.py:234-302
    """
    
    def __init__(self, inplanes=64, embed_dim=1024):
        super().__init__()
        
        # Stem: 3 conv layers + max pool → H/4 resolution
        self.stem = nn.Sequential(
            nn.Conv2d(3, inplanes, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(inplanes),
            nn.ReLU(inplace=True),
            nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(inplanes),
            nn.ReLU(inplace=True),
            nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(inplanes),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        )
        
        # Downsample stages
        self.conv2 = nn.Sequential(
            nn.Conv2d(inplanes, inplanes * 2, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(inplanes * 2),
            nn.ReLU(inplace=True),
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(inplanes * 2, inplanes * 4, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(inplanes * 4),
            nn.ReLU(inplace=True),
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(inplanes * 4, inplanes * 4, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(inplanes * 4),
            nn.ReLU(inplace=True),
        )
        
        # Projection to embedding dimension
        self.fc1 = nn.Conv2d(inplanes, embed_dim, kernel_size=1)
        self.fc2 = nn.Conv2d(inplanes * 2, embed_dim, kernel_size=1)
        self.fc3 = nn.Conv2d(inplanes * 4, embed_dim, kernel_size=1)
        self.fc4 = nn.Conv2d(inplanes * 4, embed_dim, kernel_size=1)
    
    def forward(self, x):
        # x: (B, 3, H, W)
        c1 = self.stem(x)          # (B, 64, H/4, W/4)
        c2 = self.conv2(c1)        # (B, 128, H/8, W/8)
        c3 = self.conv3(c2)        # (B, 256, H/16, W/16)
        c4 = self.conv4(c3)        # (B, 256, H/32, W/32)
        
        # Project to embedding dimension and flatten to tokens
        c1 = self.fc1(c1)  # (B, 1024, H/4, W/4)
        c2 = self.fc2(c2).flatten(2).transpose(1, 2)  # (B, N2, 1024)
        c3 = self.fc3(c3).flatten(2).transpose(1, 2)  # (B, N3, 1024)
        c4 = self.fc4(c4).flatten(2).transpose(1, 2)  # (B, N4, 1024)
        
        return c1, c2, c3, c4

# Demo
spm = SpatialPriorModule(inplanes=64, embed_dim=1024)
dummy_input = torch.randn(1, 3, 720, 720)  # Our training resolution
c1, c2, c3, c4 = spm(dummy_input)

print("Spatial Prior Module Output Shapes:")
print(f"  c1 (H/4):  {c1.shape}  - kept as spatial for final upsample")
print(f"  c2 (H/8):  {c2.shape}  - {720//8}×{720//8} = {(720//8)**2} tokens")
print(f"  c3 (H/16): {c3.shape}  - {720//16}×{720//16} = {(720//16)**2} tokens")
print(f"  c4 (H/32): {c4.shape}  - {720//32}×{720//32} = {(720//32)**2} tokens")

## 3.4 Interaction Blocks (Injectors)

Interaction blocks **fuse** spatial priors from SPM with semantic features from ViT using **MS-Deformable Attention**.

```
ViT Patch Tokens (single scale)    SPM Tokens (multi-scale: c2, c3, c4)
         |                                      |
         v                                      v
    [Add level embed]                    [Concatenate all scales]
         |                                      |
         +-------------> Deformable <-----------+
                         Cross-Attention
                              |
                              v
                    [Query = ViT tokens]
                    [Key/Value = SPM tokens]
                              |
                              v
              [ViT tokens enriched with local spatial info]
```

**Why this works:**
- SPM provides **local features** (edges, textures) via convolutions
- ViT provides **global semantics** via self-attention
- Deformable attention allows **adaptive fusion** - each ViT token queries relevant SPM locations

## 3.5 Complete ViT-Adapter Forward Pass

Here's the conceptual flow (simplified from `models/backbone/dinov3_adapter.py:408-484`):

In [None]:
# Conceptual forward pass (simplified pseudocode)
vit_adapter_forward = """
def forward(self, x):
    '''
    ViT-Adapter forward pass.
    Input: x (B, 3, H, W)
    Output: dict of multi-scale features {"1": f1, "2": f2, "3": f3, "4": f4}
    '''
    
    # 1. Generate spatial priors via CNN stem
    c1, c2, c3, c4 = self.spm(x)  # Multi-scale local features
    
    # 2. Add learnable level embeddings to distinguish scales
    c2 = c2 + self.level_embed[0]  # H/8 scale
    c3 = c3 + self.level_embed[1]  # H/16 scale
    c4 = c4 + self.level_embed[2]  # H/32 scale
    
    # 3. Concatenate all SPM tokens for deformable attention
    c = torch.cat([c2, c3, c4], dim=1)  # (B, N2+N3+N4, D)
    
    # 4. Get ViT intermediate features at interaction points
    # interaction_indexes = [4, 11, 17, 23] for 24-layer ViT
    vit_features = backbone.get_intermediate_layers(
        x, n=self.interaction_indexes, return_class_token=True
    )
    
    # 5. Interaction: Fuse ViT and SPM features
    outs = []
    for i, interaction_block in enumerate(self.interactions):
        vit_tokens, cls_token = vit_features[i]
        
        # Deformable attention: ViT queries SPM
        vit_tokens, c, cls = interaction_block(
            vit_tokens, c, cls_token,
            deform_inputs1, deform_inputs2,
            H, W
        )
        outs.append(vit_tokens)
    
    # 6. Split concatenated tokens back to scales
    c2, c3, c4 = split_by_scale(c)
    
    # 7. Reshape tokens to spatial feature maps
    c2 = c2.transpose(1,2).view(B, D, H//8, W//8)
    c3 = c3.transpose(1,2).view(B, D, H//16, W//16)
    c4 = c4.transpose(1,2).view(B, D, H//32, W//32)
    c1 = upsample(c2) + c1  # High-res via upsampling
    
    # 8. Add ViT features to output (if enabled)
    if self.add_vit_feature:
        x1, x2, x3, x4 = reshape_vit_outputs(outs)
        c1, c2, c3, c4 = c1+x1, c2+x2, c3+x3, c4+x4
    
    # 9. Final BatchNorm
    f1 = self.norm1(c1)  # H/4
    f2 = self.norm2(c2)  # H/8
    f3 = self.norm3(c3)  # H/16
    f4 = self.norm4(c4)  # H/32
    
    return {"1": f1, "2": f2, "3": f3, "4": f4}
"""
print(vit_adapter_forward)

---
# Part 4: Multi-Scale Deformable Attention (MS-Deform-Attn)

## 4.1 The Problem with Standard Attention for Dense Prediction

**Standard Self-Attention:**
- Complexity: $O(N^2)$ where $N$ = number of tokens
- For a 720×720 image with 16×16 patches: $N = 45 \times 45 = 2025$ tokens
- Attention matrix: $2025 \times 2025 \approx 4M$ elements per head

**For multi-scale dense prediction, it's worse:**
- H/8: 90×90 = 8,100 tokens
- H/16: 45×45 = 2,025 tokens  
- H/32: 22×22 = 484 tokens
- Total: ~10,600 tokens → $O(112M)$ attention elements!

**Solution: Deformable Attention** - Only attend to a **small fixed number of sampling points** per query.

## 4.2 Deformable Attention: Key Concepts

### Reference Points
Each query token has a **reference point** - a normalized (x, y) coordinate indicating its "home" position:

```
Reference Point Grid (for H/16 = 45×45 tokens):

(0.0, 0.0)  (0.02, 0.0)  ...  (1.0, 0.0)
(0.0, 0.02) (0.02, 0.02) ...  (1.0, 0.02)
    ...          ...              ...
(0.0, 1.0)  (0.02, 1.0)  ...  (1.0, 1.0)
```

### Sampling Offsets
Instead of attending to ALL positions, each query **learns offsets** to sample a small number of points:

$$\text{sampling\_location} = \text{reference\_point} + \Delta p$$

where $\Delta p$ is predicted by a linear layer from the query.

### Multi-Scale Sampling
For each query, sample points at **multiple scales** (H/8, H/16, H/32):

```
Query at position (0.5, 0.5):

Scale H/8:   Sample 4 points near (0.5, 0.5)
Scale H/16:  Sample 4 points near (0.5, 0.5)
Scale H/32:  Sample 4 points near (0.5, 0.5)

Total: 3 scales × 4 points × 8 heads = 96 points vs 10,600 in full attention!
```

## 4.3 Mathematical Formulation

**Multi-Scale Deformable Attention:**

$$\text{MSDeformAttn}(z_q, \hat{p}_q, \{x^l\}_{l=1}^{L}) = \sum_{m=1}^{M} W_m \left[ \sum_{l=1}^{L} \sum_{k=1}^{K} A_{mlqk} \cdot W'_m x^l(\phi_l(\hat{p}_q) + \Delta p_{mlqk}) \right]$$

Where:
- $z_q$: Query embedding
- $\hat{p}_q$: Reference point (normalized coordinates)
- $x^l$: Features at scale $l$ (out of $L$ total scales)
- $M$: Number of attention heads
- $K$: Number of sampling points per head per scale (typically 4)
- $A_{mlqk}$: Attention weight (learnable, normalized via softmax)
- $\Delta p_{mlqk}$: Sampling offset (learnable)
- $W_m, W'_m$: Projection matrices
- $\phi_l$: Scale-specific coordinate adjustment

**Complexity:** $O(N \cdot M \cdot L \cdot K)$ - **linear** in the number of queries!

In [None]:
# Simplified MS-Deformable Attention Implementation
# Full implementation at: models/utils/ms_deform_attn.py

class MSDeformAttnSimplified(nn.Module):
    """
    Simplified Multi-Scale Deformable Attention for understanding.
    
    Key insight: Instead of attending to all N positions,
    each query only samples K points per scale per head.
    """
    
    def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
        super().__init__()
        self.d_model = d_model
        self.n_levels = n_levels
        self.n_heads = n_heads
        self.n_points = n_points
        
        # Predict sampling offsets: query -> (n_heads * n_levels * n_points * 2)
        self.sampling_offsets = nn.Linear(
            d_model, n_heads * n_levels * n_points * 2
        )
        
        # Predict attention weights: query -> (n_heads * n_levels * n_points)
        self.attention_weights = nn.Linear(
            d_model, n_heads * n_levels * n_points
        )
        
        # Value projection
        self.value_proj = nn.Linear(d_model, d_model)
        
        # Output projection
        self.output_proj = nn.Linear(d_model, d_model)
    
    def forward(self, query, reference_points, input_flatten, 
                input_spatial_shapes, input_level_start_index):
        """
        Args:
            query: (N, Len_q, d_model) - Query embeddings
            reference_points: (N, Len_q, n_levels, 2) - Normalized (x,y) coords
            input_flatten: (N, sum(H_l*W_l), d_model) - Flattened multi-scale features
            input_spatial_shapes: (n_levels, 2) - [(H_0, W_0), (H_1, W_1), ...]
            input_level_start_index: (n_levels,) - Starting index for each level
        
        Returns:
            output: (N, Len_q, d_model)
        """
        N, Len_q, _ = query.shape
        N, Len_in, _ = input_flatten.shape
        
        # 1. Project values
        value = self.value_proj(input_flatten)  # (N, Len_in, d_model)
        
        # 2. Predict sampling offsets from query
        sampling_offsets = self.sampling_offsets(query)  # (N, Len_q, n_heads*n_levels*n_points*2)
        sampling_offsets = sampling_offsets.view(
            N, Len_q, self.n_heads, self.n_levels, self.n_points, 2
        )
        
        # 3. Predict attention weights from query
        attention_weights = self.attention_weights(query)  # (N, Len_q, n_heads*n_levels*n_points)
        attention_weights = F.softmax(
            attention_weights.view(N, Len_q, self.n_heads, self.n_levels * self.n_points),
            dim=-1
        ).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
        
        # 4. Compute sampling locations: reference_point + offset
        # (Actual implementation uses scale-specific normalization)
        sampling_locations = reference_points[:, :, None, :, None, :] + sampling_offsets
        
        # 5. Sample values at computed locations using bilinear interpolation
        # (This is where the CUDA kernel provides speedup)
        # output = sample_and_aggregate(value, sampling_locations, attention_weights)
        
        # 6. Project output
        # output = self.output_proj(output)
        
        return None  # Simplified - actual impl returns sampled output

# Demo: Show complexity comparison
print("Complexity Comparison (720×720 image):")
print("="*50)

# Standard attention
N_h8 = (720 // 8) ** 2   # 8100 tokens at H/8
N_h16 = (720 // 16) ** 2  # 2025 tokens at H/16
N_h32 = (720 // 32) ** 2  # 506 tokens at H/32
N_total = N_h8 + N_h16 + N_h32

standard_complexity = N_total ** 2
print(f"Total multi-scale tokens: {N_total:,}")
print(f"Standard attention: O(N²) = {standard_complexity:,} elements")

# Deformable attention
n_heads = 16
n_levels = 3
n_points = 4
deform_complexity = N_total * n_heads * n_levels * n_points
print(f"\nDeformable attention: O(N × M × L × K)")
print(f"  N={N_total}, M={n_heads} heads, L={n_levels} levels, K={n_points} points")
print(f"  = {deform_complexity:,} elements")

print(f"\nSpeedup: {standard_complexity / deform_complexity:.1f}x fewer operations!")

In [None]:
def visualize_deformable_sampling():
    """
    Visualize how deformable attention samples points at different scales.
    """
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    scales = [(90, 90, 'H/8'), (45, 45, 'H/16'), (22, 22, 'H/32')]
    colors = ['red', 'blue', 'green', 'orange']
    
    # Reference point (center of image)
    ref_x, ref_y = 0.5, 0.5
    
    for idx, (H, W, label) in enumerate(scales):
        ax = axes[idx]
        
        # Draw grid representing feature map
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
        
        # Grid lines
        for i in range(H + 1):
            ax.axhline(i / H, color='lightgray', linewidth=0.5)
        for j in range(W + 1):
            ax.axvline(j / W, color='lightgray', linewidth=0.5)
        
        # Reference point
        ax.scatter([ref_x], [ref_y], c='black', s=200, marker='*', 
                   zorder=10, label='Reference Point')
        
        # Simulate 4 sampling points with learned offsets
        np.random.seed(42 + idx)
        for k in range(4):
            # Random offset (in practice, these are learned)
            offset_x = np.random.uniform(-0.15, 0.15)
            offset_y = np.random.uniform(-0.15, 0.15)
            sample_x = ref_x + offset_x
            sample_y = ref_y + offset_y
            
            # Draw sampling point
            ax.scatter([sample_x], [sample_y], c=colors[k], s=100, 
                      marker='o', zorder=5)
            # Draw offset arrow
            ax.annotate('', xy=(sample_x, sample_y), xytext=(ref_x, ref_y),
                       arrowprops=dict(arrowstyle='->', color=colors[k], lw=2))
        
        ax.set_title(f'{label} Scale\n({H}×{W} = {H*W} positions)\n4 sampled points')
        ax.set_aspect('equal')
        ax.invert_yaxis()  # Match image coordinates
    
    plt.suptitle('Multi-Scale Deformable Attention: Sparse Sampling per Query', fontsize=14)
    plt.tight_layout()
    plt.show()
    
    print("\nKey Insight: Instead of attending to all ~10,600 positions,")
    print("each query samples only 4 points per scale = 12 total points!")

visualize_deformable_sampling()

---
# Part 5: The Mask2Former Framework

## 5.1 Overview: Universal Segmentation

Mask2Former unifies semantic, instance, and panoptic segmentation under one architecture:

```
                        Input Image
                             |
                             v
                    +----------------+
                    | Backbone       |  ← DINOv3 + ViT-Adapter
                    | (Multi-scale)  |    Outputs: {f1, f2, f3, f4}
                    +----------------+
                             |
                             v
                    +----------------+
                    | Pixel Decoder  |  ← FPN-style feature fusion
                    | (FPN + Deform) |    Outputs: High-res features
                    +----------------+
                             |
                             v
                    +----------------+
                    | Transformer    |  ← Query-based decoding
                    | Decoder        |    with Masked Attention
                    +----------------+
                         /     \
                        /       \
                       v         v
              +------------+  +------------+
              | Class Head |  | Mask Head  |
              +------------+  +------------+
                    |              |
                    v              v
              Class Logits    Mask Logits
              (N×C)           (N×H×W)
```

## 5.2 Component 1: Pixel Decoder

The Pixel Decoder fuses multi-scale backbone features into high-resolution feature maps.

### Architecture (MSDeformAttn-based):

```
Backbone Features:     f4 (H/32)    f3 (H/16)    f2 (H/8)    f1 (H/4)
                           |            |            |           |
                           v            v            v           v
                       [Lateral Convs - project to 256 channels]
                           |            |            |           |
                           v            v            v           v
                       [MS-Deform-Attn Transformer Encoder]
                           |            |            |           |
                           v            v            v           v
                       +--- FPN Upsample Path (top-down) ---+
                           |            |            |           |
                           v            v            v           v
                      out_32        out_16       out_8       out_4
                                                               |
                                                               v
                                                    [Pixel Features]
                                                    (B, 256, H/4, W/4)
```

**Purpose:** Create rich, high-resolution features that the mask head can use for precise boundaries.

## 5.3 Component 2: Transformer Decoder with Masked Attention

### Query-Based Approach

Instead of per-pixel classification, Mask2Former uses **learnable queries** that each predict one mask:

```
Learnable Queries: Q = {q_1, q_2, ..., q_N}  (e.g., N=100)

Each query learns to:
  - Focus on one object/segment
  - Predict its class
  - Predict its mask
```

### Masked Cross-Attention

The key innovation in Mask2Former is **masked attention** in the cross-attention layers:

$$\text{MaskedCrossAttn}(Q, K, V, M) = \text{softmax}\left(\frac{QK^T + M}{\sqrt{d_k}}\right)V$$

Where $M$ is derived from the predicted mask from the **previous decoder layer**:

$$M_{ij} = \begin{cases} 0 & \text{if } \hat{m}_i[j] > 0.5 \\ -\infty & \text{otherwise} \end{cases}$$

**Effect:** Each query only attends to pixels within its predicted mask region!

In [None]:
def visualize_masked_attention():
    """
    Visualize how masked attention restricts cross-attention to predicted regions.
    """
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    # Create synthetic image with 3 regions
    H, W = 64, 64
    image = np.zeros((H, W, 3))
    
    # Region 1: Top-left circle (sky)
    y, x = np.ogrid[:H, :W]
    mask1 = ((x - 16)**2 + (y - 16)**2) < 12**2
    image[mask1] = [0.5, 0.7, 1.0]  # Light blue
    
    # Region 2: Center rectangle (building)
    mask2 = (x >= 24) & (x < 48) & (y >= 20) & (y < 50)
    image[mask2] = [0.8, 0.6, 0.4]  # Brown
    
    # Region 3: Bottom (ground)
    mask3 = (y >= 50) & ~mask2
    image[mask3] = [0.3, 0.6, 0.3]  # Green
    
    # Row 1: Image and predicted masks
    axes[0, 0].imshow(image)
    axes[0, 0].set_title('Input Image')
    axes[0, 0].axis('off')
    
    axes[0, 1].imshow(mask1.astype(float), cmap='Blues')
    axes[0, 1].set_title('Query 1 Mask\n(Sky)')
    axes[0, 1].axis('off')
    
    axes[0, 2].imshow(mask2.astype(float), cmap='Oranges')
    axes[0, 2].set_title('Query 2 Mask\n(Building)')
    axes[0, 2].axis('off')
    
    # Row 2: Attention patterns
    # Without masking - all pixels visible
    attn_unmasked = np.random.rand(H, W)
    axes[1, 0].imshow(attn_unmasked, cmap='hot')
    axes[1, 0].set_title('Standard Cross-Attention\n(Query attends to ALL pixels)')
    axes[1, 0].axis('off')
    
    # With masking - Query 1 only sees sky region
    attn_masked1 = np.random.rand(H, W) * mask1.astype(float)
    axes[1, 1].imshow(attn_masked1, cmap='hot')
    axes[1, 1].set_title('Masked Attention (Query 1)\n(Only attends to sky pixels)')
    axes[1, 1].axis('off')
    
    # With masking - Query 2 only sees building region
    attn_masked2 = np.random.rand(H, W) * mask2.astype(float)
    axes[1, 2].imshow(attn_masked2, cmap='hot')
    axes[1, 2].set_title('Masked Attention (Query 2)\n(Only attends to building pixels)')
    axes[1, 2].axis('off')
    
    plt.suptitle('Mask2Former: Masked Cross-Attention Visualization', fontsize=14)
    plt.tight_layout()
    plt.show()

visualize_masked_attention()

## 5.4 Why Masked Attention Improves Convergence

### Problem with Standard Cross-Attention:
- Query attends to **all pixels** equally at first
- Gradients are distributed across the entire image
- Learning signal is diluted

### Solution with Masked Attention:
1. **Focused Learning:** Each query only receives gradients from its relevant region
2. **Iterative Refinement:** Masks improve over decoder layers, attention becomes more focused
3. **Prevents Interference:** Different queries don't compete for the same pixels

```
Decoder Layer 1:  Rough masks → Wide attention regions
                       ↓
Decoder Layer 2:  Better masks → Tighter attention
                       ↓
Decoder Layer 3:  Refined masks → Precise attention
                       ↓
         ...           ↓
Decoder Layer 9:  Final masks → Highly focused attention
```

## 5.5 Complete Mask2Former Decoder Flow

In [None]:
mask2former_decoder_flow = """
Mask2Former Transformer Decoder (9 layers, iterative refinement)
================================================================

Inputs:
  - queries: (B, N, 256) - N learnable query embeddings (N=100)
  - pixel_features: (B, 256, H/4, W/4) - from Pixel Decoder
  - multi_scale_features: {H/8, H/16, H/32} - from Pixel Decoder

For each decoder layer l in [1, 2, ..., 9]:
    
    1. Masked Cross-Attention
       -------------------------
       if l > 1:
           mask_l = sigmoid(mask_predictions[l-1])  # From previous layer
           attention_mask = (mask_l > 0.5)          # Binary mask
       else:
           attention_mask = None  # Layer 1: no masking
       
       queries = MaskedCrossAttn(
           query=queries,
           key=pixel_features,
           value=pixel_features,
           mask=attention_mask
       )
    
    2. Self-Attention (among queries)
       -------------------------
       queries = SelfAttn(queries)  # Queries interact with each other
    
    3. FFN
       -------------------------
       queries = FFN(queries)
    
    4. Prediction Heads
       -------------------------
       class_logits[l] = ClassHead(queries)     # (B, N, num_classes+1)
       mask_logits[l] = MaskHead(queries, pixel_features)  # (B, N, H, W)

Output:
  - Final predictions from last layer
  - Auxiliary predictions from intermediate layers (for training)
"""
print(mask2former_decoder_flow)

## 5.6 Task-Agnostic Queries

One powerful aspect of Mask2Former is that **the same queries work for all segmentation tasks**:

| Task | Query Interpretation | Post-Processing |
|------|---------------------|----------------|
| **Semantic Segmentation** | Each query = one semantic class | Merge masks by class ID |
| **Instance Segmentation** | Each query = one object instance | Keep masks separate |
| **Panoptic Segmentation** | Queries = stuff + things | Combine semantic + instance |

The only difference is in **post-processing**, not the model architecture!

---
# Part 6: Putting It All Together - Our Pipeline

## 6.1 Complete Architecture Diagram

In [None]:
architecture_diagram = r"""
╔═══════════════════════════════════════════════════════════════════════════════════╗
║                    DINOv3 + ViT-Adapter + Mask2Former Pipeline                    ║
╠═══════════════════════════════════════════════════════════════════════════════════╣
║                                                                                   ║
║  Input Image (720×720×3)                                                          ║
║        │                                                                          ║
║        │                                                                          ║
║        ▼                                                                          ║
║  ┌─────────────────────────────────────────────────────────────────────────────┐  ║
║  │                        FROZEN: DINOv3-ViT-L/16 (~1B params)                 │  ║
║  │  ┌──────────┐    ┌──────────┐    ┌──────────┐    ┌──────────┐              │  ║
║  │  │ Layer 4  │ →  │ Layer 11 │ →  │ Layer 17 │ →  │ Layer 23 │              │  ║
║  │  │ (patch)  │    │ (patch)  │    │ (patch)  │    │ (patch)  │              │  ║
║  │  └────┬─────┘    └────┬─────┘    └────┬─────┘    └────┬─────┘              │  ║
║  └───────│───────────────│───────────────│───────────────│────────────────────┘  ║
║          │               │               │               │                        ║
║          │               │               │               │                        ║
║  ┌───────▼───────────────▼───────────────▼───────────────▼────────────────────┐  ║
║  │                    TRAINABLE: ViT-Adapter (~50M params)                    │  ║
║  │                                                                            │  ║
║  │  ┌──────────────┐     ┌──────────────────────────────────────────────────┐ │  ║
║  │  │   Spatial    │     │           Interaction Blocks (×4)                │ │  ║
║  │  │    Prior     │────▶│  MS-Deformable Attention + ConvFFN               │ │  ║
║  │  │   Module     │     │  (Fuses CNN spatial priors with ViT semantics)   │ │  ║
║  │  │  (CNN Stem)  │     └──────────────────────────────────────────────────┘ │  ║
║  │  └──────────────┘                                                          │  ║
║  │                                                                            │  ║
║  │  Output: Multi-scale features {f1: H/4, f2: H/8, f3: H/16, f4: H/32}       │  ║
║  └────────────────────────────────────────────────────────────────────────────┘  ║
║          │                                                                        ║
║          │                                                                        ║
║          ▼                                                                        ║
║  ┌────────────────────────────────────────────────────────────────────────────┐  ║
║  │                    TRAINABLE: Mask2Former (~40M params)                    │  ║
║  │                                                                            │  ║
║  │  ┌─────────────────────┐    ┌────────────────────────────────────────────┐ │  ║
║  │  │    Pixel Decoder    │    │         Transformer Decoder (×9)          │ │  ║
║  │  │  (FPN + MS-Deform)  │───▶│  Learnable Queries (100)                   │ │  ║
║  │  │                     │    │  + Masked Cross-Attention                  │ │  ║
║  │  │  → Pixel Features   │    │  → Class Logits + Mask Logits              │ │  ║
║  │  └─────────────────────┘    └────────────────────────────────────────────┘ │  ║
║  └────────────────────────────────────────────────────────────────────────────┘  ║
║          │                                                                        ║
║          ▼                                                                        ║
║  ┌────────────────────────────────────────────────────────────────────────────┐  ║
║  │  Output: Semantic Segmentation (7 classes for LoveDA)                      │  ║
║  │         - Class predictions: (B, 100, 7)                                   │  ║
║  │         - Mask predictions:  (B, 100, H, W)                                │  ║
║  │         → Post-process to (B, 7, H, W) semantic map                        │  ║
║  └────────────────────────────────────────────────────────────────────────────┘  ║
║                                                                                   ║
╚═══════════════════════════════════════════════════════════════════════════════════╝
"""
print(architecture_diagram)

## 6.2 Key File Locations in Our Codebase

| Component | File | Key Classes/Functions |
|-----------|------|----------------------|
| **DINOv3 Wrapper** | `dinov3_mask2former_integration.py:85-155` | `DINOv3CompatibilityWrapper` |
| **ViT-Adapter** | `models/backbone/dinov3_adapter.py:305-484` | `DINOv3_Adapter` |
| **Spatial Prior Module** | `models/backbone/dinov3_adapter.py:234-302` | `SpatialPriorModule` |
| **Interaction Block** | `models/backbone/dinov3_adapter.py:159-231` | `InteractionBlockWithCls` |
| **MS-Deformable Attention** | `models/utils/ms_deform_attn.py:99-214` | `MSDeformAttn` |
| **Model Creation** | `dinov3_mask2former_integration.py:248-444` | `create_dinov3_mask2former()` |
| **Training** | `train_hydra.py:1-391` | `SegmentationLightningModule` |

## 6.3 Running Inference (Quick Demo)

In [None]:
# This cell demonstrates how to load our model and run inference
# (Requires HuggingFace token for DINOv3 access)

inference_code = """
# To run inference with our trained model:

from dinov3_mask2former_integration import create_dinov3_mask2former
import torch

# 1. Create model
model, processor, _ = create_dinov3_mask2former(
    dinov3_model_name="facebook/dinov3-vitl16-pretrain-sat493m",
    num_classes=7  # LoveDA classes
)

# 2. Load trained checkpoint
checkpoint = torch.load('runs/your_run/checkpoints/best.ckpt')
model.load_state_dict(checkpoint['state_dict'])
model.eval()

# 3. Prepare image
from PIL import Image
image = Image.open('your_image.png').convert('RGB')
inputs = processor(images=image, return_tensors='pt')

# 4. Run inference
with torch.no_grad():
    outputs = model(**inputs)

# 5. Post-process to get semantic segmentation map
semantic_map = processor.post_process_semantic_segmentation(
    outputs, target_sizes=[(image.height, image.width)]
)[0]

print(f"Segmentation map shape: {semantic_map.shape}")
"""
print(inference_code)

---
# Summary: Key Takeaways

## 1. Vision Transformer (ViT)
- Processes images as sequences of patches
- Global attention from layer 1 (no local inductive bias)
- Attention maps can be visualized via rollout or per-layer analysis

## 2. Position Encoding Interpolation
- 2D bicubic interpolation enables variable input resolutions
- Preserves spatial relationships learned during pre-training
- Essential for using ViT at resolutions different from pre-training

## 3. ViT-Adapter
- Bridges ViT to dense prediction via multi-scale features
- Spatial Prior Module adds CNN-like local inductive bias
- Interaction blocks fuse SPM and ViT features via deformable attention

## 4. MS-Deformable Attention
- O(N) complexity vs O(N²) for standard attention
- Sparse sampling at learned offset positions
- Multi-scale: samples from features at different resolutions

## 5. Mask2Former
- Query-based universal segmentation
- Masked attention restricts cross-attention to predicted regions
- Iterative refinement across decoder layers
- Task-agnostic: same architecture for semantic/instance/panoptic

---
# Further Reading & Resources

## Papers
- **ViT:** [An Image is Worth 16x16 Words](https://arxiv.org/abs/2010.11929)
- **DINOv2:** [DINOv2: Learning Robust Visual Features without Supervision](https://arxiv.org/abs/2304.07193)
- **ViT-Adapter:** [Vision Transformer Adapter for Dense Predictions](https://arxiv.org/abs/2205.08534)
- **Deformable DETR:** [Deformable DETR: Deformable Transformers for End-to-End Object Detection](https://arxiv.org/abs/2010.04159)
- **Mask2Former:** [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527)

## Codebase Documentation
- `docs/ARCHITECTURE.md` - Detailed architecture explanation
- `docs/TRAINING.md` - Training guide
- `docs/CONFIGURATION.md` - Hydra config reference

## Code Deep Dives
- Attention visualization: This notebook Part 1
- Position interpolation: `models/backbone/dinov3_adapter.py:387-396`
- MS-Deformable Attention: `models/utils/ms_deform_attn.py`