# Implementation 

This notebook is structured in the following manner : 
- implementation of uSCION and SCION
- experiment A : performance comparison of SCION/uSCION vs Adam/SGD/Muon over Fashion-MNIST
- experiment B : hyperparameters transfer depending on the width on SVHN
- experiment C : norm control SCG vs uSCG 

In [14]:
import os, math, time, random
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as T

import matplotlib.pyplot as plt

print("torch:", torch.__version__)
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
device

torch: 2.10.0


device(type='mps')

In [24]:
def set_seed(seed: int = 0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

set_seed(0)

## SCION / uSCION recap (implementation choices)

We implement two variants:

- **uSCION (unconstrained)**:  x <- x + γ * LMO(d)
- **SCION (constrained / SCG-style)**: x <- (1-γ) x + γ * ρ * LMO(d)

where:
- d is a momentum-filtered stochastic gradient direction
- LMO(d) returns an extreme point of the unit ball of a chosen norm (we implement:
  - sign-LMO for vector-like parameters (bias, LayerNorm, etc.)
  - spectral/polar-LMO for 2D matrix parameters
)
- ρ is an optional radius (default 1.0)

Notes:
- This notebook uses a *practical* per-parameter rule: matrices -> spectral/polar LMO, others -> sign LMO.
- You can easily swap the LMO mapping if your paper uses a specific table of norms per layer.

In [25]:
@torch.no_grad()
def lmo_sign(g: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    # returns argmin_{||s||_inf<=1} <s,g> = -sign(g)
    return -torch.sign(g).clamp(min=-1.0, max=1.0)

@torch.no_grad()
def newton_schulz_polar_factor(G: torch.Tensor, iters: int = 5, eps: float = 1e-6) -> torch.Tensor:
    """
    Approximate the polar factor Q of matrix G:
        G = Q H, with Q orthogonal-ish.
    Uses Newton–Schulz iterations on a scaled matrix.
    Returns Q (same shape as G).

    Works best for reasonably-conditioned matrices; for small models/demos it's fine.
    """
    assert G.ndim == 2
    # scale by Fro norm to keep spectral radius in a good range
    fro = torch.linalg.norm(G, ord="fro")
    if fro < eps:
        return torch.zeros_like(G)

    X = G / (fro + eps)
    I = torch.eye(X.shape[0], device=X.device, dtype=X.dtype)

    # If matrix is not square, work with "tall" case via left-polar:
    # Q = X (X^T X)^(-1/2). Newton–Schulz typically assumes square;
    # we handle rectangular by iterating on the smaller Gram matrix.
    m, n = X.shape
    if m >= n:
        # iterate on A = X^T X (n x n)
        A = X.T @ X
        # normalize A
        A = A / (torch.linalg.norm(A, ord="fro") + eps)
        Y = A
        Z = torch.eye(n, device=X.device, dtype=X.dtype)
        for _ in range(iters):
            Tm = 0.5 * (3.0*torch.eye(n, device=X.device, dtype=X.dtype) - Z @ Y)
            Y = Y @ Tm
            Z = Tm @ Z
        inv_sqrt = Z
        Q = X @ inv_sqrt
    else:
        # wide: iterate on A = X X^T (m x m)
        A = X @ X.T
        A = A / (torch.linalg.norm(A, ord="fro") + eps)
        Y = A
        Z = torch.eye(m, device=X.device, dtype=X.dtype)
        for _ in range(iters):
            Tm = 0.5 * (3.0*torch.eye(m, device=X.device, dtype=X.dtype) - Z @ Y)
            Y = Y @ Tm
            Z = Tm @ Z
        inv_sqrt = Z
        Q = inv_sqrt @ X

    # LMO direction should minimize <S,G> over unit ball
    # For the "spectral/polar" choice used in SCION-style methods, use -Q
    return Q

@torch.no_grad()
def lmo_polar(G: torch.Tensor, iters: int = 5) -> torch.Tensor:
    # return argmin <S,G> over "orthogonal-ish" extreme points ≈ -polar(G)
    Q = newton_schulz_polar_factor(G, iters=iters)
    return -Q

@torch.no_grad()
def lmo_dispatch(g: torch.Tensor, polar_iters: int = 5) -> torch.Tensor:
    if g.ndim == 2:
        return lmo_polar(g, iters=polar_iters)
    else:
        return lmo_sign(g)

In [26]:
class SCION(torch.optim.Optimizer):
    """
    SCION / uSCION optimizer.

    Params:
      lr: γ
      momentum: beta in [0,1)
      constrained: if True, does convex combo (SCION/SCG), else additive (uSCION/uSCG)
      radius: ρ (default 1.0) to scale the LMO output in constrained mode
      polar_iters: Newton–Schulz iterations used for matrix LMO
      weight_decay: optional (applied as decoupled WD, AdamW-style)
    """
    def __init__(
        self,
        params,
        lr: float = 1e-2,
        momentum: float = 0.9,
        constrained: bool = False,
        radius: float = 1.0,
        polar_iters: int = 5,
        weight_decay: float = 0.0,
    ):
        if lr <= 0:
            raise ValueError("lr must be > 0")
        if not (0.0 <= momentum < 1.0):
            raise ValueError("momentum must be in [0,1)")
        defaults = dict(
            lr=lr,
            momentum=momentum,
            constrained=constrained,
            radius=radius,
            polar_iters=polar_iters,
            weight_decay=weight_decay,
        )
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            lr = group["lr"]
            beta = group["momentum"]
            constrained = group["constrained"]
            rho = group["radius"]
            polar_iters = group["polar_iters"]
            wd = group["weight_decay"]

            for p in group["params"]:
                if p.grad is None:
                    continue

                g = p.grad

                # decoupled weight decay (optional)
                if wd != 0.0:
                    p.mul_(1.0 - lr * wd)

                st = self.state[p]
                if "buf" not in st:
                    st["buf"] = torch.zeros_like(g)

                buf = st["buf"]
                # momentum direction (simple EMA)
                buf.mul_(beta).add_(g, alpha=(1.0 - beta))
                d = buf  # direction we feed to LMO

                s = lmo_dispatch(d, polar_iters=polar_iters)

                if constrained:
                    # x <- (1-γ)x + γ * ρ*s
                    p.mul_(1.0 - lr).add_(s, alpha=lr * rho)
                else:
                    # x <- x + γ * s
                    p.add_(s, alpha=lr)

        return loss

In [27]:
def make_optimizer(name: str, model: nn.Module, lr: float, wd: float = 0.0, momentum: float = 0.9):
    name = name.lower()

    if name == "adamw":
        return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)

    if name == "sgd":
        return torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=wd, nesterov=True)

    if name == "muon":
        # PyTorch has torch.optim.Muon in newer versions. Otherwise, fallback.
        if hasattr(torch.optim, "Muon"):
            # IMPORTANT: Muon is meant for 2D weight matrices; bias/embeddings typically use AdamW.
            # Here we do a simple split: 2D -> Muon, rest -> AdamW
            mat_params = []
            other_params = []
            for p in model.parameters():
                if p.ndim == 2:
                    mat_params.append(p)
                else:
                    other_params.append(p)

            opt_muon = torch.optim.Muon(mat_params, lr=lr, weight_decay=wd)
            opt_other = torch.optim.AdamW(other_params, lr=lr, weight_decay=wd)
            return (opt_muon, opt_other)
        else:
            print("WARNING: torch.optim.Muon not found; falling back to AdamW.")
            return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)

    if name == "scion":
        return SCION(model.parameters(), lr=lr, momentum=momentum, constrained=True, radius=1.0, weight_decay=wd)

    if name == "uscion":
        return SCION(model.parameters(), lr=lr, momentum=momentum, constrained=False, radius=1.0, weight_decay=wd)

    raise ValueError(f"Unknown optimizer: {name}")

