<a href="https://colab.research.google.com/github/Demon-Sheriff/Linear-Alg_ML_fs/blob/master/sumo_cifar10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import wandb

In [2]:
if (a:=3*4) > 1/8 * (b:=5*6):
  print("ok")
print(a, b)

ok
12 30


In [10]:
u,sig,v = torch.linalg.svd(torch.randn(4, 3), full_matrices=False)
u.shape, sig.shape, v.shape

(torch.Size([4, 3]), torch.Size([3]), torch.Size([3, 3]))

In [4]:
# u.shape, sig.shape, v.shape
# mxm, rxr, nxn
# sig = torch.eye(v.size(0))
# (u @ sig).shape
"""
1 0 0 0
0 1 0 0
0 0 1 0
"""

'\n1 0 0 0\n0 1 0 0\n0 0 1 0\n'

In [5]:
u,sig,v = torch.svd_lowrank(torch.randn(4, 1), q=3)
u.shape, sig.shape, v.shape
# torch.zeros(3,4)
# i think i know where this is going wrong : what about the case when r > min(m, n) which is happening in the case of (4, 1) param in the TinyNet model.

(torch.Size([4, 1]), torch.Size([1]), torch.Size([1, 1]))

In [51]:
# torch.eye(1).shape
torch.zeros(1, 4).shape

torch.Size([1, 4])

In [6]:
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x783add660f50>

In [102]:
@torch.compile
def orthogonalization_svd(M):
    # Fix: Use M.float() to prevent FP16 instability
    U, sig, Vh = torch.linalg.svd(M.float(), full_matrices=False)
    # Fix: PyTorch returns Vh (V conjugate transpose), so O = U @ Vh is correct
    return U @ Vh

class SUMO(torch.optim.Optimizer):
    def __init__(self, params, step_size=1e-4, scale_factor=2.0, weight_decay=0.01,
                 rank=8, K=200, step_clip_ratio=1.1, momentum_coeff=0.9):
        defaults = dict(step_size=step_size, scale_factor=scale_factor, weight_decay=weight_decay,
                        step_clip_ratio=step_clip_ratio, momentum_coeff=momentum_coeff, rank=rank, K=K)
        super(SUMO, self).__init__(params, defaults)

    def step(self):
        for group in self.param_groups:
            # Load params
            step_size = group['step_size']
            scale_factor = group['scale_factor']
            weight_decay = group['weight_decay']
            rank = group['rank']
            K = group['K']
            step_clip_ratio = group['step_clip_ratio']
            momentum_coeff = group['momentum_coeff']

            for p in group['params']:
                if p.grad is None: continue

                # FIX 1: Handle 1D params (Bias/LayerNorm) via standard SGD/Adam logic
                if p.dim() < 2:
                    # Simple SGD + Weight Decay fallback for vectors
                    p.data.mul_(1 - step_size * weight_decay)
                    p.data.add_(p.grad, alpha=-step_size)
                    continue

                g = p.grad
                state = self.state[p]

                # Init step tracker
                if 'step' not in state: state['step'] = 0

                # Transpose handling for Fat Matrices (m < n)
                m, n = g.size(-2), g.size(-1)
                rev = (m < n)
                if rev:
                    g = g.transpose(-2, -1)
                    # Update m, n after transpose to reflect the working shape
                    m, n = g.size(-2), g.size(-1)

                # Determine effective rank
                k_limit = min(m, n)

                # --- BLOCK 1: Subspace Update ---
                # FIX 2: Correct Variable Scoping for U
                if state['step'] % K == 0:
                    # Compute new subspace
                    if rank >= k_limit:
                        U, S, Vh = torch.linalg.svd(g.float(), full_matrices=False)
                        # In full rank case, effective U is (m, n)
                    else:
                        U, S, Vh = torch.svd_lowrank(g.float(), q=rank)
                        # U is (m, r)

                    # Initialize buffers if first step
                    if 'sub_moment_buffer' not in state:
                        state['sub_moment_buffer'] = U
                        state['moment_buffer'] = torch.zeros(U.shape[1], n, device=g.device)
                        state['orthogonal_buf'] = torch.zeros(U.shape[1], n, device=g.device)
                        # No rotation on very first step
                        moment_in_subspace = state['moment_buffer']
                    else:
                        # Rotation Logic
                        # R = Q_new.T @ Q_old
                        R = U.transpose(-2, -1) @ state['sub_moment_buffer']

                        # Rotate the existing moment into new subspace
                        # FIX 3: Use the rotated moment for the current update
                        moment_in_subspace = R @ state['moment_buffer']

                        # Store new basis
                        state['sub_moment_buffer'] = U
                else:
                    # Retrieve existing subspace
                    U = state['sub_moment_buffer']
                    moment_in_subspace = state['moment_buffer']

                # --- BLOCK 1.5: Projected Gradient ---
                # G_hat = Q.T @ G
                G_hat = U.transpose(-2, -1) @ g

                # --- BLOCK 2: Moment Update & Orthogonalization ---
                # FIX 3 (Cont): We use 'moment_in_subspace' which is either
                # correctly rotated (if t%K==0) or loaded from state (if t%K!=0)
                M = momentum_coeff * moment_in_subspace + G_hat

                # Store the updated M back to state immediately
                state['moment_buffer'] = M

                # Orthogonalize
                O = orthogonalization_svd(M)

                # --- BLOCK 3: Norm Growth Limiter ---
                o_norm = torch.norm(O)
                prev_o_norm = torch.norm(state['orthogonal_buf'])

                # FIX 4: Prevent zeroing out on the first step
                if prev_o_norm > 1e-8 and o_norm > step_clip_ratio * prev_o_norm:
                    scale = (step_clip_ratio * prev_o_norm) / o_norm
                    O = O * scale

                state['orthogonal_buf'] = O

                # --- BLOCK 4: Weight Update ---
                # Project back: Q @ (G_hat - O)
                # Correction term is: alpha * (G_perp + Q @ O)
                # Equivalent to: alpha * (G - Q @ G_hat + Q @ O) -> alpha * (G - Q @ (G_hat - O))

                # The orthogonal update term
                low_rank_update = U @ (G_hat - O)

                # Combine: G - (G - Q G_hat + Q O) = G - G_perp - Q O
                # Wait, paper formula: W_t = W_{t-1} - eta * (alpha * (G - Q(G_hat - O)))
                # Note: The code below implements this exactly.

                update_direction = g - low_rank_update.view_as(g)

                # FIX 5: Apply Weight Decay to Weights (AdamW style), not Gradients
                p.data.mul_(1 - step_size * weight_decay)

                # Apply Gradient Update
                final_update = scale_factor * update_direction

                # Handle transpose back if needed
                if rev:
                    final_update = final_update.transpose(-2, -1)

                p.data.add_(final_update, alpha=-step_size)

                state['step'] += 1


