In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Building LLaVA and Flamingo from Scratch

*Part 2 of the Vizuara series on Multimodal Fusion Architectures*
*Estimated time: 60 minutes*

## 1. Why Does This Matter?

In the previous notebook, we compared three fusion strategies on a synthetic task. Now we are going to build the two most influential vision-language architectures from scratch: **LLaVA** and **Flamingo**.

LLaVA answers the question: what is the simplest way to make an LLM see? Flamingo answers a different question: how do you inject vision into a frozen LLM without breaking it?

By the end of this notebook, you will have:
- Built a simplified LLaVA model that processes images and generates text
- Built a Flamingo-style gated cross-attention module
- Compared both architectures on a visual question answering task with CIFAR-10 images

In [None]:
# Setup
!pip install torch torchvision matplotlib numpy -q

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(42)
np.random.seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Building Intuition

Let us think about LLaVA and Flamingo through a classroom analogy.

**LLaVA is like a student who takes notes on a lecture slide before class discussion.** The student (LLM) first looks at the slide (image), converts it into personal notes (visual tokens via projection), and then discusses the topic with those notes integrated right into the conversation transcript.

**Flamingo is like a student who can glance at a reference book during an exam.** The student (frozen LLM) takes the exam (processes text) normally, but at certain checkpoints, they are allowed to glance at a reference book (visual tokens via gated cross-attention). The glancing is controlled -- at first they barely look, and gradually they learn which parts of the book are helpful.

The key difference: LLaVA changes the input (longer sequence), while Flamingo changes the architecture (adds cross-attention layers).

### Think About This

If you wanted to add vision to an existing LLM without retraining the entire model, which approach would be cheaper? Which would preserve more of the LLM's original language ability?

## 3. The Mathematics

### 3.1 LLaVA Forward Pass

$$y = \text{LLM}\left([W \cdot \text{ViT}(I) \;;\; \text{Embed}(T)]\right)$$

Let us trace the dimensions. If $\text{ViT}(I) \in \mathbb{R}^{P \times d_v}$ (P patches, each $d_v$-dimensional), and $W \in \mathbb{R}^{d_{LLM} \times d_v}$, then:

$$W \cdot \text{ViT}(I) \in \mathbb{R}^{P \times d_{LLM}}$$

After concatenation with text embeddings $\text{Embed}(T) \in \mathbb{R}^{N \times d_{LLM}}$:

$$[W \cdot \text{ViT}(I) ; \text{Embed}(T)] \in \mathbb{R}^{(P+N) \times d_{LLM}}$$

**Computational meaning:** We project visual patches into the same dimensional space as text tokens, then treat them as extra tokens. The LLM sees a longer sequence but does not know (or care) which tokens are visual and which are textual.

Let us plug in concrete numbers. Suppose $P = 4$ patches, $d_v = 8$, $d_{LLM} = 16$, $N = 3$ text tokens:
- ViT output: $4 \times 8$ matrix
- After projection: $4 \times 16$ matrix
- Text embeddings: $3 \times 16$ matrix
- Concatenated: $7 \times 16$ matrix (7 tokens total, 16 dims each)

This is exactly what we want -- the LLM now processes a unified 7-token sequence.

### 3.2 Flamingo Gated Cross-Attention

$$h' = h + \tanh(\alpha) \cdot \text{CrossAttn}(h, v)$$

Here $\alpha$ starts at 0, so $\tanh(0) = 0$, meaning no visual information flows at first. As training progresses, $\alpha$ moves away from zero.

Let us compute with $\alpha = 0.5$, $h = [1.0, 2.0]$, and $\text{CrossAttn}(h, v) = [0.3, -0.1]$:

$$h' = [1.0, 2.0] + \tanh(0.5) \cdot [0.3, -0.1]$$
$$= [1.0, 2.0] + 0.462 \cdot [0.3, -0.1]$$
$$= [1.0, 2.0] + [0.139, -0.046]$$
$$= [1.139, 1.954]$$

This tells us that at $\alpha = 0.5$, about 46% of the cross-attention signal passes through. The visual information gently modifies the LLM hidden states. This is exactly what we want -- a smooth, controlled injection.

## 4. Let's Build It -- Component by Component

### 4.1 A Simple Vision Encoder

First, we need something to convert images into visual tokens. We will build a simplified patch-based vision encoder (inspired by ViT but much smaller).

