# Exercise 4: Transformers on Images + GLU-MLP Ablations (ViT × GLU Variants)

## In this exercise you will combine two influential ideas:

Vision Transformers (ViT) from “An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale” (Dosovitskiy et al., 2020) https://arxiv.org/pdf/2010.11929:
ViT shows that you can treat an image like a sequence of tokens by splitting it into non-overlapping patches (e.g. 16×16 in the paper), embedding each patch into a vector, adding positional information, and then applying standard Transformer blocks for classification.

Gated MLPs (GLU variants) from “GLU Variants Improve Transformer” (Shazeer, 2020) https://arxiv.org/pdf/2002.05202:
Shazeer proposes replacing the standard Transformer feed-forward layer (FFN/MLP) with gated linear unit (GLU) variants such as GEGLU and SwiGLU, which often improves training dynamics and final performance under comparable compute/parameter budgets.

## What you will do

You will implement a tiny ViT-style classifier for MNIST, then run a controlled ablation where you replace the MLP inside each Transformer block:

Baseline FFN (GELU):
Linear(d_model → d_ff) → GELU → Linear(d_ff → d_model)

GLU-family MLPs (choose at least two and justify):

GEGLU, SwiGLU, other activation functions

Your goal is to evaluate whether these GLU variants change:

- convergence speed (loss vs steps),

- final test accuracy,

- and/or stability across runs.

## Key ViT concepts you will implement

- To convert MNIST images into Transformer tokens, you will:
  Patchify each 28×28 image into non-overlapping P×P patches.
  If P=4, then you get a 7×7 patch grid → 49 tokens per image.

- Embed patches with a linear layer: patch vectors → d_model.

- Add positional embeddings so the model knows where each patch came from.

- Apply n_layers Transformer encoder blocks.

- Pool token features (e.g., mean pooling) and project to 10 classes.

## Key GLU concept you will implement

GLU-style MLPs replace a standard FFN with a gating mechanism:
compute two projections a and b, apply a nonlinearity to a (variant-dependent), multiply elementwise: act(a) * b, project back to d_model.
To keep the comparison fair, use the 2/3 width rule from Shazeer.

What we provide vs what you implement

### We provide:

- MNIST loading + dataloaders

- a minimal training loop structure (AdamW)

- a suggested small model configuration that runs on CPU

### You implement:

- patch tokenization (patchify)

- patch embedding + positional embedding strategy

- a pre-LN Transformer encoder block using nn.MultiheadAttention

- at least two GLU MLP variants + one FFN baseline

- metric logging sufficient to support your conclusion

## Deliverables

Run at least 3 variants (baseline + the activation functions you choose for GLU) and report:

- final and best test accuracy

- number of trainable parameters

- a plot or printed summary of loss/accuracy over epochs

- a short discussion of your results

In [1]:
from __future__ import annotations

import math
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [2]:
def patchify(x: torch.Tensor, patch_size: int) -> torch.Tensor:
    """Convert images to patch tokens."""
    # TODO: Implement a tokenization strategy
    B,C,H,W = x.shape
    G1 = W // patch_size
    G2 = H // patch_size
    x = x.view(B, C, G1, patch_size, G2, patch_size)
    x = x.permute(0, 2, 4, 3, 5, 1)
    x =  x.flatten(1, 2)
    x = x.flatten(2)
    return x

In [3]:
# TODO: Add positional encoding as done in the ViT paper and patch projection
class PatchEmbed(nn.Module):
    def __init__(self, patch_dim: int, d_model: int):
        super().__init__()
        # TODO: implement
        self.weight = nn.Parameter(torch.empty(d_model, patch_dim))
        self.bias = nn.Parameter(torch.empty(d_model))
        with torch.no_grad():
            fan_in, fan_out = d_model, patch_dim
            bound = (6 / (fan_in + fan_out))**0.5
            self.weight.uniform_(-bound, bound)
            self.bias.zero_()

    def forward(self, x_patches: torch.Tensor) -> torch.Tensor:
        # TODO: implement
        return torch.matmul(x_patches, self.weight.T) + self.bias


class PositionalEmbedding(nn.Module):
    def __init__(self, num_tokens: int, d_model: int):
        super().__init__()
        # TODO: implement
        self.weight = nn.Parameter(torch.empty(1, num_tokens, d_model))
        with torch.no_grad():
            bound = (6 / (d_model + d_model)**0.5)
            self.weight.uniform_(-bound, bound)


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # TODO: implement
        return x + self.weight