In [28]:
class SmallCNN(nn.Module):
    def __init__(self, in_ch: int, num_classes: int, width: int = 64, dropout: float = 0.1):
        super().__init__()
        w = width
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, w, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(w, w, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),

            nn.Conv2d(w, 2*w, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(2*w, 2*w, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),

            nn.AdaptiveAvgPool2d(1),
        )
        self.head = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(dropout),
            nn.Linear(2*w, num_classes)
        )

    def forward(self, x):
        x = self.net(x)
        return self.head(x)

In [29]:
def get_fashion_mnist(batch_size: int = 128, num_workers: int = 0):
    tfm = T.Compose([
        T.ToTensor(),
        T.Normalize((0.5,), (0.5,))
    ])
    train = torchvision.datasets.FashionMNIST(root="./data", train=True, download=True, transform=tfm)
    test  = torchvision.datasets.FashionMNIST(root="./data", train=False, download=True, transform=tfm)
    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    test_loader  = DataLoader(test, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    return train_loader, test_loader

def get_svhn(batch_size: int = 128, num_workers: int = 0):
    tfm = T.Compose([
        T.ToTensor(),
        T.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)),
    ])
    train = torchvision.datasets.SVHN(root="./data", split="train", download=True, transform=tfm)
    test  = torchvision.datasets.SVHN(root="./data", split="test",  download=True, transform=tfm)
    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    test_loader  = DataLoader(test, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    return train_loader, test_loader

In [30]:
@torch.no_grad()
def accuracy(logits, y):
    return (logits.argmax(dim=1) == y).float().mean().item()

def run_epoch(model, loader, optimizer, train: bool = True):
    model.train(train)
    total_loss = 0.0
    total_acc = 0.0
    n = 0

    for x, y in loader:
        x, y = x.to(device), y.to(device)

        if train:
            # zero grad
            if isinstance(optimizer, tuple):
                for opt in optimizer:
                    opt.zero_grad(set_to_none=True)
            else:
                optimizer.zero_grad(set_to_none=True)

        logits = model(x)
        loss = F.cross_entropy(logits, y)

        if train:
            loss.backward()
            if isinstance(optimizer, tuple):
                for opt in optimizer:
                    opt.step()
            else:
                optimizer.step()

        bs = x.size(0)
        total_loss += loss.item() * bs
        total_acc  += accuracy(logits, y) * bs
        n += bs

    return total_loss / n, total_acc / n

def fit(model, train_loader, test_loader, optimizer, epochs: int = 10):
    history = []
    for ep in range(1, epochs+1):
        t0 = time.time()
        tr_loss, tr_acc = run_epoch(model, train_loader, optimizer, train=True)
        te_loss, te_acc = run_epoch(model, test_loader, optimizer, train=False)
        dt = time.time() - t0
        history.append(dict(epoch=ep, train_loss=tr_loss, train_acc=tr_acc, test_loss=te_loss, test_acc=te_acc, sec=dt))
        print(f"ep {ep:02d} | tr {tr_loss:.4f}/{tr_acc:.4f} | te {te_loss:.4f}/{te_acc:.4f} | {dt:.1f}s")
    return pd.DataFrame(history)

## Experiment A — Fashion-MNIST

Compare:
- AdamW
- SGD (Nesterov)
- Muon (if available in torch.optim)
- SCION (constrained)
- uSCION (unconstrained)

In [31]:
set_seed(0)

train_loader, test_loader = get_fashion_mnist(batch_size=256)

def run_expA(optim_name: str, lr: float, wd: float = 0.0, momentum: float = 0.9, width: int = 64, epochs: int = 10):
    model = SmallCNN(in_ch=1, num_classes=10, width=width).to(device)
    opt = make_optimizer(optim_name, model, lr=lr, wd=wd, momentum=momentum)
    df = fit(model, train_loader, test_loader, opt, epochs=epochs)
    df["optimizer"] = optim_name
    df["lr"] = lr
    df["wd"] = wd
    df["width"] = width
    return df

# Reasonable starting points (adjust if needed)
configs = [
    ("adamw",  1e-3, 1e-4),
    ("sgd",    5e-2, 5e-4),
    ("muon",   2e-3, 1e-4),
    ("scion",  5e-2, 0.0),
    ("uscion", 5e-2, 0.0),
]

dfs = []
for name, lr, wd in configs:
    print("\n===", name, "===")
    dfs.append(run_expA(name, lr=lr, wd=wd, epochs=10))

dfA = pd.concat(dfs, ignore_index=True)
dfA.tail()


=== adamw ===


  super().__init__(loader)


ep 01 | tr 1.0309/0.6206 | te 0.6506/0.7601 | 13.7s
ep 02 | tr 0.5885/0.7905 | te 0.5270/0.8171 | 11.0s
ep 03 | tr 0.4842/0.8268 | te 0.4873/0.8209 | 11.3s
ep 04 | tr 0.4261/0.8464 | te 0.3865/0.8611 | 12.0s
ep 05 | tr 0.3816/0.8621 | te 0.3618/0.8738 | 11.9s
ep 06 | tr 0.3551/0.8725 | te 0.3757/0.8688 | 11.1s
ep 07 | tr 0.3351/0.8793 | te 0.3281/0.8845 | 11.1s
ep 08 | tr 0.3154/0.8869 | te 0.3251/0.8856 | 12.4s
ep 09 | tr 0.3046/0.8904 | te 0.3044/0.8920 | 12.6s
ep 10 | tr 0.2928/0.8941 | te 0.2967/0.8950 | 11.5s

=== sgd ===
ep 01 | tr 1.2860/0.5164 | te 0.7170/0.7386 | 12.2s


KeyboardInterrupt: 

In [None]:
def plot_metric(df, metric: str, title: str):
    plt.figure()
    for opt_name, sub in df.groupby("optimizer"):
        plt.plot(sub["epoch"], sub[metric], label=opt_name)
    plt.xlabel("epoch")
    plt.ylabel(metric)
    plt.title(title)
    plt.legend()
    plt.show()

plot_metric(dfA, "test_acc",  "Experiment A — FashionMNIST test accuracy")
plot_metric(dfA, "test_loss", "Experiment A — FashionMNIST test loss")

## Experiment B — SVHN hyperparameter transfer across widths

Protocol:
1) Choose a base width (e.g., 64) and tune γ (learning rate) on it.
2) Reuse the best γ for larger widths (e.g., 128, 256).
3) Compare how stable the transferred hyperparameter is (especially for SCION/uSCION vs AdamW/SGD).

