# Experimental

Experiments before they're integrated into the main codebase.

In [108]:
import os
import requests
from typing import Literal, Optional

In [8]:
PWD = os.getcwd()
DATA_DIR = os.path.join(PWD, "data")
INPUT_DATA_URL = ("https://raw.githubusercontent.com/karpathy/char-rnn/"
                  "master/data/tinyshakespeare/input.txt")

In [9]:
def fetch_input_data() -> str:
    """Gets the input data, caching it for easy access."""
    input_file_path = os.path.join(DATA_DIR, "input.txt")
    if not os.path.exists(input_file_path):
        with open(input_file_path, "w", encoding="utf-8") as f:
            f.write(requests.get(INPUT_DATA_URL).text)
    
    with open(input_file_path, "r", encoding="utf-8") as f:
        return f.read()

In [17]:
input_text = fetch_input_data()

In [18]:
print(len(input_text))

1115394


In [19]:
chars = sorted(list(set(input_text)))
vocab_size = len(chars)
print("".join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [13]:
string_to_int = {c: i for i, c in enumerate(chars)}
int_to_string = {i: c for i, c in enumerate(chars)}

encode = lambda s: [string_to_int[c] for c in s]
decode = lambda l: "".join([int_to_string[i] for i in l])

In [14]:
print(encode("hello"))
print(decode(encode("hello")))

[46, 43, 50, 50, 53]
hello


In [20]:
import mlx.core as mx
data = mx.array(encode(input_text), dtype=mx.int64)
print(data.shape, data.dtype)
print(data[:1000])

(1115394,) mlx.core.int64
array([18, 47, 56, ..., 8, 0, 0], dtype=int64)


In [21]:
# Split the data
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

In [25]:
block_size = 8
train_data[:block_size + 1]

array([18, 47, 56, ..., 15, 47, 58], dtype=int64)

In [29]:
x = train_data[:block_size]
y = train_data[1:block_size + 1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t].item()
    print(f"when input is {context} the target is {target}")

when input is array([18], dtype=int64) the target is 47
when input is array([18, 47], dtype=int64) the target is 56
when input is array([18, 47, 56], dtype=int64) the target is 57
when input is array([18, 47, 56, 57], dtype=int64) the target is 58
when input is array([18, 47, 56, 57, 58], dtype=int64) the target is 1
when input is array([18, 47, 56, 57, 58, 1], dtype=int64) the target is 15
when input is array([18, 47, 56, ..., 58, 1, 15], dtype=int64) the target is 47
when input is array([18, 47, 56, ..., 1, 15, 47], dtype=int64) the target is 58


In [63]:
mx.random.seed(1337)
batch_size = 4 # number of independent sequences to train on in parallel
block_size = 8 # maximum context length for predictions

def get_batch(split: Literal["train", "val"]) -> tuple[mx.array, mx.array]:
    data = train_data if split == "train" else val_data
    ix = mx.random.randint(0, len(data) - block_size, [batch_size])
    # gets `batch_size` blocks stacked
    x = mx.stack([data[i.item():i.item() + block_size] for i in ix])
    # it's shifted to compute the target vectorized
    y = mx.stack([data[i.item() + 1:i.item() + block_size + 1] for i in ix])
    return x, y

In [67]:
xb, yb = get_batch("train")

print("inputs:")
print(xb.shape)
print(xb)
print("targets:")
print(yb.shape)
print(yb)

# We can observe they match but yb is shifted by one
for b in range(batch_size):
    for t in range(block_size):
        context = xb[b, :t+1]
        target = yb[b, t].item()
        print(f"when input is {context} the target is {target}")

inputs:
(4, 8)
array([[56, 53, 51, ..., 47, 57, 1],
       [43, 50, 53, ..., 63, 1, 57],
       [6, 1, 58, ..., 58, 1, 61],
       [0, 5, 32, ..., 1, 39, 1]], dtype=int64)
targets:
(4, 8)
array([[53, 51, 1, ..., 57, 1, 48],
       [50, 53, 1, ..., 1, 57, 54],
       [1, 58, 46, ..., 1, 61, 43],
       [5, 32, 47, ..., 39, 1, 60]], dtype=int64)
when input is array([56], dtype=int64) the target is 53
when input is array([56, 53], dtype=int64) the target is 51
when input is array([56, 53, 51], dtype=int64) the target is 1
when input is array([56, 53, 51, 1], dtype=int64) the target is 46
when input is array([56, 53, 51, 1, 46], dtype=int64) the target is 47
when input is array([56, 53, 51, 1, 46, 47], dtype=int64) the target is 57
when input is array([56, 53, 51, ..., 46, 47, 57], dtype=int64) the target is 1
when input is array([56, 53, 51, ..., 47, 57, 1], dtype=int64) the target is 48
when input is array([43], dtype=int64) the target is 50
when input is array([43, 50], dtype=int64) the

In [140]:
import mlx.nn as nn

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size: int):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, vocab_size)
    
    def __call__(self, idx: mx.array, targets: Optional[mx.array] = None
                 ) -> tuple[mx.array, Optional[mx.array]]:
        logits = self.token_embedding(idx)
        if targets is None:
            return logits, None

        loss = nn.losses.cross_entropy(logits, targets)
        return logits, loss
    
    def generate(self, idx: mx.array, max_new_tokens: int) -> mx.array:
        for _ in range(max_new_tokens):
            logits, _ = self(idx)
            idx_next = mx.random.categorical(logits, axis=-1)
            idx = mx.concatenate([idx, idx_next], axis=1)
        return idx

In [141]:
m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
assert loss is not None, "loss should be computed when targets are provided"
print(logits.shape, loss.shape)

(4, 8, 65) (4, 8)


In [142]:
decode(m.generate(mx.ones((1, 1), dtype=mx.int64), 2).reshape(-1).tolist())

' HZZ'