In [None]:
class SimplePatchEncoder(nn.Module):
    """
    A simplified vision encoder that splits an image into patches
    and processes them into visual tokens.

    For a 32x32 CIFAR image with patch_size=8:
    - We get 4x4 = 16 patches
    - Each patch is 8x8x3 = 192 pixels
    - Projected to embed_dim dimensions
    """
    def __init__(self, img_size=32, patch_size=8, in_channels=3, embed_dim=64):
        super().__init__()
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2

        # Linear projection of flattened patches
        patch_dim = patch_size * patch_size * in_channels
        self.patch_proj = nn.Linear(patch_dim, embed_dim)

        # Learnable position embeddings
        self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches, embed_dim) * 0.02)

        # A small transformer layer for inter-patch reasoning
        self.transformer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=4, dim_feedforward=embed_dim*2,
            batch_first=True, dropout=0.1
        )

    def forward(self, x):
        """x: (B, C, H, W) -> (B, num_patches, embed_dim)"""
        B, C, H, W = x.shape
        p = self.patch_size

        # Extract patches: (B, C, H, W) -> (B, num_patches, patch_dim)
        patches = x.unfold(2, p, p).unfold(3, p, p)  # (B, C, H/p, W/p, p, p)
        patches = patches.contiguous().view(B, C, -1, p, p)  # (B, C, num_patches, p, p)
        patches = patches.permute(0, 2, 1, 3, 4)  # (B, num_patches, C, p, p)
        patches = patches.reshape(B, self.num_patches, -1)  # (B, num_patches, patch_dim)

        # Project patches to embedding dim and add position embeddings
        tokens = self.patch_proj(patches) + self.pos_embed

        # Process with transformer
        tokens = self.transformer(tokens)

        return tokens  # (B, num_patches, embed_dim)

# Test the encoder
encoder = SimplePatchEncoder()
dummy_img = torch.randn(2, 3, 32, 32)
visual_tokens = encoder(dummy_img)
print(f"Input image: {dummy_img.shape}")
print(f"Visual tokens: {visual_tokens.shape}")
print(f"Number of patches: {encoder.num_patches}")

In [None]:
# Visualization checkpoint: what do patches look like?
transform = transforms.Compose([transforms.ToTensor()])
cifar = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

img, label = cifar[0]
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

fig, axes = plt.subplots(1, 5, figsize=(15, 3))

# Show original image
axes[0].imshow(img.permute(1, 2, 0))
axes[0].set_title(f'Original: {class_names[label]}')
axes[0].axis('off')

# Show 4 patches
p = 8
patches = img.unfold(1, p, p).unfold(2, p, p)
patches = patches.permute(1, 2, 0, 3, 4).reshape(-1, 3, p, p)
for i in range(4):
    axes[i+1].imshow(patches[i].permute(1, 2, 0))
    axes[i+1].set_title(f'Patch {i}')
    axes[i+1].axis('off')

