# Following along https://colab.research.google.com/drive/1JMLa53HDuA-i7ZBmqV7ZnA3c_fvtXnx-?usp=sharing but with JAX/Flax

In [123]:
!mkdir data
!curl https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -o data/tinyshakespeare

  pid, fd = os.forkpty()


mkdir: data: File exists
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 1089k  100 1089k    0     0  2433k      0 --:--:-- --:--:-- --:--:-- 2436k


In [124]:
with open('data/tinyshakespeare') as f:
    text = f.read()

print('Corpus size: ' + str(len(text)))
print(text[:1000])

vocab = list(set(text))
vocab_size = len(vocab)
print('Vocabulary size: ' + str(len(vocab)))
print(''.join(sorted(vocab)))

Corpus size: 1115394
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in 

In [125]:
itos = {i:s for i,s in enumerate(vocab)}
stoi = {s:i for i,s in enumerate(vocab)}

encode = lambda x: [stoi[s] for s in x]
decode = lambda x: ''.join([itos[i] for i in x])

print(decode(encode('hello world')))

hello world


In [126]:
import jax.numpy as jnp

data = jnp.array(encode(text), dtype=jnp.int32)
print(data.dtype)
print(data.shape)
print(data[:100])

int32
(1115394,)
[47  5 64 18 57 45 13  5 57  5 12 53 40 61 54 55 53 38  4 64 53 45 27 53
 45 50 64  4 62 53 53 14 45 42 40 23 45 38 43 64 57  6 53 64  3 45  6 53
 42 64 45 41 53 45 18 50 53 42  9  0 54 54 29 24 24 61 54 49 50 53 42  9
  3 45 18 50 53 42  9  0 54 54 47  5 64 18 57 45 13  5 57  5 12 53 40 61
 54 48  4 43]


In [127]:
train_data = data[: int(.9 * len(data))]
val_data = data[int(.9 * len(data)):]

In [128]:
import jax

batch_size = 4
block_size = 8

dynamic_slice_vmap = jax.vmap(jax.lax.dynamic_slice, in_axes=(None, 0, None))


@jax.jit
def get_batch(data, key):
    ix = jax.random.randint(key, shape=(batch_size, 1), minval=0, maxval=len(data) - block_size)
    x = dynamic_slice_vmap(data, ix, (block_size,))
    y = dynamic_slice_vmap(data, ix + 1, (block_size,))
    return x, y


key = jax.random.key(1337)
print(get_batch(train_data, key))
xb, yb = get_batch(train_data, key)

(Array([[45, 41, 42, 40, 57, 24, 53, 54],
       [16, 36, 56, 61, 54, 56,  3, 45],
       [ 5, 41,  0, 54, 55, 43, 57, 45],
       [45, 18, 43, 62,  6, 45, 24, 53]], dtype=int32), Array([[41, 42, 40, 57, 24, 53, 54,  4],
       [36, 56, 61, 54, 56,  3, 45, 24],
       [41,  0, 54, 55, 43, 57, 45,  6],
       [18, 43, 62,  6, 45, 24, 53, 40]], dtype=int32))


In [129]:
from flax import nnx
import optax

class BigramLanguageModel(nnx.Module):
    def __init__(self, vocab_size, rngs: nnx.Rngs):
        self.rngs = rngs
        self.token_embedding_table = nnx.Embed(num_embeddings=vocab_size, features=vocab_size, rngs=rngs)

    def __call__(self, x):
        logits = self.token_embedding_table(x)
        return logits
    
    def generate(self, x, length): # x has the shape (batch_size, block_size)
        for i in range(length):
            logits = self(x)
            next_token = jax.random.categorical(self.rngs.next(), logits[:, -1])
            x = jnp.concatenate([x, next_token[:, None]], axis=1)
        return x
    

key = jax.random.key(1337)
rngs = nnx.Rngs(key)
model = BigramLanguageModel(vocab_size, rngs)

In [130]:
from flax import nnx
import optax

class Lookback(nnx.Module):
    def __init__(self, rngs: nnx.Rngs):
        self.rngs = rngs
    
    def __call__(self, x):
        B, T, C = x.shape
        tril = jnp.tril(jnp.ones((T, T)))
        attn = jnp.zeros((T, T))
        attn = jnp.where(tril == 0, float('-inf'), attn)
        attn = jax.nn.softmax(attn)
        return attn @ x

class NgramLanguageModel(nnx.Module):
    def __init__(self, vocab_size, n_embed, rngs: nnx.Rngs):
        self.rngs = rngs
        self.token_embedding_table = nnx.Embed(num_embeddings=vocab_size, features=n_embed, rngs=rngs)
        self.position_embedding_table = nnx.Embed(num_embeddings=block_size, features=n_embed, rngs=rngs)
        self.lookback = Lookback(rngs)
        self.lm_head = nnx.Linear(n_embed, vocab_size, rngs=rngs)

    def __call__(self, x):
        B, T = x.shape
        x = self.token_embedding_table(x) + self.position_embedding_table(jnp.arange(T))
        x = self.lookback(x)
        logits = self.lm_head(x)
        return logits
    
    def generate(self, x, length):
        for i in range(length):
            logits = self(x[:, -block_size:])
            next_token = jax.random.categorical(self.rngs.next(), logits[:, -1])
            x = jnp.concatenate([x, next_token[:, None]], axis=1)
        return x
    

