In [1]:
import sys
import jax
import jax.numpy as jnp

sys.path.append('../')

from dataset import TinyShakespeare
from jax_utils import print_param_names
from transformer_functional import cross_entropy_loss

In [2]:
seed = 1212
rnd_key = jax.random.PRNGKey(seed)

### Test the model

Define the autoregressive loss function

## Load Dataset

In [3]:
dataset = TinyShakespeare(rnd_key, batch_size=16, seq_len=32)

Loading Tiny Shakespeare dataset...
Loaded Tiny Shakespeare dataset


In [4]:
dataset.data.shape

(34848, 32)

### Training loop

Configs

In [5]:
d_model = 128
num_heads = 4
num_layers = 2
d_ff = 256
n_vocab = dataset.n_tokens
n_epochs = 1

In [6]:
from layers import create_autoregressive_transformer

transformer_model, params = create_autoregressive_transformer(rnd_key, num_layers, num_heads, d_model, d_ff, n_vocab)

def transformer_loss(params, x):
    output = transformer_model(params, x)
    return jax.vmap(cross_entropy_loss, in_axes=(0, 0))(output, x).mean()

In [8]:
model = jax.jit(jax.vmap(transformer_model, in_axes=(None, 0), out_axes=0))
batched_loss = jax.jit(jax.vmap(transformer_loss, in_axes=(None, 0), out_axes=0))

def get_loss(params, seq):
    return batched_loss(params, seq).mean()

grad_loss_fn = jax.value_and_grad(get_loss)

Let's create the model

In [9]:
x = dataset.data[0:4]
output = model(params, x)
loss, grad = grad_loss_fn(params, x)
loss

TypeError: dot_general requires contracting dimensions to have the same shape, got (32,) and (128,).

In [15]:
import optax

learning_rate = 1e-3
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(params)

for epoch in range(n_epochs):
    losses = []
    for i, batch in enumerate(dataset):
        loss, grad = grad_loss_fn(params, batch)
        losses.append(loss)
        # Update parameters
        updates, opt_state = optimizer.update(grad, opt_state)
        params = optax.apply_updates(params, updates)
        if i%100:
            print(f"{i}: Loss: {loss:.2f}")
    print(f'Epoch {epoch} loss: {jnp.mean(loss)}')

1: Loss: 8.85
2: Loss: 8.87
3: Loss: 8.66
4: Loss: 7.78
5: Loss: 8.25
6: Loss: 7.77
7: Loss: 7.60
8: Loss: 7.20
9: Loss: 6.88
10: Loss: 6.44
11: Loss: 6.67
12: Loss: 6.08
13: Loss: 6.02
14: Loss: 5.73
15: Loss: 5.43
16: Loss: 5.43
17: Loss: 5.07
18: Loss: 4.96
19: Loss: 4.81
20: Loss: 4.92
21: Loss: 4.70
22: Loss: 4.45
23: Loss: 4.37
24: Loss: 4.27
25: Loss: 4.22
26: Loss: 3.96
27: Loss: 4.10
28: Loss: 3.98
29: Loss: 3.93
30: Loss: 4.00
31: Loss: 3.70
32: Loss: 3.80
33: Loss: 3.61
34: Loss: 3.69
35: Loss: 3.48
36: Loss: 3.65
37: Loss: 3.64
38: Loss: 3.53
39: Loss: 3.58
40: Loss: 3.58
41: Loss: 3.61
42: Loss: 3.34
43: Loss: 3.35
44: Loss: 3.37
45: Loss: 3.40
46: Loss: 3.46
47: Loss: 3.42
48: Loss: 3.27
49: Loss: 3.24
50: Loss: 3.33
51: Loss: 3.21
52: Loss: 3.06
53: Loss: 3.24
54: Loss: 3.19
55: Loss: 3.11
56: Loss: 3.09
57: Loss: 3.08
58: Loss: 3.16
59: Loss: 3.18
60: Loss: 2.96
61: Loss: 3.01
62: Loss: 2.91
63: Loss: 2.91
64: Loss: 2.92
65: Loss: 2.91
66: Loss: 2.87
67: Loss: 2.96
68: 

## Alternative model designs

In [29]:
def mlp(params):    
    def forward(x):
        return x @ params
    return forward


def mlp2(d_in, d_out):
    # Initialize parameters
    params = jax.random.normal(rnd_key, (d_in, d_out))
    def forward(x, params=params):
        return x @ params
    # There is no other method to get the parameters
    return forward, params

Test the model that takes parameters as an input

In [30]:
n = 10
d_in, d_out = 8, 16
x = jnp.ones((n, d_in))
rng, rn2 = jax.random.split(rnd_key)
params = jax.random.normal(rnd_key, (d_in, d_out))

model = mlp(params)
out = model(x)

Test the second model definition

In [31]:
model, params = mlp2(d_in, d_out)
out2 = model(x)

In [28]:
assert jnp.all(out == out2)