In [115]:
@torch.compile
def orthogonalization_svd(M):
  U,sig,V = torch.linalg.svd(M.float(), full_matrices=False) # TODO: figuring out full_matrices=True case
  return U @ V

class SUMO(torch.optim.Optimizer):
  def __init__(self, params, step_size=1e-4, scale_factor=2.0, weight_decay=0.01,
                rank=4, K=200, step_clip_ratio=1.1, momentum_coeff=0.9):
      defaults = dict(step_size=step_size, scale_factor=scale_factor, weight_decay=weight_decay,
                      step_clip_ratio=step_clip_ratio, momentum_coeff=momentum_coeff, rank=rank, K=K)
      super(SUMO, self).__init__(params, defaults)

  def __setstate__(self, state):
    super(SUMO, self).__setstate__(state)

  # tracking epoch number manually in the training loop
  def step(self):
    for group in self.param_groups:
      step_size = group['step_size']
      scale_factor = group['scale_factor']
      weight_decay = group['weight_decay']
      rank = group['rank']
      K = group['K']
      step_clip_ratio = group['step_clip_ratio']
      momentum_coeff = group['momentum_coeff']

      # mul_, add_ and other inplace ops would only be used when manipulating internal gradient ops.
      for p in group['params']:
        g = p.grad # the gradient matrix for w_t-1
        if g is None: continue

        if p.dim() < 2: # logic for handling bias/layernorm like param
           p.data.mul_(1 - step_size * weight_decay)
           p.data.add_(p.grad, alpha=-step_size)

        state = self.state[p] # current state of the optimizer

        m, n = g.size(-2), g.size(-1)
        # m >= n assumption without loss of generality. [if n > m we transpose]
        # also min(m, n) = n
        if (rev:=(m<n)) : g = g.transpose(-2, -1)
        if 'epoch_num' not in state.keys(): state['epoch_num'] = 0

        if state['epoch_num'] % K == 0:
          if rank > (k:=min(m, n)): U, sig, V = torch.linalg.svd(g.float(), full_matrices=False) # U shape : (m, n)
          else: U, sig, V = torch.svd_lowrank(g.float(), q=rank) # svd(n,m) => (n,r), (r,r), (r,m)
          if 'sub_moment_buffer' not in state.keys(): # first time we are encountering this at t=0
            if rank <= k:
              state['sub_moment_buffer'] = torch.eye(rank) # (rank, rank)
              state['moment_buffer'] = torch.zeros(rank, n) # (rank, n)
            else:
              state['sub_moment_buffer'] = torch.eye(k) # (n, n)
              state['moment_buffer'] = torch.zeros(k, k) # (k, n) -> (n, n)
          else:
            R = U.transpose(-2,-1) @ state['sub_moment_buffer'] # Q_t @ Q_(t-1)
            M = R @ state['moment_buffer'] # R @ M_(t-1)

        G_hat = U.transpose(-2, -1) @ g # (r, n) if rank <= k else (k, n) -> (n, n)
        M = momentum_coeff * state['moment_buffer'] + G_hat # (r, n) if rank <= k else (k, n)

        if 'orthogonal_buf' not in state.keys():
          state['orthogonal_buf'] = torch.zeros(rank, n) if rank <= k else torch.zeros(k, n) # (rank, n) if rank <= k else (k, n)
        O = orthogonalization_svd(M) # (r, n) if r <= k else (k, n)

        if ((o_norm:=torch.norm(O)) > step_clip_ratio * (moment_norm:=torch.norm(state['orthogonal_buf']))):
          O = (O / o_norm) * (step_clip_ratio * moment_norm) # (r, n) if r <= k else (k, n)

        # weight updates
        try:
          update = step_size*weight_decay*g + scale_factor*step_size*(g - (U@(G_hat - O)).view_as(g))
        except Exception as e:
          print(f"G_hat shape: {G_hat.shape}, O shape: {O.shape}, g shape: {g.shape}, U shape: {U.shape}, M shape: {M.shape}")

        p.data.add_(update.transpose(-2, -1) if rev else update, alpha=-1.0)
        # update the buffers
        state['sub_moment_buffer'] = U
        state['moment_buffer'] = M
        state['orthogonal_buf'] = O