plt.suptitle('Image split into patches (the visual tokens)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

### 4.2 A Simple Text Encoder and Decoder

For our task, we will use a simple vocabulary for visual question answering on CIFAR-10.

In [None]:
class SimpleTextProcessor(nn.Module):
    """
    Simple text processor with a small vocabulary for VQA on CIFAR-10.

    Vocabulary: class names + special tokens + question words
    """
    def __init__(self, embed_dim=64, max_seq_len=20):
        super().__init__()
        # Build vocabulary
        self.vocab = ['<pad>', '<bos>', '<eos>', '<unk>',
                      'what', 'is', 'this', 'in', 'the', 'image', '?',
                      'airplane', 'automobile', 'bird', 'cat', 'deer',
                      'dog', 'frog', 'horse', 'ship', 'truck',
                      'a', 'an', 'it', 'of', 'photo']
        self.word2idx = {w: i for i, w in enumerate(self.vocab)}
        self.idx2word = {i: w for w, i in self.word2idx.items()}
        self.vocab_size = len(self.vocab)

        self.embedding = nn.Embedding(self.vocab_size, embed_dim, padding_idx=0)
        self.pos_embed = nn.Parameter(torch.randn(1, max_seq_len, embed_dim) * 0.02)

    def encode(self, text):
        """Convert text string to token IDs."""
        words = text.lower().split()
        ids = [self.word2idx.get(w, self.word2idx['<unk>']) for w in words]
        return torch.tensor(ids)

    def forward(self, token_ids):
        """token_ids: (B, N) -> (B, N, embed_dim)"""
        B, N = token_ids.shape
        embeds = self.embedding(token_ids) + self.pos_embed[:, :N, :]
        return embeds

text_proc = SimpleTextProcessor()
print(f"Vocabulary size: {text_proc.vocab_size}")
print(f"Example encoding: 'what is this' -> {text_proc.encode('what is this')}")

### 4.3 Building LLaVA (Simplified)

Now let us build the complete LLaVA architecture. The key is the **projection layer** that maps visual tokens to the LLM's dimension.

In [None]:
class SimpleLLaVA(nn.Module):
    """
    Simplified LLaVA architecture:
    1. Image -> Patch Encoder -> Visual tokens (P x d_v)
    2. Visual tokens -> Projection -> Visual tokens (P x d_llm)
    3. [Visual tokens ; Text tokens] -> Transformer -> Output
    """
    def __init__(self, img_size=32, patch_size=8, embed_dim=64, num_heads=4, num_layers=2):
        super().__init__()
        self.embed_dim = embed_dim

        # Vision encoder
        self.vision_encoder = SimplePatchEncoder(img_size, patch_size, embed_dim=embed_dim)

        # Visual projection (in real LLaVA, this maps ViT dim -> LLM dim)
        self.visual_proj = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.GELU(),
            nn.Linear(embed_dim, embed_dim)
        )

        # Text processor
        self.text_proc = SimpleTextProcessor(embed_dim=embed_dim)

        # Transformer (acts as the "LLM")
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=num_heads,
            dim_feedforward=embed_dim*4, batch_first=True, dropout=0.1
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Output head: classify into CIFAR-10 classes
        self.classifier = nn.Linear(embed_dim, 10)

    def forward(self, images, text_ids):
        """
        images: (B, 3, 32, 32)
        text_ids: (B, N)
        """
        # Step 1-2: Encode and project visual tokens
        vis_tokens = self.vision_encoder(images)        # (B, P, embed_dim)
        vis_tokens = self.visual_proj(vis_tokens)        # (B, P, embed_dim)

        # Step 3: Encode text tokens
        txt_tokens = self.text_proc(text_ids)            # (B, N, embed_dim)

        # Step 4: Concatenate [visual ; text]
        combined = torch.cat([vis_tokens, txt_tokens], dim=1)  # (B, P+N, embed_dim)

        # Step 5: Process through transformer
        hidden = self.transformer(combined)              # (B, P+N, embed_dim)

        # Pool over all tokens and classify
        pooled = hidden.mean(dim=1)                      # (B, embed_dim)
        logits = self.classifier(pooled)                 # (B, 10)

        return logits

llava_model = SimpleLLaVA().to(device)
print(f"LLaVA parameters: {sum(p.numel() for p in llava_model.parameters()):,}")

### 4.4 Building Flamingo (Simplified)

Now let us build Flamingo. The key innovations are the **Perceiver Resampler** and **Gated Cross-Attention**.

In [None]:
class PerceiverResampler(nn.Module):
    """
    Perceiver Resampler: compresses variable-length visual tokens
    into a fixed number of "summary" tokens.

    Uses a set of learned queries that attend to the visual tokens.
    """
    def __init__(self, embed_dim=64, num_queries=4, num_heads=4):
        super().__init__()
        # Learnable query tokens
        self.queries = nn.Parameter(torch.randn(1, num_queries, embed_dim) * 0.02)

        # Cross-attention: queries attend to visual tokens
        self.cross_attn = nn.MultiheadAttention(
            embed_dim, num_heads, batch_first=True
        )
        self.norm = nn.LayerNorm(embed_dim)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 2),
            nn.GELU(),
            nn.Linear(embed_dim * 2, embed_dim)
        )
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, visual_tokens):
        """
        visual_tokens: (B, P, embed_dim) -- variable length P
        Returns: (B, num_queries, embed_dim) -- fixed length
        """
        B = visual_tokens.shape[0]
        queries = self.queries.expand(B, -1, -1)  # (B, num_queries, embed_dim)

        # Cross-attention: queries attend to visual tokens
        attn_out, _ = self.cross_attn(queries, visual_tokens, visual_tokens)
        x = self.norm(queries + attn_out)
        x = self.norm2(x + self.ffn(x))

        return x  # (B, num_queries, embed_dim)

