In [None]:
import os
import sys
import torch
import tiktoken

from src.utils import *
from src.gpt2 import GPT2
from src.config import GPTConfig
from src.tokenizer import Tokenizer
from src.data_loader import DataLoader

### Constants and hyperparameters

In [None]:
# Constants
dataset_path = os.path.join(os.getcwd(), 'dataset', 'input.txt')

In [None]:
# Hyperparameters
batch_size = 4 # Batch size for training
epochs = 50 # Number of training epochs
sequence_length = 32 # Number of tokens in each training sequence
train_val_split = 0.1 # Percentage of training data to use for validation
learning_rate = 3e-4 # Learning rate for the optimizer

### Initializations

In [None]:
# Set the random seed for reproducibility
torch.manual_seed(1337);

# Reduce the precision for the matmul operator to improve performance
torch.set_float32_matmul_precision('high')

### Data loading

In [None]:
# Instantiate the tokenizer
tokenizer = Tokenizer('gpt2')

# Extract the vocabulary size
vocab_size = tokenizer.vocab_size()

In [None]:
# Instantiate the data loader
data_loader = DataLoader(
    txt_file = dataset_path,
    tokenizer = tokenizer,
    train_val_split = train_val_split
)

# Print the dataset statistics
print("Training set size: ", len(data_loader.train_tokens))
print("Validation set size: ", len(data_loader.val_tokens))

In [None]:
# Create the model configuration
# The vocabulary size is 50304, instead of the classic 50257 if the gpt2 tokenizer,
# because we add some padding tokens to the vocabulary in order to make the vocabulary 
# size a multiple of 8 in order to improve performance when using FP16 training.
model_config = GPTConfig(
    context_size = 1024,
    vocab_size = 50304,
    n_blocks = 12,
    n_heads = 12,
    n_embed = 768
)

### Building the model

In [None]:
# Creating the GPT-2 model
gpt2 = GPT2(model_config)

# Move the model to the GPU if available 
# and set the precision to bfloat16 for improved performance
gpt2 = gpt2.to(torch.bfloat16).to(device)

# Compile the model to optimize performance
gpt2 = torch.compile(gpt2)

### Training the model

In [None]:
# Fitting the model
gpt2.fit(
    data_loader = data_loader,
    epochs = epochs,
    lr = learning_rate,
    batch_size = batch_size,
    sequence_length = sequence_length
)

In [None]:
# Encode the context using the tokenizer and convert it to a tensor
context = "Hello, my name is"
context = torch.tensor(tokenizer.encode(context), dtype=torch.long).unsqueeze(0)
context = context.to(device) # Move the tensor to the GPU if available

# Decode and display the generated text
print(tokenizer.decode(gpt2.generate(context, max_new_tokens=20).squeeze().tolist()))