# Transformer Experiment

## src

### Configs

In [None]:
## Dataset Configurations
TRAIN_DATASETS_PATH = './data/train_data'
NORMALIZE_TRAIN_DATA_PATH = './data/norm_train_data'
TEST_DATASETS_PATH = './data/test_data'
NORMALIZE_TEST_DATA_PATH = './data/norm_test_data'
CONTRACTIONS_PATH = './data/contractions.json'
GENERATE_PATH = './generated'
MODELS_PATH = './checkpoints'
PLOTS_PATH = './plots'
LOGS_PATH = './logs'

## Training Configurations
BATCH_SIZE = 128
BLOCK_SIZE = 100
LSTM_CONFIGS = dict(
    embedding_dim=64,
    hidden_dim=64,
    num_layers=2,
    dropout=0.4
)
TRANSFORMER_CONFIGS = dict(
    embedding_dim=64,
    n_head=2,
    n_encoders=2,
    n_decoders=2,
    dim_feedforward=64,
    dropout=0.2
)
LEARNING_RATE=1e-3
WEIGHT_DECAY=0.01
EPOCHS=40
GAMMA=0.98
NUM_TRAIN_DATA=4


### Utils

In [None]:
import re
import json
from typing import Dict



def load_contractions(file_path: str) -> Dict[str, str]:
    '''
    Load contractions from a JSON file.

    Args:
        file_path (str): The path to the JSON file containing contractions.

    Returns:
        Dict[str, str]: A dictionary mapping contractions to their expanded forms.
    '''
    with open(file_path, 'r') as f:
        return json.load(f)


def expand_contractions(text: str, contractions_dict: Dict[str, str]) -> str:
    '''
    Expand contractions in a given text using a provided contractions dictionary.

    Args:
        text (str): The input text containing contractions.
        contractions_dict (Dict[str, str]): A dictionary mapping contractions to their expanded forms.

    Returns:
        str: The text with contractions expanded.
    '''
    # Compile the regular expression pattern for matching contractions
    contractions_pattern = re.compile(
        '|'.join(re.escape(key) for key in contractions_dict.keys()),
        flags=re.IGNORECASE
    )

    # Function to replace each contraction with its expanded form
    def replace(match):
        return contractions_dict[match.group(0).lower()]

    # Substitute contractions in the text using the replace function
    return contractions_pattern.sub(replace, text)


In [None]:
import torch


def get_device() -> torch.device:
    '''
    Return CUDA device if cuda is available
    '''
    return torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:
import torch
from torch import nn

from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix


def accuracy_fn(model_logits: torch.Tensor, labels: torch.Tensor) -> float:
    '''
    Compute the accuracy of a classification model.

    Args:
      	model_logits (torch.Tensor): The logits or outputs from the model, of shape (N, C),
			where N is the number of samples and C is the number of classes.
    	labels (torch.Tensor): The true labels, of shape (N,), where each value is in the range [0, C-1].

    Returns:
    	float: The accuracy of the model on the provided batch of data, in percentage (%).
    '''
    preds = torch.softmax(model_logits, dim=1).argmax(dim=1)

    return (torch.sum(preds == labels).item() / len(labels))

def get_perplexity(model_logits: torch.Tensor, labels: torch.Tensor) -> float:
    '''
    Calculate the perplexity of a language model.

    Args:
    	model_logits (torch.Tensor): The logits or outputs from the model.
    	labels (torch.Tensor): The true labels.

    Returns:
    	float: The perplexity of the model.
    '''
    criterion = nn.CrossEntropyLoss()

    loss = criterion(model_logits, labels.to(torch.long))

    perplexity = torch.exp(loss)

    return perplexity.item()

def get_precision(model_logits: torch.Tensor, labels: torch.Tensor) -> float:
    '''
    Compute the precision of a classification model.

    Args:
    	model_logits (torch.Tensor): The logits or outputs from the model.
    	labels (torch.Tensor): The true labels.

    Returns:
    	float: The precision of the model.
    '''
    preds = torch.softmax(model_logits, dim=1).argmax(dim=1)

    preds = preds.detach().cpu().numpy()
    labels = labels.detach().cpu().numpy().astype('int64')

    return precision_score(labels, preds, average='weighted')

def get_recall(model_logits: torch.Tensor, labels: torch.Tensor) -> float:
    '''
    Compute the recall of a classification model.

    Args:
    	model_logits (torch.Tensor): The logits or outputs from the model.
    	labels (torch.Tensor): The true labels.

    Returns:
		float: The recall of the model.
    '''
    preds = torch.softmax(model_logits, dim=1).argmax(dim=1)

    preds = preds.detach().cpu().numpy()
    labels = labels.detach().cpu().numpy().astype('int64')

    return recall_score(labels, preds, average='weighted')

def get_f1_score(model_logits: torch.Tensor, labels: torch.Tensor) -> float:
    '''
    Compute the F1 score of a classification model.

    Args:
    	model_logits (torch.Tensor): The logits or outputs from the model.
    	labels (torch.Tensor): The true labels.

    Returns:
    	float: The F1 score of the model.
    '''
    preds = torch.softmax(model_logits, dim=1).argmax(dim=1)

    preds = preds.detach().cpu().numpy()
    labels = labels.detach().cpu().numpy().astype('int64')

    return f1_score(labels, preds, average='weighted')

def get_specificity(model_logits: torch.Tensor, labels: torch.Tensor) -> float:
    '''
    Compute the specificity of a classification model.

    Args:
    	model_logits (torch.Tensor): The logits or outputs from the model.
    	labels (torch.Tensor): The true labels.

    Returns:
    	float: The specificity of the model.
    '''
    preds = torch.softmax(model_logits, dim=1).argmax(dim=1)

    preds = preds.detach().cpu().numpy()
    labels = labels.detach().cpu().numpy()

    # Get the confusion matrix
    cm = confusion_matrix(labels, preds)

    # Calculate specificity for each class
    average_specificity = 0
    for i in range(len(cm)):
        # True Positives, False Positives, False Negatives, and True Negatives
        TP = cm[i, i]
        FP = cm[:, i].sum() - TP
        FN = cm[i, :].sum() - TP
        TN = cm.sum() - (TP + FP + FN)
        
        # Calculate specificity
        specificity = TN / (TN + FP) if (TN + FP) > 0 else 0

        average_specificity += specificity
    
    average_specificity /= len(cm)

    return average_specificity


In [None]:
import logging


def configure_logger(name=__name__, log_file='app.log', level=logging.INFO) -> logging.Logger:
    '''
    Configures the logger with console and file handlers.

    Args:
        name (str): The name of the logger.
        log_file (str): The file to log messages to.
        level (int): The logging level.
    
    Return:
        The logger object.
    '''
    logger = logging.getLogger(name)
    logger.setLevel(level)
    logger.propagate = False # do not pass logs to the default logger

    # Check if the logger already has handlers to prevent adding multiple handlers
    if not logger.handlers:
        # Create the formatter object for the logger
        file_formatter = logging.Formatter('%(asctime)s \t %(filename)s \t %(levelname)s \t %(message)s')
        stdout_formatter = logging.Formatter('%(levelname)s \t %(message)s')

        # Create the console handler and setting its level
        console_handler = logging.StreamHandler()
        console_handler.setLevel(level)

        # Create the file handler and setting its level
        file_handler = logging.FileHandler(log_file)
        file_handler.setLevel(level)

        # Add the formatter to the handlers
        console_handler.setFormatter(stdout_formatter)
        file_handler.setFormatter(file_formatter)

        # Add the handlers to the logger
        logger.addHandler(console_handler)
        logger.addHandler(file_handler)
    
    return logger


