# Experimental

Experiments before they're integrated into the main codebase.

In [30]:
import os
import requests
from typing import Literal

In [8]:
PWD = os.getcwd()
DATA_DIR = os.path.join(PWD, "data")
INPUT_DATA_URL = ("https://raw.githubusercontent.com/karpathy/char-rnn/"
                  "master/data/tinyshakespeare/input.txt")

In [9]:
def fetch_input_data() -> str:
    """Gets the input data, caching it for easy access."""
    input_file_path = os.path.join(DATA_DIR, "input.txt")
    if not os.path.exists(input_file_path):
        with open(input_file_path, "w", encoding="utf-8") as f:
            f.write(requests.get(INPUT_DATA_URL).text)
    
    with open(input_file_path, "r", encoding="utf-8") as f:
        return f.read()

In [17]:
input_text = fetch_input_data()

In [18]:
print(len(input_text))

1115394


In [19]:
chars = sorted(list(set(input_text)))
vocab_size = len(chars)
print("".join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [13]:
string_to_int = {c: i for i, c in enumerate(chars)}
int_to_string = {i: c for i, c in enumerate(chars)}

encode = lambda s: [string_to_int[c] for c in s]
decode = lambda l: "".join([int_to_string[i] for i in l])

In [14]:
print(encode("hello"))
print(decode(encode("hello")))

[46, 43, 50, 50, 53]
hello


In [20]:
import mlx.core as mx
data = mx.array(encode(input_text), dtype=mx.int64)
print(data.shape, data.dtype)
print(data[:1000])

(1115394,) mlx.core.int64
array([18, 47, 56, ..., 8, 0, 0], dtype=int64)


In [21]:
# Split the data
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

In [25]:
block_size = 8
train_data[:block_size + 1]

array([18, 47, 56, ..., 15, 47, 58], dtype=int64)

In [29]:
x = train_data[:block_size]
y = train_data[1:block_size + 1]
for t in range(block_size):
    context = x[:t + 1]
    target = y[t].item()
    print(f"when input is {context} the target is {target}")

when input is array([18], dtype=int64) the target is 47
when input is array([18, 47], dtype=int64) the target is 56
when input is array([18, 47, 56], dtype=int64) the target is 57
when input is array([18, 47, 56, 57], dtype=int64) the target is 58
when input is array([18, 47, 56, 57, 58], dtype=int64) the target is 1
when input is array([18, 47, 56, 57, 58, 1], dtype=int64) the target is 15
when input is array([18, 47, 56, ..., 58, 1, 15], dtype=int64) the target is 47
when input is array([18, 47, 56, ..., 1, 15, 47], dtype=int64) the target is 58


In [63]:
mx.random.seed(1337)
batch_size = 4 # number of independent sequences to train on in parallel
block_size = 8 # maximum context length for predictions

def get_batch(split: Literal["train", "val"]) -> tuple[mx.array, mx.array]:
    data = train_data if split == "train" else val_data
    ix = mx.random.randint(0, len(data) - block_size, [batch_size])
    # gets `batch_size` blocks stacked
    x = mx.stack([data[i.item():i.item() + block_size] for i in ix])
    # it's shifted to compute the target vectorized
    y = mx.stack([data[i.item() + 1:i.item() + block_size + 1] for i in ix])
    return x, y

In [66]:
xb, yb = get_batch("train")

print("inputs:")
print(xb.shape)
print(xb)
print("targets:")
print(yb.shape)
print(yb)

# We can observe they match but yb is shifted by one

inputs:
(4, 8)
array([[53, 1, 51, ..., 43, 50, 44],
       [32, 53, 1, ..., 39, 58, 1],
       [53, 59, 1, ..., 50, 50, 1],
       [23, 17, 10, ..., 39, 52, 1]], dtype=int64)
targets:
(4, 8)
array([[1, 51, 63, ..., 50, 44, 0],
       [53, 1, 61, ..., 58, 1, 61],
       [59, 1, 58, ..., 50, 1, 51],
       [17, 10, 0, ..., 52, 1, 52]], dtype=int64)
