In [None]:
import os
import sys
import torch

sys.path.append("..")

from src.utils import device
from src.architectures.transformer import Tokenizer, Transformer, DataLoader

### Constants and hyperparameters

In [None]:
# Constants
dataset_path = os.path.join(os.getcwd(), 'dataset', 'divina_commedia.txt')
tokenizer_path = os.path.join(os.getcwd(), 'checkpoints', 'tokenizer.json')

In [None]:
# Hyperparameters
train_val_split = 0.9 # 90% of the data will be used for training, 10% for validation
batch_size = 32 # The number of samples to use for each batch
block_size = 512 # The size of the sequence length (the context window)
learning_rate = 1e-3 # The learning rate for the optimizer
training_steps = 500 # The number of steps to train the model for
n_embed = 384 # The size of the token embeddings (the dimensionality of the embeddings)
eval_iters = 1 # The number of iterations to evaluate the model
num_attention_heads = 8 # The number of attention heads in the multi-head attention mechanism
num_transformer_blocks = 8 # The number of transformer blocks in the model
dropout = 0.2 # The dropout rate

### Initializations

In [None]:
# Set the random seed for reproducibility
torch.manual_seed(1337);

### Data loading

In [None]:
def load_txt_file(path: str) -> str:
    """
    Load a text file from the specified path.
    
    Parameters:
    - path (str): The path to the text file.
    
    Returns:
    - str: The contents of the text file.
    """
    
    # Check if the file exists
    if not os.path.exists(path):
        raise FileNotFoundError(f'The file "{path}" does not exist.')
    
    # Read the file
    with open(path, 'r', encoding='utf-8') as file:
        return file.read()

In [None]:
# Instantiate the tokenizer
tokenizer = Tokenizer()

# Load the state of the tokenizer
tokenizer.load(tokenizer_path)

# Extract the vocabulary size
vocab_size = tokenizer.get_vocab_size()

In [None]:
# Load the text file
text = load_txt_file(dataset_path)

# Encode the text using the tokenizer
encoded_text = tokenizer.encode(text)

# Convert the data to a tensor
data = torch.tensor(encoded_text, dtype=torch.long)

In [None]:
# Instantiate the data handler
data_handler = DataLoader(
    data = data, 
    train_val_split = train_val_split,
    device = device
)

### Building the model

In [None]:
# Create the language model
language_model = Transformer(
    vocab_size = vocab_size,
    n_embed = n_embed,
    n_heads = num_attention_heads,
    block_size = block_size,
    n_transformer_blocks = num_transformer_blocks,
    dropout = dropout,
    device = device
)

### Training the model

In [None]:
# Train the model
language_model.fit(
    data_loader = data_handler,
    steps = training_steps, 
    lr = learning_rate, 
    batch_size = batch_size,
    eval_iters = eval_iters
)

### Inference

In [None]:
# Generate some text context from the trained model
context = torch.zeros((1, 1), dtype=torch.long, device=device)

# Iterate over the tokens generated by the transformer
for token in language_model.generate(context, max_new_tokens=200, stream=True):
    # Decode the token
    decoded_token = tokenizer.decode([int(token.item())])

    # Print the decoded token
    print(decoded_token, end='', flush=True)