In [None]:
import torch


def temperature_sampling(model_logits: torch.Tensor, temperature:float=1.0) -> int:
    '''
    Perform temperature sampling to generate the next token based on model logits.

    Args:
    - model_logits (torch.Tensor): Logits (raw predictions) from the model, typically for the next
        token prediction. Should have shape [sequence_length, vocab_size].
    - temperature (float): Temperature parameter to scale the logits before applying softmax.
        Higher values make the probability distribution flatter (more random), while lower values
        make it sharper (more deterministic). Default is 1.0.

    Returns:
    - int: The sampled token index from the probability distribution.
    '''
    # Get the logits for the last character in the sequence
    # logits = model_logits[-1, :]
    # Scale them using temperature to affect the steepness of the underling distrubution
    # logits = logits / temperature
    logits = model_logits / temperature
    # Generate probabilities
    probs = torch.softmax(logits, dim=-1)
    # Sample from the probability distribution
    next_token = torch.multinomial(probs, num_samples=1).item()

    return next_token

def create_tgt(prev_batch: torch.Tensor, x_batch: torch.Tensor) -> torch.Tensor:
    '''
    Generate a target tensor for sequence models by shifting the input sequence to the
    right and prepending a previous batch tensor.

    Args:
        prev_batch (torch.Tensor): A tensor containing the previous batch data, typically the start-of-sequence tokens.
        x_batch (torch.Tensor): A tensor containing the current batch of input sequences.

    Returns:
        torch.Tensor: A tensor where the `x_batch` is shifted to the right, and the `prev_batch` is prepended as the first element in the sequence.
    '''
    # Initialize tgt with zeros (assuming 0 is the start-of-sequence token)
    tgt = torch.zeros_like(x_batch)

    # Shift x_batch to the right and append y_batch as the last element
    tgt[:, 0] = prev_batch
    tgt[:, 1:] = x_batch[:, :-1]

    return tgt


In [None]:
from src.utils.log import configure_logger

import torch
from torch import nn
from pathlib import Path
from os import remove


# Get the logger for this module
logger = configure_logger(__name__)


def save_model(model: nn.Module, path: str, stops=False) -> None:
    '''
    Save a PyTorch model to a specified path.

    Args:
        model (torch.Module): The PyTorch model to be saved.
        path (str): The path where the model will be saved.
        stops (bool, optional): If True, stops the function execution if the model file already exists at the given path. Defaults to False.

    Raises:
        AssertionError: If the file extension of the specified path is not `.pt` or `.pth`.

    Returns:
        None
    '''
    target_path = Path('/'.join(path.split('/')[:-1]))
    model_name = path.split('/')[-1]

    if not (model_name.endswith('.pth') or model_name.endswith('.pt')):
        logger.error('Wrong extension: Expecting `.pt` or `.pth`.')
        return
    
    # Creating the directory that the model is going to be saved if not exists
    if not target_path.exists():
        target_path.mkdir(parents=True, exist_ok=True)

    # If path already exists
    if Path(path).is_file():
        logger.info(f'Model `{model_name}` already exists on `{target_path}`.')
        if stops:
            return
        logger.info(f'Deleting `{path}`.')
        remove(path)

    # Saving the Model to the given path
    logger.info(f'Saving Model `{model_name}` to `{target_path}`.')
    torch.save(model.state_dict(), path)

    logger.info(f'Model Successfully Saved to `{path}`.')


def load_model(model_class: nn.Module, model_path: str, device: torch.device=torch.device('cpu'), model_device: bool=False, **kwargs) -> nn.Module:
    '''
    Loads a PyTorch model from a specified file.
    
    Parameters:
        model_path (str): Path to the saved model file (e.g., 'model.pth').
        model_class (nn.Module): The class of the model to be loaded.
        device (torch.device): The device that the model will be load on. Default is CPU.
        model_device (bool): If True the model needs device in its arguments. Default is False
        **kwargs: Additional arguments required to initialize the model class.

    Returns:
        The loaded model.
    '''
    # Initialize the model
    if model_device:
        model = model_class(device=device, **kwargs)
    else:
        model = model_class(**kwargs)

    # Load the state dict (parameters)
    state_dict = torch.load(model_path, map_location=torch.device(device))
    
    # Load the parameters into the model
    model.load_state_dict(state_dict)
    
    # Set the model to evaluation mode
    model.eval()

    logger.info('Model succesfully loaded.')
    
    return model


In [None]:
import json

import os
from typing import List, Union, Dict, Any


def __load_notebook(notebook_path: str) -> Dict[str, Any]:
    '''
    Loads an existing Jupyter notebook.
    
    Args:
        notebook_path (str): Path to the .ipynb file to load.
        
    Returns:
        dict: The notebook content as a dictionary.
    '''
    if os.path.exists(notebook_path):
        with open(notebook_path, 'r') as f:
            notebook = json.load(f)
    else:
        # If the file doesn't exist, create an empty notebook structure
        notebook = {
            'cells': [],
            'metadata': {
                'kernelspec': {
                    'display_name': 'Python 3',
                    'language': 'python',
                    'name': 'python3'
                },
                'language_info': {
                    'codemirror_mode': {
                        'name': 'ipython',
                        'version': 3
                    },
                    'file_extension': '.py',
                    'mimetype': 'text/x-python',
                    'name': 'python',
                    'nbconvert_exporter': 'python',
                    'pygments_lexer': 'ipython3',
                    'version': '3.x'
                }
            },
            'nbformat': 4,
            'nbformat_minor': 2
        }
    return notebook


def py_to_ipynb(
        py_files: List[str],
        output_ipynb: str,
        comment: Union[str, None]=None
    ) -> None:
    '''
    Converts a list of Python (.py) files into a Jupyter Notebook (.ipynb) with optional comments.
    
    Args:
        py_files (list of str): List of paths to the .py files.
        output_ipynb (str): Path to the output .ipynb file.
        comment (list of str): The comment markdown of the notebook.
    '''
    # Load the existing notebook or create a new one
    notebook = __load_notebook(output_ipynb)

    # Add a Markdown cell for the comment if provided
    if comment:
        markdown_cell = {
            'cell_type': 'markdown',
            'metadata': {},
            'source': comment
        }
        notebook['cells'].append(markdown_cell)  # Append the markdown cell to the notebook

    # Iterate through each .py file and corresponding comment
    for py_file in py_files:
        if py_file.endswith('__.py') or not py_file.endswith('.py'):
            continue

        # Read the Python (.py) file content
        with open(py_file, 'r') as f:
            source_code = f.read()
        
        # Create a code cell for the Python script content
        code_cell = {
            'cell_type': 'code',
            'metadata': {},
            'source': source_code.splitlines(True),  # Split lines to maintain formatting
            'outputs': [],
            'execution_count': None
        }
        
        # Append the code cell to the list of cells
        notebook['cells'].append(code_cell)
    
    
    # Write the notebook to a file
    with open(output_ipynb, 'w') as f:
        json.dump(notebook, f, indent=4)


