In [1]:
import torch

In [2]:
"""
Architecture Overview:
1. Embedding: Token IDs -> Vectors (wte)
2. Stack of Blocks (Repeated L times):
   - RMSNorm
   - Attention (Mixing info between tokens)
   - RMSNorm
   - MLP (Processing info within a token)
3. Final Norm
4. LMHead: Vectors -> Logits (Probabilities)
"""

import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
import math

@dataclass
class GPTConfig:
    """
    Hyperparameters for the model.
    """
    # ┌─────────────────────────────────────────────────────────┐
    # │           321M CONVERSATIONAL MODEL                     │
    # ├─────────────────────────────────────────────────────────┤
    # │  hidden_dim:        1024                                │
    # │  layers:            24                                  │
    # │  heads:             8                                   │
    # │  head_dim:          128                                 │
    # │  mlp_ratio:         3x                                  │
    # │  vocab_size:        32K                                 │
    # │  context_length:    1024                                │
    # │  embedding:         tied (input = output projection)    │
    # │  activation:        relu squared                        │
    # │  position encoding: RoPE                                │
    # ├─────────────────────────────────────────────────────────┤
    # │  TOTAL PARAMETERS:  243,269,632                         │
    # └─────────────────────────────────────────────────────────┘
    # No KV cache
    # No GQA

    hidden_dim: int = 512 # hidden dimension
    n_layers: int = 5 # May need to reduce to 22 or 20
    n_heads: int = 4 # head dimension = hidden_dim / n_heads = 128
    mlp_ratio: int = 3
    vocab_size: int = 32*1024
    # vocab_size: int = 50257
    sequence_len: int = 256


def norm(x):
    """
    RMSNorm (Root Mean Square Layer Normalization).
    Used to stabilize training by normalizing activation magnitudes.
    """
    # Purely functional rmsnorm with no learnable params
    return F.rms_norm(x, (x.size(-1),))


def apply_rotatory_positional_encoding(x, cos, sin):
    """
    Applies Rotary Positional Embeddings (RoPE).
    Rotates the query and key vectors to encode relative positions.
    """
    assert x.ndim == 4  # multihead attention
    d = x.shape[3] // 2
    x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
    y1 = x1 * cos + x2 * sin # rotate pairs of dims
    y2 = x1 * (-sin) + x2 * cos
    out = torch.cat([y1, y2], 3) # re-assemble
    out = out.to(x.dtype) # ensure input/output dtypes match
    return out

device_type = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
config = GPTConfig()
config

GPTConfig(hidden_dim=512, n_layers=5, n_heads=4, mlp_ratio=3, vocab_size=32768, sequence_len=256)

In [4]:
class MultiHeadAttention(nn.Module):
    """
    Multi-Head Causal Self Attention.

    1. Projects input to Q, K, V.
    2. Applies RoPE to Q, K for position info.
    3. Computes attention scores (Q @ K) to see how much each token cares about others. Aggregates values (V) based on scores.
    4. Projects output to mix information across heads.
    """
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.n_heads = config.n_heads
        self.hidden_dim = config.hidden_dim
        self.head_dim = config.hidden_dim // config.n_heads

        # Linear projections for Query, Key, Value
        self.key = nn.Linear(self.hidden_dim, self.head_dim * self.n_heads, bias=False)
        self.query = nn.Linear(self.hidden_dim, self.head_dim * self.n_heads, bias=False)
        self.value = nn.Linear(self.hidden_dim, self.head_dim * self.n_heads, bias=False)

        # Output projection ("o"): mixes results from all heads back into n_embd
        self.proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=False)

    def forward(self, x: torch.Tensor, cos_sin: torch.Tensor) -> torch.Tensor:
        B, T, C = x.size()

        # 1. Projects input to Q, K, V.
        # reshape to (B, T, n_heads, head_dim)
        k = self.key(x).view(B, T, self.n_heads, self.head_dim)
        q = self.query(x).view(B, T, self.n_heads, self.head_dim)
        v = self.value(x).view(B, T, self.n_heads, self.head_dim)

        # 2. Applies RoPE to Q, K for position info.
        cos, sin = cos_sin
        k, q = apply_rotatory_positional_encoding(k, cos, sin), apply_rotatory_positional_encoding(q, cos, sin)

        # 3. Computes attention scores (Q @ K) to see how much each token cares about others.
        q, k = norm(q), norm(k) # QK norm

        # make head be batch dim, i.e. (B, T, n_heads, head_dim) -> (B, n_heads, T, head_dim)
        # We are making the n_heads into a batch dimension so pytorch treats it as batches and
        # applies the attention function on each head separately in parallel
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)

        # Re-assemble the heads side by side and project back to residual stream
        y = y.transpose(1, 2).contiguous().view(B, T, C)

        # 4. Projects output to mix information across heads.
        y = self.proj(y)
        return y



In [5]:
# attn = MultiHeadAttention(config)
# attn

In [6]:
# for param in attn.parameters():
#     print(type(param), param.size())

In [7]:
%pip install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [8]:
from torchinfo import summary

