# GPT-Dev

Experiments before they're integrated into the main codebase.

In [1]:
import os
import requests
from typing import Literal

In [2]:
PWD = os.getcwd()
DATA_DIR = os.path.join(PWD, "data")
INPUT_DATA_URL = ("https://raw.githubusercontent.com/karpathy/char-rnn/"
                  "master/data/tinyshakespeare/input.txt")

In [3]:
def fetch_input_data() -> str:
    """Gets the input data, caching it for easy access."""
    input_file_path = os.path.join(DATA_DIR, "input.txt")
    if not os.path.exists(input_file_path):
        with open(input_file_path, "w", encoding="utf-8") as f:
            f.write(requests.get(INPUT_DATA_URL).text)
    
    with open(input_file_path, "r", encoding="utf-8") as f:
        return f.read()

In [4]:
input_text = fetch_input_data()

In [5]:
print(len(input_text))

1115394


In [6]:
chars = sorted(list(set(input_text)))
vocab_size = len(chars)
print("".join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [7]:
string_to_int = {c: i for i, c in enumerate(chars)}
int_to_string = {i: c for i, c in enumerate(chars)}

encode = lambda s: [string_to_int[c] for c in s]
decode = lambda l: "".join([int_to_string[i] for i in l])

In [8]:
print(encode("hello"))
print(decode(encode("hello")))

[46, 43, 50, 50, 53]
hello


In [9]:
import mlx.core as mx
data = mx.array(encode(input_text), dtype=mx.int64)
print(data.shape, data.dtype)
print(data[:1000])

(1115394,) mlx.core.int64
array([18, 47, 56, ..., 8, 0, 0], dtype=int64)


In [10]:
# Split the data
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

In [11]:
block_size = 8
train_data[:block_size + 1]

array([18, 47, 56, ..., 15, 47, 58], dtype=int64)

In [12]:
x = train_data[:block_size]
y = train_data[1:block_size + 1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t].item()
    print(f"when input is {context} the target is {target}")

when input is array([18], dtype=int64) the target is 47
when input is array([18, 47], dtype=int64) the target is 56
when input is array([18, 47, 56], dtype=int64) the target is 57
when input is array([18, 47, 56, 57], dtype=int64) the target is 58
when input is array([18, 47, 56, 57, 58], dtype=int64) the target is 1
when input is array([18, 47, 56, 57, 58, 1], dtype=int64) the target is 15
when input is array([18, 47, 56, ..., 58, 1, 15], dtype=int64) the target is 47
when input is array([18, 47, 56, ..., 1, 15, 47], dtype=int64) the target is 58


In [13]:
mx.random.seed(1337)
batch_size = 4 # number of independent sequences to train on in parallel
block_size = 8 # maximum context length for predictions

def get_batch(split: Literal["train", "val"]) -> tuple[mx.array, mx.array]:
    data = train_data if split == "train" else val_data
    ix = mx.random.randint(0, len(data) - block_size, [batch_size])
    # gets `batch_size` blocks stacked
    x = mx.stack([data[i.item():i.item() + block_size] for i in ix])
    # it's shifted to compute the target vectorized
    y = mx.stack([data[i.item() + 1:i.item() + block_size + 1] for i in ix])
    return x, y

In [14]:
xb, yb = get_batch("train")

print("inputs:")
print(xb.shape)
print(xb)
print("targets:")
print(yb.shape)
print(yb)

# We can observe they match but yb is shifted by one
for b in range(batch_size):
    for t in range(block_size):
        context = xb[b, :t+1]
        target = yb[b, t].item()
        print(f"when input is {context} the target is {target}")

inputs:
(4, 8)
array([[47, 45, 52, ..., 47, 52, 1],
       [57, 0, 16, ..., 1, 57, 43],
       [40, 56, 53, ..., 11, 1, 40],
       [1, 46, 39, ..., 1, 61, 39]], dtype=int64)
targets:
(4, 8)
array([[45, 52, 57, ..., 52, 1, 45],
       [0, 16, 47, ..., 57, 43, 43],
       [56, 53, 49, ..., 1, 40, 43],
       [46, 39, 60, ..., 61, 39, 56]], dtype=int64)
when input is array([47], dtype=int64) the target is 45
when input is array([47, 45], dtype=int64) the target is 52
when input is array([47, 45, 52], dtype=int64) the target is 57
when input is array([47, 45, 52, 57], dtype=int64) the target is 1
when input is array([47, 45, 52, 57, 1], dtype=int64) the target is 47
when input is array([47, 45, 52, 57, 1, 47], dtype=int64) the target is 52
when input is array([47, 45, 52, ..., 1, 47, 52], dtype=int64) the target is 1
when input is array([47, 45, 52, ..., 47, 52, 1], dtype=int64) the target is 45
when input is array([57], dtype=int64) the target is 0
when input is array([57, 0], dtype=int6

In [15]:
import mlx.nn as nn

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size: int):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, vocab_size)
    
    def __call__(self, idx: mx.array) -> mx.array:
        return self.token_embedding(idx)
    
    def generate(self, idx: mx.array, max_new_tokens: int) -> mx.array:
        for _ in range(max_new_tokens):
            logits = self(idx)
            logits = logits[:, -1, :]
            idx_next = mx.random.categorical(logits, num_samples=1, axis=-1)
            idx = mx.concatenate([idx, idx_next], axis=1)
        # this is actually going to return 101 tokens since the input counts
        return idx