# Test
resampler = PerceiverResampler()
vis = torch.randn(2, 16, 64)  # 16 visual tokens
summary = resampler(vis)
print(f"Visual tokens: {vis.shape} -> Summary tokens: {summary.shape}")
print(f"Compressed {vis.shape[1]} tokens into {summary.shape[1]} summary tokens!")

In [None]:
class GatedCrossAttentionLayer(nn.Module):
    """
    Gated Cross-Attention: the core Flamingo innovation.

    h' = h + tanh(alpha) * CrossAttn(h, v)

    alpha starts at 0, so tanh(0) = 0: no visual info at first.
    As training progresses, alpha learns to open the gate.
    """
    def __init__(self, embed_dim=64, num_heads=4):
        super().__init__()
        self.cross_attn = nn.MultiheadAttention(
            embed_dim, num_heads, batch_first=True
        )
        self.norm = nn.LayerNorm(embed_dim)
        # THE KEY: gate initialized to zero
        self.gate = nn.Parameter(torch.zeros(1))

    def forward(self, text_hidden, visual_summary):
        """
        text_hidden: (B, N, embed_dim) -- LLM hidden states
        visual_summary: (B, K, embed_dim) -- from Perceiver Resampler
        """
        # Cross-attention: text queries, visual keys/values
        attn_out, attn_weights = self.cross_attn(
            self.norm(text_hidden), visual_summary, visual_summary
        )

        # Gated residual connection
        gate_value = torch.tanh(self.gate)
        output = text_hidden + gate_value * attn_out

        return output, gate_value.item()


class SimpleFlamingo(nn.Module):
    """
    Simplified Flamingo architecture:
    1. Image -> Patch Encoder -> Visual tokens
    2. Visual tokens -> Perceiver Resampler -> Visual summary (fixed K tokens)
    3. Text -> Embedding -> Text tokens
    4. Text tokens processed by Transformer layers with Gated Cross-Attention
    """
    def __init__(self, img_size=32, patch_size=8, embed_dim=64,
                 num_heads=4, num_layers=2, num_visual_queries=4):
        super().__init__()
        self.embed_dim = embed_dim

        # Vision encoder (frozen in real Flamingo)
        self.vision_encoder = SimplePatchEncoder(img_size, patch_size, embed_dim=embed_dim)

        # Perceiver Resampler
        self.resampler = PerceiverResampler(embed_dim, num_visual_queries, num_heads)

        # Text processor
        self.text_proc = SimpleTextProcessor(embed_dim=embed_dim)

        # Interleaved: Transformer layer -> Gated Cross-Attention -> Transformer layer -> ...
        self.layers = nn.ModuleList()
        self.cross_attn_layers = nn.ModuleList()
        for _ in range(num_layers):
            self.layers.append(nn.TransformerEncoderLayer(
                d_model=embed_dim, nhead=num_heads,
                dim_feedforward=embed_dim*4, batch_first=True, dropout=0.1
            ))
            self.cross_attn_layers.append(GatedCrossAttentionLayer(embed_dim, num_heads))

        # Classifier
        self.classifier = nn.Linear(embed_dim, 10)

    def forward(self, images, text_ids):
        # Step 1: Encode image -> visual tokens -> summary
        vis_tokens = self.vision_encoder(images)
        vis_summary = self.resampler(vis_tokens)  # (B, K, embed_dim)

        # Step 2: Encode text
        txt_hidden = self.text_proc(text_ids)     # (B, N, embed_dim)

        # Step 3: Interleaved processing
        gate_values = []
        for transformer_layer, cross_attn_layer in zip(self.layers, self.cross_attn_layers):
            txt_hidden = transformer_layer(txt_hidden)
            txt_hidden, gate_val = cross_attn_layer(txt_hidden, vis_summary)
            gate_values.append(gate_val)

        # Pool and classify
        pooled = txt_hidden.mean(dim=1)
        logits = self.classifier(pooled)

        self.last_gate_values = gate_values
        return logits

flamingo_model = SimpleFlamingo().to(device)
print(f"Flamingo parameters: {sum(p.numel() for p in flamingo_model.parameters()):,}")

In [None]:
# Visualization checkpoint: gate values at initialization
dummy_img = torch.randn(1, 3, 32, 32).to(device)
dummy_text = torch.tensor([[4, 5, 6]]).to(device)  # "what is this"

with torch.no_grad():
    _ = flamingo_model(dummy_img, dummy_text)