# summary(attn, input_size=(1, config.sequence_len, config.hidden_dim), dtypes=[torch.float32])

In [9]:
class FeedForward(nn.Module):
    """
    Feed Forward Network (MLP).
    Processes each token independently (no mixing between tokens).
    Structure: Expand -> ReLU^2 -> Contract
    """
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.proj_up = nn.Linear(config.hidden_dim, config.hidden_dim * config.mlp_ratio, bias=False)
        self.proj_down = nn.Linear(config.hidden_dim * config.mlp_ratio, config.hidden_dim, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.proj_up(x)
        x = F.relu(x).square()
        # TODO: Check if swiglu is better for 3 x hidden_dim -
        # gelu and silu are alternatives but difference seems marginal so sticking with relu^2
        x = self.proj_down(x)
        return x

In [10]:
# ff = FeedForward(config)
# ff

In [11]:
# summary(ff, input_size=(1, config.sequence_len, config.hidden_dim), dtypes=[torch.float32])

In [12]:
class TransformerBlock(nn.Module):
    """
    A single Transformer Block.
    Contains:
    1. Attention (Communication)
    2. MLP (Computation)
    Both use Residual Connections (x + ...) and Pre-Norm.
    """
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.attn = MultiHeadAttention(config)
        self.ff = FeedForward(config)

    def forward(self, x: torch.Tensor, cos_sin: torch.Tensor) -> torch.Tensor:
        # Attention with residual connection
        x = x + self.attn(norm(x), cos_sin)
        # MLP with residual connection
        x = x + self.ff(norm(x))
        return x


In [13]:
# block = TransformerBlock(GPTConfig())
# block

In [14]:
# summary(block, input_size=(1, config.sequence_len, config.hidden_dim), dtypes=[torch.float32])

In [None]:
class GPT(nn.Module):
    """
    The full GPT model.
    Contains:
    1. Token Embedding
    2. Transformer Blocks (stacked)
    3. Final Normalization
    4. LM Head - Tied weights with token embedding
    """
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.config = config
        self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_dim)
        self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
        self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
        self.lm_head.weight = self.token_embedding.weight

        self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute
        # Why 10x? This provides a generous buffer for inference/generation, allowing the model
        # to generate sequences longer than its training length without recomputing embeddings.
        # Note: While the embeddings support 10x length, the model's quality degrades beyond ~1.5-2x
        # the training length due to unseen attention patterns. This buffer is for convenience,
        # not an expectation of good performance at 10x length. Memory cost is negligible.

        head_dim = config.hidden_dim // config.n_heads
        cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
        self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
        self.register_buffer("sin", sin, persistent=False)

    def forward(self, idx, targets=None, loss_reduction="mean") -> torch.Tensor:
        T = idx.shape[1]
        cos_sin = self.cos[:, :T], self.sin[:, :T] # truncate cache to current sequence length
        x = self.token_embedding(idx)
        x = norm(x)
        for block in self.blocks:
            x = block(x, cos_sin)
        x = norm(x)

        softcap = 15 # smoothly cap the logits to the range [-softcap, softcap]
        logits = self.lm_head(x)
        logits = logits.float() # switch to fp32 for logit softcap and loss computation
        logits = softcap * torch.tanh(logits / softcap) # squash the logits

        if targets is not None:
            # training: given the targets, compute and return the loss
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
            return loss
        else:
            # inference: just return the logits directly
            return logits

    def init_weights(self):
        """
        Initialize the full model in this one function for maximum clarity.

        embedding:     normal, std=1.0
        for each block:
            attn.c_q:        uniform, std=1/sqrt(n_embd)
            attn.c_k:        uniform, std=1/sqrt(n_embd)
            attn.c_v:        uniform, std=1/sqrt(n_embd)
            attn.c_proj:     zeros
            mlp.c_fc:        uniform, std=1/sqrt(n_embd)
            mlp.c_proj:      zeros
        """

        """
        Explanation:
        The initialization logic deviates from PyTorch defaults (Kaiming defaults) to improve training 
        stability and convergence for deep Transformers.
        
        Key Differences:
        1. Zero Initialization for Output Projections (c_proj):
           - Function: Sets the weights of the final linear layer in each block to zero.
           - Why: This ensures that at initialization, the residual blocks contribute nothing to the 
             residual stream (y = x + 0). The model effectively starts as an identity function, allowing 
             unimpeded gradient flow from top to bottom. This prevents vanishing/exploding gradients 
             and provides a stable starting point for the model to gradually learn features.

        2. Zero Initialization for LM Head:
           - Function: Sets the classifier weights to zero.
           - Why: Ensures all logits are initially zero, leading to a uniform probability distribution (1/V) 
             for the next token. This minimizes the initial loss to exactly log(V) and prevents the model 
             from starting with random biases towards arbitrary tokens.

        Custom initialization for Linear and Embedding layers.
        
        1. Controlled Variance (Linear Layers):
           - Formula: std = 1 / sqrt(fan_in) * min(1, sqrt(fan_out / fan_in))
           - Why: Standard Kaiming init often leads to activation variance that grows with depth in 
             Transformers. This custom initialization (ref: https://arxiv.org/pdf/2310.17813) stabilizes 
             activation variance across layers, specifically accounting for the network width.

        2. Unit Variance (Embeddings):
           - Function: Normal distribution with std=1.0.
           - Why: Ensures strong initial signal strength before it enters the first normalization layer.
        """
        # Embedding
        torch.nn.init.normal_(self.token_embedding.weight, mean=0.0, std=1.0)

        # Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal)
        n_embd = self.config.hidden_dim
        s = 3**0.5 * n_embd**-0.5 # sqrt(3) multiplier makes sure Uniform achieves the same std as Normal

        # Zero out the output projections of the blocks
        for block in self.blocks:
            torch.nn.init.zeros_(block.ff.proj_down.weight)
            torch.nn.init.zeros_(block.attn.proj.weight)
            torch.nn.init.uniform_(block.attn.query.weight, -s, s) # weights use Uniform to avoid outliers
            torch.nn.init.uniform_(block.attn.key.weight, -s, s)
            torch.nn.init.uniform_(block.attn.value.weight, -s, s)
            torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)

        # init the rotary embeddings
        head_dim = self.config.hidden_dim // self.config.n_heads
        cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
        self.cos, self.sin = cos, sin

        # Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
        if self.token_embedding.weight.device.type == "cuda":
            self.token_embedding.to(dtype=torch.bfloat16)

    def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
        # autodetect the device from model embeddings
        if device is None:
            device = self.token_embedding.weight.device
        # stride the channels
        channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
        inv_freq = 1.0 / (base ** (channel_range / head_dim))
        # stride the time steps
        t = torch.arange(seq_len, dtype=torch.float32, device=device)
        # calculate the rotation frequencies at each (time, channel) pair
        freqs = torch.outer(t, inv_freq)
        cos, sin = freqs.cos(), freqs.sin()
        cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
        cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
        return cos, sin

    def setup_optimizers(self, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0):
        """
        Sets up the optimizers.
        Uses AdamW for embeddings/head and Muon for internal linear layers.

        Detailed Explanation of Hybrid Strategy:
        ----------------------------------------
        We use two different optimizers because different parts of the Transformer have different
        geometric properties and optimization landscapes.

        1. Muon (for internal 2D matrices):
           - Applied to: Attention projections (c_q, c_k, c_v, c_proj) and MLP weights (c_fc, c_proj).
           - Mechanism: Muon forces weight *updates* to be orthogonal. In linear algebra, orthogonal
             transformations (like rotation or reflection) preserve the magnitude (norm) of the vector
             they act on.
           - Benefit: Deep networks suffer from vanishing/exploding gradients because signals get
             scaled up or down at every layer. By forcing updates to be orthogonal, Muon ensures
             signals propagate through the network without exploding in magnitude, allowing for
             much faster and more stable training of deep layers.

        2. AdamW (for embeddings & head):
           - Applied to: Token embeddings (wte) and the final output head (lm_head).
           - Reason: These parameters are not dense 2D matrices in the same sense (embeddings are
             lookup tables). The concept of "orthogonal updates" is mathematically ill-defined or
             harmful for vectors/lookups. AdamW is ideal here as it adapts learning rates per-parameter
             based on update frequency (handling the sparse nature of token updates).

        Do they conflict?
        No. Both optimizers step in directions derived from the same global loss gradient, so they
        optimize the same function. The risk is learning speed mismatch (one part learning faster
        than the other), which we handle by manually scaling the AdamW learning rate below.
        """
        model_dim = self.config.hidden_dim
        # ddp, rank, local_rank, world_size = get_dist_info()
        # Separate out all parameters into 3 groups (matrix, embedding, lm_head)
        matrix_params = list(self.blocks.parameters())
        embedding_params = list(self.token_embedding.parameters())
        assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params)

        # Create the AdamW optimizer for the embedding
        # Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
        dmodel_lr_scale = (model_dim / 768) ** -0.5
        # if rank == 0:
        print(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
        adam_groups = [
            dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
        ]
        adamw_kwargs = dict(betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay)
        AdamWFactory = partial(torch.optim.AdamW, fused=True)
        adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)

        # Create the Muon optimizer for the linear layers
        muon_kwargs = dict(lr=matrix_lr, momentum=0.95)
        muon_optimizer = Muon(matrix_params, **muon_kwargs)

        # Combine the two optimizers into one list
        optimizers = [adamw_optimizer, muon_optimizer]
        for opt in optimizers:
            for group in opt.param_groups:
                group["initial_lr"] = group["lr"]
        return optimizers

    def estimate_flops(self):
        """ Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311 """
        nparams = sum(p.numel() for p in self.parameters())
        nparams_embedding = self.token_embedding.weight.numel()
        l, h, q, t = self.config.n_layer, self.config.n_head, self.config.hidden_dim // self.config.n_head, self.config.sequence_len
        num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
        return num_flops_per_token
    
    def get_device(self):
        return self.token_embedding.weight.device
    
    @torch.inference_mode()
    def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
        """
        Naive autoregressive streaming inference.
        To make it super simple, let's assume:
        - batch size is 1
        - ids and the yielded tokens are simple Python lists and ints
        """
        assert isinstance(tokens, list)
        device = self.get_device()
        rng = None
        if temperature > 0:
            rng = torch.Generator(device=device)
            rng.manual_seed(seed)
        ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
        for _ in range(max_tokens):
            logits = self.forward(ids) # (B, T, vocab_size)
            logits = logits[:, -1, :] # (B, vocab_size)
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            if temperature > 0:
                logits = logits / temperature
                probs = F.softmax(logits, dim=-1)
                next_ids = torch.multinomial(probs, num_samples=1, generator=rng)
            else:
                next_ids = torch.argmax(logits, dim=-1, keepdim=True)
            ids = torch.cat((ids, next_ids), dim=1)
            token = next_ids.item()
            yield token