In [116]:
class TinyNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.fc1 = nn.Linear(2, 4, bias=False)
    self.fc2 = nn.Linear(4, 8, bias=False)
    self.fc3 = nn.Linear(8, 16, bias=False)
    self.fc4 = nn.Linear(16, 64, bias=False)
    self.out = nn.Linear(64, 1, bias=False)

  def forward(self, x):
    x = torch.relu(self.fc1(x))
    x = torch.relu(self.fc3(self.fc2(x)))
    x = self.fc4(x)
    return torch.sigmoid(self.out(x))

def generate_big_xor(n=2000):
    X = torch.randn(n, 2)  # random points in 2D
    y = ((X[:, 0] > 0) ^ (X[:, 1] > 0)).float().unsqueeze(1)
    return X, y

# train/test split
X_train, y_train = generate_big_xor(2000)
X_test,  y_test  = generate_big_xor(2000)

train_data = [(X_train[i], y_train[i]) for i in range(len(X_train))]
train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True)

criterion = nn.BCELoss()
model = TinyNet()
# opt = SUMO(model.parameters(), 1e-3, 1.0, 1.0, 2, 5, 1.1, 1.0)
opt = SUMO(model.parameters())
opt.state['epoch_num'] = 0

In [117]:
def accuracy(X, y):
  with torch.no_grad():
    preds = (model(X) > 0.5).float()
    return (preds == y).float().mean().item()

for epoch in range(100):
  for xb, yb in train_loader:
    opt.zero_grad()
    loss = criterion(model(xb), yb)
    loss.backward()
    opt.step()
    opt.state['epoch_num'] += 1

  if epoch % 5 == 0:
    train_acc = accuracy(X_train, y_train)
    test_acc  = accuracy(X_test,  y_test)
    print(f"epoch={epoch}, loss={loss.item():.4f}, train_acc={train_acc:.3f}, test_acc={test_acc:.3f}")

