In [36]:
%load_ext autoreload
%autoreload 2

import jax
import jax.numpy as jnp
import optax

import configs
import optimizers
import tokenizers

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [45]:
training_config = configs.TrainingConfig(
    optimizer_config=optimizers.OptimizerConfig(
        optimizer=optimizers.Optimizer.ADAM, learning_rate=1e-1
    ),
    num_train_steps=10,
)
tokenizer_config = tokenizers.TokenizerConfig(model="gpt2")
model_config = configs.TransformerConfig(
    vocab_size=128, d_model=32, sequence_length=512, num_blocks=2, num_heads=4, ffw_multiplier=4
)

In [46]:
text = (
    "Lorem ipsum dolor sit amet. Et consectetur fuga id eius ducimus non quam corrupti "
    "aut temporibus animi qui odit odit est voluptas voluptatem vel doloribus nemo. "
    "Qui iste neque est iusto omnis hic soluta rerum est adipisci dolores? "
    "Ut veritatis sequi sed dolore dolorum quo doloremque dignissimos aut laborum voluptates. "
    "Qui tempore asperiores et voluptas dignissimos non doloribus impedit ut expedita "
    "reprehenderit quo amet temporibus. Qui nisi odio ut necessitatibus maxime ea molestiae "
    "optio ut reiciendis quis? Et veniam beatae ut omnis fuga aut veritatis quod rem quidem "
    "distinctio et dolorem aliquam rem sint sunt vel nihil dolores. Sed sunt autem aut sunt "
    "rerum rem eaque quas."
)
tokenizer = tokenizer_config.make()
tokens = tokenizer.encode(text)
assert tokenizer.decode(tokens) == text
print(len(tokens), tokens[:10])
tokens += [0] * (model_config.sequence_length - len(tokens))
x = jnp.array(tokens).reshape(1, -1)
x.shape

228 [43, 29625, 220, 2419, 388, 288, 45621, 1650, 716, 316]


(1, 512)

In [52]:
key = jax.random.PRNGKey(0)
model = model_config.make()
fake_values = jax.random.randint(
    key,
    shape=(1, model_config.sequence_length),
    minval=0,
    maxval=model_config.vocab_size,
)
params = model.init(key, fake_values)
print(jax.tree_util.tree_map(jnp.shape, params))

{'params': {'blocks_0': {'attn': {'key': {'bias': (4, 8), 'kernel': (32, 4, 8)}, 'out': {'bias': (32,), 'kernel': (4, 8, 32)}, 'query': {'bias': (4, 8), 'kernel': (32, 4, 8)}, 'value': {'bias': (4, 8), 'kernel': (32, 4, 8)}}, 'ffw': {'layer1': {'kernel': (32, 128)}, 'layer2': {'kernel': (128, 32)}}, 'norm1': {'bias': (32,), 'scale': (32,)}, 'norm2': {'bias': (32,), 'scale': (32,)}}, 'blocks_1': {'attn': {'key': {'bias': (4, 8), 'kernel': (32, 4, 8)}, 'out': {'bias': (32,), 'kernel': (4, 8, 32)}, 'query': {'bias': (4, 8), 'kernel': (32, 4, 8)}, 'value': {'bias': (4, 8), 'kernel': (32, 4, 8)}}, 'ffw': {'layer1': {'kernel': (32, 128)}, 'layer2': {'kernel': (128, 32)}}, 'norm1': {'bias': (32,), 'scale': (32,)}, 'norm2': {'bias': (32,), 'scale': (32,)}}, 'position_embedder': {'embedding': (512, 32)}, 'token_embedder': {'embedding': (128, 32)}}}


In [57]:
model.apply(params, x)

Array([[[nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan],
        ...,
        [nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan]]], dtype=float32)

In [55]:
optimizer = training_config.optimizer_config.make()
opt_state = optimizer.init(params)

def loss_fn(params, x): 
    # This does not offset the tokens for next token prediction!
    return jnp.mean(optax.softmax_cross_entropy_with_integer_labels(model.apply(params, x), x))
grad_fn = jax.value_and_grad(loss_fn)

for idx in range(training_config.num_train_steps):
    loss, gradients = grad_fn(params, x)
    print(f"Step {idx}: loss = {loss}")
    updates, opt_state = optimizer.update(gradients, opt_state)
    params = optax.apply_updates(params, updates)

Step 0: loss = nan
Step 1: loss = nan
Step 2: loss = nan
Step 3: loss = nan
Step 4: loss = nan
Step 5: loss = nan
Step 6: loss = nan
Step 7: loss = nan
Step 8: loss = nan
Step 9: loss = nan
