In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Building a Mini Vision-Language Model from Scratch

**Vizuara AI** | Cross-Attention & Token Alignment Series — Notebook 3 of 3

In Notebooks 1 and 2, we built cross-attention and token alignment as standalone components. Now we put everything together into a **working mini VLM** — a model that takes image patches and text tokens, processes them through a full transformer block with cross-attention, and produces enriched text representations.

We will train this model on a toy image captioning task and visualize what it learns.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import math

%matplotlib inline

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
torch.manual_seed(42)
np.random.seed(42)

## 1. Why Does This Matter?

All modern Vision-Language Models — LLaVA, Flamingo, BLIP-2, GPT-4V — use the same core pattern:

1. **Encode the image** into patch tokens (ViT)
2. **Align** image tokens to the language space (projection layer)
3. **Cross-attend** — let text tokens query image tokens
4. **Generate** text conditioned on the visual information

In this notebook, we build a simplified version of this pipeline. Our model will:
- Take synthetic "images" (structured patch embeddings)
- Take text token sequences
- Run cross-attention to enrich text with image info
- Learn which patches matter for which words

The output: a trained model with visualizable attention patterns showing the model has learned to ground language in vision.

## 2. Building Intuition

### The Full Pipeline

Think of a VLM block as a three-stage process for each text token:

1. **Self-attention:** "Let me first understand the context of my sentence" — text tokens talk to each other
2. **Cross-attention:** "Now let me look at the image" — text tokens query image patches
3. **Feed-forward:** "Let me process all this information" — nonlinear transformation

Each text token comes out the other side carrying both linguistic context (from self-attention) and visual information (from cross-attention).

### Think About This
Before we build:
- Why does self-attention come BEFORE cross-attention in most architectures?
- What role does the feed-forward network play after attention?

## 3. The Mathematics

### The VLM Block

A single VLM block applies three sub-layers with residual connections and layer normalization:

$$\hat{x} = x + \text{SelfAttention}(\text{LayerNorm}(x))$$

$$\tilde{x} = \hat{x} + \text{CrossAttention}(\text{LayerNorm}(\hat{x}), z_{\text{image}})$$

$$y = \tilde{x} + \text{FFN}(\text{LayerNorm}(\tilde{x}))$$

Where:
- $x$ = text token sequence
- $z_{\text{image}}$ = aligned image tokens (after projection)
- LayerNorm normalizes each token independently
- Residual connections ($+$) prevent gradient vanishing

**Numerical intuition:** Suppose a text token's value is $[1.0, 2.0]$. After cross-attention, it might become $[0.3, 0.7]$ (information from the image). The residual connection gives $[1.0 + 0.3, 2.0 + 0.7] = [1.3, 2.7]$ — the original information is preserved, with image info added on top.

### Layer Normalization

For a vector $x = [x_1, x_2, \ldots, x_d]$:

$$\text{LayerNorm}(x) = \frac{x - \mu}{\sigma + \epsilon} \cdot \gamma + \beta$$

Where $\mu$ and $\sigma$ are the mean and std of $x$, and $\gamma, \beta$ are learnable.

**Example:** If $x = [1, 3]$, then $\mu = 2$, $\sigma = 1$:

$$\text{LayerNorm}([1, 3]) = [-1, 1] \cdot \gamma + \beta$$

This normalizes each token to zero mean and unit variance, stabilizing training.

## 4. Let's Build It — Component by Component

### 4.1 The Building Blocks (from Notebooks 1 & 2)

In [None]:
# Re-implement our core components