In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset, Subset

import numpy as np
import math

from multiprocessing import cpu_count
from typing import Tuple


def get_loaders(
        dataset: Dataset,
        batch_size: int,
        train_pro: float,
        drop_last: bool=False,
        offset: int = 0,
        step: int = 1
    ) -> Tuple[DataLoader, DataLoader, int]:
    '''
    Splits the dataset into training and validation sets based on the given proportion,
    creates DataLoaders for each, and returns the DataLoaders along with an updated offset 
    for cyclic iteration through splits.

    Args:
        dataset (Dataset): The dataset to be split and loaded.
        batch_size (int): Number of samples per batch.
        train_pro (float): Proportion of data to use for training (between 0 and 1).
        drop_last (bool, optional): Whether to drop the last incomplete batch if the dataset 
                                    size is not divisible by batch size. Default is False.
        offset (int, optional): Starting point for the train/validation split to allow for 
                                cyclic shifting of the split. Default is 0.
        step (int, optional): Amount to increment the offset after each function call, 
                              useful for iterating through different splits. Default is 1.

    Returns:
        Tuple[DataLoader, DataLoader, int]:
            - `train_loader`: DataLoader for the training set.
            - `valid_loader`: DataLoader for the validation set.
            - `offset`: The updated offset value for the next split.
    '''
    # Use NumPy for efficient index handling
    indices = np.arange(len(dataset))
    # num_splits = len(dataset) // batch_size + (1 if len(dataset) % batch_size != 0 else 0)
    num_splits = len(dataset) // batch_size
    splits = np.array_split(indices, num_splits)

    train_splits_size = int(len(splits) * train_pro)
    valid_splits_size = len(splits) - train_splits_size

    # Create train splits with wrapping around if necessary
    if train_splits_size + offset > len(splits):
        train_splits = splits[offset:] + splits[: (train_splits_size + offset) % len(splits)]
    else:
        train_splits = splits[offset: train_splits_size + offset]

    # Create valid splits with wrapping around if necessary
    valid_offset = (train_splits_size + offset) % len(splits)
    if valid_offset + valid_splits_size > len(splits):
        valid_splits = splits[valid_offset:] + splits[: (valid_offset + valid_splits_size) % len(splits)]
    else:
        valid_splits = splits[valid_offset: valid_offset + valid_splits_size]

    # Flatten the lists of batches to lists of indices
    train_indices = np.concatenate(train_splits)
    valid_indices = np.concatenate(valid_splits)

    # Creating DataLoaders
    train_loader = DataLoader(Subset(dataset, train_indices), batch_size=batch_size, shuffle=True, drop_last=drop_last, num_workers=cpu_count(), pin_memory=True)
    valid_loader = DataLoader(Subset(dataset, valid_indices), batch_size=batch_size, shuffle=False, drop_last=drop_last, num_workers=cpu_count(), pin_memory=True)

    # Update the offset
    offset = (offset + step) % len(splits)

    return train_loader, valid_loader, offset


class PositionalEncoding(nn.Module):
    def __init__(self, block_size: int, d_model: int) -> None:
        super().__init__()

        # Positional encoding matrix
        pe = torch.zeros(block_size, d_model)

        # Position indices
        position = torch.arange(0, block_size, dtype=torch.float).unsqueeze(1)

        # Scaling factors for positions
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        # Sine and cosine positional encodings
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe) # buffers are tensors that are not updated during training

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.pe[:, :x.size(1), :]


In [None]:
from src.utils.log import configure_logger

import matplotlib.pyplot as plt

import os
from typing import List, Tuple, Optional


# Get the logger for this module
logger = configure_logger(__name__)


def plot_loss(
        loss_list: List[float],
        type: str = 'eval',
        c: str = 'g',
        figsize: Tuple[int, int] = (6, 4),
        fontsize: int = 14,
        save_path: Optional[str] = None
    ) -> None:
    '''
    Plots the training or evaluation loss over epochs and optionally saves the plot to a specified path.

    Args:
        loss_list (List[float]): List of loss values to be plotted.
        type (str, optional): Type of loss. Default is 'eval'. Accepted values are 'train' or 'eval'.
        c (str, optional): Color of the plot. Default is 'g' (green).
        figsize (Tuple, optional): Size of the figure (width, height) in inches. Default is (6, 4).
        fontsize (int, optional): Font size of the title. Default is 14.
        save_path (Optional[str], optional): Path to save the plot image. If None, the plot is not saved. Default is None.

    Raises:
        AssertionError: If an invalid loss type is provided.

    Returns:
        None
    '''

    if type not in ['train', 'eval']:
        logger.error(f'Invalid loss type: Got `{type}`. Only `train`, `eval` are accepted')
        return

    plt.figure(figsize=figsize)

    if type == 'eval':
        plt.plot(range(len(loss_list)), loss_list, c=c, label="Validation Lost")
        plt.title(f"Validation Loss", fontsize=fontsize)
    else:
        plt.plot(range(len(loss_list)), loss_list, c=c, label="Training Lost")
        plt.title(f"Training Loss", fontsize=fontsize)

    plt.xlabel("Epochs")
    plt.ylabel("Loss")

    if save_path:
        # Create the directory if it doesn't exist
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path)
        logger.info(f"Plot saved to `{save_path}`.")

    plt.grid(color='gray', linestyle='--', linewidth=0.5)

    plt.show()


def plot_losses(
        train_loss: List[float],
        eval_loss: List[float],
        c: List[str] = ['g', 'b'],
        fig_size: Tuple[int, int] = (6, 4),
        font_size: int = 11,
        save_path: Optional[str] = None
    ) -> None:
    '''
    Plots both training and evaluation losses over epochs and optionally saves the plot to a specified path.

    Args:
        train_loss (List[float]): List of training loss values to be plotted.
        eval_loss (List[float]): List of evaluation loss values to be plotted.
        c (List[str], optional): List of colors for the plots. Default is ['g', 'b'] (green for validation loss, blue for training loss).
        fig_size (Tuple[int, int], optional): Size of the figure (width, height) in inches. Default is (6, 4).
        font_size (int, optional): Font size of the legend and title. Default is 11.
        save_path (Optional[str], optional): Path to save the plot image. If None, the plot is not saved. Default is None.

    Returns:
        None
    '''

    plt.figure(figsize=fig_size)

    plt.plot(range(len(train_loss)), train_loss, c=c[0], label="Training Loss")
    plt.plot(range(len(eval_loss)), eval_loss, c=c[1], label="Validation Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title(f"Loss Curves", fontsize=14)
    plt.legend(fontsize=font_size)

    if save_path:
        # Create the directory if it doesn't exist
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path)
        logger.info(f"Plot saved to `{save_path}`.")

    plt.grid(color='gray', linestyle='--', linewidth=0.5)

    plt.show()


### Dataset Class

In [None]:
from src.utils import configure_logger
from src.utils.data import load_contractions, expand_contractions

import torch
from torch.utils.data import Dataset

import os
import re
import string
from tqdm import tqdm
from typing import Dict, List, Tuple, Union, Iterable


