In [None]:
import os
import sys
import torch

from src.utils import device
from src.tokenizer import Tokenizer
from src.data_loader import DataLoader
from src.transformer import Transformer

# Add parent directory to path
sys.path.append(os.path.join(os.getcwd(), ".."))

from common.src.utils import load_txt_file

### Constants and hyperparameters

In [None]:
# Constants
dataset_path = os.path.join(os.getcwd(), 'dataset', 'input.txt')
tokenizer_path = os.path.join(os.getcwd(), 'checkpoints', 'tokenizer.json')

In [None]:
# Hyperparameters
train_val_split = 0.9 # 90% of the data will be used for training, 10% for validation
batch_size = 64 # The number of samples to use for each batch
block_size = 256 # The size of the sequence length (the context window)
learning_rate = 1e-3 # The learning rate for the optimizer
epochs = 500 # The number of epochs to train the model for
n_embed = 384 # The size of the token embeddings (the dimensionality of the embeddings)
eval_iters = 10 # The number of iterations to evaluate the model
num_attention_heads = 6 # The number of attention heads in the multi-head attention mechanism
num_transformer_blocks = 6 # The number of transformer blocks in the model
dropout = 0.2 # The dropout rate

### Initializations

In [None]:
# Set the random seed for reproducibility
torch.manual_seed(1337);

### Data loading

In [None]:
# Instantiate the tokenizer
tokenizer = Tokenizer()

# Load the state of the tokenizer
tokenizer.load_state(tokenizer_path)

In [None]:
# Load the text file
text = load_txt_file(dataset_path)

# Encode the text using the tokenizer
encoded_text = tokenizer.encode(text)

# Convert the data to a tensor
data = torch.tensor(encoded_text, dtype=torch.long)

In [None]:
# Instantiate the data handler
data_handler = DataLoader(
    data = data, 
    train_val_split = train_val_split
)

### Building the model

In [None]:
# Create the language model
language_model = Transformer(
    vocab_size = tokenizer.vocab_size, # type: ignore
    n_embed = n_embed,
    n_heads = num_attention_heads,
    block_size = block_size,
    n_transformer_blocks = num_transformer_blocks,
    dropout = dropout
)

### Training the model

In [None]:
# Train the model
language_model.train_model(
    data_loader = data_handler,
    epochs = epochs, 
    lr = learning_rate, 
    batch_size = batch_size,
    eval_iters = eval_iters
)

### Inference

In [None]:
# Generate some text from the trained model
context = torch.zeros((1, 1), dtype=torch.long, device=device)

# Decode and display the generated text
print(tokenizer.decode(language_model.generate(context, max_new_tokens=100).squeeze().tolist()))