In [None]:
!python train.py config/train_tinystories.py --out_path="out-tinystories/model.eqx" --max_iters=250

In [None]:
import json
import os
from model import GPT, GPTConfig
import equinox as eqx
import jax
import jax.tree_util as jtu

def load(filename):
    with open(filename, "rb") as f:
        checkpoint_params = json.loads(f.readline().decode())
        gptconf = GPTConfig(**checkpoint_params["model_args"])
        return (
            eqx.tree_deserialise_leaves(
                f, GPT.create_instance(gptconf, key=jax.random.key(1))
            ),
            checkpoint_params,
        )
        
        
model, checkpoint_params = load("out-tinystories/model.eqx")

def model_surgery(path, x):
    if "wte.weight" in jax.tree_util.keystr(path):
        return jax.random.normal(jax.random.key(1), shape=x.shape)
    return x

model = jtu.tree_map_with_path(model_surgery, model)
print(checkpoint_params)
def save(filename, hyperparams, model):
    with open(filename, "wb") as f:
        hyperparam_str = json.dumps(hyperparams)
        f.write((hyperparam_str + "\n").encode())
        eqx.tree_serialise_leaves(f, model)

save("out-tinystories/model2.eqx", checkpoint_params, model)

In [None]:
!python3 train.py config/train_tinystories.py --init_from="resume" --out_path="out-tinystories/model2.eqx" --max_iters=250

Overriding config with config/train_tinystories.py:
# train a miniature character-level shakespeare model
# good for debugging and playing on macbooks and such

out_path = 'out-tinystories/model.eqx'
eval_interval = 250 # keep frequent because we'll overfit
eval_iters = 20
log_interval = 5 # don't print too too often

# we expect to overfit on this small dataset, so only save when val improves
always_save_checkpoint = False

wandb_log = False # override via command line if you like
tensorboard_log = True # override via command line if you like

dataset = 'tinystories'
gradient_accumulation_steps = 1
batch_size = 32
block_size = 100 # context of up to 256 previous characters

# baby GPT model :)
n_layer = 16
n_head = 12
n_embd = 512
dropout = 0.0

learning_rate = 1e-4 # with baby networks can afford to go a bit higher
max_iters = 500
lr_decay_iters = 500 # make equal to max_iters usually
min_lr = 1e-5 # learning_rate / 10 usually
beta2 = 0.99 # make a bit bigger because number of tokens