Now that we have a module that represents a GPT2 model with a relevant config file, we can train it ! We will use tinystories to train it as per the tutorial here: https://huggingface.co/blog/sachithgunasekara/nanojaxgpt

In [1]:
import os
import jax.numpy as np
from GPT2 import GPTConfig
import numpy

data_dir = "dataset"
config = GPTConfig()

def get_batch(split: str):
    # We recreate np.memmap every batch to avoid a memory leak, as per
    # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
    if split == 'train':
        data = numpy.memmap(os.path.join(data_dir, 'train.bin'), dtype=numpy.uint16, mode='r')
    else:
        data = numpy.memmap(os.path.join(data_dir, 'validation.bin'), dtype=numpy.uint16, mode='r')

    ix = numpy.random.randint(len(data) - config.block_size, size=(8,))
    x = np.stack([np.array(data[i:i + config.block_size], dtype=np.int64) for i in ix])
    y = np.stack([np.array(data[i + 1:i + 1 + config.block_size], dtype=np.int64) for i in ix])

    return x, y

In [2]:
import optax, jax
import equinox as eqx

learning_rate = 1e-5
warmup_iters = 10
init_from = "scratch"
lr_decay_iters = 20
iter_num = 0
min_lr = 1e-6

lr_scheduler = optax.warmup_cosine_decay_schedule(
    init_value=0.0,
    peak_value=learning_rate,
    warmup_steps=warmup_iters if init_from == 'scratch' else 0,
    decay_steps=lr_decay_iters - iter_num,
    end_value=min_lr,
)

optimizer = optax.inject_hyperparams(optax.adamw)(learning_rate=learning_rate)

@eqx.filter_jit
def loss(model, x, y):
    logits = jax.vmap(model)(x)

    loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=y)
    return jax.numpy.mean(loss)

def make_step(model, optimizer_state, x, y):
    losses, grads = eqx.filter_value_and_grad(loss)(model, x, y)
    updates, optimizer_state = optimizer.update(grads, optimizer_state, model)
    model = eqx.apply_updates(model, updates)
    return model, optimizer_state, losses

def estimate_loss(model):
    out = {}
    model = eqx.nn.inference_mode(model)
    for split in ['train', 'val']:
        losses = jax.numpy.zeros(10)
        for k in range(10):
            x, y = get_batch(split)
            loss = loss(model, jax.lax.stop_gradient(x), y)
            losses = losses.at[k].set(loss.item())
        out[split] = jax.numpy.mean(losses)
    return out


In [None]:
from GPT2 import GPT
import wandb
# init a new model from scratch
print("Initializing a new model from scratch")
# determine the vocab size we'll use for from-scratch training
key = jax.random.PRNGKey(69)

gptconf = GPTConfig()
model = GPT(gptconf, key)
# convert_model_to_dtype()


optimizer_state = optimizer.init(eqx.filter(model, eqx.is_array))

for local_iter_num in range(100):
    x, y = get_batch("train")

    model, optimizer_state, loss = make_step(model, optimizer_state, x, y)
    print(f"Loss: {loss}")

In [21]:
import datasets
from transformers import AutoTokenizer

dataset = datasets.load_dataset("roneneldan/TinyStories")

model_name = "gpt2"

tokenizer = AutoTokenizer.from_pretrained(model_name)


In [24]:
dataset["train"] = dataset["train"].select([i for i in range(0, 1000)])

In [None]:
def tokenize(example):
    # print(example)
    return {"tokenized": [tokenizer.tokenize(x) for x in example["text"]]}

tokenized_data = dataset.map(tokenize, remove_columns=["text"], batched=True, batch_size=8)
tokenized_data

In [None]:
for i in range(0, len(dataset), 8):
    data = tokenized_data["validation"].select([i for i in range(i, i + 8)])["tokenized"]
    print(data)
    break


In [2]:
import jax
import optax
import jax.numpy as np
import equinox as eqx
from GPT2 import GPTConfig, GPT

# Loss function with vmap to calculate loss for the entire batch
def loss(model, x, y):
    logits = jax.vmap(model)(x)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=y)
    return jax.numpy.mean(loss)

# Optimization step function
def make_step(model, optimizer_state, x, y):
    losses, grads = eqx.filter_value_and_grad(loss)(model, x, y)
    updates, optimizer_state = optimizer.update(grads, optimizer_state, model)
    model = eqx.apply_updates(model, updates)
    return model, optimizer_state, losses

# Set up GPT configuration and data
config = GPTConfig()
data = jax.numpy.cos(jax.numpy.arange(0, 100000) * (jax.numpy.pi)/1000)
data = jax.numpy.ceil(data * config.vocab_size).astype(jax.numpy.int16)

batch_size = 10
input_size = 100

# Initialize model and optimizer
key = jax.random.PRNGKey(79)
model = GPT(config, key)
trainable = eqx.filter(model, eqx.is_array)
optimizer = optax.adamw(learning_rate=1e-5)
optimizer_state = optimizer.init(trainable)

# Training loop
for i in range(0, len(data) - input_size - 1, batch_size):
    batch = jax.numpy.array([data[j: j+input_size] for j in range(i, i+batch_size)])
    batch_y = jax.numpy.array([data[j+1: j+1+input_size] for j in range(i, i+batch_size)])
    
    # Perform a training step
    model, optimizer_state, loss = make_step(model, optimizer_state, batch, batch_y)
    print(f"Loss: {loss}")

  self.attn = CausalSelfAttention(config, key=key1)
  output = _arange(start, stop=stop, step=step, dtype=dtype)


ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

In [None]:

import matplotlib.pyplot as plt

f, a = plt.subplots(1)
a.plot(data)
f.show()