In [None]:
set_seed(1)
svhn_train, svhn_test = get_svhn(batch_size=256)

In [None]:
def run_once_svhn(optim_name: str, lr: float, width: int, epochs: int = 5, wd: float = 0.0):
    model = SmallCNN(in_ch=3, num_classes=10, width=width).to(device)
    opt = make_optimizer(optim_name, model, lr=lr, wd=wd)
    df = fit(model, svhn_train, svhn_test, opt, epochs=epochs)
    df["optimizer"] = optim_name
    df["lr"] = lr
    df["width"] = width
    df["wd"] = wd
    return df

base_width = 64
grid = [1e-4, 3e-4, 1e-3, 3e-3, 1e-2]  # keep small; SVHN is harder

optimizers_B = ["adamw", "sgd", "scion", "uscion"]  # include muon if you want too

dfs = []
for opt_name in optimizers_B:
    for lr in grid:
        print(f"\n[SVHN tune] opt={opt_name} width={base_width} lr={lr}")
        dfs.append(run_once_svhn(opt_name, lr=lr, width=base_width, epochs=5, wd=1e-4 if opt_name=="adamw" else 0.0))

dfB_tune = pd.concat(dfs, ignore_index=True)
dfB_tune.tail()

In [None]:
# pick best lr by best test_acc at last epoch
last = dfB_tune.sort_values("epoch").groupby(["optimizer","lr","width"]).tail(1)
best = last.sort_values("test_acc", ascending=False).groupby("optimizer").head(1)
best_lrs = {row["optimizer"]: float(row["lr"]) for _, row in best.iterrows()}
best_lrs

