In [23]:
import sys
import jax
import jax.numpy as jnp
from tqdm import tqdm

sys.path.append('../')

from dataset import TinyShakespeare
from jax_utils import print_param_names
from layers import cross_entropy_loss

In [24]:
jax.devices()

[MetalDevice(id=0, process_index=0)]

In [25]:
seed = 1212
rnd_key = jax.random.PRNGKey(seed)

### Test the model

Define the autoregressive loss function

## Load Dataset

In [26]:
dataset = TinyShakespeare(rnd_key, batch_size=16, seq_len=32)

Loading Tiny Shakespeare dataset...
Loaded Tiny Shakespeare dataset


In [27]:
dataset.data.shape

(34848, 32)

## Create Model

Configs

In [28]:
d_model = 128
num_heads = 8
num_layers = 3
d_ff = 512
batch_size = 128
n_vocab = dataset.n_tokens
seq_len = 32
n_epochs = 32

In [29]:
from layers import create_autoregressive_transformer

transformer_model, params = create_autoregressive_transformer(rnd_key, num_layers, num_heads, 
                                                              d_model, d_ff, n_vocab, 
                                                              fast=False, lambda_pe= 1 / d_model ** 0.5)
def transformer_loss(params, x):
    output = transformer_model(params, x)
    x_shape = x.shape
    # print("x_shape", x_shape)
    # To make sure the output has the same shape as x
    x = x.reshape(*x_shape, -1)
    # Vmap to apply the loss function along the sequence
    return jax.vmap(cross_entropy_loss, in_axes=[0, 0])(output[:-1], x[1:]).mean()

In [30]:
model = jax.jit(jax.vmap(transformer_model, in_axes=(None, 0), out_axes=0))
# Vmap over the batch axis
batched_loss = jax.jit(jax.vmap(transformer_loss, in_axes=(None, 0), out_axes=0))

def get_loss(params, seq):
    return batched_loss(params, seq).mean()

grad_loss_fn = jax.value_and_grad(get_loss, argnums=0)

Let's create the model

In [31]:
x = dataset.data[0:16]
x_shape = x.shape
output = model(params, x)
loss, grad = grad_loss_fn(params, x)
loss

Array(8.417043, dtype=float32)

## Test functions

In [10]:
# Test attention
from layers import attention
d_k = d_model // num_heads
rng, q_key, k_key, v_key = jax.random.split(rnd_key, 4)
q = jax.random.normal(q_key, (batch_size, seq_len, num_heads, d_k))
k = jax.random.normal(k_key, (batch_size, seq_len, num_heads, d_k))
v = jax.random.normal(v_key, (batch_size, seq_len, num_heads, d_k))

# Move the heads to the batch dimension
q, k, v = map(lambda x: x.transpose((0, 2, 1, 3)), (q, k, v))
# Repeat along the batch dimension
output = jax.vmap(attention, in_axes=(0, 0, 0, None), out_axes=(0))(q, k, v, None)
assert output.shape == (batch_size, num_heads, seq_len, d_k)

In [11]:
dataset.n_tokens

65

## Training loop

In [32]:
from typing import Callable

def sample(model: Callable, params: dict, seq: jnp.ndarray, length: int = 20):
    """
    ### Sample

    The starting sequence is given by `seq` and we greedily sample `length` tokens
    """
    for i in range(length):
        # Sample the highest probability token
        idx = jnp.argmax(model(params, seq)[-1])
        # Add it to the sequence
        seq = jnp.concatenate((seq, idx[None]))
        # print(seq)

    # Return the sampled sequence
    return seq

def evaluate_model(model, params):
    prompt = [dataset.stoi[c] for c in 'It is']
    sampled = sample(model, params, jnp.array(prompt))[len(prompt):]
    sampled = ''.join([dataset.itos[i] for i in sampled])
    print(sampled)

In [33]:
import optax
from optim import Adam

learning_rate = 0.001
optimizer = optax.adamw(learning_rate)
opt_state = optimizer.init(params)
# Create optimizer
# optimizer = Adam(params)

for epoch in range(n_epochs):
    losses = []
    for i, batch in tqdm(enumerate(dataset)):
        loss, grads = grad_loss_fn(params, batch)
        losses.append(loss)
        # Update parameters
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        # params = optimizer.step(params, grads)
        if i==0 or (i+1)%1000==0:
            print(f"{i+1}: Loss: {loss:.2f}")
    print(f'Epoch {epoch} loss: {jnp.mean(loss)}')
    evaluate_model(transformer_model, params)

