## Bluey-MERDifold Run Jobs to test Model Architecture

## Environment set up
Set up pyproject.toml or uv.lock
Don't think we need to clone the repo from Github
Maybe mount the Google Drive and link the checkpoints-directory to be somewhere in the Google Drive
Wandb authorization


## msign function

In [None]:
import torch

ABC_LIST: list[tuple[float, float, float]] = [
    (8.28721201814563, -23.595886519098837, 17.300387312530933),
    (4.107059111542203, -2.9478499167379106, 0.5448431082926601),
    (3.9486908534822946, -2.908902115962949, 0.5518191394370137),
    (3.3184196573706015, -2.488488024314874, 0.51004894012372),
    (2.300652019954817, -1.6689039845747493, 0.4188073119525673),
    (1.891301407787398, -1.2679958271945868, 0.37680408948524835),
    (1.8750014808534479, -1.2500016453999487, 0.3750001645474248),
    (1.875, -1.25, 0.375),
]

# safety factor for numerical stability (but exclude last polynomial)
ABC_LIST_STABLE: list[tuple[float, float, float]] = [
    (a / 1.01, b / 1.01**3, c / 1.01**5) for (a, b, c) in ABC_LIST[:-1]
] + [ABC_LIST[-1]]


@torch.no_grad()
def msign(G: torch.Tensor, steps: int = 10) -> torch.Tensor:
    """
    Polar Express algorithm for the matrix sign function:
    https://arxiv.org/abs/2505.16932
    """
    assert G.ndim >= 2
    should_transpose: bool = G.size(-2) > G.size(-1)

    x = G.bfloat16()
    if should_transpose:
        x = x.mT

    x /= x.norm(dim=(-2, -1), keepdim=True) * 1.01
    for step in range(steps):
        a, b, c = ABC_LIST_STABLE[step] if step < len(ABC_LIST_STABLE) else ABC_LIST_STABLE[-1]
        s = x @ x.mT
        # goal is to compute x = a x + b S x + c S^2 x
        # we can break this up into: x = (a I + (b I + c S) S) x
        y = c * s
        y.diagonal(dim1=-2, dim2=-1).add_(b)
        y = y @ s
        y.diagonal(dim1=-2, dim2=-1).add_(a)
        x = y @ x

    if should_transpose:
        x = x.mT
    x = torch.nan_to_num(x)
    return x.float()

## ManifoldMuonW

In [None]:
import torch
from torch.optim import Optimizer
from optimizers.msign import msign
import math
import torch

def manifold_muon_step(
    W: torch.Tensor,
    G: torch.Tensor,
    lr: float,
    alpha: float = 0.01,
    steps: int = 50,
    tol: float = 1e-6,
) -> torch.Tensor:
    """One manifold Muon update step keeping W on a Stiefel-like manifold."""
    orig_tall = True
    if W.shape[0] < W.shape[1]:
        # Make W tall
        W = W.transpose(-2, -1)
        G = G.transpose(-2, -1)
        orig_tall = False

    # Dual variable initialization
    Lambda = -0.25 * (W.T @ G + G.T @ W)

    for k in range(steps):
        # Candidate direction in ambient space
        A = msign(G + 2 * W @ Lambda)

        # Measure tangent-space violation
        H = W.T @ A + A.T @ W
        if torch.norm(H) / math.sqrt(H.numel()) < tol:
            break

        # Dual ascent step with simple annealing
        Lambda = Lambda - alpha * (1.0 - k / steps) * H

    # Primal descent step
    new_W = W - lr * A

    new_W = msign(new_W)

    if not orig_tall:
        new_W = new_W.transpose(-2, -1)
        
    return new_W

def manifold_muon_ADMM_step(
    W: torch.Tensor,
    G: torch.Tensor,
    lr: float,
    alpha: float = 0.01,
    steps: int = 50,
    rho: int = 4.0,
    tol: float = 1e-6,
) -> torch.Tensor:
    """Implements GD on || G + W @ (L + L.mT) ||_* (c.f. the blog)"""
    # Ensure that W and G are both tall matrices
    should_transpose = W.shape[0] < W.shape[1]
    if should_transpose:
        W = W.T
        G = G.T
    # Initialize the lagrangian, slack, and dual variable
    Lambda = -0.25 * (W.T @ G + G.T @ W)
    X = G + 2 * W @ Lambda
    Omega = torch.zeros_like(X)
    # Solve the dual problem with ADMM to find the update direction A
    for step in range(steps):
        #if step % 10 == 0:
            #print(f"W: {W} and step: {step}")
        # Update for Lambda (orthonormal least-squares solve)
        P = W.mT @ (1 / rho * Omega + X - G)
        Lambda_upd = 0.25 * (P + P.mT)
        # Update for X (singular value thresholding)
        B = G + 2 * W @ Lambda_upd - 1 / rho * Omega
        eye = torch.eye(B.shape[1], device=B.device, dtype=B.dtype)
        P_pos = 0.5 * (eye + msign(B.mT @ B - 1 / rho**2 * eye))
        X_upd = (B - 1 / rho * msign(B)) @ P_pos
        # Update for Omega (dual ascent)
        Omega_upd = Omega + rho * (X_upd - 2 * W @ Lambda_upd - G)
        Lambda, X, Omega = Lambda_upd, X_upd, Omega_upd
    # Calculate A from final ADMM solution
    # (at convergence, G + 2 * W @ Lambda \approx X)
    A = msign(G + 2 * W @ Lambda)
    # Descend on the primal problem
    new_W = W - lr * A
    # Retract to the manifold
    new_W = msign(new_W)
    # Restore the shape of the solution and return
    return new_W.T if should_transpose else new_W

