# Step 1: Create a dataset

In [1]:
from llm_trainer import create_dataset

create_dataset(save_dir="data", dataset="fineweb-edu-10B", CHUNKS_LIMIT=5, CHUNK_SIZE=int(1e6))

Resolving data files:   0%|          | 0/2110 [00:00<?, ?it/s]

Processing Chunks: 100%|██████████| 5/5 [00:04<00:00,  1.06chunk/s]


# Step 2: Define GPT-2 model

In [2]:
from transformers import GPT2LMHeadModel, GPT2Config
import tiktoken

gpt2_config = GPT2Config(
    vocab_size=50257,
    n_positions=128,
    n_embd=128,
    n_layer=4,
    n_head=4,
    activation_function="gelu_new",
    resid_pdrop=0.1,
    embd_pdrop=0.1,
    attn_pdrop=0.1,
)
gpt2_model = GPT2LMHeadModel(gpt2_config)
tokenizer = tiktoken.get_encoding("gpt2")

# Step 3: Create an LLMTrainer object

In [3]:
from llm_trainer import LLMTrainer

trainer = LLMTrainer(model=gpt2_model,
                    optimizer=None,  # defaults to AdamW
                    scheduler=None,  # defaults to Warm-up steps + cosine annealing
                    tokenizer=tokenizer,  # GPT2 tokenizer
                    )

# Step 4: Start training

In [4]:
trainer.train(max_steps=100,
                verbose=50,                      # Sample from the model each 100 steps
                context_window=128,               # Context window of the model
                data_dir="data",                  # Directory with .npy files containing tokens
                BATCH_SIZE=256,                   # Batch size
                MINI_BATCH_SIZE=16,               # Gradient accumulation is used. BATCH_SIZE = MINI_BATCH_SIZE * accumulation_steps
                logging_file="logs_training.csv", # File to write logs of the training
                save_each_n_steps=500,            # Save the state each 500 steps
                save_dir="checkpoints",           # Directory where to save training state (model + optimizer + dataloader)
                prompt="Once upon a time"        # The model will continue this prompt each `verbose` steps
)

Current chunk: 0
step: 0 | Loss: 10.937500 | norm: 1.7791 | lr: 6.6667e-08 | dt: 3.54s | tok/sec: 9248.27
step: 1 | Loss: 10.937500 | norm: 1.8848 | lr: 1.0000e-07 | dt: 0.24s | tok/sec: 137599.12
step: 2 | Loss: 10.937500 | norm: 1.7740 | lr: 1.3333e-07 | dt: 0.24s | tok/sec: 137815.05
step: 3 | Loss: 10.937500 | norm: 1.8929 | lr: 1.6667e-07 | dt: 0.24s | tok/sec: 138792.04
step: 4 | Loss: 10.937500 | norm: 1.7970 | lr: 2.0000e-07 | dt: 0.24s | tok/sec: 138279.00
step: 5 | Loss: 10.937500 | norm: 1.7763 | lr: 2.3333e-07 | dt: 0.24s | tok/sec: 136470.96
step: 6 | Loss: 10.937500 | norm: 1.7827 | lr: 2.6667e-07 | dt: 0.26s | tok/sec: 127019.19
step: 7 | Loss: 10.937500 | norm: 1.7580 | lr: 3.0000e-07 | dt: 0.26s | tok/sec: 127186.57
step: 8 | Loss: 10.937500 | norm: 1.7092 | lr: 3.3333e-07 | dt: 0.26s | tok/sec: 127480.55
step: 9 | Loss: 10.937500 | norm: 1.7378 | lr: 3.6667e-07 | dt: 0.26s | tok/sec: 127164.33
step: 10 | Loss: 10.937500 | norm: 1.8046 | lr: 4.0000e-07 | dt: 0.26s | to