epoch=0, loss=0.6941, train_acc=0.476, test_acc=0.495
epoch=5, loss=0.6958, train_acc=0.476, test_acc=0.495
epoch=10, loss=0.6946, train_acc=0.476, test_acc=0.495
epoch=15, loss=0.6957, train_acc=0.476, test_acc=0.495
epoch=20, loss=0.6966, train_acc=0.476, test_acc=0.495
epoch=25, loss=0.6949, train_acc=0.476, test_acc=0.495
epoch=30, loss=0.6933, train_acc=0.476, test_acc=0.495
epoch=35, loss=0.6941, train_acc=0.476, test_acc=0.495
epoch=40, loss=0.6944, train_acc=0.476, test_acc=0.495
epoch=45, loss=0.6951, train_acc=0.476, test_acc=0.495
epoch=50, loss=0.6963, train_acc=0.476, test_acc=0.495
epoch=55, loss=0.6965, train_acc=0.476, test_acc=0.495
epoch=60, loss=0.6940, train_acc=0.476, test_acc=0.495
epoch=65, loss=0.6917, train_acc=0.476, test_acc=0.495
epoch=70, loss=0.6944, train_acc=0.476, test_acc=0.495
epoch=75, loss=0.6959, train_acc=0.476, test_acc=0.495
epoch=80, loss=0.6965, train_acc=0.476, test_acc=0.495
epoch=85, loss=0.6959, train_acc=0.476, test_acc=0.495
epoch=90, lo

In [91]:
criterion_adam = nn.BCELoss()
model2 = TinyNet()
adam = torch.optim.Adam(model2.parameters())

In [92]:
def accuracy(X, y):
  with torch.no_grad():
    preds = (model2(X) > 0.5).float()
    return (preds == y).float().mean().item()

for epoch in range(50):
  for xb, yb in train_loader:
    adam.zero_grad()
    loss = criterion_adam(model2(xb), yb)
    loss.backward()
    adam.step()

  if epoch % 5 == 0:
    train_acc = accuracy(X_train, y_train)
    test_acc  = accuracy(X_test,  y_test)
    print(f"epoch={epoch}, loss={loss.item():.4f}, train_acc={train_acc:.3f}, test_acc={test_acc:.3f}")

epoch=0, loss=0.6345, train_acc=0.627, test_acc=0.618
epoch=5, loss=0.0592, train_acc=0.979, test_acc=0.986
epoch=10, loss=0.0314, train_acc=0.994, test_acc=0.996
epoch=15, loss=0.0087, train_acc=0.997, test_acc=0.995
epoch=20, loss=0.0052, train_acc=0.992, test_acc=0.996
epoch=25, loss=0.0003, train_acc=0.994, test_acc=0.997
epoch=30, loss=0.0711, train_acc=0.996, test_acc=0.999
epoch=35, loss=0.0000, train_acc=0.996, test_acc=0.995
epoch=40, loss=0.0308, train_acc=0.994, test_acc=0.994
epoch=45, loss=0.0000, train_acc=0.998, test_acc=0.998


In [None]:
"""
airbench94_muon.py
Runs in 2.59 seconds on a 400W NVIDIA A100
Attains 94.004 mean accuracy (n=200 trials)
Descends from https://github.com/tysam-code/hlb-CIFAR10/blob/main/main.py
"""

#############################################
#                  Setup                    #
#############################################

import os
import sys
with open(sys.argv[0]) as f:
    code = f.read()
import uuid
from math import ceil

import torch
from torch import nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T

torch.backends.cudnn.benchmark = True

#############################################
#               Muon optimizer              #
#############################################

@torch.compile
def zeropower_via_newtonschulz5(G, steps=3, eps=1e-7):
    """
    Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
    quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
    of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
    zero even beyond the point where the iteration no longer converges all the way to one everywhere
    on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
    where S' is diagonal with S_{ii}' \sim Uniform(0.5, 1.5), which turns out not to hurt model
    performance at all relative to UV^T, where USV^T = G is the SVD.
    """
    assert len(G.shape) == 2
    a, b, c = (3.4445, -4.7750,  2.0315)
    # X = G.bfloat16()
    X = G.half()
    X /= (X.norm() + eps) # ensure top singular value <= 1
    if G.size(0) > G.size(1):
        X = X.T
    for _ in range(steps):
        A = X @ X.T
        B = b * A + c * A @ A
        X = a * X + B @ X
    if G.size(0) > G.size(1):
        X = X.T
    return X

