# Experimental

Experiments before they're integrated into the main codebase.

In [2]:
import os
import requests
from typing import Literal

In [3]:
PWD = os.getcwd()
DATA_DIR = os.path.join(PWD, "data")
INPUT_DATA_URL = ("https://raw.githubusercontent.com/karpathy/char-rnn/"
                  "master/data/tinyshakespeare/input.txt")

In [4]:
def fetch_input_data() -> str:
    """Gets the input data, caching it for easy access."""
    input_file_path = os.path.join(DATA_DIR, "input.txt")
    if not os.path.exists(input_file_path):
        with open(input_file_path, "w", encoding="utf-8") as f:
            f.write(requests.get(INPUT_DATA_URL).text)
    
    with open(input_file_path, "r", encoding="utf-8") as f:
        return f.read()

In [5]:
input_text = fetch_input_data()

In [6]:
print(len(input_text))

1115394


In [7]:
chars = sorted(list(set(input_text)))
vocab_size = len(chars)
print("".join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [8]:
string_to_int = {c: i for i, c in enumerate(chars)}
int_to_string = {i: c for i, c in enumerate(chars)}

encode = lambda s: [string_to_int[c] for c in s]
decode = lambda l: "".join([int_to_string[i] for i in l])

In [9]:
print(encode("hello"))
print(decode(encode("hello")))

[46, 43, 50, 50, 53]
hello


In [10]:
import mlx.core as mx
data = mx.array(encode(input_text), dtype=mx.int64)
print(data.shape, data.dtype)
print(data[:1000])

(1115394,) mlx.core.int64
array([18, 47, 56, ..., 8, 0, 0], dtype=int64)


In [11]:
# Split the data
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

In [12]:
block_size = 8
train_data[:block_size + 1]

array([18, 47, 56, ..., 15, 47, 58], dtype=int64)

In [13]:
x = train_data[:block_size]
y = train_data[1:block_size + 1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t].item()
    print(f"when input is {context} the target is {target}")

when input is array([18], dtype=int64) the target is 47
when input is array([18, 47], dtype=int64) the target is 56
when input is array([18, 47, 56], dtype=int64) the target is 57
when input is array([18, 47, 56, 57], dtype=int64) the target is 58
when input is array([18, 47, 56, 57, 58], dtype=int64) the target is 1
when input is array([18, 47, 56, 57, 58, 1], dtype=int64) the target is 15
when input is array([18, 47, 56, ..., 58, 1, 15], dtype=int64) the target is 47
when input is array([18, 47, 56, ..., 1, 15, 47], dtype=int64) the target is 58


In [14]:
mx.random.seed(1337)
batch_size = 4 # number of independent sequences to train on in parallel
block_size = 8 # maximum context length for predictions

def get_batch(split: Literal["train", "val"]) -> tuple[mx.array, mx.array]:
    data = train_data if split == "train" else val_data
    ix = mx.random.randint(0, len(data) - block_size, [batch_size])
    # gets `batch_size` blocks stacked
    x = mx.stack([data[i.item():i.item() + block_size] for i in ix])
    # it's shifted to compute the target vectorized
    y = mx.stack([data[i.item() + 1:i.item() + block_size + 1] for i in ix])
    return x, y

In [15]:
xb, yb = get_batch("train")

print("inputs:")
print(xb.shape)
print(xb)
print("targets:")
print(yb.shape)
print(yb)

# We can observe they match but yb is shifted by one
for b in range(batch_size):
    for t in range(block_size):
        context = xb[b, :t+1]
        target = yb[b, t].item()
        print(f"when input is {context} the target is {target}")

inputs:
(4, 8)
array([[47, 45, 52, ..., 47, 52, 1],
       [57, 0, 16, ..., 1, 57, 43],
       [40, 56, 53, ..., 11, 1, 40],
       [1, 46, 39, ..., 1, 61, 39]], dtype=int64)
targets:
(4, 8)
array([[45, 52, 57, ..., 52, 1, 45],
       [0, 16, 47, ..., 57, 43, 43],
       [56, 53, 49, ..., 1, 40, 43],
       [46, 39, 60, ..., 61, 39, 56]], dtype=int64)
when input is array([47], dtype=int64) the target is 45
when input is array([47, 45], dtype=int64) the target is 52
when input is array([47, 45, 52], dtype=int64) the target is 57
when input is array([47, 45, 52, 57], dtype=int64) the target is 1
when input is array([47, 45, 52, 57, 1], dtype=int64) the target is 47
when input is array([47, 45, 52, 57, 1, 47], dtype=int64) the target is 52
when input is array([47, 45, 52, ..., 1, 47, 52], dtype=int64) the target is 1
when input is array([47, 45, 52, ..., 47, 52, 1], dtype=int64) the target is 45
when input is array([57], dtype=int64) the target is 0
when input is array([57, 0], dtype=int6

In [16]:
import mlx.nn as nn

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size: int):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, vocab_size)
    
    def __call__(self, idx: mx.array) -> mx.array:
        return self.token_embedding(idx)
    
    def generate(self, idx: mx.array, max_new_tokens: int) -> mx.array:
        for _ in range(max_new_tokens):
            logits = self(idx)
            logits = logits[:, -1, :]
            idx_next = mx.random.categorical(logits, num_samples=1, axis=-1)
            idx = mx.concatenate([idx, idx_next], axis=1)
        # this is actually going to return 101 tokens since the input counts
        return idx