In [None]:
widths = [64, 128, 256]
dfs = []

for opt_name in optimizers_B:
    lr = best_lrs[opt_name]
    for w in widths:
        print(f"\n[SVHN transfer] opt={opt_name} width={w} lr={lr}")
        dfs.append(run_once_svhn(opt_name, lr=lr, width=w, epochs=10, wd=1e-4 if opt_name=="adamw" else 0.0))

dfB = pd.concat(dfs, ignore_index=True)
dfB.tail()

In [None]:
plt.figure()
for (opt_name, w), sub in dfB.groupby(["optimizer","width"]):
    sub_last = sub.sort_values("epoch").groupby(["optimizer","width"]).tail(1)
    # We'll plot last-epoch accuracy as points
for opt_name in optimizers_B:
    sub_last = dfB[dfB["optimizer"]==opt_name].sort_values("epoch").groupby(["optimizer","width"]).tail(1)
    plt.plot(sub_last["width"], sub_last["test_acc"], marker="o", label=opt_name)

plt.xlabel("width")
plt.ylabel("test_acc (last epoch)")
plt.title("Experiment B — SVHN hyperparam transfer across widths (best lr @ width=64)")
plt.legend()
plt.show()

## Experiment C — Norm control: constrained (SCION/SCG) vs unconstrained (uSCION/uSCG)

We log per-epoch statistics of weight matrices, e.g.:
- spectral norm (approx via power iteration)
- Frobenius norm
and compare trajectories for SCION vs uSCION.

In [None]:
@torch.no_grad()
def power_iteration_spectral_norm(W: torch.Tensor, iters: int = 10, eps: float = 1e-12) -> float:
    # approximate ||W||_2 for 2D tensor
    if W.ndim != 2:
        return float("nan")
    m, n = W.shape
    v = torch.randn(n, device=W.device, dtype=W.dtype)
    v = v / (v.norm() + eps)
    for _ in range(iters):
        u = W @ v
        u = u / (u.norm() + eps)
        v = W.T @ u
        v = v / (v.norm() + eps)
    sigma = (u @ (W @ v)).abs().item()
    return sigma