def loss_fn(model: nn.Module, x: mx.array, y: mx.array) -> mx.array:
    return mx.mean(nn.losses.cross_entropy(model(x), y))

In [16]:
model = BigramLanguageModel(vocab_size)
logits = model(xb)
loss = loss_fn(model, xb, yb)
print(logits.shape, loss.shape)

(4, 8, 65) ()


In [17]:
input = mx.zeros((1, 1), dtype=mx.int64)
decode(model.generate(input, 100).reshape(-1).tolist())

"\n.P.e'wn,CZsvq gP-f$fvW3aypokkuSEz?Paw:YCj?M;x\npctpxMvdJMlTZrmCZhPRjYRJUfTgldWbqlwXxc CHIWuAFYEBlwJrb"

In [18]:
import mlx.optimizers as optim
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
optimizer = optim.AdamW(learning_rate=1e-3)

num_epochs = 10000
for epoch in range(num_epochs):
    xb, yb = get_batch("train")

    loss, grads = loss_and_grad_fn(model, xb, yb)
    optimizer.update(model, grads)
    mx.eval(model.parameters(), optimizer.state)

print(loss_fn(model, *get_batch("train")).item())

2.4868011474609375


In [19]:
input = mx.zeros((1, 1), dtype=mx.int64)
print(decode(model.generate(input, 100).reshape(-1).tolist()))


RI d tloul ilie om toour t
IEmbe d, hthithot whars shieiststh'stet schontoumy mced bliserved isty HE


## Mathematical Trick in self-attention

In [20]:
mx.random.seed(1337)
B, T, C = shape = [4, 8, 2]  # Batch, Time, Channels
x = mx.random.normal(shape)
x.shape

(4, 8, 2)

In [21]:
# We want x[b,t] = mean_{i<=t} x[b,i]
xbow = mx.zeros(shape)
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1]  # -> [t, C]
        xbow[b, t] = mx.mean(xprev, axis=0)
xbow

array([[[1.20038, -0.102716],
        [1.32946, -0.714395],
        [0.647708, -0.289021],
        ...,
        [0.098629, -0.212],
        [0.0613602, -0.206279],
        [-0.0196391, -0.192657]],
       [[-0.88286, 1.90818],
        [0.272445, 1.62297],
        [-0.025046, 1.02757],
        ...,
        [0.144306, 0.739237],
        [0.14616, 0.550131],
        [-0.16644, 0.507876]],
       [[0.0548329, -0.871395],
        [0.0152365, -0.370004],
        [-0.421962, -0.2805],
        ...,
        [-0.432742, 0.136636],
        [-0.419489, 0.203147],
        [-0.342733, 0.326026]],
       [[-0.977191, 0.433173],
        [-0.394541, 0.021305],
        [-0.432306, -0.325967],
        ...,
        [0.00372714, -0.381902],
        [0.0797167, -0.157575],
        [0.1951, -0.193353]]], dtype=float32)