class ManifoldMuonW(Optimizer):
    """
    Hybrid optimizer:
      - For param groups with group['manifold'] == True:
          use manifold_muon_step (Stiefel + spectral norm) with a Muon-style
          momentum buffer.
      - For all other params: plain AdamW.
    """

    def __init__(
        self,
        params,
        lr: float = 1e-3,
        betas=(0.95, 0.95),     # [0] used as Muon-style momentum; [1] for AdamW's second moment
        weight_decay: float = 0.0,
        eps: float = 1e-8,
        mm_steps: int = 50,
        mm_alpha: float = 0.01,
        mm_tol: float = 1e-6,
        ADMM: bool = True,
        mm_use_momentum: bool = False,
    ):
        if lr <= 0.0:
            raise ValueError(f"Invalid learning rate: {lr}")
        if not 0.0 <= eps:
            raise ValueError(f"Invalid epsilon: {eps}")
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError(f"Invalid beta1: {betas[0]}")
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError(f"Invalid beta2: {betas[1]}")

        defaults = dict(
            lr=lr,
            betas=betas,
            weight_decay=weight_decay,
            eps=eps,
            mm_steps=mm_steps,
            mm_alpha=mm_alpha,
            mm_tol=mm_tol,
            ADMM=ADMM,
            mm_use_momentum=mm_use_momentum,
        )
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            #print(f"group: {group}, group['params'] {group['params']}")
            lr = group["lr"]
            beta1, beta2 = group["betas"]
            weight_decay = group["weight_decay"]
            eps = group["eps"]
            mm_steps = group["mm_steps"]
            mm_alpha = group["mm_alpha"]
            mm_tol = group["mm_tol"]
            mm_use_momentum = group.get("mm_use_momentum", False)
            ADMM = group.get("ADMM", True)
            use_manifold = group.get("manifold", True)

            for p in group["params"]:

                if p.grad is None:
                    continue
                grad = p.grad

                # Decoupled weight decay
                if weight_decay != 0.0:
                    p.data.mul_(1.0 - lr * weight_decay)

                state = self.state[p]

                # Initialize state lazily
                if len(state) == 0:
                    state["step"] = 0
                    # AdamW stats
                    state["exp_avg"] = torch.zeros_like(p)
                    state["exp_avg_sq"] = torch.zeros_like(p)
                    # Muon-style momentum for manifold params
                    state["muon_m"] = torch.zeros_like(p)

                state["step"] += 1
                exp_avg, exp_avg_sq, muon_m = (
                    state["exp_avg"],
                    state["exp_avg_sq"],
                    state["muon_m"],
                )

                # AdamW moments always maintained (even if not used)
                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                if use_manifold and p.ndim >= 2:
                    #print("using manifold!")
                    # I think that we should just use ManifoldMuon on all of the weights
                    # Let's not use mommentum
                    if mm_use_momentum:
                        muon_m.lerp_(grad, 1.0 - beta1)   # simple EMA; could tweak
                        G_eff = muon_m
                    else:
                        # No momentum: use raw grad
                        G_eff = grad

                    W = p.data

                    if ADMM:
                        new_W = manifold_muon_ADMM_step(
                            W,
                            G_eff,
                            lr=lr,
                            alpha=mm_alpha,
                            steps=mm_steps,
                            tol=mm_tol,
                        )
                    else:
                        new_W = manifold_muon_step(
                            W,
                            G_eff,
                            lr=lr,
                            alpha=mm_alpha,
                            steps=mm_steps,
                            tol=mm_tol,
                        )
                    p.data.copy_(new_W)

                else:
                    # ---- AdamW branch ----
                    bias_correction1 = 1.0 - beta1 ** state["step"]
                    bias_correction2 = 1.0 - beta2 ** state["step"]

                    denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
                    step_size = lr / bias_correction1

                    p.data.addcdiv_(exp_avg, denom, value=-step_size)

        return loss


## dataset

In [None]:
import torch


def get_batch(
    batch_size: int = 8,
    num_pairs: int = 5,     # T
    xy_size: int = 5,       # D
    device=None
):
    """
    Generates a fresh least-squares problem every call:
        X ~ N(0,1)           shape (B, T, D)
        W ~ N(0,1/D)         shape (B, D, D)
        Y = X @ W            shape (B, T, D)

    Returns tokens of shape:
        tokens: (B, 2T, 2*(D+1))

    Layout per token vector:
        [x_flag, y_flag, x_1..x_D, y_1..y_D]

    Sequence layout:
        x_1, y_1, ..., x_T,  y_T
        or
        y_1, x_1, ..., y_T,  x_T 
        (chosen randomly to remove any positional symmetry)

    Also returns:
        X: (B, T, D)
        Y: (B, T, D)
        W: (B, D, D)
        y_pos: (B, T)
    """

    B, T, D = batch_size, num_pairs, xy_size
    token_dim = 2 * (D + 1)   # [x_flag, y_flag, x_D, y_D]
    X = torch.randn(B, T, D, device=device)
    W = torch.randn(B, D, D, device=device) / (D ** 0.5)
    Y = torch.einsum("btd,bdk->btk", X, W)

    base = torch.arange(2*T, device=device)
    swapped = base.view(T, 2).flip(1).reshape(-1)
    flip_mask = torch.randint(0, 2, (B,), device=device)
    pos = torch.where(
        flip_mask[:,None] == 0,
        base.unsqueeze(0).expand(B, 2*T),
        swapped.unsqueeze(0).expand(B, 2*T)
    )
    pos_matrix = pos.view(B, T, 2)
    x_pos = pos_matrix[:, :, 0]
    y_pos = pos_matrix[:, :, 1]

    tokens = torch.zeros(B, 2*T, token_dim, device=device)
    b_ind = torch.arange(B, device=device).unsqueeze(1)
    tokens[b_ind, x_pos, 0] = 1.0
    tokens[b_ind, x_pos, 2:2+xy_size] = X
    tokens[b_ind, y_pos, 1] = 1.0
    tokens[b_ind, y_pos, 2+xy_size:2+2*xy_size] = Y
    # return x_pos since model outputs y_preds there
    return tokens, X, Y, W, x_pos

