In [None]:
# -----------------------------------------------------------------------------
# # Muon optimizer

def zeropower_via_svd(G, steps=None):
    U, S, V = G.svd()
    return U @ V.T

@torch.compile
def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7):
    """
    Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
    quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
    of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
    zero even beyond the point where the iteration no longer converges all the way to one everywhere
    on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
    where S' is diagonal with S_{ii}' \sim Uniform(0.5, 1.5), which turns out not to hurt model
    performance at all relative to UV^T, where USV^T = G is the SVD.
    """
    assert len(G.shape) == 2
    a, b, c = (3.4445, -4.7750,  2.0315)
    X = G.bfloat16()
    X /= (X.norm() + eps) # ensure top singular value <= 1
    if G.size(0) > G.size(1):
        X = X.T
    for _ in range(steps):
        A = X @ X.T
        B = A @ X
        X = a * X + b * B + c * A @ B
    if G.size(0) > G.size(1):
        X = X.T
    return X

zeropower_backends = dict(svd=zeropower_via_svd, newtonschulz5=zeropower_via_newtonschulz5)

class Muon(torch.optim.Optimizer):
    """
    Muon - MomentUm Orthogonalized by Newton-schulz

    Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
    processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
    matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
    the advantage that it can be stably run in bfloat16 on the GPU.

    Some warnings:
    - This optimizer assumes that all parameters passed in are 2D.
    - It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D
    parameters; those should all be optimized by a standard method (e.g., AdamW).
    - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
    - We believe it is unlikely to work well for training with small batch size.
    - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
    - We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M).

    Arguments:
        lr: The learning rate used by the internal SGD.
        momentum: The momentum used by the internal SGD.
        nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
        backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5')
        backend_steps: The number of iteration steps to use in the backend, if it is iterative.
    """
    def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True,
                 backend='newtonschulz5', backend_steps=5):
        defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, backend=backend, backend_steps=backend_steps)
        super().__init__(params, defaults)

    def step(self):

        for group in self.param_groups:

            lr = group['lr']
            momentum = group['momentum']
            zeropower_backend = zeropower_backends[group['backend']]

            # generate weight updates in distributed fashion
            total_params = sum(p.numel() for p in group['params'])
            updates_flat = torch.zeros(total_params, device='cuda', dtype=torch.bfloat16)
            curr_idx = 0
            for i, p in enumerate(group['params']):
                # luckily this will perfectly distribute a transformer with multiple of 4 layers to 8 GPUs
                if i % int(os.environ['WORLD_SIZE']) == int(os.environ['RANK']):
                    g = p.grad
                    assert g is not None
                    state = self.state[p]
                    if 'momentum_buffer' not in state:
                        state['momentum_buffer'] = torch.zeros_like(g)
                    buf = state['momentum_buffer']
                    buf.mul_(momentum).add_(g)
                    if group['nesterov']:
                        g = g.add(buf, alpha=momentum)
                    g = zeropower_backend(g, steps=group['backend_steps'])
                    g *= max(1, g.size(0)/g.size(1))**0.5
                    updates_flat[curr_idx:curr_idx+p.numel()] = g.flatten()
                curr_idx += p.numel()

            # sync updates across devices. we are not memory-constrained so can do this simple deserialization
            dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM)

            # deserialize and apply updates
            curr_idx = 0
            for p in group['params']:
                g = updates_flat[curr_idx:curr_idx+p.numel()].view_as(p.data).type_as(p.data)
                p.data.add_(g, alpha=-lr)
                curr_idx += p.numel()


In [None]:
# -----------------------------------------------------------------------------
# Our own simple Distributed Data Loader

def _peek_data_shard(filename):
    # only reads the header, returns header data
    with open(filename, "rb") as f:
        # first read the header, which is 256 int32 integers (4 bytes each)
        header = np.frombuffer(f.read(256*4), dtype=np.int32)
    if header[0] != 20240520:
        print("ERROR: magic number mismatch in the data .bin file!")
        print("---> HINT: Are you passing in a correct file with --input_bin?")
        print("---> HINT: Dataset encoding changed recently, re-run data prepro or refer again to README")
        print("---> HINT: For example re-run: `python dev/data/tinyshakespeare.py`, then re-try")
        exit(1)
    assert header[1] == 1, "unsupported version"
    ntok = header[2] # number of tokens (claimed)
    return ntok # for now just return the number of tokens