print("Gate values at initialization:")
for i, gv in enumerate(flamingo_model.last_gate_values):
    print(f"  Layer {i}: tanh(alpha) = {gv:.6f}")
print("\nAll gates are near zero -- the model starts as a pure text model!")
print("Visual information will gradually flow in during training.")

## 5. Your Turn

### TODO: Implement Multi-Image Flamingo

Real Flamingo can process multiple images. Modify the Flamingo architecture to accept a list of images and produce a visual summary that combines all of them.

In [None]:
class MultiImageFlamingo(nn.Module):
    """
    Extended Flamingo that processes multiple images per input.

    Instead of a single image, accepts a list of images.
    Each image gets its own visual tokens, then all are
    concatenated before the Perceiver Resampler.
    """
    def __init__(self, img_size=32, patch_size=8, embed_dim=64,
                 num_heads=4, num_visual_queries=8):
        super().__init__()
        # ============ TODO ============
        # Step 1: Create a SimplePatchEncoder
        # Step 2: Create a PerceiverResampler with more queries
        #         (since we have more visual tokens from multiple images)
        # Step 3: Create a GatedCrossAttentionLayer
        # ==============================

        self.vision_encoder = None   # YOUR CODE HERE
        self.resampler = None        # YOUR CODE HERE
        self.cross_attn = None       # YOUR CODE HERE

    def encode_multiple_images(self, images_list):
        """
        images_list: list of (B, C, H, W) tensors

        Returns: (B, total_patches, embed_dim) -- all visual tokens concatenated
        """
        # ============ TODO ============
        # For each image in images_list:
        #   1. Encode with vision_encoder
        #   2. Collect all visual tokens
        # Concatenate all visual tokens along sequence dim
        # ==============================

        return None  # YOUR CODE HERE

In [None]:
# Verification
if MultiImageFlamingo().vision_encoder is not None:
    multi = MultiImageFlamingo()
    imgs = [torch.randn(2, 3, 32, 32), torch.randn(2, 3, 32, 32)]
    all_tokens = multi.encode_multiple_images(imgs)
    expected_patches = multi.vision_encoder.num_patches * 2
    assert all_tokens.shape == (2, expected_patches, 64), f"Expected (2, {expected_patches}, 64)"
    print(f"Correct! {len(imgs)} images -> {all_tokens.shape[1]} visual tokens")
else:
    print("TODO: Implement MultiImageFlamingo above")

### TODO: Implement a Visual Projection with MLP (LLaVA-1.5)

LLaVA-1.5 upgraded the single linear projection to a 2-layer MLP with GELU activation. Implement this and compare.

In [None]:
class MLPProjection(nn.Module):
    """
    LLaVA-1.5 style MLP projection:
    visual_tokens -> Linear -> GELU -> Linear -> projected_tokens
    """
    def __init__(self, in_dim, out_dim):
        super().__init__()
        # ============ TODO ============
        # Create a 2-layer MLP with GELU activation between layers
        # Hint: nn.Sequential(Linear, GELU, Linear)
        # ==============================
        self.mlp = None  # YOUR CODE HERE

    def forward(self, x):
        # ============ TODO ============
        return None  # YOUR CODE HERE

In [None]:
# Verification
if MLPProjection(64, 64).mlp is not None:
    proj = MLPProjection(64, 128)
    x = torch.randn(2, 16, 64)
    out = proj(x)
    assert out.shape == (2, 16, 128), f"Expected (2, 16, 128), got {out.shape}"
    print("Correct! MLP projection works.")
else:
    print("TODO: Implement MLPProjection above")

## 6. Putting It All Together

Let us prepare the CIFAR-10 dataset for our VQA task.

In [None]:
# Prepare CIFAR-10 data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                         download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform)

# Use subset for speed
train_subset = torch.utils.data.Subset(trainset, range(5000))
test_subset = torch.utils.data.Subset(testset, range(1000))

train_loader = torch.utils.data.DataLoader(train_subset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_subset, batch_size=64, shuffle=False)

# The "question" is always "what is this" encoded as token IDs
question = torch.tensor([4, 5, 6]).unsqueeze(0)  # "what is this"
print(f"Question tokens: {question}")
print(f"Training on {len(train_subset)} images, testing on {len(test_subset)}")

## 7. Training and Results