class TransformerShakespeareDataset(Dataset):
    '''
    Dataset class for text generation using Shakespeare's works.
    '''

    # Get the logger as a class attribute
    logger = configure_logger(__name__)

    def __init__(self,
            dataset_path: str,
            norm_dataset_path: str,
            constractions_path: str,
            block_size: int,
            to_tensors: bool=False,
            write_norm: bool=True,
            device: torch.device = torch.device('cpu')
        ) -> None:
        '''
        Initializes a TransformerShakespeareDataset object.

        Args:
            dataset_path (str): Path to the directory containing the raw dataset files.
            norm_dataset_path (str): Path to save the normalized dataset files.
            contractions_path (str): Path to the file containing contractions for expansion.
            block_size (int): Size of the sequence block for creating training samples.
            to_tensors (bool, optional): If True, returns samples as PyTorch tensors (default: False).
            write_norm (bool, optional): If True, writes normlized files to memory (default: True).
            device (torch.device, optional): Device to store tensors on (default: 'cpu').
        '''
        super().__init__()

        os.makedirs(dataset_path, exist_ok=True)
        os.makedirs(norm_dataset_path, exist_ok=True)

        self.dataset_path = dataset_path
        self.norm_dataset_path = norm_dataset_path
        self.constractions_path = constractions_path
        self.block_size = block_size
        self.to_tensors = to_tensors
        self.device = device

        if write_norm:
            self._normalize_data()

        self.vocab = self._create_vocab()
        self.vocab_size = len(self.vocab)

        # Create mapping from characters to integers
        self.char_to_int = {char: i for i, char in enumerate(self.vocab)}
        self.int_to_char = {i: char for i, char in enumerate(self.vocab)}

        self.samples = self._create_samples()
    
    def __getitem__(self, index: slice) -> List[Tuple[int, Union[torch.Tensor, Iterable], int]]:
        '''
        Retrieves a sample from the dataset by index or slice.

        Args:
            index (int or slice): Index or slice to retrieve from the dataset.

        Returns:
            List[Tuple[Union[torch.Tensor, Iterable], int]]: List of samples, each sample is a tuple of 
                (input_sequence, target_char). Input_sequence can be a torch.Tensor if to_tensors is True, 
                otherwise an iterable (list of integers).
        '''
        if isinstance(index, slice):
            if self.to_tensors:
                return [(sample[0], torch.tensor(sample[1], dtype=torch.int64), sample[2]) for sample in self.samples[index]]
            return [(sample[0], sample[1], sample[2]) for sample in self.samples[index]]
        else:
            if self.to_tensors:
                return (self.samples[index][0], torch.tensor(self.samples[index][1], dtype=torch.int64), self.samples[index][2])
            return (self.samples[index][0], self.samples[index][1], self.samples[index][2])
        
    def __len__(self) -> int:
        '''
        Returns the total number of samples in the dataset.

        Returns:
            int: Total number of samples in the dataset.
        '''
        return len(self.samples)
    
    def __str__(self) -> str:
        '''
        Returns a string representation of the dataset object.

        Returns:
            str: String representation of the dataset object.
        '''
        return f'TransformerShakespeareDataset(vocab_size: {self.vocab_size}, block_size: {self.block_size}, to_tensors: {self.to_tensors}, device: {self.device})'
    
    def _encode(self, s: str) -> List[int]:
        '''
        Encode a string into a list of integers using the vocabulary.
        '''
        return [self.char_to_int[c] for c in s]

    def _decode(self, l: List[int]) -> str:
        '''
        Decode a list of integers into a string using the vocabulary.
        '''
        return ''.join([self.int_to_char[i] for i in l])

    @staticmethod
    def _create_vocab() -> List[str]:
        '''
        Geerate the vocabulary of the dataset, containing all lowercase english letters, all
            digits (0-9) and the white space characters (' ', '\n', '\t'). It doen't contain
            any punctuations for now.

        Returns:
            List[str]: The vocabulary as a list of strings.
        '''
        vocab = list(string.ascii_lowercase)
        vocab += list(string.digits)
        vocab += list(string.punctuation)
        vocab += [' ', '\n', '\t']
        vocab += '\0' # <sos> token (goes at index 0)

        return sorted(set(vocab))
    
    @staticmethod
    def _normalize(content: str, contractions_dict: Dict[str, str]) -> str:
        '''
        Normilize the given text.

        Args:
            content (str): The string to be normilized
            contractions_dict (Dict[str, str]): Dictionary of contractions for expansion.

        Returns:
            str: The normilized string.
        '''
        content = content.lower()
        content = re.sub(r' +', ' ', content) # Replace multiple spaces
        content = expand_contractions(content, contractions_dict=contractions_dict)

        return content

    def _proc_file(self, file: str, contractions_dict: Dict[str, str]) -> str:
        '''
        Processes a file by normalizing its content.

        Args:
            file (str): File to be processed.
            contractions_dict (Dict[str, str]): Dictionary of contractions for expansion.

        Returns:
            str: Normalized content of the file.
        '''
        # Get the content of the file
        file_path = os.path.join(self.dataset_path, file)

        # Only identifying .txt files as datasets
        if not file_path.endswith('.txt'):
            return ''

        with open(os.path.join(self.dataset_path, file)) as f:
            content = f.read()

        content = TransformerShakespeareDataset._normalize(content, contractions_dict)

        return content

    def _write_norm_file(self, text: str, file_name: str) -> None:
        '''
        Writes normalized text to a file.

        Args:
            text (str): Normalized text to be written.
            file_name (str): Name of the file to write the text to.
        '''
        with open(os.path.join(self.norm_dataset_path, file_name), 'w') as f:
            f.write(text)

    def _normalize_data(self) -> None:
        '''
        Normilize the dataset files.
        '''
        contractions_dict = load_contractions(self.constractions_path)

        for file_name in os.listdir(self.dataset_path):
            TransformerShakespeareDataset.logger.info(f'Processing dataset: {os.path.join(self.dataset_path, file_name)}.')
            # Normalize the content of the file
            norm_text = self._proc_file(file_name, contractions_dict)

            if norm_text == '':
                TransformerShakespeareDataset.logger.info(f'File {file_name} is not a .txt file, so it\'s been ignored.')
                continue

            self._write_norm_file(norm_text, file_name=file_name)
            TransformerShakespeareDataset.logger.info(f'Normalized dataset succesfully saved to: {os.path.join(self.norm_dataset_path, file_name)}')

    def _load_texts(self) -> str:
        '''
        Load and concatenate the text content of all files in the normalized data path.

        Returns:
            str: A single string containing the concatenated text content of all files.
        '''
        texts = ''
        for file_name in os.listdir(self.norm_dataset_path):
            with open(os.path.join(self.norm_dataset_path, file_name), 'r') as file:
                texts += file.read()

        return texts

    def _create_samples(self) -> List[Tuple[int, List[int], int]]:
        '''
        Creates training samples from the concatenated text data.

        Returns:
            List[int, Tuple[List[int], int]]: List of training samples, where each sample is a tuple of 
                (input_sequence, target_char). Each input_sequence is a list of integers (encoded characters),
                and target_char is the next character in the sequence.
        '''
        # Load the content of those normalized datasets
        text = self._load_texts()

        # Tokenizing the entire dataset
        data = self._encode(text)

        samples: List[Tuple[int, List[int], int]] = []
        for i in tqdm(range(1, len(data) - self.block_size), ascii=True, desc='Creating Samples'):
            prev_token = data[i-1]                     # previous token of the sequence
            input_sequence = data[i:i+self.block_size] # the sequence that will be passed into the model
            target_token = data[i+self.block_size]     # next token of the sequence
            samples.append((prev_token, input_sequence, target_token))

        TransformerShakespeareDataset.logger.info('Samples created succesfully.')

        return samples

    def encode(self, input_string: str) -> List[int]:
        '''
        Encode the given string.

        Args:
            input_string (str): The string to be encoded

        Returns:
            List[int]: The encoded version of the string
        '''
        contractions_dict = load_contractions(self.constractions_path)

        norm_input = self._normalize(input_string, contractions_dict)

        return self._encode(norm_input)
    
    def decode(self, output_list: List[int]) -> str:
        '''
        Decode a list of intagers and convert them into string

        Args:
            output_list (List[int]): The encoded list of intagers

        Returns:
            str: The decoded string
        '''
        return self._decode(output_list)