def _load_data_shard(filename):
    with open(filename, "rb") as f:
        # first read the header, which is 256 int32 integers (4 bytes each)
        header = np.frombuffer(f.read(256*4), dtype=np.int32)
        assert header[0] == 20240520, "magic number mismatch in the data .bin file"
        assert header[1] == 1, "unsupported version"
        ntok = header[2] # number of tokens (claimed)
        # the rest of it are tokens, stored as uint16
        tokens = np.frombuffer(f.read(), dtype=np.uint16)
    assert len(tokens) == ntok, "number of tokens read does not match header?"
    return tokens

class DistributedDataLoader:
    def __init__(self, filename_pattern, B, T, process_rank, num_processes):
        self.process_rank = process_rank
        self.num_processes = num_processes
        self.B = B
        self.T = T

        # glob files that match the pattern
        self.files = sorted(glob.glob(filename_pattern))
        assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}"

        # load and validate all data shards, count number of tokens in total
        ntok_total = 0
        for fname in self.files:
            shard_ntok = _peek_data_shard(fname)
            assert shard_ntok >= num_processes * B * T + 1
            ntok_total += int(shard_ntok)
        self.ntok_total = ntok_total

        # kick things off
        self.reset()

    def reset(self):
        self.current_shard = 0
        self.current_position = self.process_rank * self.B * self.T
        self.tokens = _load_data_shard(self.files[self.current_shard])

    def advance(self): # advance to next data shard
        self.current_shard = (self.current_shard + 1) % len(self.files)
        self.current_position = self.process_rank * self.B * self.T
        self.tokens = _load_data_shard(self.files[self.current_shard])

    def next_batch(self):
        B = self.B
        T = self.T
        buf = self.tokens[self.current_position : self.current_position+B*T+1]
        buf = torch.tensor(buf.astype(np.int32), dtype=torch.long)
        x = (buf[:-1]).view(B, T) # inputs
        y = (buf[1:]).view(B, T) # targets
        # advance current position and load next shard if necessary
        self.current_position += B * T * self.num_processes
        if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
            self.advance()
        return x.cuda(), y.cuda()

In [4]:
# import os
# import sys
# with open(sys.argv[0]) as f:
#     code = f.read() # read the code of this file ASAP, for logging

# import random
# import datetime
# import time

# import glob
from dataclasses import dataclass

# import numpy as np
import math
import torch
from torch import nn
import torch.nn.functional as F
# import torch.distributed as dist
# import torch._inductor.config as config
# from torch.nn.parallel import DistributedDataParallel as DDP

from lib.lorentz.manifold import CustomLorentz
from lib.geoopt import ManifoldParameter
# from lib.geoopt.optim import RiemannianAdam


# -----------------------------------------------------------------------------
# PyTorch nn.Module definitions for the GPT-2 model

class Rotary(torch.nn.Module):

    def __init__(self, dim, base=10000):
        super().__init__()
        self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.seq_len_cached = None
        self.cos_cached = None
        self.sin_cached = None

    def forward(self, x):
        seq_len = x.shape[1]
        if seq_len != self.seq_len_cached:
            self.seq_len_cached = seq_len
            t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
            freqs = torch.outer(t, self.inv_freq).to(x.device)
            self.cos_cached = freqs.cos().bfloat16()
            self.sin_cached = freqs.sin().bfloat16()
        return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :]

def apply_rotary_emb(x, cos, sin):
    assert x.ndim == 4 # multihead attention
    d = x.shape[3]//2
    x1 = x[..., :d]
    x2 = x[..., d:]
    y1 = x1 * cos + x2 * sin
    y2 = x1 * (-sin) + x2 * cos
    return torch.cat([y1, y2], 3).type_as(x)