## sweep

In [None]:
import hashlib
from itertools import product 
import argparse
import json
import os


HYPERPARAM_GRID_ADAMW = {
    "lr": [3e-4, 1e-3],
    "beta1": [0.9],
    "beta2": [0.98],
    "weight_decay": [0.0, 0.1],
    "batch_size": [64, 256],
}

HYPERPARAM_GRID_MUON = {
    "lr": [1e-3, 3e-3],
    "beta1": [0.9],
    "beta2": [0.98],
    "weight_decay": [0.0, 0.05],
    "batch_size": [64, 256],
}

""" 
HYPERPARAM_GRID_ADAMW = {
    "lr": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3, 3e-3, 1e-2, 3e-2],
    "beta1": [0.85, 0.9, 0.95],
    "beta2": [0.95, 0.98, 0.999],
    "weight_decay": [0.0, 0.01, 0.1, 0.2],
    "batch_size": [32, 64, 128, 256, 512, 1024],
    
} 

HYPERPARAM_GRID_MUON = {
    "lr": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3, 3e-3, 1e-2, 3e-2],
    "momentum": [0.9, 0.95, 0.98],  # often you’ll just fix 0.95
    "weight_decay": [0.0, 0.01, 0.1],
    "batch_size": [32, 64, 128, 256, 512, 1024],
}


"""

OPTIMIZER_NAMES = ['AdamW', 'MuonW', "ManifoldMuonW"]

OPTIMIZER_GRID_REGISTRY = {
    "AdamW": HYPERPARAM_GRID_ADAMW,
    "MuonW": HYPERPARAM_GRID_MUON,
    "ManifoldMuonW": HYPERPARAM_GRID_MUON,
}

MODEL_ARCHS = ["rms", "standard", "none"]

def short_hparam_str(hparams: dict, max_len: int = 128) -> str:
    """
    Turn a small hyperparam dict into a compact, filesystem-safe string.
    Example: {'lr':1e-3,'wd':0.1} -> 'lr1e-3_wd0.1' (possibly truncated + hash).
    """
    parts = []
    for k, v in hparams.items():
        # Normalize floats for readability
        if isinstance(v, float):
            v_str = f"{v:.1e}" if (v < 0.01 or v > 1000) else str(v)
        else:
            v_str = str(v)
        parts.append(f"{k}{v_str}")
    base = "_".join(parts)
    if len(base) <= max_len:
        return base
    # Truncate and append hash so we keep uniqueness but stay short
    h = hashlib.md5(base.encode("utf-8")).hexdigest()[:6]
    return base[: max_len - 7] + "_" + h


def iter_hparam_configs(hyperparam_grid: dict):
    """
    Given {"lr":[1e-4,1e-3], "wd":[0.0,0.1]}, yield:
        {"lr":1e-4,"wd":0.0}, {"lr":1e-4,"wd":0.1}, ...
    """
    keys = list(hyperparam_grid.keys())
    values = [hyperparam_grid[k] for k in keys]
    for combo in product(*values):
        yield dict(zip(keys, combo))


def main():
    parser = argparse.ArgumentParser(
        description="Generate hyperparameter sweep configuration files."
    )

    parser.add_argument(
        "--xy_size",
        type=int,
        required=True,
        help="Input feature dimensionality (D).",
    )

    parser.add_argument(
        "--num_pairs",
        type=int,
        required=True,
        help="Number of (x, y) pairs per batch (T).",
    )

    parser.add_argument(
        "--project_name",
        type=str,
        required=True,
        help="WandB project name for all generated configs.",
    )

    parser.add_argument(
        "--last_k",
        type=int,
        required=True,
        help="Number of recent losses to average for run summary.",
    )

    parser.add_argument(
        "--output_dir",
        type=str,
        default="jobs",
        help="Directory in which to save all generated config files.",
    )

    args = parser.parse_args()
    xy_size = args.xy_size
    num_pairs = args.num_pairs
    project_name = args.project_name
    last_k = args.last_k
    root = args.output_dir

    os.makedirs(root, exist_ok=True)
    print("\n=== Generating sweep configs ===")
    for optimizer_name in OPTIMIZER_NAMES:
        opt_grid = OPTIMIZER_GRID_REGISTRY[optimizer_name]
        opt_dir = os.path.join(root, optimizer_name)
        os.makedirs(opt_dir, exist_ok=True)

        for arch_name in MODEL_ARCHS:
            arch_dir = os.path.join(opt_dir, arch_name)
            os.makedirs(arch_dir, exist_ok=True)
            print(f"\nOptimizer: {optimizer_name}, Arch: {arch_name}")
            hparam_dicts = list(iter_hparam_configs(opt_grid))
            for idx, hparams in enumerate(hparam_dicts):
                batch_size = hparams["batch_size"]
                optimizer_kwargs = {k: v for k, v in hparams.items() if k != "batch_size"}
                hparam_str = short_hparam_str(hparams)
                run_name = f"{optimizer_name}_{arch_name}_{hparam_str}"
                spec = {
                    "run_name": run_name,
                    "arch_name": arch_name,
                    "optimizer_name": optimizer_name,
                    "optimizer_kwargs": optimizer_kwargs,
                    "xy_size": xy_size,
                    "num_pairs": num_pairs,
                    "batch_size": batch_size,
                    "project_name": project_name,
                    "last_k": last_k,
                }
                # job_000.json naming
                job_id = f"{idx:03d}"
                out_path = os.path.join(arch_dir, f"job_{job_id}.json")
                with open(out_path, "w") as f:
                    json.dump(spec, f, indent=2)
                #print(f"  wrote {out_path}")
    print("\n=== Sweep generation complete ===")


