In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Cross-Modal Attention: How Language Models Learn to See

*Part 3 of the Vizuara series on Multimodal Instruction Tuning*
*Estimated time: 40 minutes*

## 1. Why Does This Matter?

We have built the projection layer and the training pipeline. But here is a question that might be nagging you: how does the language model actually "look at" the image when generating text?

The answer is beautiful in its simplicity: **self-attention does everything**. When the LLM processes the concatenated sequence [visual_tokens, text_tokens], the standard self-attention mechanism naturally allows text tokens to attend to visual tokens. No special cross-attention module is needed.

**By the end of this notebook, you will have:**
- Visualized attention patterns in a multimodal model
- Understood the block structure of cross-modal attention
- Built and interpreted attention heatmaps
- Seen how different text tokens attend to different image patches

In [None]:
# Setup
!pip install -q torch torchvision matplotlib numpy seaborn

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Building Intuition

Imagine you are reading a detective novel with illustrations. As you read the sentence "The suspect wore a **red hat**," your eyes automatically glance at the illustration to find the red hat. When you read "standing near the **doorway**," your eyes shift to the doorway in the image.

This is exactly what cross-modal attention does. When the language model generates the word "red," the text token can attend to (look at) the visual tokens corresponding to the red region of the image. When it generates "car," it attends to the car patches.

The remarkable thing is that this behavior **emerges naturally** from standard self-attention. We do not need to design a special mechanism for it.

### Think About This

In a sequence of 196 visual tokens followed by 10 text tokens, the attention matrix is 206 x 206. How many of those attention weights represent "text attending to image" (cross-modal attention)? What fraction of the total attention matrix is this?

## 3. The Mathematics

### Self-Attention on the Combined Sequence

Given the combined sequence $X = [h_1, ..., h_N, t_1, ..., t_M]$ where $h_i$ are visual tokens and $t_j$ are text tokens, self-attention computes:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right) V$$

**Computationally, this says:** compute a similarity score between every pair of tokens. Normalize these scores with softmax. Use the normalized scores to create a weighted sum of value vectors. This happens for every token in the sequence.

### The Block Structure

The attention matrix naturally decomposes into four blocks:

| | Visual Keys | Text Keys |
|---|---|---|
| **Visual Queries** | V-V: Spatial relationships | V-T: Text grounding |
| **Text Queries** | T-V: Visual reasoning | T-T: Language modeling |

The **T-V block** is where "seeing" happens: text tokens attend to visual tokens.

Let us compute a numerical example. Suppose we have 2 visual tokens and 2 text tokens, all in dimension $d_k = 4$.

$$Q = \begin{bmatrix} q_{v1} \\ q_{v2} \\ q_{t1} \\ q_{t2} \end{bmatrix}, \quad K = \begin{bmatrix} k_{v1} \\ k_{v2} \\ k_{t1} \\ k_{t2} \end{bmatrix}$$

The raw attention scores are $\frac{QK^\top}{\sqrt{d_k}}$. Say we get:

$$\text{scores} = \begin{bmatrix} 1.2 & 0.8 & 0.3 & 0.1 \\ 0.7 & 1.5 & 0.2 & 0.4 \\ 0.9 & 1.8 & 2.1 & 0.5 \\ 0.3 & 0.4 & 0.8 & 1.9 \end{bmatrix}$$

After softmax on each row, the third row (text token 1) becomes approximately:

$$[0.13, 0.33, 0.44, 0.09]$$

This tells us that text token 1 gives 13% attention to visual patch 1, 33% attention to visual patch 2, 44% to itself, and 9% to text token 2. The cross-modal attention (0.13 + 0.33 = 0.46) shows the text token is "looking at" the image for almost half its attention. This is exactly what we want.