class Muon(torch.optim.Optimizer):
    def __init__(self, params, lr=1e-3, momentum=0, nesterov=False):
        if lr < 0.0:
            raise ValueError(f"Invalid learning rate: {lr}")
        if momentum < 0.0:
            raise ValueError(f"Invalid momentum value: {momentum}")
        if nesterov and momentum <= 0:
            raise ValueError("Nesterov momentum requires a momentum")
        defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov)
        super().__init__(params, defaults)

    def step(self):
        for group in self.param_groups:
            lr = group['lr']
            momentum = group['momentum']
            for p in group['params']:
                g = p.grad
                if g is None:
                    continue
                state = self.state[p]

                if 'momentum_buffer' not in state.keys():
                    state['momentum_buffer'] = torch.zeros_like(g)
                buf = state['momentum_buffer']
                buf.mul_(momentum).add_(g)
                g = g.add(buf, alpha=momentum) if group['nesterov'] else buf

                p.data.mul_(len(p.data)**0.5 / p.data.norm()) # normalize the weight
                update = zeropower_via_newtonschulz5(g.reshape(len(g), -1)).view(g.shape) # whiten the update
                p.data.add_(update, alpha=-lr) # take a step

#############################################
#                DataLoader                 #
#############################################

CIFAR_MEAN = torch.tensor((0.4914, 0.4822, 0.4465))
CIFAR_STD = torch.tensor((0.2470, 0.2435, 0.2616))

def batch_flip_lr(inputs):
    flip_mask = (torch.rand(len(inputs), device=inputs.device) < 0.5).view(-1, 1, 1, 1)
    return torch.where(flip_mask, inputs.flip(-1), inputs)

def batch_crop(images, crop_size):
    r = (images.size(-1) - crop_size)//2
    shifts = torch.randint(-r, r+1, size=(len(images), 2), device=images.device)
    images_out = torch.empty((len(images), 3, crop_size, crop_size), device=images.device, dtype=images.dtype)
    # The two cropping methods in this if-else produce equivalent results, but the second is faster for r > 2.
    if r <= 2:
        for sy in range(-r, r+1):
            for sx in range(-r, r+1):
                mask = (shifts[:, 0] == sy) & (shifts[:, 1] == sx)
                images_out[mask] = images[mask, :, r+sy:r+sy+crop_size, r+sx:r+sx+crop_size]
    else:
        images_tmp = torch.empty((len(images), 3, crop_size, crop_size+2*r), device=images.device, dtype=images.dtype)
        for s in range(-r, r+1):
            mask = (shifts[:, 0] == s)
            images_tmp[mask] = images[mask, :, r+s:r+s+crop_size, :]
        for s in range(-r, r+1):
            mask = (shifts[:, 1] == s)
            images_out[mask] = images_tmp[mask, :, :, r+s:r+s+crop_size]
    return images_out

class CifarLoader:

    def __init__(self, path, train=True, batch_size=500, aug=None):
        data_path = os.path.join(path, 'train.pt' if train else 'test.pt')
        if not os.path.exists(data_path):
            dset = torchvision.datasets.CIFAR10(path, download=True, train=train)
            images = torch.tensor(dset.data)
            labels = torch.tensor(dset.targets)
            torch.save({'images': images, 'labels': labels, 'classes': dset.classes}, data_path)

        data = torch.load(data_path, map_location=torch.device('cuda'))
        self.images, self.labels, self.classes = data['images'], data['labels'], data['classes']
        # It's faster to load+process uint8 data than to load preprocessed fp16 data
        self.images = (self.images.half() / 255).permute(0, 3, 1, 2).to(memory_format=torch.channels_last)

        self.normalize = T.Normalize(CIFAR_MEAN, CIFAR_STD)
        self.proc_images = {} # Saved results of image processing to be done on the first epoch
        self.epoch = 0

        self.aug = aug or {}
        for k in self.aug.keys():
            assert k in ['flip', 'translate'], 'Unrecognized key: %s' % k

        self.batch_size = batch_size
        self.drop_last = train
        self.shuffle = train

    def __len__(self):
        return len(self.images)//self.batch_size if self.drop_last else ceil(len(self.images)/self.batch_size)

    def __iter__(self):

        if self.epoch == 0:
            images = self.proc_images['norm'] = self.normalize(self.images)
            # Pre-flip images in order to do every-other epoch flipping scheme
            if self.aug.get('flip', False):
                images = self.proc_images['flip'] = batch_flip_lr(images)
            # Pre-pad images to save time when doing random translation
            pad = self.aug.get('translate', 0)
            if pad > 0:
                self.proc_images['pad'] = F.pad(images, (pad,)*4, 'reflect')

        if self.aug.get('translate', 0) > 0:
            images = batch_crop(self.proc_images['pad'], self.images.shape[-2])
        elif self.aug.get('flip', False):
            images = self.proc_images['flip']
        else:
            images = self.proc_images['norm']
        # Flip all images together every other epoch. This increases diversity relative to random flipping
        if self.aug.get('flip', False):
            if self.epoch % 2 == 1:
                images = images.flip(-1)

        self.epoch += 1

        indices = (torch.randperm if self.shuffle else torch.arange)(len(images), device=images.device)
        for i in range(len(self)):
            idxs = indices[i*self.batch_size:(i+1)*self.batch_size]
            yield (images[idxs], self.labels[idxs])