class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_dim = self.n_embd // self.n_head
        assert self.n_embd % self.n_head == 0
        self.c_q = nn.Linear(self.n_embd, self.n_embd, bias=False)
        self.c_k = nn.Linear(self.n_embd, self.n_embd, bias=False)
        self.c_v = nn.Linear(self.n_embd, self.n_embd, bias=False)
        # output projection
        self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
        self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977
        self.rotary = Rotary(self.head_dim)

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
        k = self.c_k(x).view(B, T, self.n_head, self.head_dim)
        v = self.c_v(x).view(B, T, self.n_head, self.head_dim)
        cos, sin = self.rotary(q)
        q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) # QK norm suggested by @Grad62304977
        q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
        y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True)
        y = y.transpose(1, 2).contiguous().view_as(x) # re-assemble all head outputs side by side
        y = self.c_proj(y)
        return y

class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
        self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977

    def forward(self, x):
        x = self.c_fc(x)
        x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977
        x = self.c_proj(x)
        return x

class Block(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.attn = CausalSelfAttention(config)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(F.rms_norm(x, (x.size(-1),)))
        x = x + self.mlp(F.rms_norm(x, (x.size(-1),)))
        return x

class LorentzMLR(nn.Module):
    """ Multinomial logistic regression (MLR) in the Lorentz model
    """
    def __init__(
            self, 
            manifold: CustomLorentz, 
            num_features: int, 
            num_classes: int
        ):
        super(LorentzMLR, self).__init__()

        self.manifold = manifold

        self.a = torch.nn.Parameter(torch.zeros(num_classes,))
        self.z = torch.nn.Parameter(F.pad(torch.zeros(num_classes, num_features-2), pad=(1,0), value=1)) # z should not be (0,0)

        self.init_weights()

    def forward(self, x):
        # x: (B, T, num_features)

        # Hyperplane parameters
        sqrt_mK = 1 / self.manifold.k.sqrt()  # scalar
        norm_z = torch.norm(self.z, dim=-1)  # (num_classes,)
        w_t = torch.sinh(sqrt_mK * self.a) * norm_z  # (num_classes,)
        w_s = torch.cosh(sqrt_mK * self.a).unsqueeze(-1) * self.z  # (num_classes, num_features -1)

        beta = torch.sqrt(-w_t**2 + torch.norm(w_s, dim=-1)**2)  # (num_classes,)

        x0 = x.narrow(-1, 0, 1)  # (B, T, 1)
        x_rest = x.narrow(-1, 1, x.shape[-1]-1)  # (B, T, num_features -1)
        inner_prod = torch.matmul(x_rest, self.z.T)  # (B, T, num_classes)
        alpha = -x0 * w_t.view(1, 1, -1) + torch.cosh(sqrt_mK * self.a).view(1, 1, -1) * inner_prod  # (B, T, num_classes)
        sqrt_mK_alpha_over_beta = sqrt_mK * alpha / beta.view(1, 1, -1)
        d = self.manifold.k.sqrt() * torch.abs(torch.asinh(sqrt_mK_alpha_over_beta))  # (B, T, num_classes)

        logits = torch.sign(alpha) * beta.view(1, 1, -1) * d  # (B, T, num_classes)

        return logits

    def init_weights(self):
        stdv = 1. / math.sqrt(self.z.size(1))
        nn.init.uniform_(self.z, -stdv, stdv)
        nn.init.uniform_(self.a, -stdv, stdv)

# -----------------------------------------------------------------------------
# The main GPT-2 model

@dataclass
class GPTConfig:
    vocab_size : int = 50304
    n_layer : int = 12
    n_head : int = 6 # head dim 128 suggested by @Grad62304977
    n_embd : int = 768

class GPT(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
        ))
        self.manifold = CustomLorentz(k=torch.tensor([1.0]))
        self.lm_head = LorentzMLR(
            manifold=self.manifold,
            num_features=config.n_embd,
            num_classes=config.vocab_size
        )
        # self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        # self.lm_head.weight.data.zero_()

    def forward(self, idx, targets=None, return_logits=True):

        # forward the GPT model itself
        x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
        x = F.rms_norm(x, (x.size(-1),))
        for block in self.transformer.h:
            x = block(x)
        x = F.rms_norm(x, (x.size(-1),))

        if targets is not None:
            # if we are given some desired targets also calculate the loss
            logits = self.lm_head(x)
            logits = logits.float() # use tf32/fp32 for logits
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else:
            # inference-time mini-optimization: only forward the lm_head on the very last position
            logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
            logits = logits.float() # use tf32/fp32 for logits
            loss = None

        # there are performance reasons why not returning logits is prudent, if not needed
        if not return_logits:
            logits = None

        return logits, loss

