In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Instruction Tuning Pipeline: Teaching a Multimodal Model to Follow Instructions

*Part 2 of the Vizuara series on Multimodal Instruction Tuning*
*Estimated time: 50 minutes*

## 1. Why Does This Matter?

A model that can project images into language space is interesting, but not yet useful. It can "see" the image but it cannot answer questions about it, describe it in detail, or reason about its contents.

The magic of instruction tuning is that it teaches the model to **follow instructions about images** -- "What color is the car?", "Describe this scene in detail", "What is unusual about this image?". This transforms a simple image-to-text mapper into an interactive visual assistant.

**By the end of this notebook, you will have:**
- Implemented the two-stage training procedure (alignment + instruction tuning)
- Built and trained a model that can answer simple visual questions
- Visualized how the loss changes across training stages
- Understood why data quality matters more than architectural complexity

In [None]:
# Setup
!pip install -q torch torchvision matplotlib numpy

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import json
import random
from collections import defaultdict

torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Building Intuition

Think about how a human interpreter is trained:

**Stage 1 -- Basic Vocabulary:**
The interpreter first learns to translate simple phrases. "Apple" in Mandarin maps to "apple" in French. "Red car" maps to "voiture rouge." This is pure vocabulary alignment -- mapping concepts from one representation to another.

**Stage 2 -- Complex Conversations:**
Once the basic vocabulary is aligned, the interpreter practices translating complex negotiations, nuanced arguments, and multi-step reasoning. This requires not just word-for-word translation but understanding context, intent, and inference.

LLaVA's two-stage training follows exactly this pattern:

- **Stage 1 (Feature Alignment):** Only the projection layer trains. The model learns that "these visual features" correspond to "these caption words." Simple mapping.
- **Stage 2 (Instruction Tuning):** The projection layer AND the LLM both train. The model learns to follow complex instructions about images -- answering questions, providing descriptions, reasoning about visual content.

### Think About This

Why do we freeze the LLM during Stage 1 but unfreeze it during Stage 2? What would happen if we unfroze everything from the start?

## 3. The Mathematics

### The Training Objective

Both stages use the same fundamental objective: **autoregressive next-token prediction**. Given visual tokens $H_v$ and a sequence of text tokens $x_1, x_2, ..., x_T$, we minimize the negative log-likelihood:

$$\mathcal{L} = -\sum_{t=1}^{T} \log p_\theta(x_t \mid x_{<t}, H_v)$$

**Computationally, this says:** for each text token in the sequence, compute the probability of that token given all previous tokens and the visual context. Take the negative log of that probability. Sum over all positions. We want this to be small, meaning the model assigns high probability to the correct next token.

Let us walk through a numerical example. Suppose our caption is "a red car" (3 tokens), and the model predicts:

- $p(\text{"a"} \mid H_v) = 0.7$ --> $-\log(0.7) = 0.357$
- $p(\text{"red"} \mid H_v, \text{"a"}) = 0.4$ --> $-\log(0.4) = 0.916$
- $p(\text{"car"} \mid H_v, \text{"a"}, \text{"red"}) = 0.6$ --> $-\log(0.6) = 0.511$

Total loss: $\mathcal{L} = 0.357 + 0.916 + 0.511 = 1.784$

This tells us the model is fairly confident about "a" and "car" but less sure about "red." Training will push the model to assign higher probability to "red" by adjusting the projection and LLM weights.

### What Changes Between Stages

In Stage 1, only projector parameters $W_1, b_1, W_2, b_2$ receive gradients. The loss still flows through the LLM, but only the projector weights update.

In Stage 2, both the projector and LLM parameters update. This means the LLM itself adapts to the visual inputs, learning deeper multimodal reasoning patterns.

In [None]:
# Numerical demonstration of the loss
import torch.nn.functional as F

# Simulated model outputs (logits) for a 3-token caption
# Vocabulary: {0: "PAD", 1: "a", 2: "red", 3: "car", 4: "blue", 5: "the"}
logits = torch.tensor([
    [2.0, 3.5, 0.5, 0.3, 0.2, 1.0],   # Position 1: should predict "a" (idx 1)
    [0.5, 0.3, 2.8, 0.8, 0.5, 0.1],   # Position 2: should predict "red" (idx 2)
    [0.1, 0.2, 0.5, 3.2, 0.3, 0.4],   # Position 3: should predict "car" (idx 3)
])
targets = torch.tensor([1, 2, 3])  # "a", "red", "car"

