In [None]:
import os
import sys
import torch

sys.path.append("..")

from src.utils import device
from src.architectures.gpt import GPT2, GPTConfig, Tokenizer, DataLoader

### Constants and hyperparameters

In [2]:
# Constants
dataset_path = os.path.join(os.getcwd(), 'dataset', 'divina_commedia.txt')

In [3]:
# Hyperparameters
batch_size = 1024 # Batch size for training

# Micro batch size for gradient accumulation. This is the number of batches to accumulate gradients before backpropagating.
# This is useful when the batch size is too large to fit into memory, so we split the batch into smaller micro batches and accumulate the gradients before backpropagating
micro_batch_size = 4

epochs = 500 # Number of training epochs
sequence_length = 32 # Number of tokens in each training sequence
train_val_split = 0.1 # Percentage of training data to use for validation
learning_rate = 3e-4 # Learning rate for the optimizer

### Initializations

In [4]:
# Set the random seed for reproducibility
torch.manual_seed(1337);

# Reduce the precision for the matmul operator to improve performance
torch.set_float32_matmul_precision('high')

### Data loading

In [5]:
# Instantiate the tokenizer
tokenizer = Tokenizer('gpt2')

In [6]:
# Instantiate the data loader
data_loader = DataLoader(
    txt_file = dataset_path,
    tokenizer = tokenizer,
    train_val_split = train_val_split,
    device = device
)

# Print the dataset statistics
print("Training set size: ", len(data_loader.train_tokens))
print("Validation set size: ", len(data_loader.val_tokens))

Training set size:  106392
Validation set size:  11821


In [7]:
# Create the model configuration
# The vocabulary size is 50304, instead of the classic 50257 if the gpt2 tokenizer,
# because we add some padding tokens to the vocabulary in order to make the vocabulary 
# size a multiple of 8 in order to improve performance when using FP16 training.
model_config = GPTConfig(
    context_size = 1024,
    vocab_size = 50304,
    n_blocks = 12,
    n_heads = 12,
    n_embed = 768
)

### Building the model

In [8]:
# Creating the GPT-2 model
gpt2 = GPT2(model_config)

# Move the model to the GPU if available 
# and set the precision to bfloat16 for improved performance
gpt2 = gpt2.to(torch.bfloat16).to(device)

# Compile the model to optimize performance
gpt2 = torch.compile(gpt2)

### Training the model

In [None]:
# Fitting the model
gpt2.fit( # type: ignore
    data_loader = data_loader,
    epochs = epochs,
    lr = learning_rate,
    batch_size = batch_size,
    micro_batch_size = micro_batch_size,
    sequence_length = sequence_length
)

Epoch 1/500 | Avg step duration: 530.82 ms/step | Epoch duration: 2728.07 ms/epoch --> loss: 10.1875 - val_loss: 9.1977
Epoch 2/500 | Avg step duration: 426.94 ms/step | Epoch duration: 2414.28 ms/epoch --> loss: 9.0000 - val_loss: 8.4579
Epoch 3/500 | Avg step duration: 424.04 ms/step | Epoch duration: 2406.24 ms/epoch --> loss: 8.1875 - val_loss: 7.7096
Epoch 4/500 | Avg step duration: 424.81 ms/step | Epoch duration: 2426.40 ms/epoch --> loss: 7.4062 - val_loss: 7.2082
Epoch 5/500 | Avg step duration: 426.11 ms/step | Epoch duration: 2413.45 ms/epoch --> loss: 6.9062 - val_loss: 6.6579
Epoch 6/500 | Avg step duration: 424.55 ms/step | Epoch duration: 2408.31 ms/epoch --> loss: 6.5000 - val_loss: 6.3913
Epoch 7/500 | Avg step duration: 429.19 ms/step | Epoch duration: 2437.44 ms/epoch --> loss: 6.2500 - val_loss: 6.2836
Epoch 8/500 | Avg step duration: 423.94 ms/step | Epoch duration: 2406.37 ms/epoch --> loss: 6.2500 - val_loss: 6.2313
Epoch 9/500 | Avg step duration: 423.57 ms/step

In [None]:
# Encode the context using the tokenizer and convert it to a tensor
context = torch.zeros(1, dtype=torch.long).unsqueeze(0).clone().detach()
context = context.to(device) # Move the tensor to the GPU if available

# Decode and display the generated text
print(tokenizer.decode(gpt2.generate(context, max_new_tokens=200).squeeze().tolist())) # type: ignore

!
Sangue; e quei, quindi mi rechi
dall'altra, dietro, e lei mifova, ed qu torn tal color li, per che quant quel quel da come tut poai quinatt di poatt cost disom fu comel guel sole chel dalle cella drve' venanta quell'ho se quello suca da un chello non credo
inciute leatt, e non quella son mella so tu is il fu fu farper tempo per questel quelle son se tut te son la la queta prop te sentinci guardia li qual far te mia passai son questun' tu dimindi, e son vindi ved fu nai torn ben dimai se un chiedi', quella guardai non volora cai pur non dalla per foratt' qual pensai fu son pos sin