class MultiHeadCrossAttention(nn.Module):
    """Multi-head cross-attention: Q from text, K/V from image."""
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)

    def forward(self, text_tokens, image_tokens):
        n_text = text_tokens.size(0)
        n_image = image_tokens.size(0)

        Q = self.W_Q(text_tokens).view(n_text, self.num_heads, self.d_k).transpose(0, 1)
        K = self.W_K(image_tokens).view(n_image, self.num_heads, self.d_k).transpose(0, 1)
        V = self.W_V(image_tokens).view(n_image, self.num_heads, self.d_k).transpose(0, 1)

        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn_weights = F.softmax(scores, dim=-1)
        context = torch.matmul(attn_weights, V)

        context = context.transpose(0, 1).contiguous().view(n_text, self.d_model)
        output = self.W_O(context)

        return output, attn_weights


class MultiHeadSelfAttention(nn.Module):
    """Multi-head self-attention: Q, K, V all from the same input."""
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x):
        n = x.size(0)

        Q = self.W_Q(x).view(n, self.num_heads, self.d_k).transpose(0, 1)
        K = self.W_K(x).view(n, self.num_heads, self.d_k).transpose(0, 1)
        V = self.W_V(x).view(n, self.num_heads, self.d_k).transpose(0, 1)

        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn_weights = F.softmax(scores, dim=-1)
        context = torch.matmul(attn_weights, V)

        context = context.transpose(0, 1).contiguous().view(n, self.d_model)
        return self.W_O(context), attn_weights

print("Building blocks ready!")

### 4.2 The Vision-Language Block

In [None]:
class VisionLanguageBlock(nn.Module):
    """
    A single transformer block for a vision-language model.
    Sub-layers: self-attention → cross-attention → FFN
    Each with residual connection and layer norm.
    """
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()

        self.self_attn = MultiHeadSelfAttention(d_model, num_heads)
        self.cross_attn = MultiHeadCrossAttention(d_model, num_heads)

        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model)
        )

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

    def forward(self, text_tokens, image_tokens):
        """
        Args:
            text_tokens:  (n_text, d_model)
            image_tokens: (n_image, d_model) — already aligned
        Returns:
            output: (n_text, d_model) — text enriched with image info
            cross_attn_weights: (num_heads, n_text, n_image)
        """
        # 1. Self-attention (text talks to text)
        normed = self.norm1(text_tokens)
        sa_out, _ = self.self_attn(normed)
        text_tokens = text_tokens + sa_out  # residual

        # 2. Cross-attention (text queries image)
        normed = self.norm2(text_tokens)
        ca_out, cross_attn_weights = self.cross_attn(normed, image_tokens)
        text_tokens = text_tokens + ca_out  # residual

        # 3. Feed-forward network
        normed = self.norm3(text_tokens)
        ff_out = self.ffn(normed)
        text_tokens = text_tokens + ff_out  # residual

        return text_tokens, cross_attn_weights

# Test
d_model = 32
block = VisionLanguageBlock(d_model=d_model, num_heads=4, d_ff=64)

text_in = torch.randn(5, d_model)
image_in = torch.randn(9, d_model)

text_out, attn = block(text_in, image_in)
print(f"Text input:  {text_in.shape}")
print(f"Image input: {image_in.shape}")
print(f"Text output: {text_out.shape} (same shape — enriched with image info)")
print(f"Attention:   {attn.shape} (4 heads, 5 text tokens, 9 image patches)")

### 4.3 The Complete Mini-VLM

In [None]:
# ============ TODO 1 ============
# Complete the MiniVLM by filling in the forward method.
# The pipeline is:
# 1. Project image tokens from d_vision to d_model
# 2. Pass text and projected image through VLM blocks
# 3. Produce output predictions
# ================================