@torch.no_grad()
def model_norm_stats(model: nn.Module) -> Dict[str, float]:
    specs = []
    fros = []
    for p in model.parameters():
        if p.ndim == 2:
            specs.append(power_iteration_spectral_norm(p))
            fros.append(torch.linalg.norm(p, ord="fro").item())
    out = {}
    if len(specs) > 0:
        out["spec_mean"] = float(np.mean(specs))
        out["spec_max"]  = float(np.max(specs))
        out["fro_mean"]  = float(np.mean(fros))
        out["fro_max"]   = float(np.max(fros))
    else:
        out["spec_mean"] = out["spec_max"] = out["fro_mean"] = out["fro_max"] = float("nan")
    return out

In [None]:
def fit_with_norm_logs(model, train_loader, test_loader, optimizer, epochs: int = 10):
    history = []
    for ep in range(1, epochs+1):
        t0 = time.time()
        tr_loss, tr_acc = run_epoch(model, train_loader, optimizer, train=True)
        te_loss, te_acc = run_epoch(model, test_loader, optimizer, train=False)
        stats = model_norm_stats(model)
        dt = time.time() - t0
        row = dict(epoch=ep, train_loss=tr_loss, train_acc=tr_acc, test_loss=te_loss, test_acc=te_acc, sec=dt, **stats)
        history.append(row)
        print(f"ep {ep:02d} | te_acc {te_acc:.4f} | spec_max {stats['spec_max']:.3f} | {dt:.1f}s")
    return pd.DataFrame(history)

In [None]:
set_seed(2)
train_loader, test_loader = get_fashion_mnist(batch_size=256)

def run_expC(constrained: bool, lr: float = 5e-2, width: int = 64, epochs: int = 10):
    model = SmallCNN(in_ch=1, num_classes=10, width=width).to(device)
    opt = SCION(model.parameters(), lr=lr, momentum=0.9, constrained=constrained, radius=1.0, weight_decay=0.0)
    df = fit_with_norm_logs(model, train_loader, test_loader, opt, epochs=epochs)
    df["variant"] = "SCION(constrained)" if constrained else "uSCION(unconstrained)"
    df["lr"] = lr
    df["width"] = width
    return df

dfC1 = run_expC(constrained=True,  lr=5e-2, epochs=10)
dfC2 = run_expC(constrained=False, lr=5e-2, epochs=10)
dfC = pd.concat([dfC1, dfC2], ignore_index=True)
dfC.tail()

In [None]:
def plot_norm(df, col: str, title: str):
    plt.figure()
    for name, sub in df.groupby("variant"):
        plt.plot(sub["epoch"], sub[col], label=name)
    plt.xlabel("epoch")
    plt.ylabel(col)
    plt.title(title)
    plt.legend()
    plt.show()

plot_norm(dfC, "spec_max", "Experiment C — spectral norm max (matrices)")
plot_norm(dfC, "fro_max",  "Experiment C — Frobenius norm max (matrices)")
plot_metric(dfC.rename(columns={"variant":"optimizer"}), "test_acc", "Experiment C — test accuracy")

## Checklist for submission

- [ ] uSCION and SCION implementation included
- [ ] Experiment A: Fashion-MNIST comparison vs Adam/SGD/Muon
- [ ] Experiment B: SVHN hyperparameter transfer across widths
- [ ] Experiment C: Norm control constrained vs unconstrained (norm logs + plot)
- [ ] Run on data not used in the original paper (Fashion-MNIST + SVHN satisfy this)

# EXPERIMENT D — MiniGPT + WikiText2

In [None]:
import torchtext
from torchtext.datasets import WikiText2
from collections import Counter
from torchtext.vocab import vocab

# ---- Tokenization ----
tokenizer = torchtext.data.utils.get_tokenizer("basic_english")

def build_vocab():
    counter = Counter()
    for line in WikiText2(split="train"):
        counter.update(tokenizer(line))
    v = vocab(counter, specials=["<unk>"])
    v.set_default_index(v["<unk>"])
    return v

vocab_obj = build_vocab()
vocab_size = len(vocab_obj)
print("Vocab size:", vocab_size)