if __name__ == "__main__":
    main()

## training

In [None]:
import torch
import torch.nn.functional as F
import wandb
import os
import glob
from collections import deque
import time
import wandb
from optimizers.muonW1 import MuonW
from optimizers.manifold_muonW import ManifoldMuonW
from loadtypes.config_types import OptimizerKwargs, ExperimentConfig
from model.model import make_model
from scripts.dataset import get_batch as get_ols_batch
import datetime

# Optional TPU support
try:
    import torch_xla
    import torch_xla.core.xla_model as xm
    HAS_XLA = True
except ImportError:
    HAS_XLA = False

def resolve_device_and_saver(device_str: str):
    """
    Returns (torch.device-like, save_fn, optimizer_step_fn).
    """
    if device_str.lower() == "tpu":
        if not HAS_XLA:
            raise RuntimeError("TPU requested but torch_xla is not installed.")
        device = xm.xla_device()
        save_fn = xm.save

        def optimizer_step_fn(optimizer):
            xm.optimizer_step(optimizer)
            xm.mark_step()

    else:
        device = torch.device(device_str)
        save_fn = torch.save

        def optimizer_step_fn(optimizer):
            optimizer.step()

    return device, save_fn, optimizer_step_fn


def save_checkpoint(model, optimizer, step: int, path: str, scheduler=None, save_fn=torch.save):
    ckpt = {
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scheduler": scheduler.state_dict() if scheduler else None,
        "step": step,
        "rng_state": torch.random.get_rng_state(),
        "cuda_rng_state": torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
    }
    save_fn(ckpt, path)
    print(f"[checkpoint] saved to {path}")

def load_checkpoint(model, optimizer, path: str, device="cuda", scheduler=None) -> int:
    ckpt = torch.load(path, map_location=device)
    print("Loaded checkpoint keys:", ckpt.keys())

    # Handle old vs new key names
    model_key = "model" if "model" in ckpt else "model_state"
    optim_key = "optimizer" if "optimizer" in ckpt else "optimizer_state"

    model.load_state_dict(ckpt[model_key])
    optimizer.load_state_dict(ckpt[optim_key])

    if scheduler is not None and "scheduler" in ckpt and ckpt["scheduler"] is not None:
        scheduler.load_state_dict(ckpt["scheduler"])

    # RNG state is optional – only restore if present
    """ if "rng_state" in ckpt:
        torch.random.set_rng_state(ckpt["rng_state"])
    if "cuda_rng_state" in ckpt and ckpt["cuda_rng_state"] is not None and torch.cuda.is_available():
        torch.cuda.set_rng_state(ckpt["cuda_rng_state"]) """

    print(f"[checkpoint] resumed from {path}")
    return ckpt.get("step", 0)

def find_latest_checkpoint(checkpoint_dir: str) -> str | None:
    if not os.path.isdir(checkpoint_dir):
        return None

    files = glob.glob(os.path.join(checkpoint_dir, "step_*.pt"))
    if not files:
        return None

    best_path = None
    best_step = -1

    for path in files:
        base = os.path.basename(path)  # e.g. "step_2000_20251119-154210.pt"
        if not base.startswith("step_"):
            continue
        parts = base.split("_")
        # ["step", "<step>", "<timestamp>.pt"]
        if len(parts) < 3:
            continue

        step_str = parts[1]
        try:
            step = int(step_str)
        except ValueError:
            continue

        if step > best_step:
            best_step = step
            best_path = path

    return best_path


class WarmupConstantDecayLrScheduler:
    def __init__(self, optimizer, total_steps, warmup_ratio=0.02, decay_ratio=0.10):
        self.optimizer = optimizer
        self.total_steps = total_steps
        self.warmup_steps = int(total_steps * warmup_ratio)
        self.decay_steps = int(total_steps * decay_ratio)
        self.decay_start = total_steps - self.decay_steps
        self.base_lrs = [g['lr'] for g in optimizer.param_groups]
        self.last_step = 0

    def state_dict(self):
        return {"last_step": self.last_step}

    def load_state_dict(self, state):
        self.last_step = state["last_step"]

    def step(self):
        step = self.last_step
        if step < self.warmup_steps and self.warmup_steps != 0:
            scale = step / self.warmup_steps
        elif step < self.decay_start:
            scale = 1.0
        else:
            remaining = max(1, self.total_steps - self.decay_start)
            scale = max(0.0, 1 - (step - self.decay_start) / remaining)

        for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups):
            group["lr"] = base_lr * scale
        self.last_step += 1


