In [1]:
import os
import sys
import torch

sys.path.append("..")

from src.utils import device
from src.architectures.gpt2 import GPT2, GPTConfig, Tokenizer, DataLoader

### Constants and hyperparameters

In [2]:
# Constants
dataset_path = os.path.join(os.getcwd(), 'dataset', 'divina_commedia.txt')

In [3]:
# Hyperparameters
batch_size = 1024 # Batch size for training

# Micro batch size for gradient accumulation. This is the number of batches to accumulate gradients before backpropagating.
# This is useful when the batch size is too large to fit into memory, so we split the batch into smaller micro batches and accumulate the gradients before backpropagating
micro_batch_size = 4

epochs = 50 # Number of training epochs
sequence_length = 32 # Number of tokens in each training sequence
train_val_split = 0.1 # Percentage of training data to use for validation
learning_rate = 3e-4 # Learning rate for the optimizer

### Initializations

In [4]:
# Set the random seed for reproducibility
torch.manual_seed(1337);

# Reduce the precision for the matmul operator to improve performance
torch.set_float32_matmul_precision('high')

### Data loading

In [5]:
# Instantiate the tokenizer
tokenizer = Tokenizer('gpt2')

In [6]:
# Instantiate the data loader
data_loader = DataLoader(
    txt_file = dataset_path,
    tokenizer = tokenizer,
    train_val_split = train_val_split
)

# Print the dataset statistics
print("Training set size: ", len(data_loader.train_tokens))
print("Validation set size: ", len(data_loader.val_tokens))

Training set size:  106392
Validation set size:  11821


In [7]:
# Create the model configuration
# The vocabulary size is 50304, instead of the classic 50257 if the gpt2 tokenizer,
# because we add some padding tokens to the vocabulary in order to make the vocabulary 
# size a multiple of 8 in order to improve performance when using FP16 training.
model_config = GPTConfig(
    context_size = 1024,
    vocab_size = 50304,
    n_blocks = 12,
    n_heads = 12,
    n_embed = 768
)

### Building the model

In [8]:
# Creating the GPT-2 model
gpt2 = GPT2(model_config)

# Move the model to the GPU if available 
# and set the precision to bfloat16 for improved performance
gpt2 = gpt2.to(torch.bfloat16).to(device)

# Compile the model to optimize performance
gpt2 = torch.compile(gpt2)

### Training the model

In [9]:
# Fitting the model
gpt2.fit(
    data_loader = data_loader,
    epochs = epochs,
    lr = learning_rate,
    batch_size = batch_size,
    micro_batch_size = micro_batch_size,
    sequence_length = sequence_length
)

Epoch: 1/50 | Completion percentage: 33.33% | Step duration 1099.23 ms/step --> loss: 11.0000
Epoch: 1/50 | Completion percentage: 66.67% | Step duration 423.12 ms/step --> loss: 9.9375
Epoch: 1/50 | Completion percentage: 100.00% | Step duration 423.96 ms/step --> loss: 9.5000
Epoch 1/50 | Average step duration 648.77 ms/step | Epoch duration 3071.61 ms/epoch --> loss: 10.1875 - val_loss: 9.1977
Epoch: 2/50 | Completion percentage: 33.33% | Step duration 424.95 ms/step --> loss: 9.1875
Epoch: 2/50 | Completion percentage: 66.67% | Step duration 431.27 ms/step --> loss: 9.0000
Epoch: 2/50 | Completion percentage: 100.00% | Step duration 421.63 ms/step --> loss: 8.8125
Epoch 2/50 | Average step duration 425.95 ms/step | Epoch duration 2411.65 ms/epoch --> loss: 9.0000 - val_loss: 8.4579
Epoch: 3/50 | Completion percentage: 33.33% | Step duration 423.60 ms/step --> loss: 8.5000
Epoch: 3/50 | Completion percentage: 66.67% | Step duration 425.40 ms/step --> loss: 8.1250
Epoch: 3/50 | Compl

In [10]:
# Encode the context using the tokenizer and convert it to a tensor
context = "the state of the"
context = torch.tensor(tokenizer.encode(context), dtype=torch.long).unsqueeze(0)
context = context.to(device) # Move the tensor to the GPU if available

# Decode and display the generated text
print(tokenizer.decode(gpt2.generate(context, max_new_tokens=200).squeeze().tolist()))

the state of the e lata;
per sidi in suo ongito se sia
Non viel, e la vchi'io fu de' mi se virto mu'an quel me loto sotto me da rcper tavaltaravei sper vanna mi Delli vannancCome tunra vollo lecchi vNon quesscli vidi sidi con la vper vien le vperanduinavavaltaveattoltella vora l'appalaccaggacc'alt malavum vcorcelel terunai fostnel melli quchi vlelivascista dche terunanindiame tutandel valanalappaveel margnelgnostchi l'io in maura percche sugnerombcIadi se vidi vchi cederan
