In [None]:
pip install --quiet torch==2.6.0 torchvision==0.21.0 einops scipy

In [None]:
#!/usr/bin/env python3
"""
torus_transformer_mnist.py
14-layer residual Transformer – TORUS observer-mute test (MNIST 10 k subset)
Outputs drift_summary.json with χ² Benford divergence.
"""

import torch, torch.nn as nn, torch.optim as optim, torchvision as tv
from itertools import product; import math, random, json
from scipy.stats import chisquare

# ---------- CONFIG ----------
blocks   = 14
width    = 128
epochs   = 5            # 5 epochs over 10 k imgs ≈ 30 min CPU
batch    = 128
lr       = 1e-3
seeds    = range(5)
# -----------------------------

# ----- data (10 k train subset) -----
ds_all  = tv.datasets.MNIST(root=".", download=True, train=True,
                            transform=tv.transforms.ToTensor())
ds      = torch.utils.data.Subset(ds_all, range(10_000))
loader  = torch.utils.data.DataLoader(ds, batch_size=batch,
                                      shuffle=True, drop_last=True)

def make_encoder():
    enc_layer = nn.TransformerEncoderLayer(
        d_model=width, nhead=4, dim_feedforward=256,
        dropout=0.1, batch_first=True)
    return nn.TransformerEncoder(enc_layer, num_layers=blocks)

device  = "cpu"
summary = {"live": [], "mute": []}

for mute in (False, True):
    label = "mute" if mute else "live"
    for seed in seeds:
        torch.manual_seed(seed); random.seed(seed)
        net   = make_encoder().to(device)
        read  = nn.Linear(width, 10).to(device)
        opt   = optim.SGD(list(net.parameters())+list(read.parameters()), lr=lr)
        loss_fn = nn.CrossEntropyLoss()
        upd_norms = []

        for ep in range(epochs):
            for imgs, lbl in loader:
                # flatten 28×28 into seq_len=28, width=128 via linear proj
                imgs = imgs.view(batch, 28, 28)
                proj = torch.randn(28, width)              # random static proj
                seq  = imgs @ proj                         # shape (B, 28, 128)

                opt.zero_grad(set_to_none=True)
                enc = net(seq)

                out = read(enc[:, -1])
                loss = loss_fn(out, lbl)

                loss.backward()
                if mute:
                    read.weight.grad.zero_()   # block update of last layer
                    read.bias.grad.zero_()

                opt.step()

                # --- collect weight-update deltas (L2) for Benford ---
                grads = [p.grad.flatten() for p in net.parameters()
                         if p.grad is not None]
                upd   = torch.cat(grads).norm().item()
                upd_norms.append(upd)

        # ------ Benford χ² divergence ------
        sigs   = [int(f"{u:.6e}"[0]) for u in upd_norms if u > 0]
        counts = [sigs.count(d) for d in range(1,10)]
        exp    = [math.log10(1+1/d)*len(sigs) for d in range(1,10)]
        chi2, _ = chisquare(counts, exp)
        summary[label].append(chi2)
        print(f"{label} seed {seed}: χ² = {chi2:7.2f}")

print("\nMean χ²  ->", {k: sum(v)/len(v) for k,v in summary.items()})
json.dump(summary, open("drift_summary.json","w"))
print("Saved drift_summary.json")

100%|██████████| 9.91M/9.91M [00:00<00:00, 42.0MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.17MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 8.42MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 4.92MB/s]


live seed 0: χ² =  513.23
live seed 1: χ² =  395.30
live seed 2: χ² =  571.49
live seed 3: χ² =  565.09
live seed 4: χ² =  401.35
mute seed 0: χ² =  431.36
mute seed 1: χ² =  405.14
mute seed 2: χ² =  588.06
mute seed 3: χ² =  560.59
mute seed 4: χ² =  343.18

Mean χ²  -> {'live': np.float64(489.2911303828888), 'mute': np.float64(465.66543603841166)}
Saved drift_summary.json