In [22]:
wei = mx.tril(mx.ones((T, T)))
wei = wei / wei.sum(1, keepdims=True)  # normalize
xbow2 = wei @ x
mx.allclose(xbow, xbow2)

array(True, dtype=bool)

In [23]:
tril = mx.tril(mx.ones((T, T)))
wei = mx.zeros((T, T))
wei = mx.where(tril == 0, float("-inf"), wei) # this replaced masked_fill
wei = mx.softmax(wei, axis=-1)
xbow3 = wei @ x
mx.allclose(xbow2, xbow3)

array(True, dtype=bool)

In [24]:
# version 4: self-attention
mx.random.seed(1337)
B, T, C = shape = [4, 8, 32]  # Batch, Time, Channels
x = mx.random.normal(shape)

head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x)  # [B, T, head_size]
q = query(x)  # [B, T, head_size]
wei = q @ mx.transpose(k, axes=[0, -1, -2]) # [B, T, head_size] @ [B, head_size, T] -> [B, T, T]

tril = mx.tril(mx.ones((T, T)))
wei = mx.where(tril == 0, float("-inf"), wei)
wei = mx.softmax(wei, axis=-1)

v = value(x)
out = wei @ v

out.shape

(4, 8, 16)

In [25]:
wei[0]

array([[1, 0, 0, ..., 0, 0, 0],
       [0.856269, 0.143731, 0, ..., 0, 0, 0],
       [0.431043, 0.00779574, 0.561161, ..., 0, 0, 0],
       ...,
       [0.219786, 0.260875, 0.325269, ..., 0.00784806, 0, 0],
       [0.390585, 0.0546818, 0.0550628, ..., 0.0648744, 0.143865, 0],
       [0.00486556, 0.0906292, 0.136023, ..., 0.183027, 0.0706174, 0.0119941]], dtype=float32)

---

**To understand the code above**, I find it useful to consider what various
parts of the code are doing.  

**Channels**  
_`C` represents the number of channels._  

Channels serve as a way to encode all sorts of information.
It does not matter what they encode, likely does not have a human
interpretation.  
These representations are one aspects of what the network is learning.
For example, a channel may encode "being a consonant".  

**Keys and Queries**  
Assuming we have a channel that represents "being a consonant".  

_A query represents what we're looking for._  
For example, a query may represent looking for a consonant at any position,
which would be encoded as a uniform distribution over all tokens in the
consonant channel in the query.  
A query could also be looking for consonants at specific positions/regions,
which woudl be encoded as higher values at those positions.  

_The key follows a similar logic, but represents what each token is,_
_in the channel._  
A consonant would have a higher value at its position, in the key vector,
in the consonant channel.  

The product between keys and queries represents the affinity between what
we're looking for and what we have.  
A _consonant query_ with a high value at the first position, will have a strong
affinity if the _consonant key_ also has a higher value at the first position.  

**Values**
_Values represent what gets communicated to the next layer._  

Values are a simple linear transformation of the input.  
Linear transformations allow the network to transform the inputs into a more
useful representation (from a network perspective).  
For example, maybe not all tokens are equally useful to the output, learning
what to consider or not is helpful.  

It can also be useful to transform the inputs into different dimensions.  
This linear transformation is what transforms the input into the
representation at the next level.  

**Outputs**
_Combines `values`, `keys` and `queries` to represent what is relevant to the_
_next layer._  

`keys` and `queries` are combined to create an affinity of what is important
for the current token (`> wei`). `values` represent what to communicate to the
next layer, but has to be combined with `wei` to "select" only relevant
information.  

---

Notes:
- Attention is a **communication mechanism**. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.
- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.
- Each example across batch dimension is of course processed completely independently and never "talk" to each other
- In an "encoder" attention block just delete the single line that does masking with `tril`, allowing all tokens to communicate. This block here is called a "decoder" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.
- "self-attention" just means that the keys and values are produced from the same source as queries. In "cross-attention", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)
- "Scaled" attention additional divides `wei` by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much. Illustration below