In [16]:
gpt = GPT(config)

In [17]:
gpt = gpt.to('cuda')

In [18]:
summary(gpt, input_size=(1, config.sequence_len), dtypes=[torch.long])

Layer (type:depth-idx)                   Output Shape              Param #
GPT                                      [1, 256, 32768]           --
├─Embedding: 1-1                         [1, 256, 512]             16,777,216
├─ModuleList: 1-2                        --                        --
│    └─TransformerBlock: 2-1             [1, 256, 512]             --
│    │    └─MultiHeadAttention: 3-1      [1, 256, 512]             1,048,576
│    │    └─FeedForward: 3-2             [1, 256, 512]             1,572,864
│    └─TransformerBlock: 2-2             [1, 256, 512]             --
│    │    └─MultiHeadAttention: 3-3      [1, 256, 512]             1,048,576
│    │    └─FeedForward: 3-4             [1, 256, 512]             1,572,864
│    └─TransformerBlock: 2-3             [1, 256, 512]             --
│    │    └─MultiHeadAttention: 3-5      [1, 256, 512]             1,048,576
│    │    └─FeedForward: 3-6             [1, 256, 512]             1,572,864
│    └─TransformerBlock: 2-4       

In [19]:
# The weight tying is working correctly — torchinfo just doesn't detect shared parameters by default.
# It counts each layer's parameters independently.

