In [1]:
import torch

In [None]:
"""
Architecture Overview:
1. Embedding: Token IDs -> Vectors (wte)
2. Stack of Blocks (Repeated L times):
   - RMSNorm
   - Attention (Mixing info between tokens)
   - RMSNorm
   - MLP (Processing info within a token)
3. Final Norm 
4. LMHead: Vectors -> Logits (Probabilities)
"""

import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
import math

@dataclass
class GPTConfig:
    """
    Hyperparameters for the model.
    """
    # ┌─────────────────────────────────────────────────────────┐
    # │           321M CONVERSATIONAL MODEL                     │
    # ├─────────────────────────────────────────────────────────┤
    # │  hidden_dim:        1024                                │
    # │  layers:            24                                  │
    # │  heads:             8                                   │
    # │  head_dim:          128                                 │
    # │  mlp_ratio:         3x                                  │
    # │  vocab_size:        32K                                 │
    # │  context_length:    1024                                │
    # │  embedding:         tied (input = output projection)    │
    # │  activation:        relu squared                        │
    # │  position encoding: RoPE                                │
    # ├─────────────────────────────────────────────────────────┤
    # │  TOTAL PARAMETERS:  243,269,632                         │
    # └─────────────────────────────────────────────────────────┘
    # No KV cache
    # No GQA

    hidden_dim: int = 1024 # hidden dimension
    n_layers: int = 24 # May need to reduce to 22 or 20
    n_heads: int = 8 # head dimension = hidden_dim / n_heads = 128
    mlp_ratio: int = 3
    vocab_size: int = 32*1024
    sequence_len: int = 1024


def norm(x):
    """
    RMSNorm (Root Mean Square Layer Normalization).
    Used to stabilize training by normalizing activation magnitudes.
    """
    # Purely functional rmsnorm with no learnable params
    return F.rms_norm(x, (x.size(-1),))


def apply_rotatory_positional_encoding(x, cos, sin):
    """
    Applies Rotary Positional Embeddings (RoPE).
    Rotates the query and key vectors to encode relative positions.
    """
    assert x.ndim == 4  # multihead attention
    d = x.shape[3] // 2
    x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
    y1 = x1 * cos + x2 * sin # rotate pairs of dims
    y2 = x1 * (-sin) + x2 * cos
    out = torch.cat([y1, y2], 3) # re-assemble
    out = out.to(x.dtype) # ensure input/output dtypes match
    return out

In [41]:
config = GPTConfig()
config

GPTConfig(hidden_dim=1024, n_layers=24, n_heads=8, mlp_ratio=3, vocab_size=32768, sequence_len=1024)

In [42]:
class MultiHeadAttention(nn.Module):
    """
    Multi-Head Causal Self Attention.
    
    1. Projects input to Q, K, V.
    2. Applies RoPE to Q, K for position info.
    3. Computes attention scores (Q @ K) to see how much each token cares about others. Aggregates values (V) based on scores.
    4. Projects output to mix information across heads.
    """
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.n_heads = config.n_heads
        self.hidden_dim = config.hidden_dim
        self.head_dim = config.hidden_dim // config.n_heads

        # Linear projections for Query, Key, Value
        self.key = nn.Linear(self.hidden_dim, self.head_dim * self.n_heads, bias=False)
        self.query = nn.Linear(self.hidden_dim, self.head_dim * self.n_heads, bias=False)
        self.value = nn.Linear(self.hidden_dim, self.head_dim * self.n_heads, bias=False)

        # Output projection ("o"): mixes results from all heads back into n_embd
        self.proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=False)

    def forward(self, x: torch.Tensor, cos_sin: torch.Tensor) -> torch.Tensor:
        B, T, C = x.size()

        # 1. Projects input to Q, K, V.
        # reshape to (B, T, n_heads, head_dim)
        k = self.key(x).view(B, T, self.n_heads, self.head_dim)
        q = self.query(x).view(B, T, self.n_heads, self.head_dim)
        v = self.value(x).view(B, T, self.n_heads, self.head_dim)
        
        # 2. Applies RoPE to Q, K for position info.
        cos, sin = cos_sin
        k, q = apply_rotatory_positional_encoding(k, cos, sin), apply_rotatory_positional_encoding(q, cos, sin)

        # 3. Computes attention scores (Q @ K) to see how much each token cares about others.
        q, k = norm(q), norm(k) # QK norm

        # make head be batch dim, i.e. (B, T, n_heads, head_dim) -> (B, n_heads, T, head_dim)
        # We are making the n_heads into a batch dimension so pytorch treats it as batches and
        # applies the attention function on each head separately in parallel
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)         
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)

        # Re-assemble the heads side by side and project back to residual stream
        y = y.transpose(1, 2).contiguous().view(B, T, C)

        # 4. Projects output to mix information across heads.
        y = self.proj(y)
        return y



