In [255]:
import math
import time
from dataclasses import dataclass
from datetime import datetime
from functools import partial

import tiktoken
import mlx.core as mx
import mlx.nn as nn
import mlx.nn.losses as F
import mlx.optimizers as optim
from mlx.utils import tree_flatten

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0

        self.n_heads = config.n_head
        self.n_embd = config.n_embd
        self.causal_mask = CausalSelfAttention.create_additive_causal_mask(config.block_size, dtype=config.dtype)

        self.query_proj = nn.Linear(self.n_embd, self.n_embd)
        self.key_proj = nn.Linear(self.n_embd, self.n_embd)
        self.value_proj = nn.Linear(self.n_embd, self.n_embd)
        self.out_proj = nn.Linear(self.n_embd, self.n_embd)

    def __call__(self, x):
        B, T, C = x.shape
        # calculate query, key, value for all heads
        q = self.query_proj(x) # (B, T, C) -> (B, T, C)
        k = self.key_proj(x) # (B, T, C) -> (B, T, C)
        v = self.value_proj(x) # (B, T, C) -> (B, T, C)

        # reshape query, key, value to batch over n_batches x n_heads
        #   - this way we can compute attention for all heads at once (i.e. multi-head attention) with a single matrix multiply
        #   - nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
        q = mx.unflatten(q, -1, (self.n_heads, -1)).transpose(0, 2, 1, 3) # (B, T, C) -> (B, T, nh, hs) -> (B, nh, T, hs)
        k = mx.unflatten(k, -1, (self.n_heads, -1)).transpose(0, 2, 1, 3) # (B, T, C) -> (B, T, nh, hs) -> (B, nh, T, hs)
        v = mx.unflatten(v, -1, (self.n_heads, -1)).transpose(0, 2, 1, 3) # (B, T, C) -> (B, T, nh, hs) -> (B, nh, T, hs)

        # causal flash attention
        scale = math.sqrt(1 / q.shape[-1])
        output = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=self.causal_mask[:T, :T]) # 3x(B, nh, T, hs) -> (B, nh, T, hs)

        # re-assemble all head outputs side by side and project out
        output = output.transpose(0, 2, 1, 3).flatten(-2, -1) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
        return self.out_proj(output) # (B, T, C) -> (B, T, C)

    @staticmethod
    def create_additive_causal_mask(N: int, dtype = mx.float32):
        indices = mx.arange(N)
        mask = indices[:, None] < indices[None]
        mask = mask.astype(dtype) * mx.finfo(dtype).min
        return mask