# This counts UNIQUE parameters (correct count with tying)
real_params = sum(p.numel() for p in gpt.parameters())
print(f"Actual unique parameters: {real_params:,}")

Actual unique parameters: 29,884,416


In [20]:
# These should all be True
print("Same object:", gpt.lm_head.weight is gpt.token_embedding.weight)
print("Same memory:", gpt.lm_head.weight.data_ptr() == gpt.token_embedding.weight.data_ptr())

Same object: True
Same memory: True


In [21]:
gpt.init_weights()

In [22]:
def check_model_dtypes(model):
    print(f"{'Layer Name':<40} | {'Type':<15} | {'Dtype'}")
    print("-" * 70)

    # Check Parameters
    for name, param in model.named_parameters():
        print(f"{name:<40} | Parameter       | {param.dtype}")

    # Check Buffers (like RoPE cos/sin)
    for name, buf in model.named_buffers():
        print(f"{name:<40} | Buffer          | {buf.dtype}")
check_model_dtypes(gpt)

Layer Name                               | Type            | Dtype
----------------------------------------------------------------------
token_embedding.weight                   | Parameter       | torch.bfloat16
blocks.0.attn.key.weight                 | Parameter       | torch.float32
blocks.0.attn.query.weight               | Parameter       | torch.float32
blocks.0.attn.value.weight               | Parameter       | torch.float32
blocks.0.attn.proj.weight                | Parameter       | torch.float32
blocks.0.ff.proj_up.weight               | Parameter       | torch.float32
blocks.0.ff.proj_down.weight             | Parameter       | torch.float32
blocks.1.attn.key.weight                 | Parameter       | torch.float32
blocks.1.attn.query.weight               | Parameter       | torch.float32
blocks.1.attn.value.weight               | Parameter       | torch.float32
blocks.1.attn.proj.weight                | Parameter       | torch.float32
blocks.1.ff.proj_up.weight          

In [23]:
# # sample input for gpt model
# sample_input = torch.ones((1, 1), dtype=torch.int64)
# sample_input

In [24]:
# gpt.eval()
# with torch.no_grad(): # Good practice for inference to save memory
#     op = gpt(sample_input)

# print(f"Output shape: {op.shape}") # Should be (1, 1, vocab_size)
# print(f"Max logit: {op.max().item():.4f}")
# print(f"Predicted token ID: {op.argmax().item()}")

Predicted token is always same as the input value. Why?
1. Weight Tying: We have set self.lm_head.weight = self.token_embedding.weight.
2. Zero-Init Blocks: init_weights function sets the output projection of every Transformer block to zero.
 - This means the blocks (Attention and MLP) contribute nothing to the residual stream at initialization.
 - The model effectively acts as an identity function for the embeddings: Embedding(token) -> Norm -> Logits.
