# BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding

# https://arxiv.org/pdf/1810.04805


In [1]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Mini-BERT: A Didactic Replication of *BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding* (2019)
==============================================================================================================

One-line summary:
A compact, classroom-ready Transformer **encoder-only** classifier with a [CLS] token, LayerNorm+Residual blocks,
and GELU MLPs—mirroring the *BERT* encoder stack—adapted to **CIFAR-10** by using **patch embeddings** (like ViT)
to turn images into token sequences.

Teaching notes (mapping BERT → this demo):
- BERT core we replicate: multi-layer bidirectional **Transformer encoder** with self-attention, residuals, LayerNorm, and GELU MLP.
- [CLS] classification token retained; its final hidden state feeds a linear classifier head.
- MLM/NSP pretraining is **omitted** for simplicity; we do **supervised classification** on CIFAR-10.
- Images → tokens via **patch embeddings** (Conv2d with kernel=stride=patch_size). This mirrors text-token embeddings in BERT.
- Position embeddings are **learnable** (as with BERT).
- Name: **MiniBERT** (a didactic stand-in for BERT’s encoder on an image task).

Run budget (default):
    epochs=5, batch_size=64, lr=1e-3
"""

import math
import os
import random
from dataclasses import dataclass
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms, utils as vutils
import matplotlib.pyplot as plt
import numpy as np

# -------------------------
# Reproducibility utilities
# -------------------------
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False  # allow fast kernels
    torch.backends.cudnn.benchmark = True

# -------------------------
# Config (didactic defaults)
# -------------------------
@dataclass
class Config:
    # Data / training
    data_root: str = "./data"
    num_classes: int = 10
    epochs: int = 5
    batch_size: int = 64
    lr: float = 1e-3
    weight_decay: float = 0.05
    num_workers: int = 2

    # Image / patches
    img_size: int = 32
    in_chans: int = 3
    patch_size: int = 4  # 32/4 → 8x8 = 64 tokens
    drop_rate: float = 0.1

    # "BERT-style" encoder (mini)
    embed_dim: int = 128
    depth: int = 4
    num_heads: int = 4
    mlp_ratio: float = 4.0
    attn_drop: float = 0.0
    proj_drop: float = 0.0

    # Logging / outputs
    out_dir: str = "./mini_bert_outputs"
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    amp: bool = torch.cuda.is_available()  # mixed precision if CUDA

# --------------------------------
# Dataset & preprocessing pipeline
# --------------------------------
def get_cifar10_loaders(cfg: Config) -> Tuple[DataLoader, DataLoader]:
    # CIFAR-10 normalization constants
    mean = (0.4914, 0.4822, 0.4465)
    std = (0.2470, 0.2435, 0.2616)
    train_tfms = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(cfg.img_size, padding=4),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    test_tfms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    train_set = datasets.CIFAR10(root=cfg.data_root, train=True, transform=train_tfms, download=True)
    test_set  = datasets.CIFAR10(root=cfg.data_root, train=False, transform=test_tfms, download=True)

    train_loader = DataLoader(train_set, batch_size=cfg.batch_size, shuffle=True,
                              num_workers=cfg.num_workers, pin_memory=True)
    test_loader  = DataLoader(test_set, batch_size=cfg.batch_size, shuffle=False,
                              num_workers=cfg.num_workers, pin_memory=True)
    return train_loader, test_loader

# --------------------------
# Model: Patch Embeddings
# --------------------------
class PatchEmbed(nn.Module):
    """
    Convert image to a sequence of patch embeddings.
    """
    def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=128):
        super().__init__()
        assert img_size % patch_size == 0, "img_size must be divisible by patch_size"
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = img_size // patch_size
        self.num_patches = self.grid_size * self.grid_size

        # Conv with kernel=stride=patch_size → non-overlapping patches
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # x: (B, C, H, W) → (B, embed_dim, H/ps, W/ps) → (B, N, embed_dim)
        x = self.proj(x)  # (B, D, Gh, Gw)
        x = x.flatten(2).transpose(1, 2)  # (B, N, D)
        return x

# --------------------------
# Model: BERT-style Encoder
# --------------------------
class MLP(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, drop=0.0):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()  # BERT uses GELU
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class TransformerEncoderBlock(nn.Module):
    """
    Mirrors a BERT encoder layer:
      - LayerNorm -> Multi-Head Self-Attention -> Residual
      - LayerNorm -> MLP (GELU) -> Residual
    """
    def __init__(self, embed_dim=128, num_heads=4, mlp_ratio=4.0, attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=attn_drop, batch_first=True)
        self.drop_path1 = nn.Dropout(proj_drop)

        self.norm2 = nn.LayerNorm(embed_dim)
        hidden = int(embed_dim * mlp_ratio)
        self.mlp = MLP(embed_dim, hidden, embed_dim, drop=proj_drop)
        self.drop_path2 = nn.Dropout(proj_drop)

    def forward(self, x):
        # Self-attention (bidirectional encoder)
        x_res = x
        x = self.norm1(x)
        x_attn, _ = self.attn(x, x, x, need_weights=False)
        x = x_res + self.drop_path1(x_attn)

        # MLP
        x_res = x
        x = self.norm2(x)
        x_mlp = self.mlp(x)
        x = x_res + self.drop_path2(x_mlp)
        return x

class MiniBERT(nn.Module):
    """
    MiniBERT Encoder for CIFAR-10:
      - Patch embeddings (image → tokens)
      - Learnable [CLS] token
      - Learnable position embeddings
      - Stack of BERT-style encoder blocks
      - Classification head on the [CLS] representation
    """
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.patch_embed = PatchEmbed(cfg.img_size, cfg.patch_size, cfg.in_chans, cfg.embed_dim)
        num_patches = self.patch_embed.num_patches
        self.cls_token = nn.Parameter(torch.zeros(1, 1, cfg.embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, 1 + num_patches, cfg.embed_dim))
        self.pos_drop = nn.Dropout(cfg.drop_rate)

        self.blocks = nn.ModuleList([
            TransformerEncoderBlock(
                embed_dim=cfg.embed_dim,
                num_heads=cfg.num_heads,
                mlp_ratio=cfg.mlp_ratio,
                attn_drop=cfg.attn_drop,
                proj_drop=cfg.proj_drop
            ) for _ in range(cfg.depth)
        ])
        self.norm = nn.LayerNorm(cfg.embed_dim)
        self.head = nn.Linear(cfg.embed_dim, cfg.num_classes)

        self._init_weights()

    def _init_weights(self):
        # BERT/ViT-style init
        nn.init.normal_(self.cls_token, std=0.02)
        nn.init.normal_(self.pos_embed, std=0.02)
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out")
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        B = x.size(0)
        x = self.patch_embed(x)  # (B, N, D)

        # prepend [CLS]
        cls_tokens = self.cls_token.expand(B, -1, -1)  # (B, 1, D)
        x = torch.cat([cls_tokens, x], dim=1)          # (B, 1+N, D)

        # add position embeddings
        x = x + self.pos_embed[:, :x.size(1), :]
        x = self.pos_drop(x)

        # encoder stack
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)

        # classification on [CLS]
        cls_out = x[:, 0]  # (B, D)
        logits = self.head(cls_out)
        return logits

# --------------------------
# Training / Evaluation
# --------------------------
def accuracy(logits, targets):
    preds = logits.argmax(dim=1)
    return (preds == targets).float().mean().item()

def train_one_epoch(model, loader, optimizer, scaler, device, epoch, cfg: Config):
    model.train()
    total_loss, total_acc, total_count = 0.0, 0.0, 0
    criterion = nn.CrossEntropyLoss()
    for images, targets in loader:
        images, targets = images.to(device, non_blocking=True), targets.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        if cfg.amp:
            with torch.cuda.amp.autocast():
                logits = model(images)
                loss = criterion(logits, targets)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            logits = model(images)
            loss = criterion(logits, targets)
            loss.backward()
            optimizer.step()

        bs = images.size(0)
        total_loss += loss.item() * bs
        total_acc  += accuracy(logits.detach(), targets) * bs
        total_count += bs
    return total_loss / total_count, total_acc / total_count

@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    total_loss, total_acc, total_count = 0.0, 0.0, 0
    criterion = nn.CrossEntropyLoss()
    for images, targets in loader:
        images, targets = images.to(device, non_blocking=True), targets.to(device, non_blocking=True)
        logits = model(images)
        loss = criterion(logits, targets)
        bs = images.size(0)
        total_loss += loss.item() * bs
        total_acc  += accuracy(logits, targets) * bs
        total_count += bs
    return total_loss / total_count, total_acc / total_count

# --------------------------
# Utilities: parameter count
# --------------------------
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# --------------------------
# Visualization
# --------------------------
def plot_curves(train_losses, val_losses, train_accs, val_accs, out_path):
    plt.figure(figsize=(7,5))
    plt.plot(train_losses, label="Train Loss")
    plt.plot(val_losses, label="Val Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("MiniBERT Training/Validation Loss")
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(out_path, "loss_curves.png"))
    plt.close()

    plt.figure(figsize=(7,5))
    plt.plot(train_accs, label="Train Acc")
    plt.plot(val_accs, label="Val Acc")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("MiniBERT Training/Validation Accuracy")
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(out_path, "accuracy_curves.png"))
    plt.close()

@torch.no_grad()
def save_prediction_grid(model, loader, class_names, device, out_path, max_images=36):
    model.eval()
    images, targets = next(iter(loader))
    images = images[:max_images].to(device)
    logits = model(images)
    preds = logits.argmax(dim=1).cpu().tolist()

    # De-normalize for visualization
    mean = torch.tensor([0.4914, 0.4822, 0.4465], device=images.device).view(1,3,1,1)
    std  = torch.tensor([0.2470, 0.2435, 0.2616], device=images.device).view(1,3,1,1)
    imgs_vis = (images * std + mean).cpu().clamp(0,1)

    grid = vutils.make_grid(imgs_vis, nrow=int(math.sqrt(max_images)))
    plt.figure(figsize=(8,8))
    plt.imshow(np.transpose(grid.numpy(), (1,2,0)))
    plt.axis("off")
    title = "Predictions: " + ", ".join(class_names[p] for p in preds[:10]) + " ..."
    plt.title(title)
    plt.tight_layout()
    plt.savefig(os.path.join(out_path, "predictions_grid.png"))
    plt.close()

# --------------------------
# Main
# --------------------------
def main():
    set_seed(42)
    cfg = Config()
    os.makedirs(cfg.out_dir, exist_ok=True)
    print("==> Using device:", cfg.device)

    # Data
    train_loader, test_loader = get_cifar10_loaders(cfg)
    class_names = train_loader.dataset.classes
    print(f"Train batches: {len(train_loader)} | Test batches: {len(test_loader)}")

    # Model
    model = MiniBERT(cfg).to(cfg.device)
    print(model)
    print(f"Trainable params: {count_parameters(model)/1e6:.2f} M")

    # Optimizer & scaler
    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    scaler = torch.cuda.amp.GradScaler(enabled=cfg.amp)

    # Train / eval
    train_losses, val_losses = [], []
    train_accs, val_accs = [], []

    for epoch in range(1, cfg.epochs + 1):
        tr_loss, tr_acc = train_one_epoch(model, train_loader, optimizer, scaler, cfg.device, epoch, cfg)
        va_loss, va_acc = evaluate(model, test_loader, cfg.device)

        train_losses.append(tr_loss); val_losses.append(va_loss)
        train_accs.append(tr_acc);    val_accs.append(va_acc)

        print(f"Epoch {epoch:02d}/{cfg.epochs} | "
              f"Train Loss {tr_loss:.4f} Acc {tr_acc*100:5.2f}% | "
              f"Val Loss {va_loss:.4f} Acc {va_acc*100:5.2f}%")

    # Save artifacts
    torch.save({"model_state": model.state_dict(), "cfg": cfg.__dict__},
               os.path.join(cfg.out_dir, "mini_bert.ckpt"))
    plot_curves(train_losses, val_losses, train_accs, val_accs, cfg.out_dir)
    save_prediction_grid(model, test_loader, class_names, cfg.device, cfg.out_dir, max_images=36)
    print(f"Artifacts saved in: {os.path.abspath(cfg.out_dir)}")
    print("Done.")

if __name__ == "__main__":
    main()


==> Using device: cuda


100%|██████████| 170M/170M [00:13<00:00, 12.7MB/s]


Train batches: 782 | Test batches: 157
MiniBERT(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
  )
  (pos_drop): Dropout(p=0.1, inplace=False)
  (blocks): ModuleList(
    (0-3): 4 x TransformerEncoderBlock(
      (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
      (drop_path1): Dropout(p=0.0, inplace=False)
      (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=128, out_features=512, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=512, out_features=128, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path2): Dropout(p=0.0, inplace=False)
    )
  )
  (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (head): Linear(in_features=128, out_features=1

  scaler = torch.cuda.amp.GradScaler(enabled=cfg.amp)
  with torch.cuda.amp.autocast():


Epoch 01/5 | Train Loss 1.8477 Acc 30.09% | Val Loss 1.6361 Acc 39.69%
Epoch 02/5 | Train Loss 1.6253 Acc 39.62% | Val Loss 1.5917 Acc 41.78%
Epoch 03/5 | Train Loss 1.4942 Acc 45.15% | Val Loss 1.3975 Acc 48.59%
Epoch 04/5 | Train Loss 1.4238 Acc 47.85% | Val Loss 1.3866 Acc 49.57%
Epoch 05/5 | Train Loss 1.3747 Acc 49.96% | Val Loss 1.3424 Acc 51.58%
Artifacts saved in: /content/mini_bert_outputs
Done.


In [3]:
# ================================
# Visualization Block (Loss & Acc)
# ================================
import matplotlib.pyplot as plt
import os

def plot_all_curves(train_losses, val_losses, train_accs, val_accs, out_dir="./outputs"):
    os.makedirs(out_dir, exist_ok=True)

    # 1. Loss curves
    plt.figure(figsize=(8,6))
    plt.plot(train_losses, label="Train Loss", marker="o")
    plt.plot(val_losses, label="Validation Loss", marker="s")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("MiniBERT Training vs Validation Loss")
    plt.legend()
    plt.grid(True, linestyle="--", alpha=0.6)
    plt.tight_layout()
    plt.savefig(os.path.join(out_dir, "loss_curve.png"))
    plt.close()

    # 2. Accuracy curves
    plt.figure(figsize=(8,6))
    plt.plot(train_accs, label="Train Accuracy", marker="o")
    plt.plot(val_accs, label="Validation Accuracy", marker="s")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("MiniBERT Training vs Validation Accuracy")
    plt.legend()
    plt.grid(True, linestyle="--", alpha=0.6)
    plt.tight_layout()
    plt.savefig(os.path.join(out_dir, "accuracy_curve.png"))
    plt.close()

    # 3. Combined figure with 2 subplots
    fig, axs = plt.subplots(1, 2, figsize=(14,5))

    axs[0].plot(train_losses, label="Train Loss", marker="o")
    axs[0].plot(val_losses, label="Val Loss", marker="s")
    axs[0].set_title("Loss Curves")
    axs[0].set_xlabel("Epoch")
    axs[0].set_ylabel("Loss")
    axs[0].legend()
    axs[0].grid(True, linestyle="--", alpha=0.6)

    axs[1].plot(train_accs, label="Train Acc", marker="o")
    axs[1].plot(val_accs, label="Val Acc", marker="s")
    axs[1].set_title("Accuracy Curves")
    axs[1].set_xlabel("Epoch")
    axs[1].set_ylabel("Accuracy")
    axs[1].legend()
    axs[1].grid(True, linestyle="--", alpha=0.6)

    plt.suptitle("MiniBERT Training Metrics", fontsize=14, fontweight="bold")
    plt.tight_layout()
    plt.savefig(os.path.join(out_dir, "metrics_curves.png"))
    plt.close()

    print(f"📊 Metric curves saved in {out_dir}")
