# Exercise 4: Transformers on Images + GLU-MLP Ablations (ViT × GLU Variants)

## In this exercise you will combine two influential ideas:

Vision Transformers (ViT) from “An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale” (Dosovitskiy et al., 2020) https://arxiv.org/pdf/2010.11929 :
ViT shows that you can treat an image like a sequence of tokens by splitting it into non-overlapping patches (e.g. 16×16 in the paper), embedding each patch into a vector, adding positional information, and then applying standard Transformer blocks for classification.

Gated MLPs (GLU variants) from “GLU Variants Improve Transformer” (Shazeer, 2020) https://arxiv.org/pdf/2002.05202 :
Shazeer proposes replacing the standard Transformer feed-forward layer (FFN/MLP) with gated linear unit (GLU) variants such as GEGLU and SwiGLU, which often improves training dynamics and final performance under comparable compute/parameter budgets.

## What you will do

You will implement a tiny ViT-style classifier for MNIST, then run a controlled ablation where you replace the MLP inside each Transformer block:

Baseline FFN (GELU):
Linear(d_model → d_ff) → GELU → Linear(d_ff → d_model)

GLU-family MLPs (choose at least two and justify):

GEGLU, SwiGLU, other activation functions

Your goal is to evaluate whether these GLU variants change:

- convergence speed (loss vs steps),

- final test accuracy,

- and/or stability across runs.

## Key ViT concepts you will implement

- To convert MNIST images into Transformer tokens, you will:
  Patchify each 28×28 image into non-overlapping P×P patches.
  If P=4, then you get a 7×7 patch grid → 49 tokens per image.

- Embed patches with a linear layer: patch vectors → d_model.

- Add positional embeddings so the model knows where each patch came from.

- Apply n_layers Transformer encoder blocks.

- Pool token features (e.g., mean pooling) and project to 10 classes.

## Key GLU concept you will implement

GLU-style MLPs replace a standard FFN with a gating mechanism:
compute two projections a and b, apply a nonlinearity to a (variant-dependent), multiply elementwise: act(a) * b, project back to d_model.
To keep the comparison fair, use the 2/3 width rule from Shazeer.

What we provide vs what you implement

### We provide:

- MNIST loading + dataloaders

- a minimal training loop structure (AdamW)

- a suggested small model configuration that runs on CPU

### You implement:

- patch tokenization (patchify)

- patch embedding + positional embedding strategy

- a pre-LN Transformer encoder block using nn.MultiheadAttention

- at least two GLU MLP variants + one FFN baseline

- metric logging sufficient to support your conclusion

## Deliverables

Run at least 3 variants (baseline + the activation functions you choose for GLU) and report:

- final and best test accuracy

- number of trainable parameters

- a plot or printed summary of loss/accuracy over epochs

- a short discussion of your results

In [1]:
from __future__ import annotations

import math
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [10]:
def patchify(x: torch.Tensor, patch_size: int) -> torch.Tensor:
    """Convert images to patch tokens."""
    B, C, H, W = x.shape
    assert H % patch_size == 0 and W % patch_size == 0, "Image dimensions must be divisible by patch size."
    x = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
    x = x.contiguous().view(B, C, -1, patch_size * patch_size)
    x = x.permute(0, 2, 1, 3).contiguous().view(B, -1, C * patch_size * patch_size)
    return x