3. Self-Similarity: Since the output head uses the same weights as the embedding, it calculates the dot product of the token's embedding vector with all other embedding vectors.
 - A vector's dot product with itself ($v \cdot v$) is almost always much higher than with other random vectors ($v \cdot w$).
 - Therefore, the model assigns the highest probability to the token that was input.

In [25]:
# # loss
# sample_input2 = torch.ones((1, 1), dtype=torch.int64)*100
# sample_input2 = sample_input2.to(device_type)
# print(f'loss when taget = input: {gpt(sample_input2, sample_input2)}')
# print(f'loss when target != input: {gpt(sample_input2, sample_input2*2)}')

In [95]:
# orig_model = gpt # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
# torch.compile: optimizing the model execution graph (JIT compilation)
gpt = torch.compile(gpt, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
# num_params = sum(p.numel() for p in model.parameters())
# print(f"Number of parameters: {num_params:,}")

In [27]:
from torch import Tensor

@torch.compile
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
    """
    Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
    quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
    of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
    zero even beyond the point where the iteration no longer converges all the way to one everywhere
    on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
    where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
    performance at all relative to UV^T, where USV^T = G is the SVD.
    """
    assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
    a, b, c = (3.4445, -4.7750,  2.0315)
    X = G.bfloat16()
    if G.size(-2) > G.size(-1):
        X = X.mT

    # Ensure spectral norm is at most 1
    X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
    # Perform the NS iterations
    for _ in range(steps):
        A = X @ X.mT
        B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
        X = a * X + B @ X

    if G.size(-2) > G.size(-1):
        X = X.mT
    return X


class Muon(torch.optim.Optimizer):
    """
    Muon - MomentUm Orthogonalized by Newton-schulz

    https://kellerjordan.github.io/posts/muon/

    Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
    processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
    matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
    the advantage that it can be stably run in bfloat16 on the GPU.

    Some warnings:
    - This optimizer should not be used for the embedding layer, the final fully connected layer,
    or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
    - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.

    Arguments:
        lr: The learning rate used by the internal SGD.
        momentum: The momentum used by the internal SGD.
        nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
        ns_steps: The number of Newton-Schulz iteration steps to use.
    """
    def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):
        defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
        params: list[Tensor] = [*params]
        param_groups = []
        for size in {p.numel() for p in params}:
            group = dict(params=[p for p in params if p.numel() == size])
            param_groups.append(group)
        super().__init__(param_groups, defaults)

    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            params: list[Tensor] = group["params"]
            for p in params:
                g = p.grad
                assert g is not None
                state = self.state[p]
                if "momentum_buffer" not in state:
                    state["momentum_buffer"] = torch.zeros_like(g)
                buf: Tensor = state["momentum_buffer"]
                buf.lerp_(g, 1 - group["momentum"])
                g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
                g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
                p.add_(g, alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5)

In [28]:
embedding_lr = 0.2
weight_decay = 0.0
matrix_lr = 0.02
from functools import partial

optimizers = gpt.setup_optimizers(embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
adamw_optimizer, muon_optimizer = optimizers

Scaling the LR for the AdamW parameters ∝1/√(512/768) = 1.224745


In [29]:
# works on gpu
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f"Cached: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")

Allocated: 0.09 GB
Cached: 0.24 GB


In [112]:
input_file_path = '/content/reddit_shard_00000.txt'
output_file_path = 'reddit_small.txt'
bytes_to_read = 1024*1024*50 # 50 MB

try:
    with open(input_file_path, 'r', encoding='utf-8') as infile:
        content = infile.read(bytes_to_read)

    with open(output_file_path, 'w', encoding='utf-8') as outfile:
        outfile.write(content)

    print(f"Successfully saved the first {bytes_to_read / 1024} KB of '{input_file_path}' to '{output_file_path}'.")
except FileNotFoundError:
    print(f"Error: The file '{input_file_path}' was not found.")
except Exception as e:
    print(f"An error occurred: {e}")

Successfully saved the first 51200.0 KB of '/content/reddit_shard_00000.txt' to 'reddit_small.txt'.


In [113]:
with open('reddit_small.txt', 'r', encoding='utf-8') as f:
    text = f.read()
    print(text[:10240])

The one feature the iPad is really missing. I don't care about the lack of camera. I never use the one on my MacBook, and even if I did the angle would be terrible on the iPad.

I don't care if third party apps can't run in the background. I don't listen to streaming music.

I don't care that the App Store is a closed system. I can jailbreak for myself and I think the closed system works better for most users.

The one feature I want is User Accounts and a Guest Account. If this device is meant to be a coffee table computer, it needs to be able to accomadate multiple users.
Dear Sydney Reddit'ers, Would you like any changes made to the style of this subreddit? I was going to subtly edit the style of the Sydney subreddit but then I found this post and realised that people have very strong opinions about how their reddit should look. 



So before I make any changes do you have any opinions or suggestions?
I skipped bail, ran away, and never got caught. AM(A)A. Long/short story, I went t

In [114]:
len(text)

23043723

In [115]:
# import tiktoken
# enc = tiktoken.get_encoding('gpt2')
# tokenizer
import pickle
import numpy as np
import tiktoken