In [37]:
# TODO: Define the variants you want to compare against each other from the GLU paper. Justify your choice.
class FeedForward(nn.Module):
    """
    Standard Transformer FFN:
      x -> Linear(d_model->d_ff) -> GELU -> Dropout -> Linear(d_ff->d_model) -> Dropout
    """
    def __init__(self, d_model: int, d_ff: int, dropout: float):
        super().__init__()
        # TODO: implement
        self.net = nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_ff, d_model), nn.Dropout(dropout))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # TODO: implement
        return self.net(x)


class GLUFeedForward(nn.Module):
    """GLU-family FFN"""
    def __init__(self, d_model: int, d_ff_gated: int, dropout: float, variant: str):
        super().__init__()
        # TODO: implement
        self.d_ff_glu = int(d_ff_gated * 2 / 3) #using 2/3 rule

        self.w1 = nn.Linear(d_model, self.d_ff_glu*2)
        self.w2 = nn.Linear(self.d_ff_glu, d_model)

        self.dropout = nn.Dropout(dropout)
        self.variant = variant

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # TODO: implement
        combined_projection = self.w1(x)

        gate, content = torch.chunk(combined_projection, 2, dim=-1)

        if self.variant == "GEGLU":
            gate = F.gelu(gate)
        elif self.variant == "SwiGLU":
            gate = F.silu(gate)
        elif self.variant == "ReGLU":
          gate = F.relu(gate)

        x = gate * content # element-wise multiplication

        return self.dropout(self.w2(x))

In [26]:
class TransformerEncoderBlock(nn.Module):
    """
    Pre-LN encoder block:
      x = x + Dropout(SelfAttn(LN(x)))
      x = x + Dropout(MLP(LN(x)))
    """
    def __init__(self, d_model: int, n_heads: int, mlp: nn.Module, dropout: float):
        super().__init__()
        # TODO: implement. For attention use nn.MultiHeadAttention
        self.ln1 = nn.LayerNorm(d_model)
        self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        self.ln2 = nn.LayerNorm(d_model)
        self.mlp = mlp

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # TODO: implement
        norm_x = self.ln1(x)
        attn_out, _ = self.self_attn(norm_x, norm_x, norm_x)
        x = x + self.dropout(attn_out)
        x = x + self.dropout(self.mlp(self.ln2(x)))
        return x

In [27]:
class TinyViT(nn.Module):
    """
    Tiny ViT-style classifier for MNIST.
    - patchify -> patch embed -> pos embed -> blocks -> mean pool -> head
    """
    def __init__(
        self,
        patch_size: int,
        d_model: int,
        n_heads: int,
        n_layers: int,
        d_ff: int,
        dropout: float,
        mlp_kind: str,
    ):
        super().__init__()
        assert 28 % patch_size == 0
        grid = 28 // patch_size
        self.num_tokens = grid * grid
        self.patch_size = patch_size
        patch_dim = patch_size * patch_size

        # TODO: implement a strategy for embedding the patches
        self.patch_embed = PatchEmbed(patch_dim, d_model)
        self.pos_embed = PositionalEmbedding(self.num_tokens, d_model)

        # TODO: implement a strategy to select the right mlp version for your experiment

        self.blocks = nn.ModuleList([
            TransformerEncoderBlock(
                d_model=d_model,
                n_heads=n_heads,
                mlp=(FeedForward(d_model, d_ff, dropout) if mlp_kind == "baseline" else GLUFeedForward(d_model, d_ff, dropout, mlp_kind)), # TODO: Feed your mlp to the encoder blocks
                dropout=dropout,
            )
            for _ in range(n_layers)
        ])

        # TODO: Add a head to project to the amount of output classes you have
        self.ln_final = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # TODO: Implement
        x = patchify(x, self.patch_size)

        x = self.patch_embed(x)
        x = self.pos_embed(x)

        for block in self.blocks:
            x = block(x)

        x = self.ln_final(x)
        x = x.mean(dim=1)
        logits = self.head(x)
        return logits

In [29]:
@dataclass(frozen=True)
class TrainConfig:
    seed: int = 0
    batch_size: int = 128
    epochs: int = 3
    lr: float = 3e-4
    weight_decay: float = 0.01
    device: str = "cuda" if torch.cuda.is_available() else "cpu"  # set "cuda" if available