#############################################
#            Network Definition             #
#############################################

# note the use of low BatchNorm stats momentum
class BatchNorm(nn.BatchNorm2d):
    def __init__(self, num_features, momentum=0.6, eps=1e-12):
        super().__init__(num_features, eps=eps, momentum=1-momentum)
        self.weight.requires_grad = False
        # Note that PyTorch already initializes the weights to one and bias to zero

class Conv(nn.Conv2d):
    def __init__(self, in_channels, out_channels):
        super().__init__(in_channels, out_channels, kernel_size=3, padding='same', bias=False)

    def reset_parameters(self):
        super().reset_parameters()
        w = self.weight.data
        torch.nn.init.dirac_(w[:w.size(1)])

class ConvGroup(nn.Module):
    def __init__(self, channels_in, channels_out):
        super().__init__()
        self.conv1 = Conv(channels_in,  channels_out)
        self.pool = nn.MaxPool2d(2)
        self.norm1 = BatchNorm(channels_out)
        self.conv2 = Conv(channels_out, channels_out)
        self.norm2 = BatchNorm(channels_out)
        self.activ = nn.GELU()

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool(x)
        x = self.norm1(x)
        x = self.activ(x)
        x = self.conv2(x)
        x = self.norm2(x)
        x = self.activ(x)
        return x

class CifarNet(nn.Module):
    def __init__(self):
        super().__init__()
        widths = dict(block1=64, block2=256, block3=256)
        whiten_kernel_size = 2
        whiten_width = 2 * 3 * whiten_kernel_size**2
        self.whiten = nn.Conv2d(3, whiten_width, whiten_kernel_size, padding=0, bias=True)
        self.whiten.weight.requires_grad = False
        self.layers = nn.Sequential(
            nn.GELU(),
            ConvGroup(whiten_width,     widths['block1']),
            ConvGroup(widths['block1'], widths['block2']),
            ConvGroup(widths['block2'], widths['block3']),
            nn.MaxPool2d(3),
        )
        self.head = nn.Linear(widths['block3'], 10, bias=False)
        for mod in self.modules():
            if isinstance(mod, BatchNorm):
                mod.float()
            else:
                mod.half()

    def reset(self):
        for m in model.modules():
            if type(m) in (nn.Conv2d, Conv, BatchNorm, nn.Linear):
                m.reset_parameters()
        w = self.head.weight.data
        w *= 1 / w.std()

    def init_whiten(self, train_images, eps=5e-4):
        c, (h, w) = train_images.shape[1], self.whiten.weight.shape[2:]
        patches = train_images.unfold(2,h,1).unfold(3,w,1).transpose(1,3).reshape(-1,c,h,w).float()
        patches_flat = patches.view(len(patches), -1)
        est_patch_covariance = (patches_flat.T @ patches_flat) / len(patches_flat)
        eigenvalues, eigenvectors = torch.linalg.eigh(est_patch_covariance, UPLO='U')
        eigenvectors_scaled = eigenvectors.T.reshape(-1,c,h,w) / torch.sqrt(eigenvalues.view(-1,1,1,1) + eps)
        self.whiten.weight.data[:] = torch.cat((eigenvectors_scaled, -eigenvectors_scaled))

    def forward(self, x, whiten_bias_grad=True):
        b = self.whiten.bias
        x = F.conv2d(x, self.whiten.weight, b if whiten_bias_grad else b.detach())
        x = self.layers(x)
        x = x.view(len(x), -1)
        return self.head(x) / x.size(-1)

############################################
#                 Logging                  #
############################################

def print_columns(columns_list, is_head=False, is_final_entry=False):
    print_string = ''
    for col in columns_list:
        print_string += '|  %s  ' % col
    print_string += '|'
    if is_head:
        print('-'*len(print_string))
    print(print_string)
    if is_head or is_final_entry:
        print('-'*len(print_string))