In [None]:
def train_vqa(model, train_loader, test_loader, question, epochs=15, lr=1e-3):
    """Train a VQA model (either LLaVA or Flamingo)."""
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    train_losses, test_accs = [], []

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            q = question.expand(imgs.size(0), -1).to(device)

            optimizer.zero_grad()
            logits = model(imgs, q)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        train_losses.append(epoch_loss / len(train_loader))

        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for imgs, labels in test_loader:
                imgs, labels = imgs.to(device), labels.to(device)
                q = question.expand(imgs.size(0), -1).to(device)
                logits = model(imgs, q)
                _, pred = logits.max(1)
                correct += (pred == labels).sum().item()
                total += labels.size(0)

        test_accs.append(correct / total)
        if (epoch + 1) % 5 == 0:
            print(f"  Epoch {epoch+1:3d}: Loss={train_losses[-1]:.4f}, Acc={test_accs[-1]:.4f}")

    return train_losses, test_accs

print("Training LLaVA...")
llava_model = SimpleLLaVA().to(device)
llava_losses, llava_accs = train_vqa(llava_model, train_loader, test_loader, question)

print("\nTraining Flamingo...")
flamingo_model = SimpleFlamingo().to(device)
flamingo_losses, flamingo_accs = train_vqa(flamingo_model, train_loader, test_loader, question)

In [None]:
# Visualization checkpoint: training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].plot(llava_losses, label='LLaVA', color='#2196F3', linewidth=2)
axes[0].plot(flamingo_losses, label='Flamingo', color='#FF5722', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Training Loss')
axes[0].set_title('Training Loss: LLaVA vs Flamingo')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(llava_accs, label='LLaVA', color='#2196F3', linewidth=2)
axes[1].plot(flamingo_accs, label='Flamingo', color='#FF5722', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Test Accuracy')
axes[1].set_title('Test Accuracy: LLaVA vs Flamingo')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nFinal Accuracies:")
print(f"  LLaVA:    {llava_accs[-1]:.4f}")
print(f"  Flamingo: {flamingo_accs[-1]:.4f}")

In [None]:
# Track Flamingo gate values during training
print("\nFlamingo Gate Values (after training):")
flamingo_model.eval()
with torch.no_grad():
    sample_img = torch.randn(1, 3, 32, 32).to(device)
    sample_q = question.to(device)
    _ = flamingo_model(sample_img, sample_q)

for i, gv in enumerate(flamingo_model.last_gate_values):
    bar = '#' * int(abs(gv) * 50)
    print(f"  Layer {i}: tanh(alpha) = {gv:+.4f} |{bar}|")
print("\nThe gates have opened! Visual info is now flowing through the model.")

## 8. Final Output

In [None]:
# Generate predictions on test images
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

fig, axes = plt.subplots(2, 5, figsize=(20, 8))

llava_model.eval()
flamingo_model.eval()

for idx in range(10):
    img, label = test_subset[idx]
    img_input = img.unsqueeze(0).to(device)
    q = question.to(device)

    with torch.no_grad():
        llava_pred = llava_model(img_input, q).argmax(1).item()
        flamingo_pred = flamingo_model(img_input, q).argmax(1).item()

    row, col = idx // 5, idx % 5
    axes[row, col].imshow(img.permute(1, 2, 0) * 0.5 + 0.5)  # Unnormalize
    axes[row, col].set_title(
        f'True: {class_names[label]}\n'
        f'LLaVA: {class_names[llava_pred]}\n'
        f'Flamingo: {class_names[flamingo_pred]}',
        fontsize=9
    )
    axes[row, col].axis('off')

plt.suptitle('LLaVA vs Flamingo Predictions on CIFAR-10', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

print("Congratulations! You have built both LLaVA and Flamingo from scratch!")

## 9. Reflection and Next Steps

### Reflection Questions
1. LLaVA's sequence length grows with the number of visual patches. What happens if we use a higher-resolution image (more patches)? How does this affect memory and compute?
2. Flamingo's gate starts at zero. Why is this important? What would happen if we initialized it to a large value?
3. The Perceiver Resampler compresses 16 visual tokens to 4 summary tokens. What information might be lost? For what tasks would this matter most?

### Optional Challenges
1. Replace the `SimplePatchEncoder` with a pretrained ResNet backbone. Compare accuracy.
2. Implement the Perceiver Resampler with multiple cross-attention layers (as in the real Flamingo paper).
3. Add a text generation head (instead of classification) and generate captions for CIFAR images.