In [1]:
import jax
import optax

# import local modules
from toylib_projects.tinystories import data
from toylib_projects.tinystories import decoder_only_model
from toylib_projects.tinystories import experiment

In [2]:
config = experiment.Config(
    model_config=decoder_only_model.ModelConfig(
        vocab_size=50257,  # GPT-2 tokenizer vocab size
    ),
    training_config=experiment.TrainingConfig(
    ),
)

In [None]:
# Dataloader
dataset = data.BatchedTokenizedHFDataset(bos_token=1000, batch_size=128, seq_len=512, tokenizer_batch_size=8)

# Model
model = decoder_only_model.DecoderOnlyTransformer(config=config.model_config, key=jax.random.PRNGKey(0))

# Logger
logger = experiment.TensorBoardLogger(config, output_path="./tensorboard_logs")

# Optimizer
optimizer = optax.adam(learning_rate=config.training_config.learning_rate)

Resolving data files:   0%|          | 0/1823 [00:00<?, ?it/s]

In [4]:
# Visualize a few batches
for ix, batch in enumerate(dataset):
    print(batch)
    if ix == 1:
        break

Token indices sequence length is longer than the specified maximum sequence length for this model (1820 > 1024). Running this sequence through the model will result in indexing errors


{'inputs': Array([[ 1000, 25586,   434, ...,   262,  1988,   583],
       [ 5680,   286,  2472, ...,  7585,   481,   307],
       [ 2622,    11,   290, ...,   416,  6193,    11],
       ...,
       [  307, 29738,   422, ...,   286,   640,    13],
       [ 4900,  4713, 13701, ...,    11, 46823,   393],
       [32099,   351,  1342, ..., 12175, 38631,   351]], dtype=uint16), 'targets': Array([[25586,   434,  1222, ...,  1988,   583,  5680],
       [  286,  2472, 14018, ...,   481,   307,  2622],
       [   11,   290,   612, ...,  6193,    11,  2975],
       ...,
       [29738,   422,  1088, ...,   640,    13,  4900],
       [ 4713, 13701,   389, ..., 46823,   393, 32099],
       [  351,  1342, 16325, ..., 38631,   351,  3623]], dtype=uint16)}
{'inputs': Array([[ 2649,  3037,   784, ...,   761,   393,  3241],
       [  290,  1337,    13, ..., 34722,  1243,   588],
       [10012,   198,  5195, ...,   262,  5290,    13],
       ...,
       [   13,  6762,   430, ...,  8829,  5275, 10691],
   

In [5]:
def log_metrics(logger: experiment.Logger, step: int, loss_val: float, updates):
    leaves, _ = jax.tree_util.tree_flatten(updates)
    metrics = {
        "train/loss": float(loss_val),
        "train/learning_rate": config.training_config.learning_rate,
        "gradients/0/mean": leaves[0].mean(),
        "gradients/1/mean": leaves[1].mean(),
        "gradients/2/mean": leaves[2].mean(),
    }
    logger.log(step=step, metrics=metrics)

In [6]:
# Optimizer
opt_state = optimizer.init(model)

# Value and gradient
loss_and_grad_fn = jax.jit(jax.value_and_grad(decoder_only_model.train_step))

step = 0

In [None]:
# Training loop
for epoch in range(config.training_config.num_epochs):
    for batch in dataset:
        inputs, targets = batch['inputs'], batch['targets']
        mask = jax.numpy.ones_like(inputs)

        # Compute loss and gradients
        loss_val, grads = loss_and_grad_fn(model, inputs, mask, targets)

        # Apply gradients
        updates, opt_state = optimizer.update(grads, opt_state)
        model = optax.apply_updates(model, updates)

        # Log metrics
        log_metrics(logger, step, loss_val, updates)

        # Increment step
        step += 1

In [None]:
model