### Evaluator Class

In [None]:
from src.utils.evaluation import (
    accuracy_fn,
    get_precision,
    get_recall,
    get_specificity,
    get_f1_score,
    get_perplexity
)
from src.utils.log import configure_logger
from src.utils.models import create_tgt

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

import multiprocessing
from multiprocessing.managers import BaseManager
from tqdm import tqdm
from typing import Callable, Union, List, Dict


class TransformerEvaluator:
    '''
    A class to evaluate a PyTorch model on a test dataset using various metrics.
    '''

    # Initialize the logger as a class attribute
    logger = configure_logger(__name__)

    def __init__(self,
            model: nn.Module,
            test_ds: Dataset,
            batch_size: int,
            cretirion: nn.Module,
            device: torch.device=torch.device('cpu')
        ) -> None:
        '''
        Initializes the TransformerEvaluator with the model, test dataset, loss function, and device.

        Args:
            model (nn.Module): The neural network model to be evaluated.
            test_ds (Dataset): Dataset containing the test data.
            batch_size (int): The batch size for creating the DataLoader.
            criterion (nn.Module): Loss function used for evaluation.
            device (torch.device, optional): Device to run the evaluation on (CPU or GPU). Defaults to CPU.
        '''
        self.model = model.to(device, non_blocking=True)
        self.test_ds = test_ds
        self.batch_size = batch_size
        self.cretirion = cretirion
        self.device = device

    def _create_DataLoader(self) -> DataLoader:
        '''
        Create a DataLoader from the dataset
        '''
        return DataLoader(
            dataset=self.test_ds,
            batch_size=self.batch_size,
            shuffle=False,
            drop_last=True,
            num_workers=multiprocessing.cpu_count(),
            pin_memory=True
        )
    
    def evaluate(self) -> Dict[str, Union[float, List[List[float]]]]:
        '''
        Evaluates the model on the test dataset using various metrics.

        Returns:
            Dict[str, Union[float, List[List[float]]]]: A dictionary containing evaluation metrics.
                - "Loss": The loss value.
                - "perplexity": The perplexity value.
                - "accuracy": Accuracy of the model.
                - "precision": Precision of the model.
                - "recall": Recall of the model.
                - "specificity": Specificity of the model.
                - "f1_score": F1 score of the model.
        '''
        def _initialize_results(manager: BaseManager) -> Dict[str, Union[float, List[List[float]]]]:
            return manager.dict({
                'Loss': 0.0,
                'perplexity': 0.0,
                'accuracy': 0.0,
                'precision': 0.0,
                'recall': 0.0,
                'specificity': 0.0,
                'f1_score': 0.0,
            })

        def _define_metrics() -> Dict[str, Callable[[torch.Tensor, torch.Tensor], float]]:
            return {
                'Loss': self.cretirion,
                'perplexity': get_perplexity,
                'accuracy': accuracy_fn,
                'precision': get_precision,
                'recall': get_recall,
                'specificity': get_specificity,
                'f1_score': get_f1_score,
            }

        def _calculate_and_update(
                y_pred: torch.Tensor,
                y_true: torch.Tensor,
                key: str,
                metric: Callable[[torch.Tensor, torch.Tensor], Union[List[float], float]],
            ) -> None:
            '''
            Calculate the specified metric and update the results dictionary.

            Args:
                y_pred (torch.Tensor): The predicted labels.
                y_true (torch.Tensor): The ground truth labels.
                key (str): The metric name.
                metric (Callable[[torch.Tensor, torch.Tensor], Union[List[float], float]]): The metric function.
            '''
            nonlocal results
            if key == 'Loss':
                results[key] = metric(y_pred, y_true.to(torch.long)).item()
            else:
                results[key] = metric(y_pred, y_true)

        manager = multiprocessing.Manager()
        results = _initialize_results(manager)
        metrics = _define_metrics()

        TransformerEvaluator.logger.info('Start Evaluation Process.')

        test_dl = self._create_DataLoader()
        y_pred = []
        y_true = []

        self.model.eval()
        with torch.inference_mode():
            for prev_batch, x_batch, y_batch in tqdm(test_dl, ascii=True, desc='Producing Predictions'):
                # Move samples to the same device as the model
                x_batch = x_batch.to(self.device, non_blocking=True)
                y_batch = y_batch.to(self.device, non_blocking=True)

                # Produce the target tensor and generate logits
                tgt = create_tgt(prev_batch, x_batch)
                y_logits = self.model(x_batch, tgt)

                for batch in y_logits:
                    y_pred.append(batch.tolist())
                y_true.extend(y_batch.tolist())

        # Start multiprocessing for metric calculations
        processes = []
        for metric_name, metric_fn in tqdm(metrics.items(), ascii=True, desc='Calculating Metrics'):
            process = multiprocessing.Process(target=_calculate_and_update, args=(torch.tensor(y_pred), torch.tensor(y_true, dtype=torch.float32), metric_name, metric_fn))
            processes.append(process)
            process.start()

        # Ensure all processes have completed
        for process in processes:
            process.join()

        TransformerEvaluator.logger.info("Evaluation Process Completed Successfully.")

        return dict(results)


### Trainer Class

In [None]:
from src.utils.save import save_model
from src.utils.log import configure_logger
from src.utils.training import get_loaders
from src.utils.models import create_tgt

import torch
from torch import nn
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import Dataset, DataLoader, random_split

from multiprocessing import cpu_count

from tqdm import tqdm
from timeit import default_timer as timer
from typing import Callable, Tuple, Dict, Union, List