def train(
    model,
    optimizer,
    logger,
    *,
    get_batch,
    batch_size=8,
    num_pairs=5,
    xy_size=5,
    num_steps=1000,
    device="cuda",
    verbose=False,
    print_interval=1000,
    checkpoint_every=20,
    checkpoint_dir=None,
    resume_from: str | None = None,
    scheduler=None,
):
    device, save_fn, optimizer_step_fn = resolve_device_and_saver(device)
    model.to(device)
    model.train()

    if checkpoint_dir is not None:
        os.makedirs(checkpoint_dir, exist_ok=True)
    
    prev_step = 0
    if resume_from:
        prev_step = load_checkpoint(model, optimizer, resume_from, device=device, scheduler=scheduler)
       
    for step in range(prev_step, num_steps):
        iter_start = time.time()
        tokens, X, Y, W, x_token_indices = get_batch(
            batch_size=batch_size,
            num_pairs=num_pairs,
            xy_size=xy_size,
            device=device,
        )
        outputs = model(tokens)
        B, S, D = outputs.shape
        b_idx = torch.arange(B, device=device).unsqueeze(1)
        y_pred = outputs[b_idx, x_token_indices, :]
        loss = torch.sum((y_pred-Y)**2, dim=1).mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer_step_fn(optimizer)
        if scheduler is not None:
            scheduler.step()

        if checkpoint_dir is not None and step % checkpoint_every == 0 and step != 0:
            now = datetime.datetime.now()
            timestamp = now.strftime("%Y%m%d-%H%M%S")
            ckpt_path = os.path.join(checkpoint_dir, f"step_{step+1}_{timestamp}.pt")
            save_checkpoint(
                model=model,
                optimizer=optimizer,
                step=step + 1,
                path=ckpt_path,
                scheduler=scheduler,
                save_fn=save_fn,   # torch.save or xm.save
            )

            #add wandb artifact logging
            if logger is not None and hasattr(logger, "run"):
                artifact = wandb.Artifact(
                    name=f"ckpt_step_{step+1}",
                    type="model",
                    metadata={
                        "step": step + 1,
                        "arch": getattr(model, "__class__", type(model)).__name__,
                    },
                )
                artifact.add_file(ckpt_path)
                logger.run.log_artifact(artifact)
            
            if verbose:
                print(f"[Step {step}] Saved checkpoint to {ckpt_path}")

        if verbose and (step % print_interval == 0):
            print(f"[Step {step}] loss = {loss.item():.6f}")

        if logger is not None:
            iter_sec = time.time() - iter_start
            logger.log({"train/loss": loss.item(), "step": step, "train/iter_sec": iter_sec})

    return model


class WandbLossLogger:
    """
    Wraps a wandb.Run-like object to:
      - forward logs to wandb
      - keep a rolling window of the last K 'loss' values
    """
    def __init__(self, run, last_k: int = 50):
        self.start_time = time.time()
        self.run = run
        self.last_k = deque(maxlen=last_k)
    
    def log(self, metrics: dict):
        if "train/loss" in metrics:
            self.last_k.append(metrics["train/loss"])
        self.run.log(metrics)
    
    def get_last_k_loss(self):
        return sum(self.last_k) / len(self.last_k)
    
    def finish(self):
        self.run.finish()


OPTIMIZER_REGISTRY = {
    "AdamW": torch.optim.AdamW,
    "MuonW": MuonW,
    "ManifoldMuonW": ManifoldMuonW,
}


def run_from_config(config: ExperimentConfig):
    """
    Run a job from a given config.
    Returns a dict with summary stats.
    """
    experiment_phase: str = config["experiment_phase"]
    run_name: str = config["run_name"]
    arch_name: str = config["arch_name"]
    optimizer_name: str = config["optimizer_name"]
    optimizer_kwargs: OptimizerKwargs = config["optimizer_kwargs"]
    xy_size: int = config["xy_size"]
    num_pairs: int = config["num_pairs"]
    num_steps: int = config["num_steps"]
    batch_size: int = config["batch_size"]
    checkpoint_every: int = config["checkpoint_every"]
    device: str = config["device"]
    project_name: str = config["project_name"]
    base_ckpt_dir: str = config["base_ckpt_dir"]
    last_k: int = config["last_k"]

    ckpt_dir = os.path.join(
        base_ckpt_dir,
        experiment_phase,
        optimizer_name,
        arch_name,
        run_name,
    )
    os.makedirs(ckpt_dir, exist_ok=True)
    resume_from = find_latest_checkpoint(ckpt_dir)

    group_name = f"{experiment_phase}/{optimizer_name}/{arch_name}"

    run = wandb.init(
        id=run_name,
        project=project_name,
        name=run_name,  # wandb name limit
        group=group_name,
        config=config,
        reinit=True,
        resume="allow",
    )

    logger = WandbLossLogger(run, last_k=last_k)
    model = make_model(arch_name)
    optimizer_class = OPTIMIZER_REGISTRY[optimizer_name]

    opt_kwargs = {}
    
    if "lr" in optimizer_kwargs:
        opt_kwargs["lr"] = optimizer_kwargs["lr"]
    if "weight_decay" in optimizer_kwargs:
        opt_kwargs["weight_decay"] = optimizer_kwargs["weight_decay"]

    # Handle beta1 / beta2 -> betas, if they exist
    if "beta1" in optimizer_kwargs and "beta2" in optimizer_kwargs:
        opt_kwargs["betas"] = (optimizer_kwargs["beta1"], optimizer_kwargs["beta2"])

    # Muon / Manifold Muon might have other fields, e.g. "momentum"
    # Pass any remaining optimizer-specific keys explicitly if you need:
    for k in ["momentum", "nesterov"]:
        if k in optimizer_kwargs:
            opt_kwargs[k] = optimizer_kwargs[k]

    optimizer = optimizer_class(model.parameters(), **opt_kwargs)

    scheduler = WarmupConstantDecayLrScheduler(optimizer, num_steps)
    model = train(
        model=model,
        optimizer=optimizer,
        logger=logger,
        get_batch=get_ols_batch,
        batch_size=batch_size,
        num_pairs=num_pairs,
        xy_size=xy_size,
        num_steps=num_steps,
        device=device,
        checkpoint_every=checkpoint_every,
        checkpoint_dir=ckpt_dir,
        resume_from=resume_from,
        verbose=True,
        scheduler=scheduler
    )

    avg_last_k_loss = logger.get_last_k_loss()
    logger.log({"avg_last_k_train_loss": avg_last_k_loss})
    logger.finish()

    return {"avg_last_k_train_loss": avg_last_k_loss}