In [None]:
def data_process(raw_text_iter):
    data = []
    for item in raw_text_iter:
        tokens = tokenizer(item)
        ids = [vocab_obj[token] for token in tokens]
        if len(ids) > 0:
            data.extend(ids)
    return torch.tensor(data, dtype=torch.long)

train_data = data_process(WikiText2(split="train"))
val_data   = data_process(WikiText2(split="valid"))
test_data  = data_process(WikiText2(split="test"))

def batchify(data, batch_size):
    seq_len = data.size(0) // batch_size
    data = data[:seq_len * batch_size]
    data = data.view(batch_size, -1).t().contiguous()
    return data

def get_batch(source, i, seq_len=128):
    seq_len = min(seq_len, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len]
    return data, target

In [None]:
class MiniGPT(nn.Module):
    def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=6, dim_ff=512, dropout=0.1):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Parameter(torch.zeros(1, 1024, d_model))
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_ff,
            dropout=dropout,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        self.ln = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size)
        
    def forward(self, x):
        B, T = x.shape
        tok = self.tok_emb(x)
        pos = self.pos_emb[:, :T, :]
        h = tok + pos
        
        mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
        h = self.transformer(h, mask=mask)
        h = self.ln(h)
        logits = self.head(h)
        return logits

In [None]:
def evaluate_gpt(model, data_source, batch_size=32, seq_len=128):
    model.eval()
    total_loss = 0.
    ntokens = vocab_size
    
    with torch.no_grad():
        for i in range(0, data_source.size(0)-1, seq_len):
            data, targets = get_batch(data_source, i, seq_len)
            data = data.to(device)
            targets = targets.to(device)
            output = model(data.T)
            loss = F.cross_entropy(
                output.reshape(-1, ntokens),
                targets.reshape(-1)
            )
            total_loss += loss.item()
    return math.exp(total_loss / (data_source.size(0) // seq_len))


def train_gpt(model, train_data, val_data, optimizer, epochs=3, batch_size=32, seq_len=128):
    train_data_b = batchify(train_data, batch_size).to(device)
    val_data_b   = batchify(val_data, batch_size).to(device)
    
    history = []
    ntokens = vocab_size
    
    for ep in range(1, epochs+1):
        model.train()
        total_loss = 0.
        
        for i in range(0, train_data_b.size(0)-1, seq_len):
            data, targets = get_batch(train_data_b, i, seq_len)
            data = data.to(device)
            targets = targets.to(device)
            
            optimizer.zero_grad()
            output = model(data.T)
            loss = F.cross_entropy(
                output.reshape(-1, ntokens),
                targets.reshape(-1)
            )
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        val_ppl = evaluate_gpt(model, val_data_b, batch_size, seq_len)
        print(f"Epoch {ep} | Val Perplexity: {val_ppl:.2f}")
        history.append(val_ppl)
        
    return history

In [None]:
set_seed(0)

model_adam = MiniGPT(vocab_size).to(device)
opt_adam = torch.optim.AdamW(model_adam.parameters(), lr=3e-4)

print("Training Adam...")
hist_adam = train_gpt(model_adam, train_data, val_data, opt_adam, epochs=3)

model_scion = MiniGPT(vocab_size).to(device)
opt_scion = SCION(model_scion.parameters(), lr=5e-3, constrained=True)

print("Training SCION...")
hist_scion = train_gpt(model_scion, train_data, val_data, opt_scion, epochs=3)

In [None]:
plt.figure()
plt.plot(hist_adam, label="AdamW")
plt.plot(hist_scion, label="SCION")
plt.ylabel("Val Perplexity")
plt.xlabel("Epoch")
plt.legend()
plt.title("MiniGPT on WikiText2")
plt.show()

In [None]:
batch_sizes = [16, 32, 64]

results_batch = {}

for bs in batch_sizes:
    print(f"\nBatch size {bs} — SCION")
    model = MiniGPT(vocab_size).to(device)
    opt = SCION(model.parameters(), lr=5e-3, constrained=True)
    hist = train_gpt(model, train_data, val_data, opt, epochs=2, batch_size=bs)
    results_batch[bs] = hist[-1]

In [None]:
widths = [128, 256, 384]
lr_scion = 5e-3

transfer_results = {}

for d_model in widths:
    print(f"\nWidth {d_model}")
    model = MiniGPT(vocab_size, d_model=d_model).to(device)
    opt = SCION(model.parameters(), lr=lr_scion, constrained=True)
    hist = train_gpt(model, train_data, val_data, opt, epochs=2)
    transfer_results[d_model] = hist[-1]