class MiniVLM(nn.Module):
    """
    A mini Vision-Language Model.
    Takes image patches (d_vision) and text tokens (d_model),
    runs cross-attention, outputs predictions.
    """
    def __init__(self, d_vision, d_model, num_heads, d_ff, num_layers, vocab_size):
        super().__init__()

        # Token alignment: project image from vision to text space
        self.image_proj = nn.Linear(d_vision, d_model)

        # Text embedding (simple lookup table)
        self.text_embedding = nn.Embedding(vocab_size, d_model)

        # Stack of VLM blocks
        self.blocks = nn.ModuleList([
            VisionLanguageBlock(d_model, num_heads, d_ff)
            for _ in range(num_layers)
        ])

        # Output head: predict next token
        self.output_head = nn.Linear(d_model, vocab_size)

    def forward(self, text_ids, image_patches):
        """
        Args:
            text_ids:      (n_text,) — integer token IDs
            image_patches: (n_image, d_vision) — raw image patch features
        Returns:
            logits:       (n_text, vocab_size)
            all_attn:     list of (num_heads, n_text, n_image) per layer
        """
        # ============ YOUR CODE HERE ============
        # Step 1: Embed text tokens
        text_tokens = ???  # Use self.text_embedding

        # Step 2: Project image tokens
        image_tokens = ???  # Use self.image_proj

        # Step 3: Pass through VLM blocks
        all_attn = []
        for block in self.blocks:
            text_tokens, attn_weights = ???  # Call block
            all_attn.append(attn_weights)

        # Step 4: Compute output logits
        logits = ???  # Use self.output_head
        # ========================================

        return logits, all_attn

In [None]:
# Verification for TODO 1
class MiniVLMRef(nn.Module):
    def __init__(self, d_vision, d_model, num_heads, d_ff, num_layers, vocab_size):
        super().__init__()
        self.image_proj = nn.Linear(d_vision, d_model)
        self.text_embedding = nn.Embedding(vocab_size, d_model)
        self.blocks = nn.ModuleList([
            VisionLanguageBlock(d_model, num_heads, d_ff)
            for _ in range(num_layers)
        ])
        self.output_head = nn.Linear(d_model, vocab_size)

    def forward(self, text_ids, image_patches):
        text_tokens = self.text_embedding(text_ids)
        image_tokens = self.image_proj(image_patches)
        all_attn = []
        for block in self.blocks:
            text_tokens, attn_weights = block(text_tokens, image_tokens)
            all_attn.append(attn_weights)
        logits = self.output_head(text_tokens)
        return logits, all_attn

torch.manual_seed(42)
ref_model = MiniVLMRef(d_vision=64, d_model=32, num_heads=4, d_ff=64, num_layers=2, vocab_size=100)
torch.manual_seed(42)
student_model = MiniVLM(d_vision=64, d_model=32, num_heads=4, d_ff=64, num_layers=2, vocab_size=100)
student_model.load_state_dict(ref_model.state_dict())

test_text = torch.tensor([5, 12, 37, 88])
test_image = torch.randn(9, 64)

try:
    ref_logits, ref_attn = ref_model(test_text, test_image)
    stu_logits, stu_attn = student_model(test_text, test_image)

    assert stu_logits.shape == (4, 100), f"Expected (4, 100), got {stu_logits.shape}"
    assert len(stu_attn) == 2, f"Expected 2 attention layers, got {len(stu_attn)}"
    assert torch.allclose(stu_logits, ref_logits, atol=1e-4), "Logits don't match"
    print("Correct! Your MiniVLM works perfectly.")
    print(f"  Logits: {stu_logits.shape}")
    print(f"  Attention layers: {len(stu_attn)}")
    print(f"  Each attention: {stu_attn[0].shape}")
except Exception as e:
    print(f"Not quite: {e}")

## 5. Your Turn — Train on a Toy Task

Let us create a simple toy task: given an "image" with a highlighted region, predict which region class the image belongs to.

In [None]:
# ============ TODO 2 ============
# Create a simple training loop for the MiniVLM.
# We will train it to associate image patches with text tokens.
#
# The task: given an image where one patch is "activated" (has a special pattern),
# predict which text token corresponds to that patch.
# ================================