In [43]:
attn = MultiHeadAttention(config)
attn

MultiHeadAttention(
  (key): Linear(in_features=1024, out_features=1024, bias=False)
  (query): Linear(in_features=1024, out_features=1024, bias=False)
  (value): Linear(in_features=1024, out_features=1024, bias=False)
  (proj): Linear(in_features=1024, out_features=1024, bias=False)
)

In [44]:
for param in attn.parameters():
    print(type(param), param.size())

<class 'torch.nn.parameter.Parameter'> torch.Size([1024, 1024])
<class 'torch.nn.parameter.Parameter'> torch.Size([1024, 1024])
<class 'torch.nn.parameter.Parameter'> torch.Size([1024, 1024])
<class 'torch.nn.parameter.Parameter'> torch.Size([1024, 1024])


In [9]:
# %pip install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [15]:
from torchinfo import summary

summary(attn, input_size=(1, config.sequence_len, config.hidden_dim), dtypes=[torch.float32])

RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: []

In [45]:
class FeedForward(nn.Module):
    """
    Feed Forward Network (MLP).
    Processes each token independently (no mixing between tokens).
    Structure: Expand -> ReLU^2 -> Contract
    """
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.proj_up = nn.Linear(config.hidden_dim, config.hidden_dim * config.mlp_ratio, bias=False)
        self.proj_down = nn.Linear(config.hidden_dim * config.mlp_ratio, config.hidden_dim, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.proj_up(x)
        x = F.relu(x).square()
        # TODO: Check if swiglu is better for 3 x hidden_dim - 
        # gelu and silu are alternatives but difference seems marginal so sticking with relu^2
        x = self.proj_down(x)
        return x

In [46]:
ff = FeedForward(config)
ff

FeedForward(
  (proj_up): Linear(in_features=1024, out_features=3072, bias=False)
  (proj_down): Linear(in_features=3072, out_features=1024, bias=False)
)

In [None]:
summary(ff, input_size=(1, config.sequence_len, config.hidden_dim), dtypes=[torch.float32])

Layer (type:depth-idx)                   Output Shape              Param #
FeedForward                              [1, 1024, 1024]           --
├─Linear: 1-1                            [1, 1024, 3072]           3,145,728
├─Linear: 1-2                            [1, 1024, 1024]           3,145,728
Total params: 6,291,456
Trainable params: 6,291,456
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 6.29
Input size (MB): 4.19
Forward/backward pass size (MB): 33.55
Params size (MB): 25.17
Estimated Total Size (MB): 62.91

In [47]:
class TransformerBlock(nn.Module):
    """
    A single Transformer Block.
    Contains:
    1. Attention (Communication)
    2. MLP (Computation)
    Both use Residual Connections (x + ...) and Pre-Norm.
    """
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.attn = MultiHeadAttention(config)
        self.ff = FeedForward(config)
    
    def forward(self, x: torch.Tensor, cos_sin: torch.Tensor) -> torch.Tensor:
        # Attention with residual connection
        x = x + self.attn(norm(x), cos_sin)
        # MLP with residual connection
        x = x + self.ff(norm(x))
        return x


In [48]:
block = TransformerBlock(GPTConfig())
block

TransformerBlock(
  (attn): MultiHeadAttention(
    (key): Linear(in_features=1024, out_features=1024, bias=False)
    (query): Linear(in_features=1024, out_features=1024, bias=False)
    (value): Linear(in_features=1024, out_features=1024, bias=False)
    (proj): Linear(in_features=1024, out_features=1024, bias=False)
  )
  (ff): FeedForward(
    (proj_up): Linear(in_features=1024, out_features=3072, bias=False)
    (proj_down): Linear(in_features=3072, out_features=1024, bias=False)
  )
)

In [14]:
summary(block, input_size=(1, config.sequence_len, config.hidden_dim), dtypes=[torch.float32])

NameError: name 'summary' is not defined

In [49]:
class GPT(nn.Module):
    """
    The full GPT model.
    Contains:
    1. Token Embedding
    2. Transformer Blocks (stacked)
    3. Final Normalization
    4. LM Head - Tied weights with token embedding
    """
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.config = config
        self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_dim)
        self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
        self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
        self.lm_head.weight = self.token_embedding.weight

        self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute
        # Why 10x? This provides a generous buffer for inference/generation, allowing the model
        # to generate sequences longer than its training length without recomputing embeddings.
        # Note: While the embeddings support 10x length, the model's quality degrades beyond ~1.5-2x
        # the training length due to unseen attention patterns. This buffer is for convenience,
        # not an expectation of good performance at 10x length. Memory cost is negligible.
        
        head_dim = config.hidden_dim // config.n_heads
        cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
        self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
        self.register_buffer("sin", sin, persistent=False)
    
    def forward(self, idx, targets=None, loss_reduction="mean") -> torch.Tensor:
        T = idx.shape[1]
        cos_sin = self.cos[:, :T], self.sin[:, :T] # truncate cache to current sequence length
        x = self.token_embedding(idx)
        x = norm(x)
        for block in self.blocks:
            x = block(x, cos_sin)
        x = norm(x)

        softcap = 15 # smoothly cap the logits to the range [-softcap, softcap]
        logits = self.lm_head(x)
        logits = logits.float() # switch to fp32 for logit softcap and loss computation
        logits = softcap * torch.tanh(logits / softcap) # squash the logits
        
        if targets is not None:
            # training: given the targets, compute and return the loss
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
            return loss
        else:
            # inference: just return the logits directly
            return logits

    def init_weights(self):
        """
        Custom weight initialization scheme.
        
        The initialization logic deviates from PyTorch defaults (Kaiming defaults) to improve training 
        stability and convergence for deep Transformers.
        
        Key Difference:
        Zero Initialization for Output Projections:
        - Function: Sets the weights of the final linear layer in each block to zero.
        - Why: This ensures that at initialization, the residual blocks contribute nothing to the 
            residual stream (y = x + 0). The model effectively starts as an identity function, allowing 
            unimpeded gradient flow from top to bottom. This prevents vanishing/exploding gradients 
            and provides a stable starting point for the model to gradually learn features.
        """
        # Initialize the weights of the model
        self.apply(self._init_weights)

        # Zero out the output projections of the blocks
        for block in self.blocks:
            torch.nn.init.zeros_(block.ff.proj_down.weight)
            torch.nn.init.zeros_(block.attn.proj.weight)

        # init the rotary embeddings
        head_dim = self.config.hidden_dim // self.config.n_heads
        cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
        self.cos, self.sin = cos, sin

        # Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
        if self.token_embedding.weight.device.type == "cuda":
            self.token_embedding.to(dtype=torch.bfloat16)

    
    def _init_weights(self, module):
        """
        Custom initialization for Linear and Embedding layers.
        
        1. Controlled Variance (Linear Layers):
           - Formula: std = 1 / sqrt(fan_in) * min(1, sqrt(fan_out / fan_in))
           - Why: Standard Kaiming init often leads to activation variance that grows with depth in 
             Transformers. This custom initialization (ref: https://arxiv.org/pdf/2310.17813) stabilizes 
             activation variance across layers, specifically accounting for the network width.

        2. Unit Variance (Embeddings):
           - Function: Normal distribution with std=1.0.
           - Why: Ensures strong initial signal strength before it enters the first normalization layer.
        """
        if isinstance(module, nn.Linear):
            # https://arxiv.org/pdf/2310.17813 
            fan_out = module.weight.size(0)
            fan_in = module.weight.size(1)
            std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in))
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)


    def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
        # autodetect the device from model embeddings
        if device is None:
            device = self.token_embedding.weight.device
        # stride the channels
        channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
        inv_freq = 1.0 / (base ** (channel_range / head_dim))
        # stride the time steps
        t = torch.arange(seq_len, dtype=torch.float32, device=device)
        # calculate the rotation frequencies at each (time, channel) pair
        freqs = torch.outer(t, inv_freq)
        cos, sin = freqs.cos(), freqs.sin()
        cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
        cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
        return cos, sin

    def setup_optimizers(self, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0):
        """
        Sets up the optimizers.
        Uses AdamW for embeddings/head and Muon for internal linear layers.

        Detailed Explanation of Hybrid Strategy:
        ----------------------------------------
        We use two different optimizers because different parts of the Transformer have different 
        geometric properties and optimization landscapes.

        1. Muon (for internal 2D matrices):
           - Applied to: Attention projections (c_q, c_k, c_v, c_proj) and MLP weights (c_fc, c_proj).
           - Mechanism: Muon forces weight *updates* to be orthogonal. In linear algebra, orthogonal 
             transformations (like rotation or reflection) preserve the magnitude (norm) of the vector 
             they act on.
           - Benefit: Deep networks suffer from vanishing/exploding gradients because signals get 
             scaled up or down at every layer. By forcing updates to be orthogonal, Muon ensures 
             signals propagate through the network without exploding in magnitude, allowing for 
             much faster and more stable training of deep layers.

        2. AdamW (for embeddings & head):
           - Applied to: Token embeddings (wte) and the final output head (lm_head).
           - Reason: These parameters are not dense 2D matrices in the same sense (embeddings are 
             lookup tables). The concept of "orthogonal updates" is mathematically ill-defined or 
             harmful for vectors/lookups. AdamW is ideal here as it adapts learning rates per-parameter 
             based on update frequency (handling the sparse nature of token updates).

        Do they conflict? 
        No. Both optimizers step in directions derived from the same global loss gradient, so they 
        optimize the same function. The risk is learning speed mismatch (one part learning faster 
        than the other), which we handle by manually scaling the AdamW learning rate below.
        """
        model_dim = self.config.hidden_dim
        # ddp, rank, local_rank, world_size = get_dist_info()
        # Separate out all parameters into 3 groups (matrix, embedding, lm_head)
        matrix_params = list(self.blocks.parameters())
        embedding_params = list(self.token_embedding.parameters())
        assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params)
        
        # Create the AdamW optimizer for the embedding
        # Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
        dmodel_lr_scale = (model_dim / 768) ** -0.5
        # if rank == 0:
        print(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
        adam_groups = [
            dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
        ]
        adamw_kwargs = dict(betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay)
        AdamWFactory = partial(torch.optim.AdamW, fused=True)
        adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
        
        # Create the Muon optimizer for the linear layers
        muon_kwargs = dict(lr=matrix_lr, momentum=0.95)
        muon_optimizer = Muon(matrix_params, **muon_kwargs)

        # Combine the two optimizers into one list
        optimizers = [adamw_optimizer, muon_optimizer]
        for opt in optimizers:
            for group in opt.param_groups:
                group["initial_lr"] = group["lr"]
        return optimizers

In [50]:
gpt = GPT(GPTConfig())

In [51]:
# gpt.to('cuda')

In [52]:
summary(gpt, input_size=(1, config.sequence_len), dtypes=[torch.long])

Layer (type:depth-idx)                   Output Shape              Param #
GPT                                      [1, 1024, 32768]          --
├─Embedding: 1-1                         [1, 1024, 1024]           33,554,432
├─ModuleList: 1-2                        --                        --
│    └─TransformerBlock: 2-1             [1, 1024, 1024]           --
│    │    └─MultiHeadAttention: 3-1      [1, 1024, 1024]           4,194,304
│    │    └─FeedForward: 3-2             [1, 1024, 1024]           6,291,456
│    └─TransformerBlock: 2-2             [1, 1024, 1024]           --
│    │    └─MultiHeadAttention: 3-3      [1, 1024, 1024]           4,194,304
│    │    └─FeedForward: 3-4             [1, 1024, 1024]           6,291,456
│    └─TransformerBlock: 2-3             [1, 1024, 1024]           --
│    │    └─MultiHeadAttention: 3-5      [1, 1024, 1024]           4,194,304
│    │    └─FeedForward: 3-6             [1, 1024, 1024]           6,291,456
│    └─TransformerBlock: 2-4       

In [53]:
# The weight tying is working correctly — torchinfo just doesn't detect shared parameters by default. 
# It counts each layer's parameters independently.

# This counts UNIQUE parameters (correct count with tying)
real_params = sum(p.numel() for p in gpt.parameters())
print(f"Actual unique parameters: {real_params:,}")

Actual unique parameters: 285,212,672


In [21]:
# These should all be True
print("Same object:", gpt.lm_head.weight is gpt.token_embedding.weight)
print("Same memory:", gpt.lm_head.weight.data_ptr() == gpt.token_embedding.weight.data_ptr())

Same object: True
Same memory: True


In [54]:
gpt.init_weights()

In [23]:
def check_model_dtypes(model):
    print(f"{'Layer Name':<40} | {'Type':<15} | {'Dtype'}")
    print("-" * 70)
    
    # Check Parameters
    for name, param in model.named_parameters():
        print(f"{name:<40} | Parameter       | {param.dtype}")
    
    # Check Buffers (like RoPE cos/sin)
    for name, buf in model.named_buffers():
        print(f"{name:<40} | Buffer          | {buf.dtype}")
check_model_dtypes(gpt)

Layer Name                               | Type            | Dtype
----------------------------------------------------------------------
token_embedding.weight                   | Parameter       | torch.float32
blocks.0.attn.key.weight                 | Parameter       | torch.float32
blocks.0.attn.query.weight               | Parameter       | torch.float32
blocks.0.attn.value.weight               | Parameter       | torch.float32
blocks.0.attn.proj.weight                | Parameter       | torch.float32
blocks.0.ff.proj_up.weight               | Parameter       | torch.float32
blocks.0.ff.proj_down.weight             | Parameter       | torch.float32
blocks.1.attn.key.weight                 | Parameter       | torch.float32
blocks.1.attn.query.weight               | Parameter       | torch.float32
blocks.1.attn.value.weight               | Parameter       | torch.float32
blocks.1.attn.proj.weight                | Parameter       | torch.float32
blocks.1.ff.proj_up.weight           

In [55]:
# sample input for gpt model
sample_input = torch.ones((1, 1), dtype=torch.int64)
sample_input

tensor([[1]])

In [56]:
gpt.eval()
with torch.no_grad(): # Good practice for inference to save memory
    op = gpt(sample_input)
    
print(f"Output shape: {op.shape}") # Should be (1, 1, vocab_size)
print(f"Max logit: {op.max().item():.4f}")
print(f"Predicted token ID: {op.argmax().item()}")

Output shape: torch.Size([1, 1, 32768])
Max logit: 14.5836
Predicted token ID: 1


Predicted token is always same as the input value. Why?
1. Weight Tying: We have set self.lm_head.weight = self.token_embedding.weight.
2. Zero-Init Blocks: init_weights function sets the output projection of every Transformer block to zero.
 - This means the blocks (Attention and MLP) contribute nothing to the residual stream at initialization.
 - The model effectively acts as an identity function for the embeddings: Embedding(token) -> Norm -> Logits.
3. Self-Similarity: Since the output head uses the same weights as the embedding, it calculates the dot product of the token's embedding vector with all other embedding vectors.
 - A vector's dot product with itself ($v \cdot v$) is almost always much higher than with other random vectors ($v \cdot w$).
 - Therefore, the model assigns the highest probability to the token that was input.

In [57]:
# loss
sample_input2 = torch.ones((1, 1), dtype=torch.int64)*100
print(f'loss when taget = input: {gpt(sample_input2, sample_input2)}')
print(f'loss when target != input: {gpt(sample_input2, sample_input2*2)}')

loss when taget = input: 0.025031551718711853
loss when target != input: 14.135931968688965


In [58]:
orig_model = gpt # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
# torch.compile: optimizing the model execution graph (JIT compilation)
model = torch.compile(gpt, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
num_params = sum(p.numel() for p in model.parameters())
print(f"Number of parameters: {num_params:,}")

Number of parameters: 285,212,672


In [59]:
class Muon(torch.optim.Optimizer):
    """
    Muon - MomentUm Orthogonalized by Newton-schulz

    https://kellerjordan.github.io/posts/muon/

    Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
    processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
    matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
    the advantage that it can be stably run in bfloat16 on the GPU.

    Some warnings:
    - This optimizer should not be used for the embedding layer, the final fully connected layer,
    or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
    - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.

    Arguments:
        lr: The learning rate used by the internal SGD.
        momentum: The momentum used by the internal SGD.
        nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
        ns_steps: The number of Newton-Schulz iteration steps to use.
    """
    def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):
        defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
        params: list[Tensor] = [*params]
        param_groups = []
        for size in {p.numel() for p in params}:
            group = dict(params=[p for p in params if p.numel() == size])
            param_groups.append(group)
        super().__init__(param_groups, defaults)

    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            params: list[Tensor] = group["params"]
            for p in params:
                g = p.grad
                assert g is not None
                state = self.state[p]
                if "momentum_buffer" not in state:
                    state["momentum_buffer"] = torch.zeros_like(g)
                buf: Tensor = state["momentum_buffer"]
                buf.lerp_(g, 1 - group["momentum"])
                g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
                g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
                p.add_(g, alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5)

In [60]:
embedding_lr = 0.2
weight_decay = 0.0
matrix_lr = 0.02
from functools import partial

optimizers = model.setup_optimizers(embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
adamw_optimizer, muon_optimizer = optimizers

Scaling the LR for the AdamW parameters ∝1/√(1024/768) = 0.866025


In [61]:
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f"Cached: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")

Allocated: 0.00 GB
Cached: 0.00 GB
