In [None]:
import os
import torch

from src.torch.utils import device
from src.utils import load_txt_file
from src.tokenizer import Tokenizer
from src.torch.data_loader import DataLoader
from src.torch.transformer import Transformer

### Constants and hyperparameters

In [11]:
# Constants
dataset_path = os.path.join(os.getcwd(), 'dataset', 'input.txt')
tokenizer_path = os.path.join(os.getcwd(), 'checkpoints', 'tokenizer.json')

In [None]:
# Hyperparameters
train_val_split = 0.9 # 90% of the data will be used for training, 10% for validation
batch_size = 64 # The number of samples to use for each batch
block_size = 256 # The size of the sequence length (the context window)
learning_rate = 1e-3 # The learning rate for the optimizer
epochs = 500 # The number of epochs to train the model for
n_embed = 384 # The size of the token embeddings (the dimensionality of the embeddings)
eval_iters = 10 # The number of iterations to evaluate the model
num_attention_heads = 6 # The number of attention heads in the multi-head attention mechanism
num_transformer_blocks = 6 # The number of transformer blocks in the model
dropout = 0.2 # The dropout rate

### Initializations

In [13]:
# Set the random seed for reproducibility
torch.manual_seed(1337);

### Data loading

In [14]:
# Instantiate the tokenizer
tokenizer = Tokenizer()

# Load the state of the tokenizer
tokenizer.load_state(tokenizer_path)

In [15]:
# Load the text file
text = load_txt_file(dataset_path)

# Encode the text using the tokenizer
encoded_text = tokenizer.encode(text)

# Convert the data to a tensor
data = torch.tensor(encoded_text, dtype=torch.long)

In [16]:
# Instantiate the data handler
data_handler = DataLoader(
    data = data, 
    train_val_split = train_val_split
)

### Building the model

In [17]:
# Create the language model
language_model = Transformer(
    vocab_size = tokenizer.vocab_size, # type: ignore
    n_embed = n_embed,
    n_heads = num_attention_heads,
    block_size = block_size,
    n_transformer_blocks = num_transformer_blocks,
    dropout = dropout
)

Model moved to device: mps


### Training the model

In [18]:
# Train the model
language_model.train_model(
    data_loader = data_handler,
    epochs = epochs, 
    lr = learning_rate, 
    batch_size = batch_size,
    eval_iters = eval_iters
)

Epoch 1/500 - Train Loss: 7.1033, Val Loss: 7.0970
Epoch 2/500 - Train Loss: 6.3920, Val Loss: 6.4655
Epoch 3/500 - Train Loss: 6.0905, Val Loss: 6.0725
Epoch 4/500 - Train Loss: 5.7916, Val Loss: 5.8165
Epoch 5/500 - Train Loss: 5.5507, Val Loss: 5.5572
Epoch 6/500 - Train Loss: 5.3630, Val Loss: 5.3522
Epoch 7/500 - Train Loss: 5.2500, Val Loss: 5.2770
Epoch 8/500 - Train Loss: 5.2101, Val Loss: 5.2534
Epoch 9/500 - Train Loss: 5.2063, Val Loss: 5.2511
Epoch 10/500 - Train Loss: 5.1816, Val Loss: 5.2306
Epoch 11/500 - Train Loss: 5.1752, Val Loss: 5.2044
Epoch 12/500 - Train Loss: 5.1322, Val Loss: 5.1698
Epoch 13/500 - Train Loss: 5.1153, Val Loss: 5.1428
Epoch 14/500 - Train Loss: 5.0602, Val Loss: 5.1285
Epoch 15/500 - Train Loss: 5.0454, Val Loss: 5.0358
Epoch 16/500 - Train Loss: 4.9656, Val Loss: 5.0266
Epoch 17/500 - Train Loss: 4.9001, Val Loss: 4.9540
Epoch 18/500 - Train Loss: 4.8673, Val Loss: 4.9235
Epoch 19/500 - Train Loss: 4.7798, Val Loss: 4.8656
Epoch 20/500 - Train 

### Inference

In [19]:
# Generate some text from the trained model
context = torch.zeros((1, 1), dtype=torch.long, device=device)

# Decode and display the generated text
print(tokenizer.decode(language_model.generate(context, max_new_tokens=100).squeeze().tolist()))

 Tart thou hast scolder in a done:
Say she shall save but it your ships,
And to think not be bound Go, young low incheeks of mine,
Before I have not to my sick go, not your passing revenge
Would not re