## model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class RotaryEmbedding(nn.Module):
  def __init__(self, head_dim, base=10000):
    super().__init__()
    self.head_dim = head_dim
    self.base = base
  
  def calc_inv_freqs(self):
    inv_freqs = -2 * torch.arange(self.head_dim // 2) / self.head_dim
    inv_freqs = self.base ** inv_freqs
    return inv_freqs
  
  def calc_cos_sin(self, num_tokens):
    inv_freqs = self.calc_inv_freqs()
    t = torch.arange(num_tokens)
    freqs = torch.einsum("i,j->ij", t, inv_freqs)
    cos = freqs.cos()
    sin = freqs.sin()
    return cos, sin
  
  def apply_rotary_emb(self, x, cos, sin):
    # t, d/2 = cos.shape
    # t, d/2 = sin.shape
    # b, h, t, d = x.shape
    x1, x2 = torch.chunk(x, 2, dim=-1)
    # b, h, t, d/2 = x1.shape
    # b, h, t, d/2 = x2.shape
    o1 = x1 * cos - x2 * sin
    o2 = x1 * sin + x2 * cos
    # absolute position of rotated features doesn't matter as long as it's consistent in q and k in dot prod
    return torch.cat([o1, o2], dim=-1)

  def forward(self, q, k):
    num_tokens = q.shape[2]
    cos, sin = self.calc_cos_sin(num_tokens)
    cos, sin = cos.to(q.device), sin.to(q.device)
    q = self.apply_rotary_emb(q, cos, sin)
    k = self.apply_rotary_emb(k, cos, sin)
    return q, k
    
class RMSNorm(nn.Module):
  def __init__(self, num_features, eps=1e-5, learnable=True):
    super().__init__()
    self.num_features = num_features
    self.eps = eps
    self.learnable = learnable
    if self.learnable:
      self.scale = nn.Parameter(torch.ones(num_features))
  
  def forward(self, x):
    x_norm = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
    if self.learnable:
      return x_norm * self.scale
    return x_norm

class LayerNorm(nn.Module):
  def __init__(self, num_features, learnable=True, eps=1e-5):
    super().__init__()
    self.num_features = num_features
    self.eps = eps
    self.learnable = learnable
    self.scale = nn.Parameter(torch.ones(num_features))
    self.bias = nn.Parameter(torch.zeros(num_features))

  def forward(self, x):
    mean = x.mean(dim=-1, keepdim=True)
    variance = x.var(dim=-1, keepdim=True, unbiased=False)
    x_norm = (x - mean) * torch.rsqrt(variance + self.eps)
    return x_norm * self.scale #+ self.bias

class MultiHeadAttention(nn.Module):
  def __init__(self, d_model = 256, n_heads = 8):
    super().__init__()
    self.d_model = d_model
    self.n_heads = n_heads
    self.d_k = d_model // n_heads
    self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
    self.out = nn.Linear(d_model, d_model, bias=False)
    self.rope = RotaryEmbedding(self.d_k)

  def sdpa(self, Q, K, V):
    B, H, T, D = Q.shape
    Q, K = self.rope(Q, K)
    attn_scores = torch.matmul(Q, K.transpose(-2, -1))
    attn_scores = attn_scores / self.d_k ** 0.5
    mask = torch.tril(torch.ones(T, T, device=Q.device))
    attn_scores= attn_scores.masked_fill(mask == 0, -float("inf"))
    attn_probs = F.softmax(attn_scores, dim=-1)
    out = torch.matmul(attn_probs, V)
    return out, attn_probs

  def split_heads(self, x):
    b, t, d = x.shape
    return x.view(b, t, self.n_heads, self.d_k).transpose(1, 2)

  def combine_heads(self, x):
    b, _, t, d = x.shape
    return x.transpose(1, 2).contiguous().view(b, t, self.d_model)
  
  def forward(self, x):
    b, t, d = x.shape
    qkv = self.qkv(x)
    q = qkv[:, :, :self.d_model].contiguous()
    k = qkv[:, :, self.d_model:2*self.d_model].contiguous()
    v = qkv[:, :, 2*self.d_model:].contiguous()
    q = self.split_heads(q)
    k = self.split_heads(k)
    v = self.split_heads(v)
    attn_out, attn_probs = self.sdpa(q, k, v)
    output = self.out(self.combine_heads(attn_out))
    return output, attn_probs


class AttentionBlock(nn.Module):
  def __init__(self, n_layers=15, hidden_size=256, n_heads=8, norm_fn=None):
    super().__init__()
    self.n_layers = n_layers
    self.hidden_size = hidden_size
    self.n_heads = n_heads
    self.has_norm = norm_fn is not None
    if self.has_norm:
      self.norm = norm_fn(hidden_size)
    self.mha = MultiHeadAttention(hidden_size, n_heads)
  
  def forward(self, x):
    if self.has_norm:
      t = self.norm(x)
    else:
      t = x
    t, _ = self.mha(t)
    return t / self.n_layers + x * (self.n_layers - 1) / self.n_layers


class SwiGLU(nn.Module):
  def __init__(self, hidden_size=256):
      super().__init__()
      self.fc1 = nn.Linear(hidden_size, 2 * 2 * hidden_size, bias=False)
      self.fc2 = nn.Linear(2 * hidden_size, hidden_size, bias=False)
      self.beta = nn.Parameter(torch.tensor(1.0))

  def forward(self, x):
      x_proj = self.fc1(x)
      x_main, x_gate = x_proj.chunk(2, dim=-1)
      gate = x_gate * torch.sigmoid(self.beta * x_gate)
      x = x_main * gate
      return self.fc2(x)


class MLP(nn.Module):
  def __init__(self, n_layers=15, hidden_size=256, norm_fn=None, ):
    super().__init__()
    self.n_layers = n_layers
    self.hidden_size = hidden_size
    self.has_norm = norm_fn is not None
    if self.has_norm:
      self.norm = norm_fn(hidden_size)
    self.swiglu = SwiGLU(hidden_size)

  def forward(self, x):
    if self.has_norm:
      t = self.norm(x)
    else:
      t = x
    t = self.swiglu(t)
    return t / self.n_layers + x * (self.n_layers - 1) / self.n_layers


class TransformerBlock(nn.Module):
  def __init__(self, n_layers=15, hidden_size=256, n_heads=8, norm_fn=None):
    super().__init__()
    self.n_layers = n_layers
    self.hidden_size = hidden_size
    self.n_heads = n_heads
    self.attn = AttentionBlock(n_layers, hidden_size, n_heads, norm_fn=norm_fn)
    self.mlp = MLP(n_layers, hidden_size, norm_fn=norm_fn)
  
  def forward(self, x):
    x = self.attn(x)
    x = self.mlp(x)
    return x


class Transformer(nn.Module):
  def __init__(self,
                hidden_size=256, 
                n_heads=8, 
                n_layers=15, 
                xy_size=5, 
                norm_fn=lambda d: RMSNorm(d, learnable=False)):
    super().__init__()
    self.hidden_size = hidden_size
    self.n_heads = n_heads
    self.n_layers = n_layers
    self.xy_size = xy_size
    self.blocks = nn.ModuleList([TransformerBlock(n_layers, hidden_size, n_heads, norm_fn=norm_fn) for _ in range(n_layers)])
    self.embedding = nn.Linear(2 * (xy_size + 1), hidden_size, bias=False)
    # emb should NOT use standard Xavier initialization
    # we can calculate and see that we need to scale by (xy_size + 1)**-0.5 to get the activation rms norm to be 1
    nn.init.normal_(self.embedding.weight, mean=0.0, std=(xy_size + 1)**-0.5)
    self.has_norm = norm_fn is not None
    if self.has_norm:
      self.norm = norm_fn(hidden_size)
    self.unembedding = nn.Linear(hidden_size, xy_size, bias=False)
  
  def forward(self, x):
    x = self.embedding(x)
    for block in self.blocks:
      x = block(x)
    if self.has_norm:
      x = self.norm(x)
    x = self.unembedding(x)
    return x


def make_model(arch_name):
    if arch_name == "rms":
      ln = lambda d: RMSNorm(d, learnable=False)
    elif arch_name == "standard":
      ln = lambda d: LayerNorm(d, learnable=True)
    else:
      ln = None
    transformer = Transformer(hidden_size=256, n_heads=8, n_layers=15, xy_size=5, norm_fn=ln)
    return transformer


## main

In [None]:
import json
import argparse
from typing import cast
from scripts.training import run_from_config
from loadtypes.config_types import ExperimentSpec, RunOptions, ExperimentConfig

def load_spec(path: str) -> ExperimentSpec:
    """Load only the experiment specification (static part) from JSON."""
    with open(path, "r") as f:
        data = json.load(f)
    return cast(ExperimentSpec, data)


def main():
    parser = argparse.ArgumentParser(description="Run training from a config file.")
    parser.add_argument(
        "--config",
        type=str,
        required=True,
        help="Path to JSON experiment spec file."
    )
    parser.add_argument(
        "--phase",
        type=str,
        required=True,
        help="Experiment phase (e.g. sweep, exp1, ablation1)."
    )
    parser.add_argument(
        "--num-steps",
        type=int,
        required=True,
        help="Number of steps to train on."
    )
    parser.add_argument(
        "--device",
        type=str,
        required=True,
        help="Device to run on: cpu, cuda, tpu, auto."
    )
    parser.add_argument(
        "--ckpt-root",
        type=str,
        required=True,
        help="Base checkpoint directory."
    )
    parser.add_argument(
        "--job-id",
        type=str,
        required=True,
        help="Id of the current job."
    )
    parser.add_argument(
        "--checkpoint_every",
        type=int,
        required=True,
        help="how often to checkpoint."
    )

    args = parser.parse_args()

    print(f"\n=== Loading experiment spec from {args.config} ===")
    spec = load_spec(args.config)
    run_options: RunOptions = {
        "experiment_phase": args.phase,
        "device": args.device,
        "base_ckpt_dir": args.ckpt_root,
        "num_steps": args.num_steps,
        "checkpoint_every": args.checkpoint_every,
    }

    config: ExperimentConfig = {**spec, **run_options}
    print("\n=== Starting training run ===")
    result = run_from_config(config)
    print("\n" + "=" * 60)
    print("TRAINING RUN COMPLETE")
    print("=" * 60)
    print(f"Run name:                 {spec['run_name']}")
    print(f"Experiment phase:         {run_options['experiment_phase']}")
    print(f"Optimizer:                {spec['optimizer_name']}")
    print(f"Architecture:             {spec['arch_name']}")
    print(f"Avg last-k train loss:    {result['avg_last_k_train_loss']:.6f}")
    print(f"Checkpoint directory:     {run_options['base_ckpt_dir']}")
    print(f"Checkpoint every:     {run_options['checkpoint_every']}")
    print("=" * 60)
    print()

if __name__ == "__main__":
    main()

## run job

In [None]:
#!/bin/bash
set -e

if [ $# -ne 8 ]; then
    echo "Usage: $0 OPTIMIZER ARCH JOB_ID PHASE DEVICE CKPT_ROOT NUM_STEPS CHECKPOINT_EVERY"
    echo "Example: $0 muon rms 003 sweep cuda checkpoints 3000 200"
    exit 1
fi

OPTIMIZER=$1
ARCH=$2
JOB_ID=$3
PHASE=$4
DEVICE=$5
CKPT_ROOT=$6
NUM_STEPS=$7
CHECKPOINT_EVERY=$8

CONFIG="jobs/${OPTIMIZER}/${ARCH}/job_${JOB_ID}.json"

echo "Running config: $CONFIG"
python3 main.py \
    --config "$CONFIG" \
    --phase "$PHASE" \
    --num-steps "$NUM_STEPS" \
    --device "$DEVICE" \
    --ckpt-root "$CKPT_ROOT" \
    --checkpoint_every "$CHECKPOINT_EVERY" \
    --job-id "$JOB_ID"

## run jobs

In [None]:
#!/bin/bash
set -e

# Usage:
#   ./run_jobs_local.sh OPTIMIZER ARCH START_ID END_ID PHASE DEVICE CKPT_ROOT NUM_STEPS CHECKPOINT_EVERY [NUM_GPUS]
# Example:
#   ./run_jobs_local.sh AdamW rms 0 15 sweep cuda checkpoints 3000 200 4
#
# If NUM_GPUS is omitted, defaults to 1 (serial).

if [ "$#" -lt 9 ] || [ "$#" -gt 10 ]; then
  echo "Usage: $0 OPTIMIZER ARCH START_ID END_ID PHASE DEVICE CKPT_ROOT NUM_STEPS CHECKPOINT_EVERY [NUM_GPUS]"
  exit 1
fi

OPTIMIZER=$1    # e.g. AdamW
ARCH=$2         # e.g. rms
START=$3        # e.g. 0
END=$4          # e.g. 15
PHASE=$5        # e.g. sweep
DEVICE=$6       # cpu | cuda | tpu | auto
CKPT_ROOT=$7    # e.g. checkpoints
NUM_STEPS=$8    # e.g. 3000
CHECKPOINT_EVERY=$9
NUM_GPUS=${10:-1}   # Optional; default 1

echo "Using up to ${NUM_GPUS} GPU(s)"

# simple concurrency limiter
active_jobs=0

for i in $(seq "$START" "$END"); do
  JOB_ID=$(printf "%03d" "$i")
  GPU_ID=$(( i % NUM_GPUS ))   # round-robin assignment

  echo "=== Launching job ${JOB_ID} on GPU ${GPU_ID} (${OPTIMIZER}, ${ARCH}) ==="

  if [ "$DEVICE" = "cuda" ] || [ "$DEVICE" = "auto" ]; then
    # Bind this job to one GPU
    CUDA_VISIBLE_DEVICES=$GPU_ID \
      scripts/run_job.sh "$OPTIMIZER" "$ARCH" "$JOB_ID" "$PHASE" "cuda" "$CKPT_ROOT" "$NUM_STEPS" "$CHECKPOINT_EVERY" &
  else
    # CPU / TPU case: no CUDA_VISIBLE_DEVICES
    scripts/run_job.sh "$OPTIMIZER" "$ARCH" "$JOB_ID" "$PHASE" "$DEVICE" "$CKPT_ROOT" "$NUM_STEPS" "$CHECKPOINT_EVERY" &
  fi

  active_jobs=$((active_jobs + 1))

  # if we already have NUM_GPUS jobs running, wait for one to finish
  if [ "$active_jobs" -ge "$NUM_GPUS" ]; then
    # wait for any one background job to finish (bash 4.3+)
    wait -n
    active_jobs=$((active_jobs - 1))
  fi
done

# wait for all remaining jobs
wait
echo "All jobs finished."