# -----------------------------------------------------------------------------
# int main

@dataclass
class Hyperparameters:
    # data hyperparams
    input_bin : str = 'data/fineweb10B/fineweb_train_*.bin' # input .bin to train on
    input_val_bin : str = 'data/fineweb10B/fineweb_val_*.bin' # input .bin to eval validation loss on
    # optimization hyperparams
    batch_size : int = 8*64 # batch size, in sequences, across all devices
    device_batch_size : int = 64 # batch size, in sequences, per device
    sequence_length : int = 1024 # sequence length, in tokens
    num_iterations : int = 4578 # number of iterations to run
    warmup_iters : int = 0
    warmdown_iters : int = 1308 # number of iterations of linear warmup/warmdown for triangular or trapezoidal schedule
    weight_decay : float = 0
    # evaluation and logging hyperparams
    val_loss_every : int = 125 # every how many steps to evaluate val loss? 0 for only at the end
    val_tokens : int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons
    save_every : int = 0 # every how many steps to save the checkpoint? 0 for only at the end
    lr : float = 1e-2
args = Hyperparameters()

# set up DDP (distributed data parallel). torchrun sets this env variable
assert torch.cuda.is_available()
# dist.init_process_group(backend='nccl')
# ddp_rank = int(os.environ['RANK'])
# ddp_local_rank = int(os.environ['LOCAL_RANK'])
# ddp_world_size = int(os.environ['WORLD_SIZE'])
# device = f'cuda:{ddp_local_rank}'
# torch.cuda.set_device(device)
# print(f"using device: {device}")
# master_process = (ddp_rank == 0) # this process will do logging, checkpointing etc.

# convenience variables
# B, T = args.device_batch_size, args.sequence_length
# # calculate the number of steps to take in the val loop.
# assert args.val_tokens % (B * T * ddp_world_size) == 0
# val_steps = args.val_tokens // (B * T * ddp_world_size)
# # calculate the steps of gradient accumulation required to attain the desired global batch size.
# assert args.batch_size % (B * ddp_world_size) == 0
# train_accumulation_steps = args.batch_size // (B * ddp_world_size)

# load tokens
# train_loader = DistributedDataLoader(args.input_bin, B, T, ddp_rank, ddp_world_size)
# val_loader = DistributedDataLoader(args.input_val_bin, B, T, ddp_rank, ddp_world_size)
# if master_process:
#     print(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files")
#     print(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files")
# x, y = train_loader.next_batch()

# there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. suggested to me by @Grad62304977.
# this originates from Karpathy's experiments.
num_vocab = 50304
model = GPT(GPTConfig(vocab_size=num_vocab, n_layer=12, n_head=6, n_embd=768))
raw_model = model.cuda()
# if hasattr(config, "coordinate_descent_tuning"):
#     config.coordinate_descent_tuning = True # suggested by @Chillee
# model = torch.compile(model)
# # here we wrap model into DDP container
# model = DDP(model, device_ids=[ddp_local_rank])
# raw_model = model.module # always contains the "raw" unwrapped model
# ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16)

In [5]:
# All parameters in the model
all_params = list(raw_model.parameters())

# Parameters 
manifold_params = [p for p in raw_model.parameters() if isinstance(p, ManifoldParameter)]
transformer_params = list(raw_model.transformer.h.parameters())
matrix_params = [p for p in transformer_params if p.ndim == 2 and not isinstance(p, ManifoldParameter)]
scalar_params = [p for p in transformer_params if p.ndim < 2 and not isinstance(p, ManifoldParameter)]
embedding_params = [raw_model.transformer.wte.weight]
# other_params = [p for p in all_params if p not in manifold_params + matrix_params + scalar_params + embedding_params]

In [25]:
for n, p in raw_model.lm_head.named_parameters():
    print(n, p.size(), type(p))