In [None]:
# Numerical verification of cross-modal attention
d_k = 4
scores = torch.tensor([
    [1.2, 0.8, 0.3, 0.1],   # visual token 1
    [0.7, 1.5, 0.2, 0.4],   # visual token 2
    [0.9, 1.8, 2.1, 0.5],   # text token 1
    [0.3, 0.4, 0.8, 1.9],   # text token 2
])

# Scale by sqrt(d_k)
scaled_scores = scores / np.sqrt(d_k)

# Softmax
attn_weights = F.softmax(scaled_scores, dim=-1)

print("Attention weights (each row sums to 1):")
labels = ["vis_1", "vis_2", "txt_1", "txt_2"]
for i, label in enumerate(labels):
    weights = [f"{w:.3f}" for w in attn_weights[i].tolist()]
    print(f"  {label} attends to: {dict(zip(labels, weights))}")

# Cross-modal attention for text token 1
cross_modal = attn_weights[2, :2].sum().item()
print(f"\nText token 1 cross-modal attention: {cross_modal:.1%} (looking at the image)")

## 4. Let's Build It -- Component by Component

### 4.1 Multi-Head Self-Attention with Visualization

In [None]:
class VisualizableAttention(nn.Module):
    """Multi-head self-attention that stores attention weights for visualization."""

    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

        self.attention_weights = None  # Stored for visualization

    def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor:
        B, S, D = x.shape

        # Project to Q, K, V
        Q = self.W_q(x).view(B, S, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(B, S, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(B, S, self.num_heads, self.d_k).transpose(1, 2)

        # Attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn_weights = F.softmax(scores, dim=-1)
        self.attention_weights = attn_weights.detach()  # Store for visualization

        # Weighted sum
        output = torch.matmul(attn_weights, V)
        output = output.transpose(1, 2).contiguous().view(B, S, D)
        return self.W_o(output)


# Test
attn = VisualizableAttention(d_model=64, num_heads=4)
x = torch.randn(1, 10, 64)  # 10 tokens
output = attn(x)
print(f"Input:  {x.shape}")
print(f"Output: {output.shape}")
print(f"Attention weights shape: {attn.attention_weights.shape}")
print(f"  -> {attn.num_heads} heads, each with {x.shape[1]}x{x.shape[1]} attention matrix")

### 4.2 A Simple Multimodal Model for Attention Analysis

In [None]:
class AttentionAnalysisModel(nn.Module):
    """Simplified model designed for attention visualization."""

    def __init__(self, image_size=64, patch_size=8, vision_dim=64,
                 llm_dim=64, vocab_size=20, num_heads=4):
        super().__init__()
        self.num_patches = (image_size // patch_size) ** 2  # 64

        # Vision
        self.patch_embed = nn.Conv2d(3, vision_dim, kernel_size=patch_size, stride=patch_size)
        self.projector = nn.Sequential(
            nn.Linear(vision_dim, llm_dim),
            nn.GELU(),
            nn.Linear(llm_dim, llm_dim),
        )

        # Text
        self.text_embed = nn.Embedding(vocab_size, llm_dim)
        self.pos_embed = nn.Embedding(256, llm_dim)

        # Self-attention (with visualization support)
        self.attention = VisualizableAttention(llm_dim, num_heads)
        self.norm = nn.LayerNorm(llm_dim)
        self.output_head = nn.Linear(llm_dim, vocab_size)

    def forward(self, images, token_ids):
        # Vision path
        vis = self.patch_embed(images).flatten(2).transpose(1, 2)
        vis_tokens = self.projector(vis)

        # Text path
        text_tokens = self.text_embed(token_ids)

        # Combine
        combined = torch.cat([vis_tokens, text_tokens], dim=1)
        positions = torch.arange(combined.shape[1], device=combined.device).unsqueeze(0)
        combined = combined + self.pos_embed(positions)

        # Self-attention
        attended = self.attention(combined)
        output = self.norm(attended + combined)

        return self.output_head(output)


model_attn = AttentionAnalysisModel().to(device)
print(f"Model created with {model_attn.num_patches} visual patches")

In [None]:
# Create a test image with distinct regions
def create_test_image(size=64):
    """Create an image with a red circle in the top-left and blue square in the bottom-right."""
    img = torch.ones(3, size, size) * 0.9  # Light gray background

    # Red circle in top-left
    y, x = torch.meshgrid(torch.arange(size), torch.arange(size), indexing='ij')
    circle_mask = ((x - size//4)**2 + (y - size//4)**2) < (size//8)**2
    img[0][circle_mask] = 1.0
    img[1][circle_mask] = 0.1
    img[2][circle_mask] = 0.1

    # Blue square in bottom-right
    sq_mask = (abs(x - 3*size//4) < size//8) & (abs(y - 3*size//4) < size//8)
    img[0][sq_mask] = 0.1
    img[1][sq_mask] = 0.1
    img[2][sq_mask] = 1.0

    return img

test_img = create_test_image().unsqueeze(0).to(device)
test_tokens = torch.tensor([[1, 3, 4, 9]], device=device)  # <start> what color ?

with torch.no_grad():
    _ = model_attn(test_img, test_tokens)

print(f"Forward pass complete. Attention weights stored.")
print(f"Attention shape: {model_attn.attention.attention_weights.shape}")

In [None]:
# Visualize the full attention matrix with block structure
attn_weights = model_attn.attention.attention_weights[0, 0].cpu().numpy()  # Head 0

num_vis = model_attn.num_patches  # 64
num_txt = test_tokens.shape[1]     # 4
total = num_vis + num_txt

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Full attention matrix
im = axes[0].imshow(attn_weights, cmap='Blues', aspect='auto')
axes[0].set_title("Full Attention Matrix (Head 0)", fontsize=12)
axes[0].set_xlabel("Key position")
axes[0].set_ylabel("Query position")

# Draw block boundaries
axes[0].axhline(y=num_vis-0.5, color='red', linewidth=2, linestyle='--')
axes[0].axvline(x=num_vis-0.5, color='red', linewidth=2, linestyle='--')

# Label blocks
axes[0].text(num_vis//2, num_vis//2, "V-V\nSpatial", ha='center', va='center',
            fontsize=10, color='white', fontweight='bold',
            bbox=dict(boxstyle='round', facecolor='steelblue', alpha=0.8))
axes[0].text(num_vis + num_txt//2, num_vis//2, "V-T\nGrounding", ha='center', va='center',
            fontsize=10, color='white', fontweight='bold',
            bbox=dict(boxstyle='round', facecolor='green', alpha=0.8))
axes[0].text(num_vis//2, num_vis + num_txt//2, "T-V\nSeeing", ha='center', va='center',
            fontsize=10, color='white', fontweight='bold',
            bbox=dict(boxstyle='round', facecolor='orange', alpha=0.8))
axes[0].text(num_vis + num_txt//2, num_vis + num_txt//2, "T-T\nLanguage", ha='center', va='center',
            fontsize=10, color='white', fontweight='bold',
            bbox=dict(boxstyle='round', facecolor='purple', alpha=0.8))
plt.colorbar(im, ax=axes[0])

# Zoom into T-V block (text attending to visual)
tv_block = attn_weights[num_vis:, :num_vis]
token_labels = ["<start>", "what", "color", "?"]
im2 = axes[1].imshow(tv_block, cmap='Oranges', aspect='auto')
axes[1].set_title("T-V Block: Text Tokens Looking at Image", fontsize=12)
axes[1].set_xlabel("Visual patch index")
axes[1].set_ylabel("Text token")
axes[1].set_yticks(range(num_txt))
axes[1].set_yticklabels(token_labels)
plt.colorbar(im2, ax=axes[1])

plt.suptitle("Cross-Modal Attention: How Text Tokens 'See' the Image",
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

### 4.3 Spatial Attention Heatmap

In [None]:
# Reshape the T-V attention to the spatial grid of the image
def visualize_spatial_attention(attention_weights, num_patches_side, image, token_labels,
                               text_start_idx, num_text_tokens):
    """Overlay attention weights on the image to show which patches each text token attends to."""

    fig, axes = plt.subplots(1, num_text_tokens + 1, figsize=(4 * (num_text_tokens + 1), 4))

    # Show original image
    display_img = image[0].permute(1, 2, 0).cpu().numpy()
    axes[0].imshow(display_img)
    axes[0].set_title("Original Image", fontsize=11)
    axes[0].axis('off')

    # For each text token, show its attention over image patches
    for t in range(num_text_tokens):
        # Get attention weights from this text token to all visual patches
        text_pos = text_start_idx + t
        vis_attention = attention_weights[0, 0, text_pos, :text_start_idx].cpu().numpy()

        # Reshape to spatial grid
        attn_map = vis_attention.reshape(num_patches_side, num_patches_side)
        # Upsample to image size
        attn_map_up = np.kron(attn_map, np.ones((image.shape[2]//num_patches_side,
                                                    image.shape[3]//num_patches_side)))

        axes[t+1].imshow(display_img)
        axes[t+1].imshow(attn_map_up, cmap='hot', alpha=0.6)
        axes[t+1].set_title(f'"{token_labels[t]}"', fontsize=11, fontweight='bold')
        axes[t+1].axis('off')

    plt.suptitle("Where Each Text Token Looks in the Image",
                 fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

# Visualize
num_patches_side = int(np.sqrt(model_attn.num_patches))
visualize_spatial_attention(
    model_attn.attention.attention_weights,
    num_patches_side,
    test_img,
    ["<start>", "what", "color", "?"],
    text_start_idx=model_attn.num_patches,
    num_text_tokens=4
)

## 5. Your Turn

### TODO: Compute Cross-Modal Attention Statistics

Implement a function that analyzes the attention patterns and computes key statistics about cross-modal attention.

In [None]:
def analyze_cross_modal_attention(
    attention_weights: torch.Tensor,
    num_visual_tokens: int,
    num_text_tokens: int,
) -> dict:
    """Analyze cross-modal attention patterns.

    Args:
        attention_weights: (batch, num_heads, seq_len, seq_len) attention matrix
        num_visual_tokens: number of visual tokens (N)
        num_text_tokens: number of text tokens (M)

    Returns:
        dict with keys:
        - "tv_mean": float, mean attention from text to visual tokens
        - "vt_mean": float, mean attention from visual to text tokens
        - "vv_mean": float, mean attention from visual to visual tokens
        - "tt_mean": float, mean attention from text to text tokens
        - "most_attended_patch": int, the visual patch index most attended by text tokens
        - "cross_modal_ratio": float, fraction of text attention going to visual tokens
    """
    # ============ TODO ============
    # Step 1: Average across batch and heads: (seq_len, seq_len)
    # Step 2: Extract the 4 blocks: V-V, V-T, T-V, T-T
    # Step 3: Compute mean attention for each block
    # Step 4: Find the most attended visual patch (from text perspective)
    # Step 5: Compute cross-modal ratio for text tokens
    # ==============================

    result = {}  # YOUR CODE HERE

    return result

In [None]:
# Verification
with torch.no_grad():
    _ = model_attn(test_img, test_tokens)

stats = analyze_cross_modal_attention(
    model_attn.attention.attention_weights,
    num_visual_tokens=model_attn.num_patches,
    num_text_tokens=test_tokens.shape[1]
)

assert "tv_mean" in stats, "Missing 'tv_mean' key"
assert "cross_modal_ratio" in stats, "Missing 'cross_modal_ratio' key"
assert 0 <= stats["cross_modal_ratio"] <= 1, f"Cross-modal ratio should be in [0,1], got {stats['cross_modal_ratio']}"
assert 0 <= stats["most_attended_patch"] < model_attn.num_patches, "Invalid patch index"

print("Cross-Modal Attention Statistics:")
for key, value in stats.items():
    if isinstance(value, float):
        print(f"  {key}: {value:.4f}")
    else:
        print(f"  {key}: {value}")
print("\nAll assertions passed!")

### TODO: Multi-Head Attention Comparison

Implement a visualization that shows how different attention heads specialize in different aspects.

In [None]:
def compare_attention_heads(
    attention_weights: torch.Tensor,
    num_visual_tokens: int,
    num_patches_side: int,
    image: torch.Tensor,
    text_token_idx: int,
    token_name: str,
):
    """Visualize how different attention heads attend to image patches for a given text token.

    Args:
        attention_weights: (1, num_heads, seq_len, seq_len)
        num_visual_tokens: number of visual patches
        num_patches_side: grid size (e.g., 8 for 64 patches)
        image: (1, 3, H, W) the input image
        text_token_idx: which text token to visualize (0-indexed from text start)
        token_name: name of the token for the title
    """
    # ============ TODO ============
    # Step 1: Get number of heads from attention_weights shape
    # Step 2: For each head, extract the attention from the text token to visual patches
    #         Position in the full sequence = num_visual_tokens + text_token_idx
    # Step 3: Reshape each head's attention to (num_patches_side, num_patches_side)
    # Step 4: Create a subplot grid with the image + one heatmap per head
    # Step 5: Overlay each head's attention on the image using imshow with alpha
    # ==============================

    pass  # YOUR CODE HERE

In [None]:
# Verification (visual check)
compare_attention_heads(
    model_attn.attention.attention_weights,
    num_visual_tokens=model_attn.num_patches,
    num_patches_side=int(np.sqrt(model_attn.num_patches)),
    image=test_img,
    text_token_idx=2,  # "color" token
    token_name="color"
)
print("Check the visualization above -- each head should show different attention patterns!")

## 6. Putting It All Together

In [None]:
# Compare attention patterns across all heads for all text tokens
num_heads = model_attn.attention.num_heads
num_txt = test_tokens.shape[1]
token_labels = ["<start>", "what", "color", "?"]

fig, axes = plt.subplots(num_txt, num_heads + 1, figsize=(3 * (num_heads + 1), 3 * num_txt))

display_img = test_img[0].permute(1, 2, 0).cpu().numpy()
num_ps = int(np.sqrt(model_attn.num_patches))

for t in range(num_txt):
    # Show image in first column
    axes[t, 0].imshow(display_img)
    axes[t, 0].set_title(f'"{token_labels[t]}"', fontsize=10, fontweight='bold')
    axes[t, 0].axis('off')

    for h in range(num_heads):
        text_pos = model_attn.num_patches + t
        vis_attn = model_attn.attention.attention_weights[0, h, text_pos, :model_attn.num_patches].cpu().numpy()
        attn_map = vis_attn.reshape(num_ps, num_ps)
        attn_up = np.kron(attn_map, np.ones((test_img.shape[2]//num_ps, test_img.shape[3]//num_ps)))

        axes[t, h+1].imshow(display_img)
        axes[t, h+1].imshow(attn_up, cmap='hot', alpha=0.6)
        axes[t, h+1].set_title(f'Head {h}', fontsize=9)
        axes[t, h+1].axis('off')

plt.suptitle("All Text Tokens x All Attention Heads: Spatial Attention Heatmaps",
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("Each row shows how one text token looks at the image across different heads.")
print("Different heads learn to attend to different regions -- this is the power of multi-head attention!")

## 7. Training and Results

In [None]:
# Train the attention model on our VQA task and observe how attention patterns change

# Quick training on colored shapes
def create_shape_data(n=200, size=64):
    data = []
    for _ in range(n):
        color = np.random.randint(0, 3)
        img = torch.zeros(3, size, size)
        y, x = torch.meshgrid(torch.arange(size), torch.arange(size), indexing='ij')
        mask = ((x - size//2)**2 + (y - size//2)**2) < (size//4)**2
        img[color][mask] = 1.0
        img += torch.randn_like(img) * 0.05
        data.append((img.clamp(0,1), color + 10))  # 10=red, 11=blue, 12=green
    return data

train_data = create_shape_data(300)
model_attn2 = AttentionAnalysisModel().to(device)

optimizer = torch.optim.Adam(model_attn2.parameters(), lr=1e-3)
losses = []

for epoch in range(40):
    random.shuffle(train_data)
    epoch_loss = 0
    for i in range(0, len(train_data), 32):
        batch = train_data[i:i+32]
        imgs = torch.stack([b[0] for b in batch]).to(device)
        labels = torch.tensor([b[1] for b in batch], device=device)

        input_ids = torch.ones(len(batch), 1, dtype=torch.long, device=device)  # <start>
        logits = model_attn2(imgs, input_ids)
        loss = F.cross_entropy(logits[:, -1, :], labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    losses.append(epoch_loss / (len(train_data) // 32))

plt.figure(figsize=(8, 4))
plt.plot(losses, color='steelblue', linewidth=2)
plt.xlabel("Epoch"); plt.ylabel("Loss")
plt.title("Training Loss: Attention Model Learning Color Classification")
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
# Now compare attention before and after training
# Create a red circle image
red_img = torch.zeros(1, 3, 64, 64)
y, x = torch.meshgrid(torch.arange(64), torch.arange(64), indexing='ij')
circle_mask = ((x - 32)**2 + (y - 32)**2) < 12**2
red_img[0, 0][circle_mask] = 1.0
red_img = red_img.to(device)

# Trained model attention
with torch.no_grad():
    _ = model_attn2(red_img, torch.ones(1, 1, dtype=torch.long, device=device))

trained_attn = model_attn2.attention.attention_weights.cpu()

# Untrained model attention
untrained_model = AttentionAnalysisModel().to(device)
with torch.no_grad():
    _ = untrained_model(red_img, torch.ones(1, 1, dtype=torch.long, device=device))

untrained_attn = untrained_model.attention.attention_weights.cpu()

# Compare
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

display_img = red_img[0].permute(1, 2, 0).cpu().numpy()
num_ps = int(np.sqrt(model_attn2.num_patches))

axes[0].imshow(display_img)
axes[0].set_title("Input: Red Circle", fontsize=12)
axes[0].axis('off')

# Untrained
vis_attn = untrained_attn[0, 0, -1, :model_attn2.num_patches].numpy()
attn_map = vis_attn.reshape(num_ps, num_ps)
attn_up = np.kron(attn_map, np.ones((64//num_ps, 64//num_ps)))
axes[1].imshow(display_img)
axes[1].imshow(attn_up, cmap='hot', alpha=0.6)
axes[1].set_title("Before Training: Uniform Attention", fontsize=12)
axes[1].axis('off')

# Trained
vis_attn = trained_attn[0, 0, -1, :model_attn2.num_patches].numpy()
attn_map = vis_attn.reshape(num_ps, num_ps)
attn_up = np.kron(attn_map, np.ones((64//num_ps, 64//num_ps)))
axes[2].imshow(display_img)
axes[2].imshow(attn_up, cmap='hot', alpha=0.6)
axes[2].set_title("After Training: Focused Attention", fontsize=12)
axes[2].axis('off')

plt.suptitle("Training Teaches the Model WHERE to Look in the Image",
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 8. Final Output

In [None]:
# Final visualization: Attention flow in a multimodal model
fig = plt.figure(figsize=(16, 8))

# Create a rich test scene
scene_img = torch.ones(1, 3, 64, 64) * 0.85
# Red object top-left
mask1 = ((x - 16)**2 + (y - 16)**2) < 8**2
scene_img[0, 0][mask1] = 1.0; scene_img[0, 1][mask1] = 0.1; scene_img[0, 2][mask1] = 0.1
# Blue object bottom-right
mask2 = (abs(x - 48) < 8) & (abs(y - 48) < 8)
scene_img[0, 0][mask2] = 0.1; scene_img[0, 1][mask2] = 0.1; scene_img[0, 2][mask2] = 1.0
# Green object center
mask3 = ((x - 32)**2 + (y - 32)**2) < 6**2
scene_img[0, 0][mask3] = 0.1; scene_img[0, 1][mask3] = 0.8; scene_img[0, 2][mask3] = 0.1
scene_img = scene_img.to(device)

with torch.no_grad():
    _ = model_attn2(scene_img, torch.ones(1, 1, dtype=torch.long, device=device))

final_attn = model_attn2.attention.attention_weights.cpu()
display_scene = scene_img[0].permute(1, 2, 0).cpu().numpy()

# Show all 4 heads for this scene
gs = fig.add_gridspec(2, 3, hspace=0.3, wspace=0.3)

ax0 = fig.add_subplot(gs[0, 0])
ax0.imshow(display_scene)
ax0.set_title("Input Scene", fontsize=12, fontweight='bold')
ax0.axis('off')

for h in range(4):
    row, col = divmod(h, 2)
    if h < 2:
        ax = fig.add_subplot(gs[0, h+1])
    else:
        ax = fig.add_subplot(gs[1, h-2])

    vis_attn = final_attn[0, h, -1, :model_attn2.num_patches].numpy()
    attn_map = vis_attn.reshape(num_ps, num_ps)
    attn_up = np.kron(attn_map, np.ones((64//num_ps, 64//num_ps)))

    ax.imshow(display_scene)
    ax.imshow(attn_up, cmap='hot', alpha=0.6)
    ax.set_title(f"Attention Head {h}", fontsize=11)
    ax.axis('off')

# Add summary text
ax_text = fig.add_subplot(gs[1, 2])
ax_text.axis('off')
summary = (
    "Key Takeaways:\n\n"
    "1. Different heads attend to\n   different image regions\n\n"
    "2. Cross-modal attention\n   emerges from standard\n   self-attention\n\n"
    "3. No special architecture\n   needed -- just projection\n   into shared space"
)
ax_text.text(0.1, 0.9, summary, fontsize=10, verticalalignment='top',
            fontfamily='monospace', bbox=dict(boxstyle='round', facecolor='lightyellow'))

plt.suptitle("Cross-Modal Attention: How a Language Model Learns to See",
             fontsize=15, fontweight='bold')
plt.show()

print("Congratulations! You have visualized and understood cross-modal attention!")
print("The key insight: self-attention on the combined [visual, text] sequence")
print("naturally enables the LLM to 'see' relevant parts of the image.")

## 9. Reflection and Next Steps

### Reflection Questions

1. In the attention heatmaps, different heads often attend to different spatial regions. Why is multi-head attention particularly important for multimodal models compared to single-head attention?

2. The LLaVA architecture uses 576 visual tokens for a 336x336 image. Each visual token can attend to every other token. What is the computational cost (in terms of Big-O) of self-attention on a sequence with 576 visual + 100 text tokens? How does this compare to a text-only model with 100 tokens?

3. Some models (like InstructBLIP) compress 576 visual tokens down to 32 "query tokens" before feeding them to the LLM. What are the trade-offs of this compression in terms of (a) computational cost, (b) detail preservation, and (c) attention pattern flexibility?

### Optional Challenges

1. **Attention rollout:** Instead of looking at a single attention layer, implement attention rollout -- multiply attention matrices across layers to see how information flows from image patches to the final output token.

2. **Head pruning:** After training, set the attention weights of one head to uniform and measure the accuracy drop. Which heads are most important for the task?

3. **Positional attention bias:** Analyze whether the model develops positional biases -- do patches near the center of the image get more attention than edge patches? Plot average attention vs. patch position.