class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)

    def __call__(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x


class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def __call__(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


@dataclass
class GPTConfig:
    block_size: int = 1024 # max sequence length
    vocab_size: int = 50257 # number of tokens: 50,000 BPE merges + 256 byte tokes + 1<|endoftext|>
    n_layer: int = 12 # number of layers
    n_head: int = 12 # number of heads
    n_embd: int = 768 # embedding dimension
    dtype = mx.bfloat16
    # NOTE: head_size = n_embd / n_head = 64  # embedding dimension of each attention head


class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.transformer = dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            h = [Block(config) for _ in range(config.n_layer)],
            ln_f = nn.LayerNorm(config.n_embd),
        )
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # weight sharing scheme (refer to [1] in play.ipynb)
        self.transformer['wte'].weight = self.lm_head.weight

    def __call__(self, idx):
        # idx is of shape (B, T)
        B, T = idx.shape
        assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
        # forward the token and position embeddings
        pos = mx.arange(0, T, dtype=mx.int32)  # shape (T)
        pos_emb = self.transformer['wpe'](pos)  # position embeddings of shape (T, n_embd)
        tok_emb = self.transformer['wte'](idx)  # token embeddings of shape (B, T, n_embd)
        x = tok_emb + pos_emb  # (B, T, n_embd) + (T, n_embd) -> (B, T, n_embd)
        # forward the blocks of the transformer
        for block in self.transformer['h']:
            x = block(x)
        # forward the final layernorm and the classifier
        x = self.transformer['ln_f'](x)
        return self.lm_head(x)  # (B, T, vocab_size)

In [4]:
class DataLoaderLite:
    def __init__(self, path, batch_shape):
        self.B = batch_shape[0]
        self.T = batch_shape[1]

        # at init load tokens from disk and store them in memory
        with open(path, 'r') as f:
            text = f.read()
        enc = tiktoken.get_encoding('gpt2')
        tokens = enc.encode(text)
        self.tokens = mx.array(tokens)
        print(f"loaded {len(self.tokens)} tokens")
        print(f"1 epoch = {len(self.tokens) // (self.B * self.T)} batches")

        # state
        self.current_position = 0

    def next_batch(self):
        B, T = self.B, self.T
        buf = self.tokens[self.current_position : self.current_position + B * T + 1]
        x = (buf[:-1]).reshape((B, T)) # inputs
        y = (buf[1:]).reshape((B, T)) # targets
        # advance the position in the tensor
        self.current_position += B * T
        # if loading the next batch would be out of bounds, reset
        if self.current_position + (B * T + 1) > len(self.tokens):
            self.current_position = 0
        return x, y

In [5]:
n_batch = 10

gpt_config = GPTConfig()
train_loader = DataLoaderLite('res/tinyshakespeare.txt', (16, gpt_config.block_size))

model = GPT(gpt_config)
model.set_dtype(gpt_config.dtype)
mx.eval(model.parameters())
nparams = sum(x.size for k, x in tree_flatten(model.parameters()) if "embedding" not in k)
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")

optimizer = optim.AdamW(learning_rate=3e-4, betas=[0.9, 0.95], eps=1e-8, weight_decay=0.1)

loaded 338025 tokens
1 epoch = 20 batches
Training a transformer with 167.484 M parameters


In [256]:
def loss_fn(model, x, y, reduce=True):
    logits = model(x)
    losses = F.cross_entropy(logits, y)
    return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2))

state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(inputs, targets):
    loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
    loss, grads = loss_and_grad_fn(model, inputs, targets)
    optimizer.update(model, grads)
    return loss

start = datetime.now()
num_epochs = 1000
for i in range(num_epochs):
    t0 = time.time()
    x, y = train_loader.next_batch()

    loss = step(x, y)
    mx.eval(state)

    t1 = time.time()
    dt = (t1 - t0) * 1000  # time difference in milliseconds
    tokens_per_sec = (train_loader.B * train_loader.T) / (t1 - t0)
    iterations_per_sec = 1 / (t1 - t0)
    print(f"{datetime.now()} - step {i}, loss: {loss:.4f}, dt: {dt:.2f}ms, tok/sec: {tokens_per_sec:.2f} tokens/sec")

end = datetime.now()
print(f"total time: {end - start}")
print(f"average tokens/sec: {(train_loader.B * train_loader.T * num_epochs) / (end.timestamp() - start.timestamp())}")

2025-03-20 20:21:55.071798 - step 0, loss: 0.3066, dt: 4822.96ms, tok/sec: 3397.08 tokens/sec
2025-03-20 20:21:56.897179 - step 1, loss: 0.2832, dt: 1781.57ms, tok/sec: 9196.38 tokens/sec
2025-03-20 20:21:58.716430 - step 2, loss: 0.3809, dt: 1782.89ms, tok/sec: 9189.56 tokens/sec
2025-03-20 20:22:00.542471 - step 3, loss: 0.3809, dt: 1789.15ms, tok/sec: 9157.40 tokens/sec
2025-03-20 20:22:02.360690 - step 4, loss: 0.3262, dt: 1777.10ms, tok/sec: 9219.50 tokens/sec
2025-03-20 20:22:04.139384 - step 5, loss: 0.3379, dt: 1778.26ms, tok/sec: 9213.50 tokens/sec
2025-03-20 20:22:05.919853 - step 6, loss: 0.3047, dt: 1780.05ms, tok/sec: 9204.24 tokens/sec
2025-03-20 20:22:07.697025 - step 7, loss: 0.3555, dt: 1776.77ms, tok/sec: 9221.21 tokens/sec
2025-03-20 20:22:09.474194 - step 8, loss: 0.3379, dt: 1776.74ms, tok/sec: 9221.41 tokens/sec
2025-03-20 20:22:11.266750 - step 9, loss: 0.3027, dt: 1792.14ms, tok/sec: 9142.13 tokens/sec
2025-03-20 20:22:13.075549 - step 10, loss: 0.2988, dt: 1778