for n, p in raw_model.transformer.wte.named_parameters():
    print(n, p.size())

a torch.Size([50304]) <class 'torch.nn.parameter.Parameter'>
z torch.Size([50304, 767]) <class 'torch.nn.parameter.Parameter'>
manifold.k torch.Size([1]) <class 'torch.nn.parameter.Parameter'>
weight torch.Size([50304, 768])


In [44]:
import numpy as np
ass = []
for n,p in raw_model.named_parameters():
    ass.append(str(p.size()))
np.unique(ass)

array(['torch.Size([1])', 'torch.Size([3072, 768])',
       'torch.Size([50304, 767])', 'torch.Size([50304, 768])',
       'torch.Size([50304])', 'torch.Size([768, 3072])',
       'torch.Size([768, 768])'], dtype='<U24')

In [25]:
# Function to count and print parameter shapes
def print_param_info(param_list, group_name):
    count = len(param_list)
    total_params = sum(p.numel() for p in param_list)
    print(f"{group_name}:")
    print(f"  Number of Parameters: {count}")
    print(f"  Total Elements: {total_params}")
    print(f"  Shapes:")
    for i, param in enumerate(param_list):
        print(f"    Param {i + 1}: {tuple(param.shape)}")
    print()

# Extra: Total Parameters in the model
total_params = sum(p.numel() for p in all_params)
print(f"Total Parameters in Model: {total_params}")

Total Parameters in Model: 162201601


In [45]:
def get_param_groups(model, lr_manifold, weight_decay_manifold):
    no_decay = ["scale"]
    k_params = ["manifold.k"]

    muon_params = []
    muon_no_decay_params = []
    riemannian_params = []
    riemannian_no_decay_params = []
    manifold_params = []
    k_parameters = []

    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if any(nd in n for nd in k_params):
            k_parameters.append(p)
            continue
        if isinstance(p, ManifoldParameter):
            manifold_params.append(p)
            continue
        if p.ndim == 2:
            if any(nd in n for nd in no_decay):
                muon_no_decay_params.append(p)
            else:
                muon_params.append(p)
        else:
            if any(nd in n for nd in no_decay):
                riemannian_no_decay_params.append(p)
            else:
                riemannian_params.append(p)

    parameters = []

    # Muon optimizer parameter groups
    if muon_params:
        parameters.append(
            {
                "params": muon_params,
                # Add any Muon optimizer-specific settings here
            }
        )
    if muon_no_decay_params:
        parameters.append(
            {
                "params": muon_no_decay_params,
                "weight_decay": 0,
                # Add any Muon optimizer-specific settings here
            }
        )

    # Riemannian SGD optimizer parameter groups
    if riemannian_params:
        parameters.append(
            {
                "params": riemannian_params,
                # Add any RiemannianSGD optimizer-specific settings here
            }
        )
    if riemannian_no_decay_params:
        parameters.append(
            {
                "params": riemannian_no_decay_params,
                "weight_decay": 0,
                # Add any RiemannianSGD optimizer-specific settings here
            }
        )

    # Manifold parameters
    if manifold_params:
        parameters.append(
            {
                "params": manifold_params,
                "lr": lr_manifold,
                "weight_decay": weight_decay_manifold,
            }
        )

    # k parameters
    if k_parameters:
        parameters.append(
            {
                "params": k_parameters,
                "weight_decay": 0,
                "lr": 1e-4,
            }
        )

    return parameters

param_groups = get_param_groups(model, lr_manifold=0.01, weight_decay_manifold=0.001)

# # Create the Muon optimizer for matrix parameters
# muon_params = [group for group in param_groups if 'muon' in group]
# muon_optimizer = Muon(muon_params, lr=0.02, momentum=0.95, nesterov=True)

# # Create the Riemannian SGD optimizer for non-matrix parameters
# riemannian_params = [group for group in param_groups if 'riemannian' in group]
# riemannian_optimizer = RiemannianSGD(riemannian_params, lr=0.01, weight_decay=0.001)



In [69]:
g = param_groups[0]
c = 0
for p in g['params']:
    c += p.numel()
c

162151296