class TransformerTrainer:
    '''
    A class to handle the training of a PyTorch model.
    '''

    # Initialize the logger as a class attribute
    logger = configure_logger(__name__)

    def __init__(self,
            model: nn.Module,
            dataset: Dataset,
            batch_size: int,
            criterion : Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
            eval_fn: Callable[[torch.Tensor, torch.Tensor], float],
            opt: torch.optim.Optimizer,
            scheduler: Union[LRScheduler, None]=None,
            train_prop: float=0.8,
            step: int=1,
            device: torch.device=torch.device('cpu'),
        ) -> None:
        '''
        Initializes the TransformerTrainer with the model, data loaders, loss function, evaluation function, and optimizer.

        Args:
            model (nn.Module): The neural network model to be trained.
            dataset (Dataset): The dataset the will be train the model.
            batch_size (int): The batch_size of the model's DataLoaders.
            criterion (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]): Loss function.
            eval_fn (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]): Evaluation function.
            opt (torch.optim.Optimizer): Optimizer for updating the model parameters.
            train_prop (int): The propotion of the dataset that will be use to train the model. Default is 80%.
            step (int): The step of the cross validational batches.
            device (torch.device): The device that the model will be trained on. Default is cpu.
        '''
        self.model = model.to(device, non_blocking=True)
        self.dataset = dataset
        self.batch_size = batch_size
        self.criterion  = criterion
        self.eval_fn = eval_fn
        self.opt = opt
        self.scheduler = scheduler
        self.train_prop = train_prop
        self.step = step
        self.device = device

    def _get_loaders(self) -> Tuple[DataLoader, DataLoader]:
        # Get the training and validation Loaders for the epoch
        train_ds, valid_ds = random_split(self.dataset, [self.train_prop, 1 - self.train_prop])

        train_dl = DataLoader(train_ds, self.batch_size, shuffle=True, num_workers=cpu_count(), pin_memory=True)
        valid_dl = DataLoader(valid_ds, self.batch_size, shuffle=False, num_workers=cpu_count(), pin_memory=True)

        return train_dl, valid_dl

    def _get_loaders_cv(self, offset: Union[List[int], int]) -> Tuple[DataLoader, DataLoader, int]:
        train_dl, valid_dl, offset = get_loaders(
            dataset=self.dataset,
            batch_size=self.batch_size,
            train_pro=self.train_prop,
            drop_last=True,
            offset=offset,
            step=self.step
        )

        return train_dl, valid_dl, offset
    
    def _process_data_loaders(self, dl: DataLoader) -> Tuple[float, float]:
        '''
        Process batches from a DataLoader for either training or validation.

        This method iterates over batches of data from a given DataLoader (`dl`), computes
        the loss and evaluation metrics for each batch, and optionally performs gradient 
        descent (backpropagation) if the model is in training mode.

        Args:
            dl (DataLoader): DataLoader containing batches of data.

        Returns:
            Tuple[float, float]: A tuple containing the average batch loss and evaluation 
                score across all batches in the DataLoader.
        '''
        # Initialize batch loss and accuracy
        batch_loss, batch_eval = 0.0, 0.0

        phase = 'Training Step' if self.model.training else 'Validation Step'

        for prev_batch, x_batch, y_batch in tqdm(dl, ascii=True, desc=f'             {phase}'):
            # Moving batches to device
            prev_batch = prev_batch.to(self.device, non_blocking=True) # batch containing all the previous tokens of the sequence
            x_batch = x_batch.to(self.device, non_blocking=True)       # batch containing the sequence to be passed into the model
            y_batch = y_batch.to(self.device, non_blocking=True)       # batch containing the next token of the sequence

            # Creating the inputs of the embedding layer of the decoder, which consist of the previous tokens in the sequence
            tgt = create_tgt(prev_batch, x_batch)

            # Generating predictions (forward pass)
            model_logits = self.model(x_batch, tgt)

            # Calculate loss
            loss = self.criterion(model_logits, y_batch)
            batch_loss += loss.item()
            batch_eval += self.eval_fn(model_logits, y_batch)

            # Backward pass and optimizer step (only for training)
            if self.model.training:
                self.opt.zero_grad()
                loss.backward()
                self.opt.step()

        batch_loss /= len(dl)
        batch_eval /= len(dl)

        return batch_loss, batch_eval
    
    def _training_step(self, train_dl: Union[DataLoader, List[DataLoader]]) -> Tuple[float, float]:
        '''
        Performs a single training step over the training DataLoader.

        Args:
            train_dl (DataLoader): The training dataloader that will fit the model.

        Returns:
            Tuple[float, float]: The average training loss and evaluation score for the epoch.
        '''
        self.model.train()

        train_loss, train_eval = self._process_data_loaders(train_dl)

        self.model.eval()
        
        return train_loss, train_eval
    
    def _validation_step(self, valid_dl: Union[DataLoader, List[DataLoader]]) -> Tuple[float, float]:
        '''
        Performs a single validation step over the validation DataLoader.

        Args:
            valid_dl (Dataloader): The validation dataloader to evaluate the model.

        Returns:
            Tuple[float, float]: The average validation loss and evaluation score for the epoch.
        '''
        self.model.eval()

        with torch.inference_mode():
            valid_loss, valid_eval = self._process_data_loaders(valid_dl)

        return valid_loss, valid_eval

    def fit(self,
            epochs: int,
            save_per: Union[int, None]=None,
            save_path: Union[str, None]=None,
            save_best: bool=False,
            cross_validate: bool=False,
        ) -> Dict[str, Union[List[float], str, int]]:
        '''
        Trains the model for a specified number of epochs and optionally saves checkpoints.

        Args:
            epochs (int): The number of epochs to train the model for.
            save_per (Union[int, None], optional): Frequency (in epochs) to save model checkpoints. Defaults to None.
            save_path (Union[str, None], optional): The path that the checkpoints will be saved on. Defaults to None.
            cross_validate (bool, optional): Whether to use cross-validation. Default is False.

        Returns:
            Dict[str, Union[List[float], str, int]]: A dictionary containing training statistics and metadata.
                It includes:
                - 'train_loss': List of training losses for each epoch.
                - 'train_eval': List of training evaluation scores for each epoch.
                - 'valid_loss': List of validation losses for each epoch.
                - 'valid_eval': List of validation evaluation scores for each epoch.
                - 'model_name': Name of the model class.
                - 'loss_fn': Name of the loss function class.
                - 'eval_fn': Name of the evaluation function.
                - 'optimizer': Name of the optimizer class.
                - 'device': Type of device the model is on.
                - 'epochs': Total number of epochs trained.
                - 'total_time': Total time taken for training and evaluation.
                - 'save_path': Saved path of the checkpoints.
                - 'cross_validate': Wheather cross-validation is being used.
        '''
        start_time = timer()
        train_losses, train_evals = [], []
        valid_losses, valid_evals = [], []
        best_valid_loss = float('inf')

        # the variable of cross validation
        offset = 0

        TransformerTrainer.logger.info('Start Training Process.')

        # Get the loaders if the user doesn't want to use cross validation
        if not cross_validate:
            TransformerTrainer.logger.info('Creating training and validation DataLoaders.')
            train_dl, valid_dl = self._get_loaders()
            TransformerTrainer.logger.info('Dataloaders created succesfully.')

        for epoch in range(1, epochs + 1):
            TransformerTrainer.logger.info(f'-> Epoch: {epoch}/{epochs}')

            # Get the loaders if the user want to use cross validation
            if cross_validate:
                TransformerTrainer.logger.info('    Creating training and validation DataLoaders (cross-validation step)')
                train_dl, valid_dl, offset = self._get_loaders_cv(offset)
                TransformerTrainer.logger.info('    Dataloaders created succesfully.')

            # Training and Evaluating the Model
            train_loss, train_eval = self._training_step(train_dl)
            valid_loss, valid_eval = self._validation_step(valid_dl)

            # Log the results
            TransformerTrainer.logger.info('')
            TransformerTrainer.logger.info(f'    Results (lr={self.opt.param_groups[0]["lr"]:.6f}):')
            TransformerTrainer.logger.info(f'    Train Loss:       {train_loss:.4f}')
            TransformerTrainer.logger.info(f'    Train Eval Score: {train_eval:.4f}')
            TransformerTrainer.logger.info(f'    Valid Loss:       {valid_loss:.4f}')
            TransformerTrainer.logger.info(f'    Valid Eval Score: {valid_eval:.4f}')

            train_losses.append(train_loss)
            train_evals.append(train_eval)
            valid_losses.append(valid_loss)
            valid_evals.append(valid_eval)

            # Step the scheduler
            if self.scheduler is not None:
                self.scheduler.step()
                TransformerTrainer.logger.info('')
                TransformerTrainer.logger.info(f'    Scheduling step executed succesfully (new lr={self.opt.param_groups[0]["lr"]:.6f})')

            # Save the best generator model based on loss
            if save_best and valid_loss < best_valid_loss:
                best_valid_loss = valid_loss
                save_model(self.model, f'{save_path}/{self.model.__class__.__name__}_best.pth')

            # Saving the model
            if save_per and save_path and (epoch % save_per == 0):
                save_model(self.model, f'{save_path}/{self.model.__class__.__name__}_checkpoint_{epoch}.pth')

            TransformerTrainer.logger.info(('-' * 100))

        # After training, clear the CUDA cache
        if self.device.type == 'cuda':
            torch.cuda.empty_cache()

        TransformerTrainer.logger.info('Training Process Completed Successfully.')

        return {
            'train_loss': train_losses,
            'train_eval': train_evals,
            'valid_loss': valid_losses,
            'valid_eval': valid_evals,
            'model_name': self.model.__class__.__name__,
            'loss_fn': self.criterion.__class__.__name__,
            'eval_fn': self.eval_fn.__name__,
            'optimizer': self.opt.__class__.__name__,
            'scheduler': self.scheduler.__class__.__name__,
            'device': self.device.type,
            'epochs': epochs,
            'total_time': timer() - start_time,
            'save_path': save_path,
            'cross_validate': cross_validate
        }


