In [1]:
import math, torch
import torch.nn as nn
import torch.nn.functional as F

class AbcLinear(nn.Module):
    """
    A fully-connected layer that realises the abc-parametrisation:
        W = n^{-a} · w,       w_ij ~ N(0, n^{-2b})
    where n is the (potentially-infinite) input width.
    """
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 a: float = 0.0,
                 b: float = 0.0,
                 bias: bool = False):
        super().__init__()
        self.in_features  = in_features
        self.out_features = out_features
        self.n_infty      = in_features            # width that may go →∞
        self.a, self.b    = a, b

        # “small-w” parameters (no scaling yet)
        self.weight = nn.Parameter(torch.empty(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_features))
        else:
            self.register_parameter("bias", None)

        self.reset_parameters()

        # stash n_infty on the tensor so the optimizer can see it
        self.weight.n_infty = self.n_infty

    def reset_parameters(self):
        # Initialise  w_ij  ~ 𝒩(0, n^{-2b})
        std = self.n_infty ** (-self.b)
        with torch.no_grad():
            self.weight.normal_(0.0, std)
        if self.bias is not None:
            nn.init.zeros_(self.bias)

    def forward(self, x):
        scale = self.n_infty ** (-self.a)      # multiply by n^{-a}
        return F.linear(x, scale * self.weight, self.bias)

In [None]:
class AbcMLP(nn.Module):
    """
    Minimal 2-layer MLP (Input→H1→H2→Read-out) governed by per-layer (a,b)
    and a *global* exponent c that the optimiser will use.
    """
    def __init__(self,
                 width: int          = 256,
                 a_b_list: list      = None,   # [(a1,b1), (a2,b2), (a3,b3)]
                 act_fn              = F.relu):
        super().__init__()
        if a_b_list is None:
            # -- defaults reproduce μP -----------------------------
            a_b_list = [(0.0, 0.0),   # first layer
                        (0.0, 0.5),   # hidden
                        (0.0, 0.5)]   # read-out
        (a1,b1), (a2,b2), (a3,b3) = a_b_list

        self.act      = act_fn
        self.fc1      = AbcLinear(32*32*3, width, a=a1, b=b1, bias=False)
        self.fc2      = AbcLinear(width,     width, a=a2, b=b2, bias=False)
        self.readout  = AbcLinear(width,        10, a=a3, b=b3, bias=False)

        with torch.no_grad():
            self.readout.weight.zero_()

    def forward(self, x):
        x  = x.flatten(1)           # B × 3072
        h1 = self.act(self.fc1(x))
        h2 = self.act(self.fc2(h1))
        return self.readout(h2)


In [3]:
def make_abc_sgd(model, base_lr: float = 0.1, c: float = 0.0, momentum=0.9):
    """
    Create torch.optim.SGD with one param-group per distinct n_infty so
    each group gets η · n^{-c}.  (If c=0 you recover μP.)
    """
    groups = {}
    for p in model.parameters():
        n = getattr(p, "n_infty", None)
        if n is None:                      # bias / non-scaled param  → same lr
            n = 1
        eff_lr = base_lr * (n ** (-c))
        groups.setdefault(eff_lr, []).append(p)

    param_groups = [ {"params": v, "lr": k, "momentum": momentum}
                     for k, v in groups.items() ]
    return torch.optim.SGD(param_groups)


In [4]:
abc_defaults = [(0,0), (0,0.5), (0,0.5)]   # a,b per layer → μP
net      = AbcMLP(width=256, a_b_list=abc_defaults).cuda()
opt      = make_abc_sgd(net, base_lr=0.1, c=0.0, momentum=0.9)
criterion = torch.nn.CrossEntropyLoss()


In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms, datasets

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# CIFAR-10 normalization: mean and std for each channel
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5),  # mean for R,G,B
                         (0.5, 0.5, 0.5))  # std for R,G,B
])

# Load training and test sets
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset  = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader  = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)


100%|██████████| 170M/170M [00:06<00:00, 27.9MB/s] 


In [9]:
with torch.no_grad():
    x,_  = next(iter(train_loader))
    x    = x.cuda()
    for w in [64, 512, 4096]:
        mdl = AbcMLP(width=w, a_b_list=abc_defaults).cuda()
        out = mdl(x)
        print(f"width {w:<8}  logits mean {out.mean():+.20e}   std {out.std():.20e}")


width 64        logits mean +0.00000000000000000000e+00   std 0.00000000000000000000e+00
width 512       logits mean +0.00000000000000000000e+00   std 0.00000000000000000000e+00
width 4096      logits mean +0.00000000000000000000e+00   std 0.00000000000000000000e+00