In [38]:
def train_one_run(
    mlp_kind: str,
    model: nn.Module,
    train_loader: DataLoader,
    test_loader: DataLoader,
    cfg: TrainConfig,
) -> dict:
    model.to(cfg.device)
    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

    train_losses: list[float] = []
    test_accs: list[float] = []

    for epoch in range(cfg.epochs):

        # Train loop
        model.train()
        epoch_loss = 0.0
        for i, (xb, yb) in enumerate(train_loader):
            xb = xb.to(cfg.device)
            yb = yb.to(cfg.device)

            logits = model(xb)
            criterion = nn.CrossEntropyLoss()
            loss = criterion(logits, yb) # TODO: Your criterion

            opt.zero_grad()
            loss.backward()
            opt.step()
            epoch_loss += loss.item()

        avg_training_loss = epoch_loss / len(train_loader)
        train_losses.append(avg_training_loss)

        # Evaluation loop NOTE: Should be no need to change this
        model.eval()
        correct = 0.0
        total = 0.0
        with torch.no_grad():
            for xb, yb in test_loader:
                xb = xb.to(cfg.device)
                yb = yb.to(cfg.device)
                logits = model(xb)
                correct += (logits.argmax(dim=-1) == yb).float().sum().item()
                total += yb.numel()

        test_accs.append(correct / total)
        print(f"[{mlp_kind}] epoch {epoch+1}/{cfg.epochs} | test acc: {test_accs[-1]:.4f}")

        num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return {
            "mlp_kind": mlp_kind,
            "train_losses": train_losses,
            "test_accs": test_accs,
            "final_acc": test_accs[-1],
            "best_acc": max(test_accs),
            "num_params": num_params
        }

In [40]:
cfg = TrainConfig(seed=0, batch_size=128, epochs=5, lr=3e-4, weight_decay=0.01, device="cuda" if torch.cuda.is_available() else "cpu")

tfm = transforms.Compose([transforms.ToTensor()])

train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=tfm)
test_ds = datasets.MNIST(root="./data", train=False, download=True, transform=tfm)

train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, num_workers=0)
test_loader = DataLoader(test_ds, batch_size=cfg.batch_size, shuffle=False, num_workers=0)

# Tiny model example. TODO: You're welcome to experiment with these parameters
patch_size = 4
d_model = 64
n_heads = 4
n_layers = 2
d_ff = 256
dropout = 0.1

runs = ["baseline", "Bilinear", "GEGLU", "SwiGLU"] # TODO: Name your runs
results = []

for kind in runs:
    model = TinyViT(
        patch_size=patch_size,
        d_model=d_model,
        n_heads=n_heads,
        n_layers=n_layers,
        d_ff=d_ff,
        dropout=dropout,
        mlp_kind=kind,
    )
    # TODO: print anything you might want here
    print(f"\nRun: {kind} | " )
    out = train_one_run(kind, model, train_loader, test_loader, cfg)
    results.append(out)


Run: baseline | 
[baseline] epoch 1/5 | test acc: 0.8905
[baseline] epoch 2/5 | test acc: 0.9363
[baseline] epoch 3/5 | test acc: 0.9532
[baseline] epoch 4/5 | test acc: 0.9574
[baseline] epoch 5/5 | test acc: 0.9584

Run: Bilinear | 
[Bilinear] epoch 1/5 | test acc: 0.9238
[Bilinear] epoch 2/5 | test acc: 0.9529
[Bilinear] epoch 3/5 | test acc: 0.9577
[Bilinear] epoch 4/5 | test acc: 0.9672
[Bilinear] epoch 5/5 | test acc: 0.9702

Run: GEGLU | 
[GEGLU] epoch 1/5 | test acc: 0.9251
[GEGLU] epoch 2/5 | test acc: 0.9546
[GEGLU] epoch 3/5 | test acc: 0.9646
[GEGLU] epoch 4/5 | test acc: 0.9698
[GEGLU] epoch 5/5 | test acc: 0.9726

Run: SwiGLU | 
[SwiGLU] epoch 1/5 | test acc: 0.9225
[SwiGLU] epoch 2/5 | test acc: 0.9527
[SwiGLU] epoch 3/5 | test acc: 0.9656
[SwiGLU] epoch 4/5 | test acc: 0.9687
[SwiGLU] epoch 5/5 | test acc: 0.9735