1it [00:01,  1.31s/it]

1: Loss: 8.82


1000it [08:09,  2.04it/s]

1000: Loss: 3.14


2000it [16:22,  2.08it/s]

2000: Loss: 3.03


2178it [17:48,  2.04it/s]


Epoch 0 loss: 3.096646308898926
 t t t t t t t t t t


1it [00:00,  1.10it/s]

1: Loss: 3.11


1000it [07:48,  2.15it/s]

1000: Loss: 2.73


2000it [15:33,  2.08it/s]

2000: Loss: 2.75


2178it [16:56,  2.14it/s]


Epoch 1 loss: 2.736133575439453
 the the the the the


1it [00:00,  1.84it/s]

1: Loss: 2.56


1000it [07:44,  2.13it/s]

1000: Loss: 2.64


2000it [15:39,  2.15it/s]

2000: Loss: 2.67


2178it [17:01,  2.13it/s]


Epoch 2 loss: 2.6039538383483887
 anou I ie.











1it [00:00,  1.59it/s]

1: Loss: 2.58


369it [02:54,  2.12it/s]


KeyboardInterrupt: 

## Generate text using the model

'ssssssssGddddddddd d'

In [85]:
output = transformer_model(params, jnp.array(prompt))
output.shape

(5, 65)

In [86]:
output.shape[-1]

65

In [87]:
jnp.argmax(output[-1])

Array(57, dtype=int32)

In [28]:
jnp.array([prompt]).shape

(1, 3)

In [14]:
print_param_names(params)

embedding
embedding.emb
embedding.pos
layers
layers.layer_0
layers.layer_0.ff
layers.layer_0.ff.ff1
layers.layer_0.ff.ff1.b
layers.layer_0.ff.ff1.w
layers.layer_0.ff.ff2
layers.layer_0.ff.ff2.b
layers.layer_0.ff.ff2.w
layers.layer_0.heads
layers.layer_0.heads.w_k
layers.layer_0.heads.w_k.b
layers.layer_0.heads.w_k.w
layers.layer_0.heads.w_q
layers.layer_0.heads.w_q.b
layers.layer_0.heads.w_q.w
layers.layer_0.heads.w_v
layers.layer_0.heads.w_v.b
layers.layer_0.heads.w_v.w
layers.layer_0.ln1
layers.layer_0.ln1.bias
layers.layer_0.ln1.gain
layers.layer_0.ln2
layers.layer_0.ln2.bias
layers.layer_0.ln2.gain
layers.layer_0.output
layers.layer_0.output.b
layers.layer_0.output.w
layers.layer_1
layers.layer_1.ff
layers.layer_1.ff.ff1
layers.layer_1.ff.ff1.b
layers.layer_1.ff.ff1.w
layers.layer_1.ff.ff2
layers.layer_1.ff.ff2.b
layers.layer_1.ff.ff2.w
layers.layer_1.heads
layers.layer_1.heads.w_k
layers.layer_1.heads.w_k.b
layers.layer_1.heads.w_k.w
layers.layer_1.heads.w_q
layers.layer_1.heads.w

## Alternative model designs

In [29]:
def mlp(params):    
    def forward(x):
        return x @ params
    return forward


def mlp2(d_in, d_out):
    # Initialize parameters
    params = jax.random.normal(rnd_key, (d_in, d_out))
    def forward(x, params=params):
        return x @ params
    # There is no other method to get the parameters
    return forward, params

Test the model that takes parameters as an input

In [30]:
n = 10
d_in, d_out = 8, 16
x = jnp.ones((n, d_in))
rng, rn2 = jax.random.split(rnd_key)
params = jax.random.normal(rnd_key, (d_in, d_out))

model = mlp(params)
out = model(x)

Test the second model definition

In [31]:
model, params = mlp2(d_in, d_out)
out2 = model(x)

In [28]:
assert jnp.all(out == out2)

## Nested `vmap`s

In [14]:
def sum(x, y):
    return x + y

x = jnp.arange(10)
y = jnp.arange(8, 10)

jax.vmap(jax.vmap(sum, in_axes=(None, 0)), in_axes=(0, None))(x, y)

Array([[ 8,  9],
       [ 9, 10],
       [10, 11],
       [11, 12],
       [12, 13],
       [13, 14],
       [14, 15],
       [15, 16],
       [16, 17],
       [17, 18]], dtype=int32)