# Synthetic dataset
def generate_batch(batch_size=16, n_patches=9, d_vision=64, vocab_size=10):
    """Generate a batch of (image, text, target) tuples.

    Each image has one "hot" patch. The text is a fixed prompt [CLS, 0, 0, 0, PREDICT].
    The target is the index of the hot patch (0-8), mapped to a class.
    """
    images = torch.randn(batch_size, n_patches, d_vision) * 0.1  # base noise
    targets = torch.randint(0, n_patches, (batch_size,))

    for i in range(batch_size):
        # Make the target patch distinctive
        images[i, targets[i]] = torch.randn(d_vision) * 2.0 + targets[i].float()

    # Fixed text prompt: [CLS=0, PAD=1, PAD=1, PAD=1, QUERY=2]
    text_ids = torch.tensor([[0, 1, 1, 1, 2]] * batch_size)

    return images, text_ids, targets


# Training
torch.manual_seed(42)
model = MiniVLMRef(d_vision=64, d_model=32, num_heads=4, d_ff=64, num_layers=2, vocab_size=10)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

losses = []
accuracies = []

for step in range(200):
    images, text_ids, targets = generate_batch(batch_size=32)

    logits, attn = model(text_ids[0], images[0])  # process one at a time for simplicity

    # Use the last token's logits for prediction (like [PREDICT])
    # Process batch
    batch_logits = []
    batch_attn = []
    for i in range(images.size(0)):
        logit, att = model(text_ids[i], images[i])
        batch_logits.append(logit[-1])  # last token prediction
        batch_attn.append(att)

    batch_logits = torch.stack(batch_logits)  # (batch, vocab_size)

    # Map targets to vocab (0-8 -> 0-8, using first 9 vocab entries)
    loss = loss_fn(batch_logits, targets)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Accuracy
    preds = batch_logits.argmax(dim=-1)
    acc = (preds == targets).float().mean().item()

    losses.append(loss.item())
    accuracies.append(acc)

    if (step + 1) % 50 == 0:
        print(f"Step {step+1}: loss={loss.item():.4f}, accuracy={acc:.2%}")

In [None]:
# Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(losses, color='#e74c3c', alpha=0.7)
ax1.set_xlabel('Training Step')
ax1.set_ylabel('Loss')
ax1.set_title('Training Loss', fontsize=13)
ax1.grid(alpha=0.3)

ax2.plot(accuracies, color='#2ecc71', alpha=0.7)
ax2.set_xlabel('Training Step')
ax2.set_ylabel('Accuracy')
ax2.set_title('Training Accuracy', fontsize=13)
ax2.axhline(y=1/9, color='gray', linestyle='--', alpha=0.5, label='Random chance (1/9)')
ax2.legend()
ax2.grid(alpha=0.3)