### Models

In [None]:
from src.utils.training import PositionalEncoding
from src.utils.models import temperature_sampling, create_tgt

import torch
import torch.nn as nn

from typing import List


class LSTMCharModel(nn.Module):
    def __init__(self,
            block_size: int,
            vocab_size: int,
            embedding_dim: int,
            hidden_dim: int,
            num_layers: int,
            dropout: int
        ) -> None:
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embedding_dim)

        self.lstm = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            dropout=dropout,
            batch_first=True,
            bidirectional=True
        )

        self.dropout = nn.Dropout(dropout)

        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(2*hidden_dim*block_size, vocab_size)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.embedding(x)
        x, _ = self.lstm(x)
        x = self.dropout(x)

        return self.fc(x)


class TransformerCharModel(nn.Module):
    def __init__(self,
            vocab_size: int,
            block_size: int,
            embedding_dim: int,
            n_head: int,
            n_encoders: int,
            n_decoders: int,
            dim_feedforward: int,
            dropout: float,
            device: torch.device
        ) -> None:
        super().__init__()
        self.block_size = block_size
        self.embedding_dim = embedding_dim
        self.vocab_size = vocab_size
        self.device = device

        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.positional_embeddings = PositionalEncoding(block_size, embedding_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embedding_dim,
            nhead=n_head,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            norm_first=True,
            activation='gelu',
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_encoders, enable_nested_tensor=False)

        decoder_layer = nn.TransformerDecoderLayer(
            d_model=embedding_dim,
            nhead=n_head,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            norm_first=True,
            activation='gelu',
            batch_first=True
        )
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=n_decoders)

        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(embedding_dim)

        self.classification_head = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=embedding_dim*block_size, out_features=128),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(in_features=128, out_features=vocab_size)
        )

        self._init_weights()

    def _init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def _generate_square_subsequent_mask(self, size: int) -> torch.Tensor:
        '''
        Create a mask for the target sequence to prevent the model from looking ahead.
        '''
        mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask
    
    def forward(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
        '''
        - The `src` provides contextual information to the encoder, which the decoder uses
            along with the partially generated `tgt` sequence to produce the next token.
        - Teacher Forcing: During training, the tgt sequence is used to guide the model in
            generating the correct sequence. This technique, known as teacher forcing, helps
            the model learn to generate sequences efficiently.
        '''
        # Get the embeddings for the input tokens
        src = self.embeddings(src) + self.positional_embeddings(src)
        tgt = self.embeddings(tgt) + self.positional_embeddings(tgt)

        # Normalization and dropout
        src = self.norm(self.dropout(src))
        tgt = self.norm(self.dropout(tgt))

        # Prevents the model from attending to future tokens in the target sequence (look-ahead mask).
        tgt_mask = self._generate_square_subsequent_mask(tgt.size(1)).to(self.device)

        # Transformer encoder and decoder
        memory = self.transformer_encoder(src)
        output = self.transformer_decoder(tgt, memory, tgt_mask=tgt_mask)

        # Normalization, dropout, and classification head
        output = self.norm(self.dropout(output))
        return self.classification_head(output)

    def generate(self, first_token: torch.Tensor, initial_tokens: torch.Tensor, max_length: int, temperature: float) -> List[int]:
        '''
        Generate a sequence of tokens using the model, starting from an initial set of tokens and generating up to a specified maximum length.

        Args:
            first_token (torch.Tensor): A tensor containing the first token to start the sequence generation.
            initial_tokens (torch.Tensor): A tensor containing the initial sequence of tokens.
            max_length (int): The maximum length of the sequence to be generated.
            temperature (float): A temperature parameter used for controlling the randomness of predictions during sampling.

        Returns:
            List[int]: A list of generated token IDs.
        '''
        # generated_tokens = [first_token.item()] + initial_tokens.tolist()
        generated_tokens = first_token.tolist()
        generated_tokens.extend(initial_tokens.squeeze(dim=0).tolist())

        initial_idx = 0
        for _ in range(max_length):
            # Get the current input sequence
            tokens_cond = initial_tokens[:, -self.block_size:]
            # tgt = create_tgt(first_token, initial_tokens)
            tgt = create_tgt(first_token, tokens_cond)

            # Forward pass through the model
            model_logits = self(tokens_cond, tgt)
            next_token = temperature_sampling(model_logits.squeeze(dim=0), temperature)

            # Append the generated token to the sequence
            generated_tokens.append(next_token)

            # Resize the next token to concatenate it to the `initial_tokens`
            next_token = torch.tensor([[next_token]]).to(self.device, non_blocking=True)
            
            # Update the sequence for the next prediction
            initial_tokens = torch.cat((initial_tokens, next_token), dim=1)
            first_token[0] = initial_tokens[:, initial_idx]
            initial_idx += 1

        return generated_tokens


## Train

In [None]:
from src import config
from src.dataset import TransformerShakespeareDataset
from src.models import TransformerCharModel
from src.training import TransformerTrainer
from src.utils import (
    configure_logger,
    accuracy_fn,
    plot_losses,
    get_device
)

from torch import nn, optim
from torch.optim.lr_scheduler import ExponentialLR

import os


# Get the logger for this module
logger = configure_logger(__name__)

# Get default device
device = get_device()


def main() -> None:
    '''Main function to train the deep learning model.'''

    # Initialize the dataset object
    dataset = TransformerShakespeareDataset(
        dataset_path=config.TRAIN_DATASETS_PATH,
        norm_dataset_path=config.NORMALIZE_TRAIN_DATA_PATH,
        constractions_path=config.CONTRACTIONS_PATH,
        block_size=config.BLOCK_SIZE,
        to_tensors=True,
        write_norm=False,
        device=device
    )

    # Saving the vocab size because we'll use it for evaluating the model
    vocab_size = dataset.vocab_size

    # Instanciate the Model
    model = TransformerCharModel(block_size=config.BLOCK_SIZE, vocab_size=vocab_size, device=device, **config.TRANSFORMER_CONFIGS).to(device, non_blocking=True)

    logger.info(f'The model is created and placed on the device: {device.type}')

    loss_fn = nn.CrossEntropyLoss()

    opt = optim.AdamW(model.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY)
    scheduler = ExponentialLR(opt, gamma=config.GAMMA)

    # Instanciate the Trainer
    trainer = TransformerTrainer(
        model = model,
        dataset = dataset,
        batch_size = config.BATCH_SIZE,
        criterion = loss_fn,
        eval_fn = accuracy_fn,
        opt = opt,
        scheduler=scheduler,
        device = device,
    )

    logger.info('The trainer is created.')

    # Train the model
    train_res = trainer.fit(
        epochs = config.EPOCHS,
        save_per = config.EPOCHS,
        save_path = config.MODELS_PATH,
        cross_validate = False,
        save_best=True
    )

    logger.info(f'Training Results: {train_res}')

    # Plot the losses and save them
    plot_losses(train_res['train_loss'], train_res['valid_loss'], save_path=os.path.join(config.PLOTS_PATH, 'losses.png'))


if __name__ == '__main__':
    main()


## Evaluate

In [None]:
from src import config
from src.dataset import TransformerShakespeareDataset
from src.models import TransformerCharModel
from src.evaluator import TransformerEvaluator
from src.utils.save import load_model
from src.utils.log import configure_logger
from src.utils import get_device

from torch import nn

import os
from typing import Dict


# Get the logger for this module
logger = configure_logger(__name__)

# Get default device
device = get_device()

# The vocab_size from training the model
VOCAB_SIZE = 72


def evaluate(model: nn.Module, dataset: TransformerShakespeareDataset, loss_fn: nn.Module, file: str) -> Dict[str, int]:
    '''
    Evaluate the model on the given dataset.

    Args:
        model (nn.Module): The model to evaluate.
        dataset (Dataset): The dataset to evaluate on.
        loss_fn (nn.Module): The loss function to use.
        file (str): Filename to save the evaluation results.

    Returns:
        Dict[str, float]: Evaluation results.
    '''
    evaluator = TransformerEvaluator(
        model = model,
        test_ds = dataset,
        cretirion = loss_fn,
        batch_size=config.BATCH_SIZE,
        device = device
    )
    logger.info('The evaluator is created.')

    eval_res = evaluator.evaluate()

    os.makedirs(config.LOGS_PATH, exist_ok=True)

    with open(os.path.join(config.LOGS_PATH, file), 'w') as f:
        f.write(str(eval_res))

    return eval_res


def main() -> None:
    '''Main function to evaluate the deep learning model.'''

    # Initialize the dataset object
    dataset = TransformerShakespeareDataset(
        dataset_path=config.TEST_DATASETS_PATH,
        norm_dataset_path=config.NORMALIZE_TEST_DATA_PATH,
        constractions_path=config.CONTRACTIONS_PATH,
        block_size=config.BLOCK_SIZE,
        to_tensors=True,
        device=device
    )

    # Instanciate the Model and load the pre-trained
    model_name = 'TransformerCharModel_checkpoint_20.pth'
    model = load_model(
        model_class=TransformerCharModel,
        model_path=os.path.join(config.MODELS_PATH, model_name),
        model_device=True,
        block_size=config.BLOCK_SIZE,
        vocab_size=VOCAB_SIZE,
        device=device,
        **config.TRANSFORMER_CONFIGS
    ).to(device, non_blocking=True)

    logger.info(f'The model is loaded and moved to device: {device.type}')

    loss_fn = nn.CrossEntropyLoss()
    
    # Evaluate the model
    results = evaluate(model, dataset, loss_fn, file='transformer_evaluate.txt')

    logger.info(f'Results: {results}')


if __name__ == '__main__':
    main()


## Generate

In [None]:
from src import config
from src.dataset import TransformerShakespeareDataset
from src.models import TransformerCharModel
from src.utils import configure_logger, get_device, load_model

import torch

import os
from typing import Tuple


# Get the logger for this module
logger = configure_logger(__name__)

# Get default device
device = get_device()


def process_input(dataset: TransformerShakespeareDataset, message: str) -> Tuple[torch.Tensor, torch.Tensor]:
    '''
    Encode the given input and get it ready to pass it to the model

    Args:
        dataset (TransformerShakespeareDataset): The dataset object in which we get the encode method
        message (str): The message to be encoded and converted to tensor

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: The input to the model (initial seed for predictions)
    '''
    encoded_message = dataset.encode(message)[:config.BLOCK_SIZE + 1]
    prev_token = torch.tensor(encoded_message[0], dtype=torch.int32).unsqueeze(dim=0).to(device)
    model_input = torch.tensor(encoded_message[1:], dtype=torch.int32).unsqueeze(dim=0).to(device)
    
    return prev_token, model_input


def main() -> None:
    # Initialize the dataset object to get encoder-decoder and vocab_size
    dataset = TransformerShakespeareDataset(
        dataset_path=config.TRAIN_DATASETS_PATH,
        norm_dataset_path=config.NORMALIZE_TRAIN_DATA_PATH,
        constractions_path=config.CONTRACTIONS_PATH,
        block_size=config.BLOCK_SIZE,
        to_tensors=True,
        write_norm=False,
        device=device
    )

    os.makedirs(config.GENERATE_PATH, exist_ok=True)

    # Create the initial seed for the prediction
    message = '''Peter, the Great Emperor

**** ACT I ****
**** SCENE I. France. In the battlefield. ****
     Enter Peter with his sword.
Peter
 '''
    first_token, initial_tokens = process_input(dataset, message)
    logger.info(f'Input has been encoded and moved to device: {device.type}')

    # Load the pre-trained model
    model_name = 'checkpoints/TransformerCharModel_checkpoint_20.pth'
    model = load_model(
        model_class=TransformerCharModel,
        model_path=model_name,
        device=device,
        model_device=True,
        block_size=config.BLOCK_SIZE,
        vocab_size=dataset.vocab_size,
        **config.TRANSFORMER_CONFIGS
    ).to(device)

    logger.info(f'The model is loaded and moved to device: {device.type}')

    # Generate predictions
    tokens = model.generate(first_token, initial_tokens, max_length=1000, temperature=1.0)
    decoded_preds = dataset.decode(tokens)
    logger.info(f'Predictions have been made and decoded succesfully, saving them into {os.path.join(config.GENERATE_PATH, 'transformer_generation.txt')}')

    # Saving predictions
    with open(os.path.join(config.GENERATE_PATH, 'generation.txt'), 'w') as f:
        f.write(decoded_preds)


if __name__ == '__main__':
    main()