# Compute probabilities
probs = F.softmax(logits, dim=-1)
print("Token probabilities:")
for i, (t, p) in enumerate(zip(["a", "red", "car"], probs)):
    token_prob = p[targets[i]].item()
    print(f"  p('{t}') = {token_prob:.3f}  -->  -log(p) = {-np.log(token_prob):.3f}")

# Compute loss
loss = F.cross_entropy(logits, targets)
print(f"\nTotal loss (cross-entropy): {loss.item():.4f}")
manual_loss = -torch.log(probs[range(3), targets]).mean()
print(f"Manual calculation:         {manual_loss.item():.4f}")

## 4. Let's Build It -- Component by Component

### 4.1 Synthetic Visual Question Answering Dataset

For training, we need image-text pairs. Let us create a simple but illustrative dataset: images of colored shapes with questions and answers.

In [None]:
class VisualQADataset:
    """Synthetic VQA dataset with colored shapes.

    Image types:
    - Colored circles/squares in different positions
    Questions:
    - "What color?" -> "red" / "blue" / "green"
    - "What shape?" -> "circle" / "square"
    - "Where is it?" -> "top" / "bottom" / "left" / "right" / "center"
    """

    COLORS = {0: "red", 1: "blue", 2: "green"}
    SHAPES = {0: "circle", 1: "square"}
    POSITIONS = {0: "top-left", 1: "top-right", 2: "bottom-left", 3: "bottom-right", 4: "center"}

    # Token vocabulary
    VOCAB = {
        "<pad>": 0, "<start>": 1, "<end>": 2,
        "what": 3, "color": 4, "shape": 5, "where": 6, "is": 7, "it": 8, "?": 9,
        "red": 10, "blue": 11, "green": 12,
        "circle": 13, "square": 14,
        "top-left": 15, "top-right": 16, "bottom-left": 17, "bottom-right": 18, "center": 19,
    }
    INV_VOCAB = {v: k for k, v in VOCAB.items()}

    def __init__(self, n_samples: int = 500, image_size: int = 64):
        self.image_size = image_size
        self.data = []

        for _ in range(n_samples):
            color_idx = random.randint(0, 2)
            shape_idx = random.randint(0, 1)
            pos_idx = random.randint(0, 4)
            q_type = random.choice(["color", "shape", "position"])

            img = self._create_image(color_idx, shape_idx, pos_idx)

            if q_type == "color":
                q_tokens = [self.VOCAB["what"], self.VOCAB["color"], self.VOCAB["?"]]
                a_tokens = [self.VOCAB[self.COLORS[color_idx]]]
            elif q_type == "shape":
                q_tokens = [self.VOCAB["what"], self.VOCAB["shape"], self.VOCAB["?"]]
                a_tokens = [self.VOCAB[self.SHAPES[shape_idx]]]
            else:
                q_tokens = [self.VOCAB["where"], self.VOCAB["is"], self.VOCAB["it"], self.VOCAB["?"]]
                a_tokens = [self.VOCAB[self.POSITIONS[pos_idx]]]

            # Also create a simple caption for Stage 1
            caption_tokens = [
                self.VOCAB[self.COLORS[color_idx]],
                self.VOCAB[self.SHAPES[shape_idx]],
            ]

            self.data.append({
                "image": img,
                "question_tokens": q_tokens,
                "answer_tokens": a_tokens,
                "caption_tokens": caption_tokens,
                "color": self.COLORS[color_idx],
                "shape": self.SHAPES[shape_idx],
                "position": self.POSITIONS[pos_idx],
            })

    def _create_image(self, color_idx, shape_idx, pos_idx):
        img = torch.zeros(3, self.image_size, self.image_size)

        # Position offsets
        positions = {
            0: (self.image_size//4, self.image_size//4),
            1: (self.image_size//4, 3*self.image_size//4),
            2: (3*self.image_size//4, self.image_size//4),
            3: (3*self.image_size//4, 3*self.image_size//4),
            4: (self.image_size//2, self.image_size//2),
        }
        cy, cx = positions[pos_idx]
        r = self.image_size // 6

        # Draw shape
        y, x = torch.meshgrid(torch.arange(self.image_size), torch.arange(self.image_size), indexing='ij')

        if shape_idx == 0:  # Circle
            mask = ((x - cx)**2 + (y - cy)**2) < r**2
        else:  # Square
            mask = (abs(x - cx) < r) & (abs(y - cy) < r)

        # Set color
        color_channels = {0: [1.0, 0.1, 0.1], 1: [0.1, 0.1, 1.0], 2: [0.1, 0.8, 0.1]}
        for c, val in enumerate(color_channels[color_idx]):
            img[c][mask] = val

        # Add noise
        img += torch.randn_like(img) * 0.05
        return img.clamp(0, 1)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


# Create datasets
train_dataset = VisualQADataset(n_samples=600)
test_dataset = VisualQADataset(n_samples=120)

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples:     {len(test_dataset)}")
print(f"Vocabulary size:  {len(VisualQADataset.VOCAB)}")

In [None]:
# Visualize some examples
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
axes = axes.flatten()

for i in range(10):
    sample = train_dataset[i]
    img = sample["image"].permute(1, 2, 0).numpy()
    axes[i].imshow(img)

    q = " ".join([VisualQADataset.INV_VOCAB[t] for t in sample["question_tokens"]])
    a = " ".join([VisualQADataset.INV_VOCAB[t] for t in sample["answer_tokens"]])
    axes[i].set_title(f"Q: {q}\nA: {a}", fontsize=8)
    axes[i].axis('off')

plt.suptitle("Sample Training Data: Colored Shapes with Questions", fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()

### 4.2 The Multimodal Model with Frozen/Unfrozen Components

In [None]:
class InstructionTunableModel(nn.Module):
    """Multimodal model with configurable frozen/trainable components."""

    def __init__(self, image_size=64, patch_size=8, vision_dim=128,
                 llm_dim=64, vocab_size=20, num_heads=2, num_layers=2):
        super().__init__()
        self.num_patches = (image_size // patch_size) ** 2

        # Vision encoder (always frozen in LLaVA)
        self.vision_encoder = nn.Sequential(
            nn.Conv2d(3, vision_dim, kernel_size=patch_size, stride=patch_size),
            nn.Flatten(2),
        )

        # Projector (trainable in both stages)
        self.projector = nn.Sequential(
            nn.Linear(vision_dim, llm_dim),
            nn.GELU(),
            nn.Linear(llm_dim, llm_dim),
        )

        # LLM components (frozen in Stage 1, trainable in Stage 2)
        self.text_embedding = nn.Embedding(vocab_size, llm_dim)
        self.position_embedding = nn.Embedding(512, llm_dim)

        decoder_layer = nn.TransformerDecoderLayer(
            d_model=llm_dim, nhead=num_heads, dim_feedforward=llm_dim * 4,
            batch_first=True, dropout=0.1
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.output_head = nn.Linear(llm_dim, vocab_size)

    def set_stage(self, stage: int):
        """Configure which components are trainable.

        Stage 1: Only projector is trainable
        Stage 2: Projector + LLM are trainable
        """
        # Vision encoder is always frozen
        for p in self.vision_encoder.parameters():
            p.requires_grad = False

        # Projector is always trainable
        for p in self.projector.parameters():
            p.requires_grad = True

        # LLM components
        llm_trainable = (stage == 2)
        for module in [self.text_embedding, self.position_embedding,
                       self.decoder, self.output_head]:
            for p in module.parameters():
                p.requires_grad = llm_trainable

        trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        total = sum(p.numel() for p in self.parameters())
        print(f"Stage {stage}: {trainable:,} / {total:,} parameters trainable ({trainable/total:.1%})")

    def forward(self, images, token_ids):
        B = images.shape[0]

        # Vision path
        with torch.no_grad():
            vis = self.vision_encoder(images)       # (B, vision_dim, N)
            vis = vis.transpose(1, 2)                # (B, N, vision_dim)
        vis_tokens = self.projector(vis)              # (B, N, llm_dim)

        # Text path
        text_tokens = self.text_embedding(token_ids)  # (B, M, llm_dim)

        # Combine
        combined = torch.cat([vis_tokens, text_tokens], dim=1)
        seq_len = combined.shape[1]

        positions = torch.arange(seq_len, device=combined.device).unsqueeze(0)
        combined = combined + self.position_embedding(positions)

        # Causal attention
        mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(combined.device)
        memory = torch.zeros(B, 1, combined.shape[-1], device=combined.device)
        output = self.decoder(combined, memory, tgt_mask=mask)

        logits = self.output_head(output)
        return logits

## 5. Your Turn

### TODO: Implement Stage 1 Training (Feature Alignment)

Implement the training loop for Stage 1, where only the projector is trainable and the data is image-caption pairs.

In [None]:
def train_stage1(model, dataset, num_epochs=30, batch_size=32, lr=1e-3):
    """Stage 1: Feature alignment with captioning data.

    Only the projector should be trainable. The training data is
    image-caption pairs (not instruction-following data).

    Args:
        model: InstructionTunableModel
        dataset: VisualQADataset (we use the caption_tokens)
        num_epochs: number of training epochs
        batch_size: batch size
        lr: learning rate

    Returns:
        list of per-epoch losses
    """
    # ============ TODO ============
    # Step 1: Call model.set_stage(1) to freeze LLM
    # Step 2: Create optimizer with only trainable parameters
    #         Hint: filter(lambda p: p.requires_grad, model.parameters())
    # Step 3: Training loop:
    #   a) Sample a batch of images and captions
    #   b) Input tokens = [<start>] + caption_tokens (teacher forcing)
    #   c) Target tokens = caption_tokens + [<end>]
    #   d) Forward pass through model
    #   e) Compute cross-entropy loss on caption positions only
    #      (ignore visual token positions in the loss)
    #   f) Backprop and optimize
    # ==============================

    losses = []  # YOUR CODE HERE

    return losses

In [None]:
# Verification - check your implementation produces decreasing loss
model = InstructionTunableModel().to(device)
stage1_losses = train_stage1(model, train_dataset, num_epochs=30)

assert len(stage1_losses) == 30, f"Expected 30 loss values, got {len(stage1_losses)}"
assert stage1_losses[-1] < stage1_losses[0], "Loss should decrease during training"
print(f"\nStage 1 complete!")
print(f"  Initial loss: {stage1_losses[0]:.4f}")
print(f"  Final loss:   {stage1_losses[-1]:.4f}")
print("All checks passed!")

### TODO: Implement Stage 2 Training (Instruction Tuning)

Now implement Stage 2, where both the projector and LLM are trainable and the data is question-answer pairs.

In [None]:
def train_stage2(model, dataset, num_epochs=50, batch_size=32, lr=5e-4):
    """Stage 2: Instruction tuning with QA data.

    Both projector and LLM should be trainable. The training data is
    image + question -> answer pairs.

    Args:
        model: InstructionTunableModel (already Stage 1 trained)
        dataset: VisualQADataset (we use question_tokens and answer_tokens)
        num_epochs: number of training epochs
        batch_size: batch size
        lr: learning rate

    Returns:
        list of per-epoch losses
    """
    # ============ TODO ============
    # Step 1: Call model.set_stage(2) to unfreeze LLM
    # Step 2: Create optimizer (note: lower learning rate!)
    # Step 3: Training loop:
    #   a) Sample a batch of images, questions, and answers
    #   b) Input tokens = [<start>] + question_tokens + answer_tokens
    #   c) Target = answer_tokens + [<end>] (only compute loss on answer positions)
    #   d) Forward pass, compute loss on answer positions only
    #   e) Backprop and optimize
    # ==============================

    losses = []  # YOUR CODE HERE

    return losses

In [None]:
# Verification
stage2_losses = train_stage2(model, train_dataset, num_epochs=50)

assert len(stage2_losses) == 50, f"Expected 50 loss values, got {len(stage2_losses)}"
assert stage2_losses[-1] < stage2_losses[0], "Loss should decrease during training"
print(f"\nStage 2 complete!")
print(f"  Initial loss: {stage2_losses[0]:.4f}")
print(f"  Final loss:   {stage2_losses[-1]:.4f}")
print("All checks passed!")

## 6. Putting It All Together

In [None]:
# Reference implementation of both stages for those who want to see the full code

def train_both_stages(model, dataset, stage1_epochs=30, stage2_epochs=50):
    """Complete two-stage training pipeline."""

    all_losses = {"stage1": [], "stage2": []}

    # ---- Stage 1: Feature Alignment ----
    model.set_stage(1)
    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3
    )

    for epoch in range(stage1_epochs):
        model.train()
        epoch_loss = 0
        n_batches = 0

        indices = list(range(len(dataset)))
        random.shuffle(indices)

        for i in range(0, len(indices), 32):
            batch_idx = indices[i:i+32]
            batch = [dataset[j] for j in batch_idx]

            images = torch.stack([b["image"] for b in batch]).to(device)
            captions = [b["caption_tokens"] for b in batch]

            # Build input/target sequences
            max_len = max(len(c) for c in captions) + 1  # +1 for start token
            input_ids = torch.zeros(len(batch), max_len, dtype=torch.long, device=device)
            target_ids = torch.full((len(batch), max_len), -100, dtype=torch.long, device=device)

            for j, cap in enumerate(captions):
                input_ids[j, 0] = 1  # <start>
                for k, t in enumerate(cap):
                    input_ids[j, k+1] = t
                    target_ids[j, k] = t
                target_ids[j, len(cap)] = 2  # <end>

            logits = model(images, input_ids)
            # Only compute loss on text positions (after visual tokens)
            num_patches = model.num_patches
            text_logits = logits[:, num_patches:, :]
            loss = F.cross_entropy(text_logits.reshape(-1, text_logits.shape[-1]),
                                   target_ids.reshape(-1), ignore_index=-100)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            n_batches += 1

        all_losses["stage1"].append(epoch_loss / n_batches)
        if (epoch + 1) % 10 == 0:
            print(f"  Stage 1 Epoch {epoch+1}: loss = {all_losses['stage1'][-1]:.4f}")

    # ---- Stage 2: Instruction Tuning ----
    model.set_stage(2)
    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()), lr=5e-4
    )

    for epoch in range(stage2_epochs):
        model.train()
        epoch_loss = 0
        n_batches = 0

        indices = list(range(len(dataset)))
        random.shuffle(indices)

        for i in range(0, len(indices), 32):
            batch_idx = indices[i:i+32]
            batch = [dataset[j] for j in batch_idx]

            images = torch.stack([b["image"] for b in batch]).to(device)

            # Build input: [<start>] + question + answer
            all_tokens = []
            for b in batch:
                all_tokens.append([1] + b["question_tokens"] + b["answer_tokens"])

            max_len = max(len(t) for t in all_tokens)
            input_ids = torch.zeros(len(batch), max_len, dtype=torch.long, device=device)
            target_ids = torch.full((len(batch), max_len), -100, dtype=torch.long, device=device)

            for j, tokens in enumerate(all_tokens):
                for k, t in enumerate(tokens):
                    input_ids[j, k] = t
                # Only compute loss on answer tokens
                q_len = len(batch[j]["question_tokens"]) + 1  # +1 for start
                a_tokens = batch[j]["answer_tokens"]
                for k, t in enumerate(a_tokens):
                    target_ids[j, q_len - 1 + k] = t
                target_ids[j, q_len - 1 + len(a_tokens)] = 2  # <end>

            logits = model(images, input_ids)
            num_patches = model.num_patches
            text_logits = logits[:, num_patches:, :]
            loss = F.cross_entropy(text_logits.reshape(-1, text_logits.shape[-1]),
                                   target_ids.reshape(-1), ignore_index=-100)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            n_batches += 1

        all_losses["stage2"].append(epoch_loss / n_batches)
        if (epoch + 1) % 10 == 0:
            print(f"  Stage 2 Epoch {epoch+1}: loss = {all_losses['stage2'][-1]:.4f}")

    return all_losses

# Train the model
model = InstructionTunableModel().to(device)
losses = train_both_stages(model, train_dataset)

## 7. Training and Results

In [None]:
# Visualization: Two-stage training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Stage 1
axes[0].plot(losses["stage1"], color='steelblue', linewidth=2)
axes[0].set_title("Stage 1: Feature Alignment\n(Only projector trains)", fontsize=12)
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].grid(True, alpha=0.3)
axes[0].axhline(y=losses["stage1"][-1], color='steelblue', linestyle='--', alpha=0.5)

# Stage 2
axes[1].plot(losses["stage2"], color='coral', linewidth=2)
axes[1].set_title("Stage 2: Instruction Tuning\n(Projector + LLM train)", fontsize=12)
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Loss")
axes[1].grid(True, alpha=0.3)
axes[1].axhline(y=losses["stage2"][-1], color='coral', linestyle='--', alpha=0.5)

plt.suptitle("Two-Stage Training: This is exactly how LLaVA is trained!",
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print(f"Stage 1 final loss: {losses['stage1'][-1]:.4f}")
print(f"Stage 2 final loss: {losses['stage2'][-1]:.4f}")

In [None]:
# Evaluate on test set
model.eval()
correct = 0
total = 0
results_by_type = defaultdict(lambda: {"correct": 0, "total": 0})

with torch.no_grad():
    for sample in test_dataset:
        img = sample["image"].unsqueeze(0).to(device)
        q_tokens = sample["question_tokens"]
        a_tokens = sample["answer_tokens"]

        # Input: [<start>] + question
        input_ids = torch.tensor([[1] + q_tokens], device=device)
        logits = model(img, input_ids)

        # Predict answer at the position after the question
        num_patches = model.num_patches
        answer_pos = num_patches + len(q_tokens)
        pred_id = logits[0, answer_pos, :].argmax().item()
        true_id = a_tokens[0]

        is_correct = (pred_id == true_id)
        correct += int(is_correct)
        total += 1

        # Track by question type
        if 4 in q_tokens:  # "color" question
            q_type = "color"
        elif 5 in q_tokens:  # "shape" question
            q_type = "shape"
        else:
            q_type = "position"
        results_by_type[q_type]["correct"] += int(is_correct)
        results_by_type[q_type]["total"] += 1

print(f"Overall test accuracy: {correct/total:.1%}")
print(f"\nBreakdown by question type:")
for q_type, stats in sorted(results_by_type.items()):
    acc = stats["correct"] / stats["total"]
    print(f"  {q_type:10s}: {acc:.1%} ({stats['correct']}/{stats['total']})")

## 8. Final Output

In [None]:
# Interactive demo: show images and model's answers
fig, axes = plt.subplots(3, 5, figsize=(16, 10))
axes = axes.flatten()

model.eval()
inv_vocab = VisualQADataset.INV_VOCAB

demo_samples = random.sample(range(len(test_dataset)), 15)

with torch.no_grad():
    for i, idx in enumerate(demo_samples):
        sample = test_dataset[idx]
        img = sample["image"].unsqueeze(0).to(device)
        q_tokens = sample["question_tokens"]
        a_tokens = sample["answer_tokens"]

        input_ids = torch.tensor([[1] + q_tokens], device=device)
        logits = model(img, input_ids)

        num_patches = model.num_patches
        answer_pos = num_patches + len(q_tokens)
        pred_id = logits[0, answer_pos, :].argmax().item()
        true_id = a_tokens[0]

        question = " ".join([inv_vocab[t] for t in q_tokens])
        pred_answer = inv_vocab.get(pred_id, "?")
        true_answer = inv_vocab[true_id]
        is_correct = pred_id == true_id

        display_img = sample["image"].permute(1, 2, 0).numpy()
        axes[i].imshow(display_img)
        axes[i].set_title(
            f"Q: {question}\nPred: {pred_answer} | True: {true_answer}",
            fontsize=8,
            color='green' if is_correct else 'red',
            fontweight='bold'
        )
        axes[i].axis('off')

plt.suptitle("Visual QA: Model Predictions After Two-Stage Training",
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("Congratulations! You have trained a multimodal model with the exact two-stage")
print("pipeline used by LLaVA -- feature alignment followed by instruction tuning!")

## 9. Reflection and Next Steps

### Reflection Questions

1. In our experiment, Stage 1 used caption data and Stage 2 used QA data. What would happen if you skipped Stage 1 entirely and went straight to Stage 2? Try it and compare the results.

2. LLaVA uses 595K caption pairs for Stage 1 but only 158K instruction-following examples for Stage 2. Why might the ratio be weighted toward simpler data? What happens if you use more Stage 2 data?

3. During Stage 2, the LLM is unfrozen. This means it might "forget" some of its pretraining knowledge (catastrophic forgetting). How does the lower learning rate in Stage 2 help mitigate this?

### Optional Challenges

1. **Skip Stage 1:** Modify the code to train directly with Stage 2. Compare final accuracy with the two-stage approach. Which works better?

2. **Data augmentation:** Add color jitter and random rotations to the training images. Does the model become more robust?

3. **Multi-turn conversations:** Extend the dataset to include follow-up questions (e.g., "What color?" -> "red" -> "And what shape?" -> "circle"). Modify the training to handle multi-turn dialogues.