#check works
if __name__ == "__main__":
    # Example usage
    batch_size = 1
    channels = 1
    height = 28
    width = 28
    patch_size = 4

    # Create a random tensor simulating a batch of images
    images = torch.randn(batch_size, channels, height, width)

    # Patchify the images
    patches = patchify(images, patch_size)

    print(f"Original shape: {images.shape}")
    print(f"Patchified shape: {patches.shape}")
    # Expected output shape: (batch_size, num_patches, patch_dim)
    expected_num_patches = (height // patch_size) * (width // patch_size)
    expected_patch_dim = channels * patch_size * patch_size
    assert patches.shape == (batch_size, expected_num_patches, expected_patch_dim), "Patchified shape is incorrect."


Original shape: torch.Size([1, 1, 28, 28])
Patchified shape: torch.Size([1, 49, 16])


In [11]:
# TODO: Add positional encoding as done in the ViT paper and patch projection
class PatchEmbed(nn.Module):
    def __init__(self, patch_dim: int, d_model: int):
        super().__init__()
        self.proj = nn.Linear(patch_dim, d_model)


    def forward(self, x_patches: torch.Tensor) -> torch.Tensor:
        return self.proj(x_patches)


class PositionalEmbedding(nn.Module):
    def __init__(self, num_tokens: int, d_model: int):
        super().__init__()
        self.pos_embedding = nn.Parameter(torch.empty(1, num_tokens, d_model))
        nn.init.xavier_uniform_(self.pos_embedding)


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.pos_embedding



In [16]:
# TODO: Define the variants you want to compare against each other from the GLU paper. Justify your choice.
class FeedForward(nn.Module):
    """
    Standard Transformer FFN:
      x -> Linear(d_model->d_ff) -> GELU -> Dropout -> Linear(d_ff->d_model) -> Dropout
    """
    def __init__(self, d_model: int, d_ff: int, dropout: float):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.activation = nn.GELU()
        self.dropout1 = nn.Dropout(dropout)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fc1(x)
        x = self.activation(x)
        x = self.dropout1(x)
        x = self.fc2(x)
        x = self.dropout2(x)
        return x


class GLUFeedForward(nn.Module):
    """GLU-family FFN"""
    def __init__(self, d_model: int, d_ff_gated: int, dropout: float, variant: str = 'swiglu'):
        super().__init__()
        self.variant = variant.lower()

        # We project the input into two vectors (W and V) simultaneously.
        # This is equivalent to having two linear layers but more efficient.
        # Shape: [d_model, 2 * d_ff_gated]
        self.w_gate_linear = nn.Linear(d_model, d_ff_gated * 2)

        # The output projection layer (W2)
        self.w_out = nn.Linear(d_ff_gated, d_model)

        self.dropout = nn.Dropout(dropout)

        if self.variant not in ['geglu', 'swiglu']:
            raise ValueError("Variant must be 'geglu' or 'swiglu'")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 1. Project to 2 * d_ff_gated
        projected = self.w_gate_linear(x)

        # 2. Split into two halves: xW (gate) and xV (linear)
        x_gate, x_linear = projected.chunk(2, dim=-1)

        # 3. Apply activation to the gate component
        if self.variant == 'geglu':
            x_gate = F.gelu(x_gate)
        elif self.variant == 'swiglu':
            # Swish with beta=1 is equivalent to SiLU
            x_gate = F.silu(x_gate)

        # 4. Element-wise multiplication (The "Gating" mechanism)
        x = x_gate * x_linear

        # 5. Dropout
        x = self.dropout(x)

        # 6. Output projection
        x = self.w_out(x)

        return x

*Justification for chosen GLU variants:*

According to the paper's experimental results (Tables 1, 2, and 3), GEGLU and SwiGLU produced the lowest (best) log-perplexity results during pre-training on the C4 dataset compared to other variants like ReGLU or Bilinear.

SwiGLU has since become the standard for state-of-the-art LLMs (including LLaMA, PaLM, and Mistral), making it the most practically relevant variant to implement. GEGLU is the direct "GLU-fied" evolution of the standard GELU activation used in BERT/GPT-2. So both variants seemed interesting to implement and compare against the baseline FFN.



In [14]:
class TransformerEncoderBlock(nn.Module):
    """
    Pre-LN encoder block:
      x = x + Dropout(SelfAttn(LN(x)))
      x = x + Dropout(MLP(LN(x)))
    """
    def __init__(self, d_model: int, n_heads: int, mlp: nn.Module, dropout: float):
        super().__init__()
        # Pre-layer norms
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)

        # Multi-head self-attention (batch_first so x is (B, T, D))
        self.self_attn = nn.MultiheadAttention(
            embed_dim=d_model,
            num_heads=n_heads,
            dropout=dropout,
            batch_first=True,
        )

        # Feed-forward / MLP block (standard FFN or GLUFFN, passed in)
        self.mlp = mlp

        # Dropouts on residual branches
        self.dropout_attn = nn.Dropout(dropout)
        self.dropout_mlp = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, T, D)
        return: (B, T, D)
        """
        # Self-attention sub-layer
        x_norm = self.ln1(x)
        attn_out, _ = self.self_attn(x_norm, x_norm, x_norm)
        x = x + self.dropout_attn(attn_out)

        # MLP sub-layer
        x_norm = self.ln2(x)
        mlp_out = self.mlp(x_norm)
        x = x + self.dropout_mlp(mlp_out)

        return x

In [None]:
class TinyViT(nn.Module):
    """
    Tiny ViT-style classifier for MNIST.
    - patchify -> patch embed -> pos embed -> blocks -> mean pool -> head
    """
    def __init__(
        self,
        patch_size: int,
        d_model: int,
        n_heads: int,
        n_layers: int,
        d_ff: int,
        dropout: float,
        mlp_kind: str,
    ):
        super().__init__()
        assert 28 % patch_size == 0
        grid = 28 // patch_size
        self.num_tokens = grid * grid
        self.patch_size = patch_size
        patch_dim = patch_size * patch_size

        # TODO: implement a strategy for embedding the patches
        self.patch_embed = PatchEmbed(patch_dim=patch_dim, d_model=d_model)
        self.pos_embed = PositionalEmbedding(num_tokens=self.num_tokens, d_model=d_model)

        self.blocks = nn.ModuleList([])


        for _ in range(n_layers):
            # 1. Create a FRESH instance of the MLP for this specific block
            if mlp_kind == "standard":
                mlp_layer = FeedForward(d_model=d_model, d_ff=d_ff, dropout=dropout)
            elif mlp_kind == "geglu":
                # Optional: Adjust d_ff for fairness (approx 2/3)
                d_ff_gated = int(d_ff * 2/3)
                # d_ff_gated = d_ff alternative (NO FAIRNESS)
                mlp_layer = GLUFeedForward(d_model=d_model, d_ff_gated=d_ff_gated, dropout=dropout, variant="geglu")
            elif mlp_kind == "swiglu":
                # Optional: Adjust d_ff for fairness (approx 2/3)
                d_ff_gated = int(d_ff * 2/3)
                # d_ff_gated = d_ff alternative (NO FAIRNESS)
                mlp_layer = GLUFeedForward(d_model=d_model, d_ff_gated=d_ff_gated, dropout=dropout, variant="swiglu")
            else:
                raise ValueError(f"Unknown mlp_kind: {mlp_kind}")

            # 2. Pass the fresh instance to the block
            block = TransformerEncoderBlock(
                d_model=d_model,
                n_heads=n_heads,
                mlp=mlp_layer,
                dropout=dropout,
            )
            self.blocks.append(block)

        self.head = nn.Linear(d_model, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size = x.shape[0]

        # --- Patchify Logic ---

        # (B, 1, 28, 28) -> (B, 1, grid, patch, grid, patch)
        x = x.view(batch_size, 1, 28 // self.patch_size, self.patch_size, 28 // self.patch_size, self.patch_size)
        # Permute to (B, 1, grid, grid, patch, patch) then collapse
        x = x.permute(0, 2, 4, 1, 3, 5).contiguous()
        x = x.view(batch_size, self.num_tokens, -1) # (B, N, patch_dim)

        # Embeddings
        x = self.patch_embed(x)
        x = self.pos_embed(x)

        # Blocks
        for block in self.blocks:
            x = block(x)

        # Head
        x = x.mean(dim=1) # Global Average Pooling
        logits = self.head(x)

        return logits

In [20]:
@dataclass(frozen=True)
class TrainConfig:
    seed: int = 0
    batch_size: int = 128
    epochs: int = 3
    lr: float = 3e-4
    weight_decay: float = 0.01
    device: str = "cuda"  # set "cuda" if available

In [21]:
def train_one_run(
    mlp_kind: str,
    model: nn.Module,
    train_loader: DataLoader,
    test_loader: DataLoader,
    cfg: TrainConfig,
) -> dict:
    model.to(cfg.device)
    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

    # Standard classification loss
    criterion = nn.CrossEntropyLoss() #could use ex3's crossentropy_from_logits too

    train_losses: list[float] = []
    test_accs: list[float] = []

    for epoch in range(cfg.epochs):

        # Train loop
        model.train()
        for i, (xb, yb) in enumerate(train_loader):
            xb = xb.to(cfg.device)
            yb = yb.to(cfg.device)

            logits = model(xb)
            loss = criterion(logits, yb)

            opt.zero_grad()
            loss.backward()
            opt.step()

            train_losses.append(loss.item())

        # Evaluation loop NOTE: Should be no need to change this
        model.eval()
        correct = 0.0
        total = 0.0
        with torch.no_grad():
            for xb, yb in test_loader:
                xb = xb.to(cfg.device)
                yb = yb.to(cfg.device)
                logits = model(xb)
                correct += (logits.argmax(dim=-1) == yb).float().sum().item()
                total += yb.numel()

        acc = correct / total
        test_accs.append(acc)
        print(f"[{mlp_kind}] epoch {epoch+1}/{cfg.epochs} | test acc: {test_accs[-1]:.4f}")

    return {
        # TODO: Return your metrics
        "kind": mlp_kind,
        "train_losses": train_losses,  # Useful for plotting convergence speed
        "test_accs": test_accs,        # Useful for plotting stability/overfitting
        "final_acc": test_accs[-1],    # The final result
        "best_acc": max(test_accs)     # The peak performance
    }

In [None]:
cfg = TrainConfig(seed=0, batch_size=128, epochs=5, lr=3e-4, weight_decay=0.01, device="cpu")

tfm = transforms.Compose([transforms.ToTensor()])

train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=tfm)
test_ds = datasets.MNIST(root="./data", train=False, download=True, transform=tfm)

train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, num_workers=0)
test_loader = DataLoader(test_ds, batch_size=cfg.batch_size, shuffle=False, num_workers=0)

# Tiny model example. TODO: You're welcome to experiment with these parameters
patch_size = 4
d_model = 64
n_heads = 4
n_layers = 2
d_ff = 256
dropout = 0.1

# comparing Standard baseline vs the two best performing GLU variants
runs = ["standard", "geglu", "swiglu"]
results = []

print(f"Device: {cfg.device}")

for kind in runs:
    model = TinyViT(
        patch_size=patch_size,
        d_model=d_model,
        n_heads=n_heads,
        n_layers=n_layers,
        d_ff=d_ff,
        dropout=dropout,
        mlp_kind=kind,
    )

    # Calculate number of parameters to ensure fair comparison
    # (GLU variants should have d_ff reduced by ~2/3 to match Standard params)
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    # TODO: print anything you might want here
    print(f"Run: {kind.upper():<10} | Params: {num_params:,}")
    # Train
    out = train_one_run(kind, model, train_loader, test_loader, cfg)
    results.append(out)
    print("-" * 60)
# Simple summary table at the end
print("\nFinal Results Summary:")
print(f"{'Variant':<15} | {'Final Acc':<10} | {'Best Acc':<10}")
for res in results:
    print(f"{res['kind']:<15} | {res['final_acc']:.4f}     | {res['best_acc']:.4f}")

Device: cpu
Run: STANDARD   | Params: 104,842
[standard] epoch 1/5 | test acc: 0.8039