KeyboardInterrupt: 

In [257]:
num_return_sequences = 5
max_length = 30

# encode prefix tokens
enc = tiktoken.get_encoding('gpt2')
# tokens = enc.encode("Hello, I'm a language model,")
tokens = enc.encode("Second Citizen:")
tokens = mx.array(tokens, dtype=mx.int32)  # (8 tokens,)
x = mx.repeat(mx.expand_dims(tokens, axis=0), num_return_sequences, axis=0)  # (5 rows, 8 tokens)

# generate! right now x is (B, T) where B = 5, T = 8
while x.shape[1] < max_length:
    # forward the model to get the logits
    logits = model(x)  # (B, T, vocab_size)
    # take the logits at the last position
    logits = logits[:, -1, :]  # (B, vocab_size)

    # # get the probabilities
    # probs = nn.softmax(logits, axis=-1)
    # # do top-k sampling of 50 (huggingface pipeline default)
    # k = 50  # Number of top elements
    # # Get the sorted indices in descending order
    # topk_indices = mx.argsort(probs, axis=-1)[:, -k:] # (B, 50)
    # # Use the indices to gather the top K values
    # topk_probs = mx.take_along_axis(probs, indices=topk_indices, axis=-1) # (B, 50)

    # select a token from the top probabilities
    ix = mx.random.categorical(logits, num_samples=1)  # (B, 1)
    # append to the sequence
    x = mx.concatenate([x, ix], axis=1)

# print the generated text
for i in range(x.shape[0]):
    tokens = x[i, :max_length].tolist()
    decoded = enc.decode(tokens)
    print(">", decoded)

> Second Citizen:
Here anchors! the hair
Exceptity and we receive
All: that! my hand not true twenty times from my knees for
> Second Citizen:
Before we calls me! another! Was of theOUGH,,
under traitor, hear not obedient be Men them:
words
> Second Citizen:
They did us down before but theps.

Third Gentleman: he hath't? he was another floods?
That!--
> Second Citizen:
Before we proceed yourlander to live, my power till
And I desperately we must quite forth hat or from his ruin and my
> Second Citizen:
Before I meant, sir; ye could a
The strokes, which pilgr' service'st, few Rome, to are faults


In [None]:
# ----------------------------------------------------------------------------------
# Debugging
# ----------------------------------------------------------------------------------

In [246]:
# encode prefix tokens
enc = tiktoken.get_encoding('gpt2')
tokens = enc.encode("Second Citizen:\n")
tokens = mx.array(tokens, dtype=mx.int32)
x = mx.expand_dims(tokens, axis=0)
print(x)

array([[12211, 22307, 25, 198]], dtype=int32)


In [247]:
logits = model(x)  # (B, T, vocab_size)
print(logits[0, 2])
logits = logits[:, -1, :]  # (B, vocab_size)
print(logits)
probs = nn.softmax(logits, axis=-1)
print(probs)

array([2.4375, -16.75, -16.75, ..., -16.875, -5.90625, -16.75], dtype=bfloat16)
array([[1.58594, -6.09375, -6.34375, ..., -6.25, -1.85938, -6.28125]], dtype=bfloat16)
array([[9.53674e-05, 4.23752e-08, 3.28291e-08, ..., 3.74857e-08, 2.96533e-06, 3.74857e-08]], dtype=bfloat16)


In [248]:
print(mx.sort(probs, axis=-1))

array([[2.1304e-08, 2.2701e-08, 2.2701e-08, ..., 0.0654297, 0.12207, 0.177734]], dtype=bfloat16)


In [249]:
k = 50  # Number of top elements
# Get the sorted indices in descending order
topk_indices = mx.argsort(probs, axis=-1)[:, -k:] # (B, 50)
# Use the indices to gather the top K values
topk_probs = mx.take_along_axis(probs, indices=topk_indices, axis=-1) # (B, 50)
print(topk_indices)
print(topk_probs)
print(enc.decode(topk_indices[0].tolist()))
print("\\n =", enc.encode("\n"))

