In [1]:
import torch
from utils import PicabooLMPretainingDataset, configure_adamw_optimizer
from tokenizer import Tokenizer
from model import PicabooLMParams, PicabooLM
from trainer import TrainerParams, Trainer
from torch.utils.data import DataLoader
from dataclasses import asdict

# Tokenize the dataset

In [2]:
tokenizer = Tokenizer.load("models/")
eod_token = "<|endoftext|>"
with open("datasets/combined.txt","r", encoding="utf-8") as fp:
    text = fp.read().replace("########",eod_token)
tokens = tokenizer.encode(text=text)
len(tokens), tokens[:10]

(894783, [1068, 692, 649, 101, 358, 115, 296, 258, 429, 111])

# Model

In [3]:
model_params = PicabooLMParams(
    context_length=1024,
    vocab_size=2048,
    num_blocks=8,
    num_heads=8,
    d_model=256,
    head_dim=256//8,
    device="cpu"
)
model = PicabooLM(params=model_params)
output_tokens = model.generate(torch.tensor([tokenizer.encode("he is")]), max_new_tokens=100)
sample_completion = tokenizer.decode(output_tokens.tolist()[0])
print(asdict(model_params))
f"sample completion: {sample_completion}"

number of parameters: 7.08M
{'context_length': 1024, 'vocab_size': 2048, 'num_blocks': 8, 'num_heads': 8, 'd_model': 256, 'head_dim': 32, 'dropout_rate': 0.1, 'device': 'cpu', 'bias': False}


'sample completion: he isshffeeeringarldsally ` death rose probably dou never His Heart seemedab lea ranri creatureiallyokeear closed thoseneingld direct’ked wateralekapped neededved gr dang murph endove\x12 white slowse takeO read movieals Jakany br ple1es fleinken z guard wo onlyart Even faceL used angumm focreenels being storyific front fear u am\x1f amadia waved control hopalek wra stre short pointed down read'

# Train

In [4]:
train_test_split_pct = 0.99
train_test_split_idx = int(len(tokens)*train_test_split_pct)
microbatch_size = 2
gradient_accumulation_steps = 2
epochs = 2
effective_batch_size = microbatch_size * gradient_accumulation_steps
training_dataset = PicabooLMPretainingDataset(dataset=tokens[:train_test_split_idx],context_size=model_params.context_length)
validation_dataset = PicabooLMPretainingDataset(dataset=tokens[train_test_split_idx+1:],context_size=model_params.context_length)
training_dataloader = DataLoader(dataset=training_dataset, batch_size=microbatch_size, shuffle=True)
validation_dataloader = DataLoader(dataset=validation_dataset, batch_size=microbatch_size, shuffle=False)
total_steps = len(training_dataset) * epochs // effective_batch_size
warmup_steps = int(total_steps * 0.1) # 10% warmup
max_steps = int(total_steps * 0.9) # 90% cosine decay
print(f"total_steps: {total_steps}")
print(f"warmup: 0 - {warmup_steps}")
print(f"decay: {warmup_steps+1} - {max_steps}")
print(f"finalizing: {max_steps+1} - {total_steps}")

total_steps: 432
warmup: 0 - 43
decay: 44 - 388
finalizing: 389 - 432


On terminal, run `mlflow server --host 127.0.0.1 --port 8000`

In [None]:
max_lr = 1e-3
min_lr = max_lr * 0.1
b1,b2,eps,weight_decay = 0.9,0.95,1e-8,0.1
device = "cpu"
optimizer = configure_adamw_optimizer(model=model, weight_decay=weight_decay, learning_rate=max_lr, betas=(b1,b2), eps=eps, device_type=device)
loss_fn = torch.nn.CrossEntropyLoss()
save_every = 1
checkpoints_path = "models/"

trainer_params = TrainerParams(
    model=model,
    optimizer=optimizer,
    loss_fn=loss_fn,
    train_dataloader=training_dataloader,
    val_dataloader=validation_dataloader,
    device=device,
    epochs=epochs,
    batch_size=effective_batch_size,
    save_every=save_every,
    checkpoints_path=checkpoints_path,
    gradient_accumulation_steps=gradient_accumulation_steps,
    total_steps = total_steps,
    max_steps = max_steps,
    warmup_steps = warmup_steps,
    max_learning_rate=max_lr,
    min_learning_rate=min_lr,
    tokenizer=tokenizer
)
trainer = Trainer(params=trainer_params)
trainer.train()

Step 1/432 | loss: 7.70 | norm: 4.9528e+00 | dt: 17872.16 ms | lr: 2.33e-05 | tokens/s: 229.18
Step 2/432 | loss: 7.59 | norm: 4.4285e+00 | dt: 16858.97 ms | lr: 4.65e-05 | tokens/s: 242.96
Step 3/432 | loss: 7.47 | norm: 2.5289e+00 | dt: 16529.77 ms | lr: 6.98e-05 | tokens/s: 247.80
Step 4/432 | loss: 7.39 | norm: 2.5241e+00 | dt: 15891.32 ms | lr: 9.30e-05 | tokens/s: 257.75
Step 5/432 | loss: 7.31 | norm: 1.9744e+00 | dt: 15357.22 ms | lr: 1.16e-04 | tokens/s: 266.71
Step 6/432 | loss: 7.25 | norm: 1.5482e+00 | dt: 17102.38 ms | lr: 1.40e-04 | tokens/s: 239.50
Step 7/432 | loss: 7.20 | norm: 1.4156e+00 | dt: 16695.18 ms | lr: 1.63e-04 | tokens/s: 245.34
Step 8/432 | loss: 7.15 | norm: 1.4023e+00 | dt: 15429.81 ms | lr: 1.86e-04 | tokens/s: 265.46
Step 9/432 | loss: 7.09 | norm: 1.3482e+00 | dt: 16668.02 ms | lr: 2.09e-04 | tokens/s: 245.74
Step 10/432 | loss: 7.04 | norm: 1.3221e+00 | dt: 17565.45 ms | lr: 2.33e-04 | tokens/s: 233.19
Step 11/432 | loss: 7.00 | norm: 1.2694e+00 | dt:

In [None]:
model = trainer.model
prompt_tokens = torch.tensor([tokenizer.encode(text="he is")]) # 1 B by 1 T
output_tokens = model.generate(prompt_tokens, max_new_tokens=100)
tokenizer.decode(output_tokens.tolist()[0])

'he is turn ste ste tru vill crching toirt. Fro someone?”Lurned watching dump of the lateringching to make a smallake. But a wholeasize his to her lH more if could’t youthough, and front of gold one had un under the rest a whole hurtar So package. ourom worn B last time,” he was nuggets\n�I a to looked him thought of themed it wornside my  pl'