logging_columns_list = ['run   ', 'epoch', 'train_acc', 'val_acc', 'tta_val_acc', 'time_seconds']
def print_training_details(variables, is_final_entry):
    formatted = []
    for col in logging_columns_list:
        var = variables.get(col.strip(), None)
        if type(var) in (int, str):
            res = str(var)
        elif type(var) is float:
            res = '{:0.4f}'.format(var)
        else:
            assert var is None
            res = ''
        formatted.append(res.rjust(len(col)))
    print_columns(formatted, is_final_entry=is_final_entry)

############################################
#               Evaluation                 #
############################################

def infer(model, loader, tta_level=0):

    # Test-time augmentation strategy (for tta_level=2):
    # 1. Flip/mirror the image left-to-right (50% of the time).
    # 2. Translate the image by one pixel either up-and-left or down-and-right (50% of the time,
    #    i.e. both happen 25% of the time).
    #
    # This creates 6 views per image (left/right times the two translations and no-translation),
    # which we evaluate and then weight according to the given probabilities.

    def infer_basic(inputs, net):
        return net(inputs).clone()

    def infer_mirror(inputs, net):
        return 0.5 * net(inputs) + 0.5 * net(inputs.flip(-1))

    def infer_mirror_translate(inputs, net):
        logits = infer_mirror(inputs, net)
        pad = 1
        padded_inputs = F.pad(inputs, (pad,)*4, 'reflect')
        inputs_translate_list = [
            padded_inputs[:, :, 0:32, 0:32],
            padded_inputs[:, :, 2:34, 2:34],
        ]
        logits_translate_list = [infer_mirror(inputs_translate, net)
                                 for inputs_translate in inputs_translate_list]
        logits_translate = torch.stack(logits_translate_list).mean(0)
        return 0.5 * logits + 0.5 * logits_translate

    model.eval()
    test_images = loader.normalize(loader.images)
    infer_fn = [infer_basic, infer_mirror, infer_mirror_translate][tta_level]
    with torch.no_grad():
        return torch.cat([infer_fn(inputs, model) for inputs in test_images.split(2000)])

def evaluate(model, loader, tta_level=0):
    logits = infer(model, loader, tta_level)
    return (logits.argmax(1) == loader.labels).float().mean().item()

############################################
#                Training                  #
############################################

def main(run, model):

    batch_size = 2000
    bias_lr = 0.053
    head_lr = 0.67
    wd = 2e-6 * batch_size

    test_loader = CifarLoader('cifar10', train=False, batch_size=512)
    train_loader = CifarLoader('cifar10', train=True, batch_size=batch_size, aug=dict(flip=True, translate=2))
    if run == 'warmup':
        # The only purpose of the first run is to warmup the compiled model, so we can use dummy data
        train_loader.labels = torch.randint(0, 10, size=(len(train_loader.labels),), device=train_loader.labels.device)
    total_train_steps = ceil(8 * len(train_loader))
    whiten_bias_train_steps = ceil(3 * len(train_loader))

    # Create optimizers and learning rate schedulers
    filter_params = [p for p in model.parameters() if len(p.shape) == 4 and p.requires_grad]
    norm_biases = [p for n, p in model.named_parameters() if 'norm' in n and p.requires_grad]
    param_configs = [dict(params=[model.whiten.bias], lr=bias_lr, weight_decay=wd/bias_lr),
                     dict(params=norm_biases, lr=bias_lr, weight_decay=wd/bias_lr),
                     dict(params=[model.head.weight], lr=head_lr, weight_decay=wd/head_lr)]
    optimizer1 = torch.optim.SGD(param_configs, momentum=0.85, nesterov=True, fused=True)
    optimizer2 = Muon(filter_params, lr=0.24, momentum=0.6, nesterov=True)
    optimizers = [optimizer1, optimizer2]
    for opt in optimizers:
        for group in opt.param_groups:
            group["initial_lr"] = group["lr"]

    # For accurately timing GPU code
    starter = torch.cuda.Event(enable_timing=True)
    ender = torch.cuda.Event(enable_timing=True)
    time_seconds = 0.0
    def start_timer():
        starter.record()
    def stop_timer():
        ender.record()
        torch.cuda.synchronize()
        nonlocal time_seconds
        time_seconds += 1e-3 * starter.elapsed_time(ender)

    model.reset()
    step = 0

    # Initialize the whitening layer using training images
    start_timer()
    train_images = train_loader.normalize(train_loader.images[:5000])
    model.init_whiten(train_images)
    stop_timer()

    for epoch in range(ceil(total_train_steps / len(train_loader))):

        ####################
        #     Training     #
        ####################

        start_timer()
        model.train()
        for inputs, labels in train_loader:
            outputs = model(inputs, whiten_bias_grad=(step < whiten_bias_train_steps))
            F.cross_entropy(outputs, labels, label_smoothing=0.2, reduction='sum').backward()
            for group in optimizer1.param_groups[:1]:
                group["lr"] = group["initial_lr"] * (1 - step / whiten_bias_train_steps)
            for group in optimizer1.param_groups[1:]+optimizer2.param_groups:
                group["lr"] = group["initial_lr"] * (1 - step / total_train_steps)
            for opt in optimizers:
                opt.step()
            model.zero_grad(set_to_none=True)
            step += 1
            if step >= total_train_steps:
                break
        stop_timer()

        ####################
        #    Evaluation    #
        ####################

        # Save the accuracy and loss from the last training batch of the epoch
        train_acc = (outputs.detach().argmax(1) == labels).float().mean().item()
        val_acc = evaluate(model, test_loader, tta_level=0)
        print_training_details(locals(), is_final_entry=False)
        run = None # Only print the run number once

    ####################
    #  TTA Evaluation  #
    ####################

    start_timer()
    tta_val_acc = evaluate(model, test_loader, tta_level=0)
    stop_timer()
    epoch = 'eval'
    print_training_details(locals(), is_final_entry=True)

    return tta_val_acc

  where S' is diagonal with S_{ii}' \sim Uniform(0.5, 1.5), which turns out not to hurt model