with open("/content/tokenizer.pkl", "rb") as f:
    enc = pickle.load(f)
print(enc)
tokens = enc.encode("hello world")
print(tokens)


<Encoding 'rustbpe'>
[13726, 111, 1170]


In [116]:
input_tokens = enc.encode(text)
print(len(input_tokens))

5699090


In [117]:
B = 32

start = 0

def load_next_batch():
    global start
    end = (start + B*config.sequence_len + 1) % len(input_tokens)
    if end < start:
        start = 0
        end = B*config.sequence_len + 1
    buf = torch.tensor(input_tokens[start:end])
    start = end

    x = buf[:-1].view(B, config.sequence_len)
    y = buf[1:].view(B, config.sequence_len)

    return x, y

def get_gpu_stats():
    print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    print(f"Cached: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
    print()

In [118]:
def tiktokenize_tokens(tokens):
  for i, token in enumerate(tokens):
    print(f"\033[{i%4+41}m{enc.decode([token])}\033[0m", end="")
  print()

x_val, y_val = load_next_batch()
x_val = x_val.to(device_type)
y_val = y_val.to(device_type)

print(x_val.shape)
print(y_val.shape)

tiktokenize_tokens(x_val[0])
print("-"*50)
tiktokenize_tokens(y_val[0])

start = 0

torch.Size([32, 256])
torch.Size([32, 256])
[41mThe[0m[42m one[0m[43m feature[0m[44m the[0m[41m iPad[0m[42m is[0m[43m really[0m[44m missing[0m[41m.[0m[42m I[0m[43m don[0m[44m't[0m[41m care[0m[42m about[0m[43m the[0m[44m lack[0m[41m of[0m[42m camera[0m[43m.[0m[44m I[0m[41m never[0m[42m use[0m[43m the[0m[44m one[0m[41m on[0m[42m my[0m[43m Mac[0m[44mBook[0m[41m,[0m[42m and[0m[43m even[0m[44m if[0m[41m I[0m[42m did[0m[43m the[0m[44m angle[0m[41m would[0m[42m be[0m[43m terrible[0m[44m on[0m[41m the[0m[42m iPad[0m[43m.[0m[44m
[0m[41m
[0m[42mI[0m[43m don[0m[44m't[0m[41m care[0m[42m if[0m[43m third[0m[44m party[0m[41m apps[0m[42m can[0m[43m't[0m[44m run[0m[41m in[0m[42m the[0m[43m background[0m[44m.[0m[41m I[0m[42m don[0m[43m't[0m[44m listen[0m[41m to[0m[42m streaming[0m[43m music[0m[44m.[0m[41m
[0m[42m
[0m[43mI[0m[44m don[0m[41m't[0m[42m care[0m[

In [53]:
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()

optimizer = torch.optim.AdamW(gpt.parameters(), lr=3e-4)

In [None]:
import time

# test if this makes it faster --> No
# torch.backends.cudnn.conv.fp32_precision = 'tf32'
# torch.backends.cuda.matmul.fp32_precision = 'ieee'

def training_loop(iter=500):

  # validation set
  sample = enc.encode("The Project Gutenberg eBook of")
  buf = torch.tensor(sample)
  x_sample_base = buf.view(1, -1)
  x_sample_base = x_sample_base.to(device_type)

  # train
  t0 = time.time()
  for i in range(iter):
    optimizer.zero_grad()
    x, y = load_next_batch()
    x = x.to(device_type)
    y = y.to(device_type)

    with autocast_ctx:
        loss = gpt(x, y)
    loss.backward()
    optimizer.step()
    if i % 10 == 0:
        t1 = time.time()
        gpt.eval()
        with autocast_ctx:
          val_loss = gpt(x_val, y_val)
        gpt.train()

        print(f"step {i}, validation loss: {val_loss.item()}, loss: {loss} average time over last 10 steps = {(t1-t0)/10}")
        t0 = t1

        if i%100 == 0:

            gpt.eval()
            x_sample = x_sample_base
            for _ in range(25):
                with torch.no_grad(): # Good practice for inference to save memory
                    with autocast_ctx:
                        op = gpt(x_sample)
                    op = op[:,-1,:]

                    next_ids = torch.argmax(op, dim=-1, keepdim=True)
                    x_sample = torch.cat((x_sample, next_ids), dim=1)
                    next_token = next_ids.item()
                    # print(f"{enc.decode([next_token])}", end = "")
                    # print("\n", x_sample.shape)
            tiktokenize_tokens(x_sample[0])
            gpt.train()

            get_gpu_stats()


training_loop(5000)

step 0, validation loss: 6.024102687835693, average time over last 10 steps = 0.0023022890090942383
[41mThe[0m[42m Project[0m[43m Gutenberg[0m[44m eBook[0m[41m of[0m[42m the[0m[43m Project[0m[44m Gutenberg[0m[41m™[0m[42m electronic[0m[43m works[0m[44m in[0m[41m the[0m[42m Project[0m[43m Gutenberg[0m[44m™[0m[41m electronic[0m[42m works[0m[43m in[0m[44m the[0m[41m Project[0m[42m Gutenberg[0m[43m™[0m[44m electronic[0m[41m works[0m[42m in[0m[43m the[0m[44m Project[0m[41m Gutenberg[0m[42m™[0m
Allocated: 5.15 GB
Cached: 12.61 GB

step 10, validation loss: 6.1195573806762695, average time over last 10 steps = 0.8296464443206787
step 20, validation loss: 6.17494010925293, average time over last 10 steps = 0.8212828874588013
step 30, validation loss: 5.7937846183776855, average time over last 10 steps = 0.838221263885498
step 40, validation loss: 5.8871636390686035, average time over last 10 steps = 0.8483931303024292
step 50, validati

In [68]:
def test_model_generations(sample_text = "The Project Gutenberg eBook of", len_gen = 10):
  sample = enc.encode(sample_text)
  buf = torch.tensor(sample)
  sample = buf.view(1, -1).to(device_type)

  for _ in range(len_gen):
    with torch.no_grad(): # Good practice for inference to save memory
        with autocast_ctx:
            op = gpt(sample)
        op = op[:,-1,:]

        next_ids = torch.argmax(op, dim=-1, keepdim=True)
        sample = torch.cat((sample, next_ids), dim=1)
        next_token = next_ids.item()
  tiktokenize_tokens(sample[0])
  print()

test_model_generations(sample_text="Project Gutenberg is", len_gen=25)
test_model_generations(sample_text="He was the main", len_gen=50)
test_model_generations(sample_text="The place is well known for its", len_gen=50)


[41mProject[0m[42m Gutenberg[0m[43m is[0m[44m [0m[41m [0m[42m [0m[43m [0m[44m [0m[41m [0m[42m [0m[43m [0m[44m [0m[41m [0m[42m [0m[43m [0m[44m [0m[41m [0m[42m [0m[43m [0m[44m [0m[41m [0m[42m [0m[43m [0m[44m [0m[41m [0m[42m [0m[43m [0m[44m [0m

[41mHe[0m[42m was[0m[43m the[0m[44m main[0m[41m [0m[42m [0m[43m [0m[44m [0m[41m [0m[42m [0m[43m [0m[44m [0m[41m [0m[42m [0m[43m [0m[44m [0m[41m [0m[42m [0m[43m [0m[44m [0m[41m [0m[42m [0m[43m [0m[44m [0m[41m [0m[42m [0m[43m [0m[44m [0m[41m [0m[42m [0m[43m [0m[44m [0m[41m [0m[42m [0m[43m [0m[44m [0m[41m [0m[42m [0m[43m [0m[44m [0m[41m [0m[42m [0m[43m [0m[44m [0m[41m [0m[42m [0m[43m [0m[44m [0m[41m [0m[42m [0m[43m [0m[44m [0m[41m [0m[42m [0m

[41mThe[0m[42m place[0m[43m is[0m[44m well[0m[41m known[0m[42m for[0m[43m its[0m[44m [0m[41m [0m[42m [0m[43m [0m[44m [0m[41

In [102]:
print(len(text), start)

52428800 0


In [105]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [106]:
from datetime import datetime

# Get current timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

# Construct the filename with timestamp
filename = f"gpt_weights_{timestamp}.pth"

# save gpt model weights to drive
torch.save(gpt.state_dict(), f'/content/drive/MyDrive/{filename}')

## Multi optimizers

In [71]:
embedding_lr = 0.2 # learning rate for the embedding parameters (Adam)
unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)
weight_decay = 0.0 # weight decay for the embedding/unembedding parameters (Adam)
matrix_lr = 0.02 # learning rate for the matrix parameters (Muon) - Muon needs different LR scaling than Adam

grad_clip = 1.0 # gradient clipping value (0.0 = disabled): prevents gradient explosions
warmup_ratio = 0.0 # ratio of iterations for LR warmup: start slow then ramp up
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown: cosine decay at the end
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
resume_from_step = -1 # resume training from this step of the optimization (-1 = disable)
num_iterations = 10000 # explicit number of steps of the optimization (-1 = disable)

# Learning rate scheduler (Warmup -> Constant -> Warmdown/Cos Decay)
def get_lr_multiplier(it):
    warmup_iters = round(warmup_ratio * num_iterations)
    warmdown_iters = round(warmdown_ratio * num_iterations)
    if it < warmup_iters:
        return (it + 1) / warmup_iters
    elif it <= num_iterations - warmdown_iters:
        return 1.0
    else:
        progress = (num_iterations - it) / warmdown_iters
        return progress * 1.0 + (1 - progress) * final_lr_frac


def get_muon_momentum(it):
    frac = min(it / 300, 1)
    momentum = (1 - frac) * 0.85 + frac * 0.95
    return momentum

In [120]:
# TODO: check if compile helps?? --> Got stuck for 6 minutes. I ran out of patience and interrupted
# @torch.compile
def training_loop_multi_optimizer(iter = 10):

  x_val, y_val = load_next_batch()
  x_val = x_val.to(device_type)
  y_val = y_val.to(device_type)


  model = gpt

  iter+=1
  t0 = time.time()
  for i in range(iter): # Use 'i' as the loop counter
    model.zero_grad(set_to_none=True) # Zero gradients at the beginning of the current iteration

    x, y = load_next_batch()
    x = x.to(device_type)
    y = y.to(device_type)

    with autocast_ctx:
      loss = model(x, y) # Forward pass

    loss.backward() # Backward pass

    # Calculate learning rate and momentum based on current iteration 'i'
    lrm = get_lr_multiplier(i+500)

    for opt in optimizers:
        for group in opt.param_groups:
            group["lr"] = group["initial_lr"] * lrm

    muon_momentum = get_muon_momentum(i+500)

    # print(f"muon_momentum: {muon_momentum}, lrm: {lrm}")

    for group in muon_optimizer.param_groups:
        group["momentum"] = muon_momentum

    for opt in optimizers:
        opt.step() # Update weights


    # print(f"step {i}, loss: {loss.item()}")

    if i % 10 == 0:
        t1 = time.time()
        gpt.eval()
        with autocast_ctx:
          val_loss = gpt(x_val, y_val)
        gpt.train()

        print(f"step {i}, validation loss: {val_loss.item()}, loss: {loss} average time over last 10 steps = {(t1-t0)/10}")
        t0 = t1

    if i%50 == 0:
      get_gpu_stats()

      test_model_generations(sample_text="Project Gutenberg is", len_gen=25)
      test_model_generations(sample_text="He was the main", len_gen=50)
      test_model_generations(sample_text="The place is well known for its", len_gen=25)


training_loop_multi_optimizer(1000)


step 0, validation loss: 17.582365036010742, loss: 17.37175178527832 average time over last 10 steps = 0.08408477306365966
Allocated: 6.82 GB
Cached: 11.41 GB

[41mProject[0m[42m Gutenberg[0m[43m is[0m[44m
[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m

[41mHe[0m[42m was[0m[43m the[0m[44m main[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m[41m
[0m

In [121]:
training_loop_multi_optimizer(2000)

step 0, validation loss: 15.820544242858887, loss: 15.656527519226074 average time over last 10 steps = 0.08732683658599853
Allocated: 6.82 GB
Cached: 12.41 GB

[41mProject[0m[42m Gutenberg[0m[43m is[0m[44m
[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m

[41mHe[0m[42m was[0m[43m the[0m[44m main[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m[41m
[0m[42m
[0m[43m
[0m[44m
[0m[41m
[0

In [122]:
# Get current timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

# Construct the filename with timestamp
filename = f"gpt_weights_{timestamp}.pth"

# save gpt model weights to drive
torch.save(gpt.state_dict(), f'/content/drive/MyDrive/{filename}')

In [124]:

import gc

print(f"before\nAllocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f"Cached: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")

gc.collect()
# clear cuda cache
torch.cuda.empty_cache()
# works on gpu
print(f"After\nAllocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f"Cached: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")

before
Allocated: 6.83 GB
Cached: 12.41 GB
After
Allocated: 3.64 GB
Cached: 6.55 GB


In [127]:
training_loop_multi_optimizer(2000)
# Get current timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

# Construct the filename with timestamp
filename = f"gpt_weights_{timestamp}.pth"

# save gpt model weights to drive
torch.save(gpt.state_dict(), f'/content/drive/MyDrive/{filename}')

step 0, validation loss: 11.120105743408203, loss: 11.24026107788086 average time over last 10 steps = 0.08760862350463867
Allocated: 6.83 GB
Cached: 12.58 GB

[41mProject[0m[42m Gutenberg[0m[43m is[0m[44m	[0m[41m	[0m[42m	[0m[43m	[0m[44m	[0m[41m	[0m[42m	[0m[43m	[0m[44m	[0m[41m	[0m[42m	[0m[43m	[0m[44m	[0m[41m	[0m[42m	[0m[43m	[0m[44m	[0m[41m	[0m[42m	[0m[43m	[0m[44m	[0m[41m	[0m[42m	[0m[43m	[0m[44m	[0m

[41mHe[0m[42m was[0m[43m the[0m[44m main[0m[41m [0m[42m	[0m[43m	[0m[44m	[0m[41m	[0m[42m	[0m[43m	[0m[44m	[0m[41m	[0m[42m	[0m[43m	[0m[44m	[0m[41m	[0m[42m	[0m[43m	[0m[44m	[0m[41m	[0m[42m	[0m[43m	[0m[44m	[0m[41m	[0m[42m	[0m[43m	[0m[44m	[0m[41m	[0m[42m	[0m[43m	[0m[44m	[0m[41m	[0m[42m	[0m[43m	[0m[44m	[0m[41m	[0m[42m	[0m[43m	[0m[44m	[0m[41m	[0m[42m	[0m[43m	[0m[44m	[0m[41m	[0m[42m	[0m[43m	[0m[44m	[0m[41m	[0m[42m	[0m[43m	[0m[44m	[0m[41m	[0m