plt.suptitle('MiniVLM Training on Patch Classification Task', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 6. Putting It All Together — Visualize the Trained Model

In [None]:
# Generate a test example and visualize attention
model.eval()

test_images, test_text, test_targets = generate_batch(batch_size=4)

fig, axes = plt.subplots(2, 4, figsize=(16, 8))

for i in range(4):
    with torch.no_grad():
        logits, attn_layers = model(test_text[i], test_images[i])

    pred = logits[-1].argmax().item()
    target = test_targets[i].item()

    # Show image patches (as heatmap of patch norms)
    patch_norms = test_images[i].norm(dim=-1).numpy()
    axes[0, i].bar(range(9), patch_norms, color=['#e74c3c' if j == target else '#3498db'
                                                   for j in range(9)])
    axes[0, i].set_title(f'Target patch: {target}, Pred: {pred}',
                          fontsize=11, fontweight='bold',
                          color='green' if pred == target else 'red')
    axes[0, i].set_xlabel('Patch Index')
    axes[0, i].set_ylabel('Patch Norm')

    # Show attention of the PREDICT token (last text token) to image patches
    # Use the last layer's attention, averaged across heads
    last_layer_attn = attn_layers[-1].mean(dim=0)  # (n_text, n_image)
    predict_token_attn = last_layer_attn[-1].detach().numpy()  # last text token

    axes[1, i].bar(range(9), predict_token_attn,
                   color=['#e74c3c' if j == target else '#95a5a6' for j in range(9)])
    axes[1, i].set_xlabel('Image Patch')
    axes[1, i].set_ylabel('Attention Weight')
    axes[1, i].set_title('Cross-Attn of [PREDICT] Token', fontsize=10)

axes[0, 0].set_ylabel('Patch Norm\n(higher = activated)', fontsize=10)
axes[1, 0].set_ylabel('Attention Weight\n(higher = more focus)', fontsize=10)

plt.suptitle('Trained MiniVLM: Does the Model Attend to the Right Patch?',
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("Red bars show the target patch.")
print("If the model learned correctly, the attention weights (bottom row)")
print("should peak at the same position as the activated patch (top row).")

## 7. Training and Results — Architecture Comparison

In [None]:
# Compare: 1 layer vs 2 layers vs 4 layers
configs = [
    ("1 layer", 1),
    ("2 layers", 2),
    ("4 layers", 4),
]

results = {}

for name, n_layers in configs:
    torch.manual_seed(42)
    m = MiniVLMRef(d_vision=64, d_model=32, num_heads=4, d_ff=64,
                    num_layers=n_layers, vocab_size=10)
    opt = torch.optim.Adam(m.parameters(), lr=1e-3)

    accs = []
    for step in range(200):
        imgs, txt, tgt = generate_batch(batch_size=32)
        batch_logits = []
        for i in range(imgs.size(0)):
            logit, _ = m(txt[i], imgs[i])
            batch_logits.append(logit[-1])
        batch_logits = torch.stack(batch_logits)
        loss = loss_fn(batch_logits, tgt)
        opt.zero_grad()
        loss.backward()
        opt.step()
        accs.append((batch_logits.argmax(-1) == tgt).float().mean().item())

    results[name] = accs
    print(f"{name}: final accuracy = {accs[-1]:.2%}")

# Plot comparison
plt.figure(figsize=(10, 5))
colors = ['#3498db', '#e74c3c', '#2ecc71']
for (name, _), color in zip(configs, colors):
    plt.plot(results[name], label=name, color=color, alpha=0.8)

plt.xlabel('Training Step')
plt.ylabel('Accuracy')
plt.title('Model Depth Comparison on Patch Classification', fontsize=14, fontweight='bold')
plt.axhline(y=1/9, color='gray', linestyle='--', alpha=0.5, label='Random chance')
plt.legend()
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()

## 8. Final Output — The Complete VLM Pipeline Visualization

In [None]:
# Final comprehensive visualization
model.eval()

# Generate a single test case
torch.manual_seed(99)
test_img, test_txt, test_tgt = generate_batch(batch_size=1)

with torch.no_grad():
    logits, attn_layers = model(test_txt[0], test_img[0])

pred = logits[-1].argmax().item()
target = test_tgt[0].item()

fig = plt.figure(figsize=(18, 12))
gs = fig.add_gridspec(3, 4, hspace=0.4, wspace=0.3)

# Row 1: Image patches + projection
ax1 = fig.add_subplot(gs[0, :2])
patch_norms = test_img[0].norm(dim=-1).numpy()
bars = ax1.bar(range(9), patch_norms,
               color=['#e74c3c' if j == target else '#3498db' for j in range(9)],
               edgecolor='white', linewidth=0.5)
ax1.set_title('Image Patches (red = target)', fontsize=13, fontweight='bold')
ax1.set_xlabel('Patch Index')
ax1.set_ylabel('Feature Norm')

# Row 1: Text tokens
ax2 = fig.add_subplot(gs[0, 2:])
text_labels = ['[CLS]', 'PAD', 'PAD', 'PAD', '[PREDICT]']
text_colors = ['#9b59b6', '#bdc3c7', '#bdc3c7', '#bdc3c7', '#f39c12']
ax2.barh(range(5), [1]*5, color=text_colors, edgecolor='white')
for i, label in enumerate(text_labels):
    ax2.text(0.5, i, label, ha='center', va='center', fontsize=12, fontweight='bold')
ax2.set_title('Text Token Sequence', fontsize=13, fontweight='bold')
ax2.set_yticks([])
ax2.set_xticks([])

# Row 2: Cross-attention per head (last layer, PREDICT token)
for h in range(4):
    ax = fig.add_subplot(gs[1, h])
    attn_map = attn_layers[-1][h, -1].detach().numpy().reshape(3, 3)
    im = ax.imshow(attn_map, cmap='YlOrRd', vmin=0)
    for r in range(3):
        for c in range(3):
            ax.text(c, r, f'{attn_map[r,c]:.2f}', ha='center', va='center',
                    fontsize=9, color='white' if attn_map[r,c] > 0.13 else 'black')
    ax.set_title(f'Head {h+1}', fontsize=12)
    ax.set_xticks([])
    ax.set_yticks([])

# Row 3: Average attention + prediction
ax5 = fig.add_subplot(gs[2, :2])
avg_attn = attn_layers[-1].mean(dim=0)[-1].detach().numpy()
ax5.bar(range(9), avg_attn,
        color=['#e74c3c' if j == target else '#95a5a6' for j in range(9)])
ax5.set_title('Average Cross-Attention (all heads)', fontsize=12, fontweight='bold')
ax5.set_xlabel('Image Patch')
ax5.set_ylabel('Attention Weight')

ax6 = fig.add_subplot(gs[2, 2:])
probs = F.softmax(logits[-1], dim=-1).detach().numpy()[:9]
ax6.bar(range(9), probs,
        color=['#2ecc71' if j == pred else '#bdc3c7' for j in range(9)])
ax6.set_title(f'Prediction: patch {pred} (target: {target})',
              fontsize=12, fontweight='bold',
              color='green' if pred == target else 'red')
ax6.set_xlabel('Class (Patch Index)')
ax6.set_ylabel('Probability')

plt.suptitle('Mini VLM: Complete Cross-Attention Pipeline',
             fontsize=16, fontweight='bold', y=1.0)
plt.show()

print(f"\nTarget patch: {target}")
print(f"Predicted patch: {pred}")
print(f"Correct: {'Yes!' if pred == target else 'No'}")
print(f"\nThis visualization shows the complete VLM pipeline:")
print("1. Image patches → 2. Text tokens → 3. Per-head attention → 4. Prediction")

## 9. Reflection and Next Steps

### Key Takeaways

1. A **VLM block** has three sub-layers: self-attention (text context), cross-attention (visual grounding), and FFN (nonlinear processing)
2. **Residual connections** preserve the original information while adding new information from attention
3. **Layer normalization** stabilizes training by normalizing each token independently
4. The trained model learns to **focus its cross-attention on the relevant image patch** — this is visual grounding in action
5. **Multiple heads** attend to different aspects of the image simultaneously

### What We Built

Across these three notebooks, we built every component of a modern VLM from scratch:
- **Notebook 1:** Scaled dot-product attention, self-attention, cross-attention
- **Notebook 2:** Token alignment (linear/MLP projection), multi-head attention
- **Notebook 3:** Full VLM block, training loop, attention visualization

### Reflection Questions
- How would you modify this model for actual image captioning (generating words sequentially)?
- What is the computational cost of cross-attention compared to self-attention? (Hint: think about the shapes of Q, K, V)
- In practice, LLaVA skips cross-attention entirely and just concatenates image tokens with text tokens. What are the tradeoffs?
- How would you add causal masking to the self-attention (needed for autoregressive generation)?

### Going Further
- Read the LLaVA paper (Liu et al., 2023) to see how a simple linear projection achieves strong results
- Read the Flamingo paper (Alayrac et al., 2022) to see gated cross-attention in action
- Try extending this notebook to generate text autoregressively