In [None]:
# import torch
torch.cuda.is_available()

True

In [None]:
if __name__ == "__main__":

    # We re-use the compiled model between runs to save the non-data-dependent compilation time
    model = CifarNet().cuda().to(memory_format=torch.channels_last)
    model.compile(mode='max-autotune')

    print_columns(logging_columns_list, is_head=True)
    main('warmup', model)
    accs = torch.tensor([main(run, model) for run in range(200)])
    print('Mean: %.4f    Std: %.4f' % (accs.mean(), accs.std()))

    log_dir = os.path.join('logs', str(uuid.uuid4()))
    os.makedirs(log_dir, exist_ok=True)
    log_path = os.path.join(log_dir, 'log.pt')
    torch.save(dict(code=code, accs=accs), log_path)
    print(os.path.abspath(log_path))

---------------------------------------------------------------------------------
|  run     |  epoch  |  train_acc  |  val_acc  |  tta_val_acc  |  time_seconds  |
---------------------------------------------------------------------------------


100%|██████████| 170M/170M [00:04<00:00, 40.4MB/s]
W1119 05:57:30.700000 686 torch/_inductor/utils.py:1436] [0/0] Not enough SMs to use max_autotune_gemm mode


|  warmup  |      0  |     0.0980  |   0.1028  |               |      105.6305  |
|          |      1  |     0.1020  |   0.1011  |               |      108.1593  |
|          |      2  |     0.1015  |   0.0478  |               |      110.6868  |
|          |      3  |     0.1065  |   0.0955  |               |      156.2594  |
|          |      4  |     0.1085  |   0.1273  |               |      158.7978  |
|          |      5  |     0.1020  |   0.1214  |               |      161.3617  |
|          |      6  |     0.1040  |   0.1127  |               |      163.9526  |
|          |      7  |     0.1030  |   0.0945  |               |      166.5874  |
|          |   eval  |     0.1030  |   0.0945  |       0.0945  |      166.7657  |
---------------------------------------------------------------------------------
|       0  |      0  |     0.7370  |   0.6457  |               |        2.9034  |
|          |      1  |     0.8155  |   0.7358  |               |        5.7790  |
|          |    

KeyboardInterrupt: 

In [None]:
"""
questions i have :
1. how exactly do we make sure that we don't run OOM on cuda, how do different dtypes used while training (bfloat16, fp32, fp16, mxfp4, int4, int8)
   impact the memory and what exactly goes / loads on memory while using torch.compile()
2. using reduction="mean" (default torch behaviour), or reduction="sum", how do we correctly log val and train losses in both, is there a generalized way to do
   how does the last batch_size affect the logging if the number of samples is not a multiple of batch_size ?
3. why do we go with batch_sizes which are normally are a multiple of 16 (32, 128, 512 ... ?), is this related to warps in gpus ?
"""