key = jax.random.key(1337)
rngs = nnx.Rngs(key)
model = NgramLanguageModel(vocab_size, 32, rngs)

In [131]:
from flax import nnx
import optax

class SelfAttentionHead(nnx.Module):
    def __init__(self, n_embed, head_dim, rngs: nnx.Rngs):
        self.rngs = rngs
        self.query = nnx.Linear(n_embed, head_dim, rngs=rngs, use_bias=False)
        self.key = nnx.Linear(n_embed, head_dim, rngs=rngs, use_bias=False)
        self.value = nnx.Linear(n_embed, head_dim, rngs=rngs, use_bias=False)
        self.tril = jnp.tril(jnp.ones((block_size, block_size)))
    
    def __call__(self, x):
        B, T, C = x.shape
        q = self.query(x)
        k = self.key(x)

        attn = jnp.einsum('btd,bTd->btT', q, k) / jnp.sqrt(16)

        attn = jnp.where(self.tril[:T, :T] == 0, float('-inf'), attn)
        attn = jax.nn.softmax(attn)

        v = self.value(x)
        return attn @ v
    
class MultiHeadAttention(nnx.Module):
    def __init__(self, n_embed, n_head, head_dim, rngs: nnx.Rngs):
        self.rngs = rngs
        self.heads = [SelfAttentionHead(n_embed, head_dim, rngs) for _ in range(n_head)]
        self.proj = nnx.Linear(n_embed, n_embed, rngs=rngs)

    def __call__(self, x):
        x = jnp.concatenate([head(x) for head in self.heads], axis=-1)
        return self.proj(x)
    
class FeedForward(nnx.Module):
    def __init__(self, n_embed, rngs: nnx.Rngs):
        self.rngs = rngs
        self.fc1 = nnx.Linear(n_embed, 4 * n_embed, rngs=rngs)
        self.fc2 = nnx.Linear(4 * n_embed, n_embed, rngs=rngs)
    
    def __call__(self, x):
        return self.fc2(jax.nn.relu(self.fc1(x)))

class Block(nnx.Module):
    def __init__(self, n_embed, n_head, rngs: nnx.Rngs):
        self.rngs = rngs
        self.sa_heads = MultiHeadAttention(n_embed, n_head, n_embed // n_head, rngs=rngs)
        self.ffwd = FeedForward(n_embed, rngs=rngs)
        self.ln1 = nnx.LayerNorm(n_embed, rngs=rngs)
        self.ln2 = nnx.LayerNorm(n_embed, rngs=rngs)

    def __call__(self, x):
        x = x + self.sa_heads(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class GPT(nnx.Module):
    def __init__(self, vocab_size, n_embed, n_head, n_blocks, rngs: nnx.Rngs):
        self.rngs = rngs
        self.token_embedding_table = nnx.Embed(num_embeddings=vocab_size, features=n_embed, rngs=rngs)
        self.position_embedding_table = nnx.Embed(num_embeddings=block_size, features=n_embed, rngs=rngs)
        self.blocks = nnx.Sequential(*[Block(n_embed, n_head, rngs) for _ in range(n_blocks)])
        self.lm_head = nnx.Linear(n_embed, vocab_size, rngs=rngs)

    def __call__(self, x):
        B, T = x.shape
        x = self.token_embedding_table(x) + self.position_embedding_table(jnp.arange(T))
        x = self.blocks(x)
        logits = self.lm_head(x)
        return logits
    
    def generate(self, x, length):
        for i in range(length):
            logits = self(x[:, -block_size:])
            next_token = jax.random.categorical(self.rngs.next(), logits[:, -1])
            x = jnp.concatenate([x, next_token[:, None]], axis=1)
        return x
    

key = jax.random.key(1337)
rngs = nnx.Rngs(key)
model = GPT(vocab_size, 32, 4, 4, rngs)

In [132]:
def loss(model, x, targets):
        logits = model(x)
        return optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean()

xb, yb = get_batch(train_data, key)
print(loss(model, xb, yb))

7.938052


In [133]:
batch_size = 32

import tqdm

@nnx.jit
def train_step(model, optimizer, xb, yb):
    grads = (nnx.grad(loss))(model, xb, yb)
    optimizer.update(grads)

def train(key, model):
    optimizer = nnx.Optimizer(model, optax.adam(1e-3))
    for i in tqdm.trange(10000):
        key, subkey = jax.random.split(key)
        xb, yb = get_batch(train_data, subkey)
        train_step(model, optimizer, xb, yb)
train(key, model)
print(loss(model, xb, yb))

val_xb, val_yb = get_batch(val_data, key)
print(loss(model, val_xb, val_yb))

100%|██████████| 10000/10000 [09:39<00:00, 17.27it/s]


2.1941688
2.198675


In [134]:
print([decode(row.tolist()) for row in model.generate(jnp.zeros((1, 1), dtype=jnp.int32), 500)][0])

.

Come anly, thenee to tpastrusions for, low rentre, thrubly by himge farsfilf uny you, nooul and my face
If shown him thy, bee ace thish han himfurs maqieso;
Where wardly,
Your thy lears lale my arde artcigh moow
thisstalk'd whatys, wheld om-by brood of on it frumptlath caugh thow ims courtrn the stek but armxtell's Atcinls?

SHUSTEN Rhe youd could
The methy fand-do! our athed heach; all jowour shall, harthen:
If our infircen Mencir,
Whave the shatilted
Ming; brelr have ongterwith bewsore thred