def loss_fn(model: nn.Module, x: mx.array, y: mx.array) -> mx.array:
    return mx.mean(nn.losses.cross_entropy(model(x), y))

In [17]:
model = BigramLanguageModel(vocab_size)
logits = model(xb)
loss = loss_fn(model, xb, yb)
print(logits.shape, loss.shape)

(4, 8, 65) ()


In [18]:
input = mx.zeros((1, 1), dtype=mx.int64)
decode(model.generate(input, 100).reshape(-1).tolist())

"\n.P.e'wn,CZsvq gP-f$fvW3aypokkuSEz?Paw:YCj?M;x\npctpxMvdJMlTZrmCZhPRjYRJUfTgldWbqlwXxc CHIWuAFYEBlwJrb"

In [19]:
import mlx.optimizers as optim
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
optimizer = optim.AdamW(learning_rate=1e-3)

num_epochs = 10000
for epoch in range(num_epochs):
    xb, yb = get_batch("train")

    loss, grads = loss_and_grad_fn(model, xb, yb)
    optimizer.update(model, grads)
    mx.eval(model.parameters(), optimizer.state)

print(loss_fn(model, *get_batch("train")).item())

2.4868011474609375


In [20]:
input = mx.zeros((1, 1), dtype=mx.int64)
print(decode(model.generate(input, 100).reshape(-1).tolist()))


RI d tloul ilie om toour t
IEmbe d, hthithot whars shieiststh'stet schontoumy mced bliserved isty HE


## Mathematical Trick in self-attention

In [22]:
mx.random.seed(1337)
B, T, C = shape = [4, 8, 2]  # Batch, Time, Channels
x = mx.random.normal(shape)
x.shape

(4, 8, 2)

In [45]:
# We want x[b,t] = mean_{i<=t} x[b,i]
xbow = mx.zeros(shape)
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1]  # -> [t, C]
        xbow[b, t] = mx.mean(xprev, axis=0)
xbow

array([[[1.20038, -0.102716],
        [1.32946, -0.714395],
        [0.647708, -0.289021],
        ...,
        [0.098629, -0.212],
        [0.0613602, -0.206279],
        [-0.0196391, -0.192657]],
       [[-0.88286, 1.90818],
        [0.272445, 1.62297],
        [-0.025046, 1.02757],
        ...,
        [0.144306, 0.739237],
        [0.14616, 0.550131],
        [-0.16644, 0.507876]],
       [[0.0548329, -0.871395],
        [0.0152365, -0.370004],
        [-0.421962, -0.2805],
        ...,
        [-0.432742, 0.136636],
        [-0.419489, 0.203147],
        [-0.342733, 0.326026]],
       [[-0.977191, 0.433173],
        [-0.394541, 0.021305],
        [-0.432306, -0.325967],
        ...,
        [0.00372714, -0.381902],
        [0.0797167, -0.157575],
        [0.1951, -0.193353]]], dtype=float32)

In [43]:
wei = mx.tril(mx.ones((T, T)))
wei = wei / wei.sum(1, keepdims=True)  # normalize
xbow2 = wei @ x

In [44]:
xbow2

array([[[1.20038, -0.102716],
        [1.32946, -0.714395],
        [0.647708, -0.289021],
        ...,
        [0.098629, -0.212],
        [0.0613602, -0.206279],
        [-0.0196391, -0.192657]],
       [[-0.88286, 1.90818],
        [0.272445, 1.62297],
        [-0.025046, 1.02757],
        ...,
        [0.144306, 0.739237],
        [0.14616, 0.550131],
        [-0.16644, 0.507876]],
       [[0.0548329, -0.871395],
        [0.0152365, -0.370004],
        [-0.421962, -0.2805],
        ...,
        [-0.432742, 0.136636],
        [-0.419489, 0.203147],
        [-0.342733, 0.326026]],
       [[-0.977191, 0.433173],
        [-0.394541, 0.021305],
        [-0.432306, -0.325967],
        ...,
        [0.00372712, -0.381902],
        [0.0797167, -0.157575],
        [0.1951, -0.193353]]], dtype=float32)