array([[22788, 11486, 17821, ..., 817, 1135, 8421]], dtype=uint32)
array([[0.00174713, 0.00180054, 0.00180054, ..., 0.0654297, 0.12207, 0.177734]], dtype=bfloat16)
SirYetTrueAyClGRThirdyourSecondMyThatKINGItPleaseHowFSVIfAnTheyToHereMoreWhereThetheSoY
QMadHDGoodWellNPMENButFirstAndWhyNoYouWhatIThWeBefore
\n = [198]


In [250]:
# PROBLEM: as is, the categorical sampling is basically fucking random
for i in range(10):
    ix = mx.random.categorical(topk_probs, num_samples=1)
    print(ix.item())

46
20
37
47
4
44
44
45
34
41


In [253]:
# The fix is to pass just the logits... we could just forget about the topk nonsense ()even though it seems to work
print(topk_probs)
for i in range(10):
    ix = mx.random.categorical(logits, num_samples=1)
    print(ix.item(), enc.decode([ix.item()]), probs[0][ix.item()].item())
# much better!

array([[0.00174713, 0.00180054, 0.00180054, ..., 0.0654297, 0.12207, 0.177734]], dtype=bfloat16)
2514 To 0.004730224609375
5195 Why 0.0247802734375
49370 Mist 0.0010223388671875
20840  surveyed 3.748573362827301e-08
8421 Before 0.177734375
534  your 0.0003871917724609375
1537 But 0.0205078125
4053 well 0.000705718994140625
1135 We 0.1220703125
8421 Before 0.177734375


In [191]:
ix = mx.random.categorical(logits, num_samples=1)
x = mx.concatenate([x, ix], axis=1)
print(x)

array([[12211, 22307, 25, ..., 198, 198, 198]], dtype=int64)


In [254]:
# Full run
enc = tiktoken.get_encoding('gpt2')
tokens = enc.encode("Second Citizen:\n")
tokens = mx.array(tokens, dtype=mx.int32)
x = mx.expand_dims(tokens, axis=0)
for i in range(100):
    logits = model(x)  # (B, T, vocab_size)
    logits = logits[:, -1, :]  # (B, vocab_size)
    probs = nn.softmax(logits, axis=-1)
    ix = mx.random.categorical(2 * logits, num_samples=1)  # (B, 1)
    print(enc.decode([ix.item()]), "->", probs[0, ix.item()].item())
    x = mx.concatenate([x, ix], axis=1)

print(x.tolist())
print(enc.decode(x.tolist()[0]))

We -> 0.1220703125
 hear -> 0.080078125
 me -> 1.0
, -> 0.92578125
 not -> 0.296875
, -> 0.111328125
 see -> 0.5234375
 your -> 0.1962890625
woman -> 0.013427734375
 do -> 0.5859375
; -> 0.1640625

 -> 1.0
The -> 0.203125
 service -> 0.05419921875
, -> 0.87109375
' -> 0.1494140625
er -> 0.93359375

 -> 0.5078125
I -> 0.390625
'll -> 0.671875
 prove -> 0.0830078125
 his -> 0.9921875
 own -> 0.12890625
 heart -> 0.28125
, -> 1.0
' -> 0.93359375
er -> 0.76171875

 -> 0.984375
Like -> 0.1796875
 worth -> 0.0018157958984375
 all -> 0.265625
 are -> 0.9921875
 not -> 0.734375
 in -> 0.119140625
't -> 0.1953125

 -> 0.36328125
Even -> 0.48046875
 in -> 1.0
 your -> 0.9921875
 times -> 0.212890625
 i -> 0.60546875
' -> 1.0
er -> 0.90625
 I -> 0.69140625
 have -> 0.671875
 left -> 0.197265625
; -> 1.0

 -> 1.0
So -> 0.21875
 locks -> 0.0400390625
 o -> 0.53515625
' -> 0.984375
 the -> 0.9140625
 court -> 0.0966796875
. -> 0.921875

 -> 0.9921875
How -> 0.0322265625
 might -> 0.48046875
 have ->