In [None]:
import os
import re
import math
import time
import shutil
import warnings
import pickle
import dill
from collections import Counter, OrderedDict

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler
from torchtext.vocab import vocab
from torch.nn import Transformer
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm

# Common Task 1. Dataset preprocessing

### Declaring Global Variables

In [None]:
#Special tokens & coressponding ids
BOS_IDX, PAD_IDX, EOS_IDX, UNK_IDX , SEP_IDX = 0, 1, 2, 3, 4
special_symbols = ['<s>', '<pad>', '</s>', '<unk>', '<sep>'] 

# Directory where data and model checkpoints will be stored
root_dir = "./"

# Device for training (e.g., "cuda" for GPU, "cpu" for CPU)
device = "cuda"

# Epochs at which to save model checkpoints during training
save_at_epochs = []

# Total number of epochs for training
epochs = 40

# Seed for reproducibility
seed = 42

# Whether to use half precision (FP16) for training
use_half_precision = True

# Whether to shuffle training data during each epoch
train_shuffle = True

# Whether to shuffle test data
test_shuffle = False

# Batch size for training
training_batch_size = 64

# Batch size for testing
test_batch_size = 128

# Number of worker processes for data loading
num_workers = 4

# Tokenizer for text processing (e.g., "bert-base-uncased", "gpt2", etc.)
tokenizer = None

# Paths to training, testing, and validation datasets
train = None
test = None
valid = None

# Size of vocabulary for source and target sequences
src_voc_size = None
tgt_voc_size = None

# Whether to use pinned memory for data loading (faster on GPU)
pin_memory = True

# Learning rate for optimizer
optimizer_lr = 0.0001

# Gradient clipping threshold (set to -1 to disable)
clip_grad_norm = -1

# Name of the sequence-to-sequence model architecture
model_name = "seqtoseq_basic"

# Dimensionality of word embeddings
embedding_size = 512

# Dimensionality of hidden layers in the transformer model
hidden_dim = 3072

# Number of attention heads in the transformer model
nhead = 8

# Number of encoder layers in the transformer model
num_encoder_layers = 4

# Number of decoder layers in the transformer model
num_decoder_layers = 4

# Dropout probability for regularization
dropout = 0.1

# Current epoch number (used for resuming training)
curr_epoch = 0




In [None]:
#loading features data
data_directory = 'Data/Feynman_with_units'
N = 10 # number of feature rows per equation 
data = []

for filename in os.listdir(data_directory):
    if os.path.isfile(os.path.join(data_directory, filename)):
        file_path = os.path.join(data_directory, filename)
        with open(file_path, 'r', encoding='utf-8') as f:
            lines = f.read().split('\n')
            for line in lines[:N]:
                data.append((filename, line))
                
df = pd.DataFrame(data, columns=['Filename', 'features'])
del data
print(df)

In [None]:
#loading target/equation data
eq_df = pd.read_csv("Data/FeynmanEquations.csv")[['Filename','Formula']]

In [None]:
eq_df

In [None]:
#merging features & target dataframes
df = pd.merge(eq_df,df,on="Filename",how='inner').drop(columns=['Filename'])
del eq_df

## Tokenization explanation

**Input Sequence Tokenization:**
For the input sequences, tokenization is straightforward—character-level tokenization is employed, where each character is treated as a token. This approach, combined with positional embeddings, provides the model with a robust representation of the diverse combinations present within these sequences. Experimentally, I found out that character-level tokenization outperform by slight margins, subword tokenization based on n-gram frequency.


For example : The input sequence `1.70 -0.093` will be tokenized as <mark>['&lt;s>', '1', '.', '7', '0', '&lt;sep>', '-', '0', '.', '0', '9', '3', '&lt;/s>']
</mark>


**Output/Target Sequence Tokenization:**
For the output/target sequences, several steps are followed to obtain the final tokens. Firstly, the tokens are preprocessed to ensure a consistent representation of the equations. Subsequently, tokenization is performed in such a way that each variable, operator, and parenthesis is assigned a separate token. This results in final tokens that carry both physical and semantic meaning. Moreover, these tokens have been observed to perform slightly better than subword tokens based on n-gram frequency or character-level tokens.


For example : The output sequence `exp(-theta**2/2)` will be tokenized to <mark>[&lt;s>, exp, ' ', ( , ' ', -, ' ', theta, ^, 2, ' ', /, ' ', 2 , ), &lt;/s>']</mark>

<br>

### Class explanation

1. **Initialization**: The `Tokenizer` class is initialized with a list of equations (`eqns`) to tokenize. Regular expressions are defined to identify different components of the equations, including identifiers, numbers, operators, and spaces.

2. **Token Extraction (_extract)**: The `_extract` method is responsible for extracting tokens from an equation and updating a dictionary with token counts. It uses regular expressions to find all occurrences of tokens within the equation.

3. **Preprocessing (_preprocess_eqn)**: The `_preprocess_eqn` method preprocesses the equation by replacing **'*\*'** with **'^'** and adding spaces around operators. This step ensures uniformity in tokenization and helps in separating different components of the equation.

4. **Building Vocabulary (build_tgt_vocab and build_src_vocab)**: The `build_tgt_vocab` and `build_src_vocab` methods build vocabularies for target and source sequences, respectively. They iterate through each equation, preprocess it, extract tokens using regular expressions, and update an ordered dictionary with token counts. Finally, a vocabulary is created using the `torchtext.vocab.vocab` class.

5. **Source Tokenization (src_tokenize)**: The `src_tokenize` method tokenizes source equations. It replaces digits with semicolons around them, adds separators around operators, and splits the equation into tokens. Special tokens '&lt;s>' and '&lt;/s>' are added at the beginning and end of the token list, respectively.

6. **Target Tokenization (tgt_tokenize)**: The `tgt_tokenize` method tokenizes target equations. It adds separators around identifiers and operators, splits the equation into tokens, & adds special tokens '&lt;s>' and '&lt;/s>' at the beginning and end of the token list, respectively.

7. **Source Decoding (src_decode)**: The `src_decode` method decodes tokens of a source equation into a string representation. It removes special tokens '&lt;s>' and '&lt;/s>', replaces '&lt;sep>' with a space, and removes any remaining semicolons. Additionally, it removes spaces after minus signs.

8. **Target Decoding (tgt_decode)**: The `tgt_decode` method decodes tokens of a target equation into a string representation. It removes special tokens '&lt;s>' and '&lt;/s>' and replaces '&lt;sep>' with a space.

In [None]:
class Tokenizer:
    """
    Class for tokenizing equations and building vocabularies for source and target sequences.
    """

    def __init__(self, eqns):
        """
        Initializes the Tokenizer with equations.
        Args:
            eqns (list): List of equations to tokenize.
        """
        self.eqns = eqns
        self.identifier = r'[a-zA-Z_][a-zA-Z_0-9]*'
        self.number = r'[0-9]+(?:\.[0-9]*)?'
        self.operator = r'\^|[-+*/=<>()]'
        self.space = r'[ \t]+'

    @staticmethod
    def _extract(eqn, ord_dict, pattern):
        """
        Extracts tokens from the equation and updates the dictionary with token counts.
        """
        elems = pattern.findall(eqn)
        elem_counts = Counter(elems)
        for key, value in elem_counts.items():
            if key not in ord_dict:
                ord_dict[key] = value
            else:
                ord_dict[key] += value
        return ord_dict

    @staticmethod
    def add_separators(match):
        """
        Adds separators around operators in the equation for easier splitting.
        """
        return f";{match.group(0)};"

    @staticmethod
    def preprocess_eqn(data):
        """
        Preprocesses the equation by replacing '**' with '^' and adding spaces around operators.
        """
        data = re.sub(r'\*\*', '^', data)
        for r in (('{', ' {'), ('}', '}'), (' + ', ' + '), ('*', ' * '), ('-', ' - '), ('+', ' + '),
                  ('^', '^'), ('(', ' ('), (')', ')'), ('/', ' / '), ('  ', ' '), (' - ', ' - '), ('( (', '((')):
            data = data.replace(*r)
        return data.strip()

    def build_tgt_vocab(self):
        """
        Builds vocabulary for target sequences.
        """
        ordered_dict = OrderedDict()
        exps = [self.space, self.identifier, self.number, self.operator]
        for eqn in self.eqns:
            eqn = self.preprocess_eqn(eqn)
            for exp in exps:
                ordered_dict = self._extract(eqn, ordered_dict, re.compile(exp))
        voc = vocab(ordered_dict, specials=special_symbols[:-1], special_first=True)
        voc.set_default_index(UNK_IDX)
        return voc

    def build_src_vocab(self):
        """
        Builds vocabulary for source sequences.
        """
        ordered_dict = OrderedDict()
        for i in range(10):
            ordered_dict[str(i)] = 1
        ordered_dict['-'] = 1
        ordered_dict['.'] = 1
        voc = vocab(ordered_dict, specials=special_symbols, special_first=True)
        voc.set_default_index(UNK_IDX)
        return voc

    def src_tokenize(self, expr):
        """
        Tokenizes source equations.
        """
        pattern = r'\d'
        expr = re.sub(pattern, r';\g<0>;', expr)
        expr = re.sub(r';{2,}', ';', expr)  # Replace multiple semicolons with single semicolon
        expr = re.sub(r'\s+', '<sep>', expr)
        expr = re.sub('-', ';-;', expr)
        expr = re.sub(r';{2,}', ';', expr)
        toks = expr.split(';')
        toks[0] = '<s>'
        toks[-1] = '</s>'
        return toks

    def tgt_tokenize(self, expr):
        """
        Tokenizes target equations.
        """
        exps = [self.space, self.identifier, self.operator]
        expr = self.preprocess_eqn(expr)
        for exp in exps:
            expr = re.sub(exp, self.add_separators, expr)
        expr = re.sub(r';{2,}', ';', expr)  # Replace multiple semicolons with single semicolon
        toks = expr.split(';')
        toks[0] = '<s>'
        toks[-1] = '</s>'
        return toks

    def src_decode(self, tokens):
        """
        Decodes source tokens into equation string.
        """
        expr = ' '.join(tokens[1:-1])  # Exclude <s> and </s> tokens
        expr = expr.replace('<sep>', ' ')
        expr = expr.replace(';', '')  # Remove any remaining semicolons
        expr = expr.replace('- ', '-')  # Remove space after minus sign
        return expr

    def tgt_decode(self, tokens):
        """
        Decodes target tokens into equation string.
        """
        expr = ''.join(tokens[1:-1])  # Exclude <s> and </s> tokens
        expr = expr.replace('<sep>', ' ')
        return expr
    

In [None]:
fyn_eqs = df.Formula.tolist()

df_train = pd.DataFrame()
df_test = pd.DataFrame()
df_valid = pd.DataFrame()

for i in range(len(set(fyn_eqs))):
    dat = df.iloc[i * N: N * (i + 1)].sample(frac=1, random_state=seed)
    total_len = len(dat)
    train_len = int(0.9 * total_len)
    test_len = int(0.05 * total_len)  # Remaining 5% for test and valid splits
    valid_len = total_len - train_len - test_len
    
    df_train = pd.concat([df_train, dat.iloc[:train_len]])
    df_test = pd.concat([df_test, dat.iloc[train_len:train_len + test_len]])
    df_valid = pd.concat([df_valid, dat.iloc[train_len + test_len:]])

# Delete unnecessary variable
del dat

# Assign data to splits
df_train.reset_index(drop=True,inplace=True)
df_test.reset_index(drop=True,inplace=True)
df_valid.reset_index(drop=True,inplace=True)


# Defining target & source vocabulary sizes & id-to-string mapping for target sequence
tokenizer = Tokenizer(fyn_eqs)
v = tokenizer.build_tgt_vocab()
itos = {value: key for key, value in v.get_stoi().items()}
src_voc_size = len(tokenizer.build_src_vocab())
tgt_voc_size = len(tokenizer.build_tgt_vocab())

# Delete the original DataFrame
del df, fyn_eqs, v

In [None]:
# # Uncomment & execute to save tokenzier

# #Saving our tokenizer
# with open('tokenizer.pkl', 'wb') as f:
#     dill.dump(tokenizer, f)

# Common Task 2: Train/Evaluate Transformer model

## Defining helper functions

In [None]:
def generate_eqn_mask(n: int, device: torch.device) -> torch.Tensor:
    """
    Generate an equation mask for the Transformer model.

    Args:
        n (int): The size of the mask.
        device (torch.device): The device on which to create the mask.

    Returns:
        torch.Tensor: The equation mask.
    """
    mask = (torch.triu(torch.ones((n, n), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

In [None]:
def create_mask(src: torch.Tensor, tgt: torch.Tensor, device: torch.device) -> tuple:
    """
    Create masks for source and target sequences.

    Args:
        src (torch.Tensor): Source sequence.
        tgt (torch.Tensor): Target sequence.
        device (torch.device): Device on which to create the masks.

    Returns:
        tuple: Tuple containing four masks: source mask, target mask, source padding mask, target padding mask.
    """
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    # Generate equation mask for target sequence
    tgt_mask = generate_eqn_mask(tgt_seq_len, device)
    
    # Create source mask
    src_mask = torch.zeros((src_seq_len, src_seq_len), device=device).type(torch.bool)

    # Create source and target padding masks
    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

In [None]:
def collate_fn(batch: list) -> tuple:
    """
    Collate function for batching sequences.

    Args:
        batch (list): List of tuples containing source and target sequences.

    Returns:
        tuple: Tuple containing padded source batch and padded target batch.
    """
    src_batch, tgt_batch = [], []
    for (src_sample, tgt_sample) in batch:
        src_batch.append(src_sample)
        tgt_batch.append(tgt_sample)

    # Pad sequences in the batch
    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
    
    return src_batch, tgt_batch

In [None]:
def sequence_accuracy(load_best=True, epochs=None):
    """
    Calculate the sequence accuracy.

    Args:
        load_best (bool, optional): Whether to load the best model. Defaults to True.
        epochs (int, optional): Number of epochs. Defaults to None.

    Returns:
        float: Sequence accuracy.
    """
    predictor = Predictor(load_best, epochs)
    count = 0
    random_df = df_test.sample(frac=0.1, random_state=seed)
    length = len(random_df)
    pbar = tqdm(range(length))
    pbar.set_description("Seq_Acc_Cal")
    for i in pbar:
        original_tokens, predicted_tokens = predictor.predict(random_df.iloc[i], raw_tokens=True)
        original_tokens = original_tokens.tolist()
        predicted_tokens = predicted_tokens.tolist()
        if original_tokens == predicted_tokens:
            count = count + 1
        pbar.set_postfix(seq_accuracy=count / (i + 1))
    return count / length

## Defining required classes

In [None]:
# Helper modules for network architecture

class PositionalEncoding(nn.Module):
    """
    Positional encoding module for transformer architectures.

    Args:
        emb_size (int): The embedding size.
        dropout (float): Dropout rate.
        maxlen (int, optional): Maximum sequence length. Defaults to 5000.
    """

    def __init__(self, emb_size: int, dropout: float, maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(-torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

class TokenEmbedding(nn.Module):
    """
    Token embedding module for transformer architectures.

    Args:
        vocab_size (int): Size of the vocabulary.
        emb_size (int): The embedding size.
    """

    def __init__(self, vocab_size: int, emb_size: int):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

In [None]:
class Dataset(Dataset):
    """
    Custom PyTorch dataset for handling data.

    Args:
        df (DataFrame): DataFrame containing data.
    """

    def __init__(self, df):
        super(Dataset, self).__init__()
        self.tgt_vals = df['Formula']
        self.src_vals = df['features']
        self.tgt_vocab = tokenizer.build_tgt_vocab()
        self.src_vocab = tokenizer.build_src_vocab()
        self.tgt_tokenize = tokenizer.tgt_tokenize
        self.src_tokenize = tokenizer.src_tokenize

    def __len__(self):
        """
        Get the length of the dataset.

        Returns:
            int: Length of the dataset.
        """
        return len(self.src_vals)

    def __getitem__(self, idx):
        """
        Get an item from the dataset at the specified index.

        Args:
            idx (int): Index of the item.

        Returns:
            tuple: Tuple containing source and target tensors.
        """
        src_tokenized = self.src_tokenize(self.src_vals[idx])
        tgt_tokenized = self.tgt_tokenize(self.tgt_vals[idx])
        src_ids = self.src_vocab(src_tokenized)
        tgt_ids = self.tgt_vocab(tgt_tokenized)

        src_tensor = torch.tensor(src_ids, dtype=torch.long)
        tgt_tensor = torch.tensor(tgt_ids, dtype=torch.long)

        return src_tensor, tgt_tensor

    @staticmethod
    def get_data():
        """
        Create datasets (train, test, and valid)

        Returns:
            dict: Dictionary containing train, test, and valid datasets.
        """
        train = Dataset(df_train)
        test = Dataset(df_test)
        valid = Dataset(df_valid)

        return {'train': train, 'test': test, 'valid': valid}

In [None]:
class Model(nn.Module):
    """
    Transformer-based model for sequence-to-sequence tasks.

    Args:
        num_encoder_layers (int): Number of encoder layers.
        num_decoder_layers (int): Number of decoder layers.
        emb_size (int): Size of the embedding.
        nhead (int): Number of attention heads.
        src_vocab_size (int): Size of the source vocabulary.
        tgt_vocab_size (int): Size of the target vocabulary.
        dim_feedforward (int, optional): Dimension of the feedforward network. Defaults to 512.
        dropout (float, optional): Dropout rate. Defaults to 0.1.
    """

    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1):
        super(Model, self).__init__()
        self.transformer = Transformer(
            d_model=emb_size,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout
        )
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout)

    def forward(self,
                src: Tensor,
                trg: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):
        """
        Forward pass of the model.

        Args:
            src (Tensor): Source input.
            trg (Tensor): Target input.
            src_mask (Tensor): Mask for source input.
            tgt_mask (Tensor): Mask for target input.
            src_padding_mask (Tensor): Padding mask for source input.
            tgt_padding_mask (Tensor): Padding mask for target input.
            memory_key_padding_mask (Tensor): Padding mask for memory.

        Returns:
            Tensor: Output tensor.
        """
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        outs = self.transformer(
            src_emb, tgt_emb, src_mask, tgt_mask, None,
            src_padding_mask, tgt_padding_mask, memory_key_padding_mask
        )
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        """
        Encode the source input.

        Args:
            src (Tensor): Source input.
            src_mask (Tensor): Mask for source input.

        Returns:
            Tensor: Encoded tensor.
        """
        return self.transformer.encoder(self.positional_encoding(self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        """
        Decode the target input.

        Args:
            tgt (Tensor): Target input.
            memory (Tensor): Memory tensor.
            tgt_mask (Tensor): Mask for target input.

        Returns:
            Tensor: Decoded tensor.
        """
        return self.transformer.decoder(self.positional_encoding(self.tgt_tok_emb(tgt)), memory, tgt_mask)

In [None]:
def get_model():
    """
    Function to instantiate a Model object and initialize its parameters using 
    previously defined global variables.

    Returns:
        Model: Initialized model object.
    """
    model = Model(num_encoder_layers, num_decoder_layers, embedding_size,
                  nhead, src_voc_size, tgt_voc_size, hidden_dim, dropout)

    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    return model


## Designing the Prediction & Training classes

In [None]:
class Predictor():
    """
    Class for generating predictions using a trained model.

    Args:
        device (str): Device to use for inference.
        epoch (int): Epoch number.

    Attributes:
        model (Model): Trained model for prediction.
        path (str): Path to the trained model.
        device (str): Device for inference.
        df (DataFrame): DataFrame containing training data.
        vocab (dict): Vocabulary for tokenization.
        attrs (list): List of attributes in the dataset.
        checkpoint (str): model checkpoint path
    """

    def __init__(self,load_best=True,epochs=None):
        self.model = get_model()
        self.checkpoint = f"model_best_tr.pth" if load_best else f"model_ep{epochs+1}.pth"
        self.path = os.path.join(root_dir,self.checkpoint)
        self.device = device
        self.df = train
        state = torch.load(self.path, map_location=self.device)
        self.model.load_state_dict(state['state_dict'])
        self.model.to(self.device)
        self.vocab = {}
        self.attrs = ['features', 'Formula']
        
        self.vocab[self.attrs[0]] = tokenizer.build_src_vocab()
        self.vocab[self.attrs[1]] = tokenizer.build_tgt_vocab()

    def tok_to_id(self, tokenize, vocab, val):
        """
        Convert tokens to token IDs using the provided tokenizer and vocabulary.

        Args:
            tokenize (function): Tokenization function.
            vocab (function): Vocabulary function.
            val (str): Input string.

        Returns:
            Tensor: Token IDs.
        """
        val = tokenize(val)
        token_ids = vocab(val)
        return torch.tensor(token_ids, dtype=torch.int)

    def greedy_decode(self, src, src_mask, max_len, start_symbol):
        """
        Generate a sequence using greedy decoding.

        Args:
            src (Tensor): Source input.
            src_mask (Tensor): Mask for source input.
            max_len (int): Maximum length of the generated sequence.
            start_symbol (int): Start symbol for decoding.

        Returns:
            Tensor: Generated sequence.
        """
        src = src.to(self.device)
        src_mask = src_mask.to(self.device)
        dim = 1

        memory = self.model.encode(src, src_mask)
        memory = memory.to(self.device)
        dim = 0
        ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(self.device)
        for i in range(max_len - 1):
            tgt_mask = (generate_eqn_mask(ys.size(0), self.device).type(torch.bool)).to(self.device)
            out = self.model.decode(ys, memory, tgt_mask)
            out = out.transpose(0, 1)
            prob = self.model.generator(out[:, -1])

            _, next_word = torch.max(prob, dim=1)
            next_word = next_word.item()

            ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=dim)
            if next_word == EOS_IDX:
                break
        return ys

    def predict(self, test_example, raw_tokens=False):
        """
        Generate prediction for a test example.

        Args:
            test_example (dict): Test example containing input features.
            raw_tokens (bool, optional): Whether to return raw tokens. Defaults to False.

        Returns:
            str or tuple: Decoded equation or tuple of original and predicted tokens.
        """
        self.model.eval()
        src_sentence = test_example[self.attrs[0]]

        src = self.tok_to_id(tokenizer.src_tokenize, self.vocab[self.attrs[0]], src_sentence).view(-1, 1)
        num_tokens = src.shape[0]

        src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
        tgt_tokens = self.greedy_decode(src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()

        if raw_tokens:
            original_sentence = test_example[self.attrs[1]]
            original_tokens = self.tok_to_id(tokenizer.tgt_tokenize, self.vocab[self.attrs[1]], original_sentence)
            return original_tokens, tgt_tokens

        decoded_eqn = ''
        for t in tgt_tokens:
            decoded_eqn += itos[int(t)]

        return decoded_eqn


In [None]:
class Trainer():
    """
    Class for training a sequence-to-sequence model.

    Args:
        start_epoch (int, optional): Starting epoch number. Defaults to 0.

    Attributes:
        scaler (GradScaler): Gradient scaler for half-precision training.
        dtype (torch.dtype): Data type for training.
        dataloaders (dict): Dataloaders for train, validation, and test datasets.
        root_dir (str): Root directory for saving models and logs.
        device (str): Device for training.
        current_epoch (int): Current epoch number.
        best_val_loss (float): Best validation loss.
        train_loss_list (list): List of training losses.
        valid_loss_list (list): List of validation losses.
        valid_accuracy_tok_list (list): List of validation token accuracies.
        model (Model): Model for training.
        optimizer (Optimizer): Optimizer for training.
        scheduler (Scheduler): Learning rate scheduler.
        resume_best (bool): Whether to resume from the last best saved model
        save_freq (int): Frequency of saving in terms of epochs
        save_last (bool): Whether to save model after complete training
        
    """

    def __init__(self, resume_best=False, save_freq=None,save_last = True, start_epoch=0):

        # For half precision training
        self.scaler = GradScaler()
        if use_half_precision:
            self.dtype = torch.float16
        else:
            self.dtype = torch.float32

        self.dataloaders = self._prepare_dataloaders()
        self.root_dir = root_dir
        self.device = device
        self.current_epoch = start_epoch
        self.best_val_loss = 1e6
        self.train_loss_list = []
        self.valid_loss_list = []
        self.valid_accuracy_tok_list = []
        self.model = self._prepare_model()
        self.optimizer = self._prepare_optimizer()
        self.scheduler = self._prepare_scheduler()
        self.save_freq = save_freq
        self.resume_best = resume_best
        self.save_last = save_last

    def criterion(self, y_pred, y_true):
        """
        Calculate the loss between predicted and true values.

        Args:
            y_pred (Tensor): Predicted values.
            y_true (Tensor): True values.

        Returns:
            Tensor: Loss value.
        """
        loss_fn = torch.nn.CrossEntropyLoss()
        return loss_fn(y_pred, y_true)

    def __call__(self, x):
        """
        Make a forward pass through the model.

        Args:
            x (Tensor): Input tensor.

        Returns:
            Tensor: Output tensor.
        """
        return self.model(x)

    def _prepare_model(self):
        """
        Initialize and prepare the model for training.

        Returns:
            Model: Initialized model.
        """
        model = get_model()
        model.to(self.device)
        return model

    def _prepare_optimizer(self):
        """
        Initialize the optimizer.

        Returns:
            Optimizer: Initialized optimizer.
        """
        param_optimizer = list(self.model.parameters())
        optimizer = torch.optim.Adam(param_optimizer, lr=optimizer_lr, eps=1e-9)
        return optimizer

    def _prepare_scheduler(self):
        """
        Initialize the learning rate scheduler.

        Returns:
            Scheduler: Initialized scheduler.
        """
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', patience=2)
        return scheduler

    def _prepare_dataloaders(self):
        """
        Prepare dataloaders for training, validation, and testing.

        Returns:
            dict: Dictionary containing train, validation, and test dataloaders.
        """
        datasets = Dataset.get_data()
        train_loader = torch.utils.data.DataLoader(datasets['train'], batch_size=training_batch_size,
                                                   shuffle=train_shuffle, num_workers=num_workers,
                                                   pin_memory=pin_memory, collate_fn=collate_fn)

        dataloaders = {
            'train': train_loader,
            'valid': torch.utils.data.DataLoader(datasets['valid'],
                                                  batch_size=test_batch_size, shuffle=test_shuffle,
                                                  num_workers=num_workers, pin_memory=pin_memory, collate_fn=collate_fn),
            'test': torch.utils.data.DataLoader(datasets['test'],
                                                 batch_size=test_batch_size, shuffle=test_shuffle,
                                                 num_workers=num_workers, pin_memory=pin_memory, collate_fn=collate_fn),
        }
        return dataloaders

    def load_model(self, resume=False, epoch=None):
        """
        Load the most recent model checkpoint.

        Args:
            resume (bool, optional): Whether to resume training. Defaults to False.
            epoch (int, optional): Load model from a particular epoch
        """
        checkpoint_name = f"model_best_tr.pth" if epoch is None else f"model_ep{epoch+1}.pth"
        file = os.path.join(self.root_dir, checkpoint_name)
        state = torch.load(file, map_location=self.device)
        self.model.load_state_dict(state['state_dict'])
        if resume:
            self.optimizer.load_state_dict(state['optimizer'])

    def _train_epoch(self):
        """
        Perform a single training epoch.

        Returns:
            float: Average training loss for the epoch.
        """
        self.model.train()
        pbar = tqdm(self.dataloaders['train'], total=len(self.dataloaders['train']))
        pbar.set_description(f"[{self.current_epoch+1}/{epochs}] Train")
        running_loss = 0.0
        total_samples = 0

        for src, tgt in pbar:
            src = src.to(self.device)
            tgt = tgt.to(self.device)
            bs = src.size(1)

            with torch.autocast(device_type='cuda', dtype=self.dtype):
                src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt[:-1, :], self.device)

                logits = self.model(src, tgt[:-1, :], src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)

                # Calculate loss
                loss = self.criterion(logits.reshape(-1, logits.shape[-1]), tgt[1:, :].reshape(-1))

            running_loss += loss.item() * bs
            total_samples += bs
            avg_loss = running_loss / total_samples
            pbar.set_postfix(loss=avg_loss)

            # Backward
            self.optimizer.zero_grad()
            self.scaler.scale(loss).backward()
            if clip_grad_norm > 0:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), clip_grad_norm)
            self.scaler.step(self.optimizer)
            self.scaler.update()

        return avg_loss

    def evaluate(self, phase):
        """
        Evaluate the model on the validation or test set.

        Args:
            phase (str): Phase of evaluation ('valid' or 'test').

        Returns:
            tuple: Tuple containing average token accuracy and average loss.
        """
        self.model.eval()
        pbar = tqdm(self.dataloaders[phase], total=len(self.dataloaders[phase]))
        pbar.set_description(f"[{self.current_epoch+1}/{epochs}] {phase.capitalize()}")
        running_loss = 0.0
        total_samples = 0
        running_acc_tok = 0.0

        with torch.no_grad():
            for src, tgt in pbar:
                src = src.to(self.device)
                tgt = tgt.to(self.device)
                bs = src.size(1)

                src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt[:-1, :], self.device)
                logits = self.model(src, tgt[:-1, :], src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)
                loss = self.criterion(logits.reshape(-1, logits.shape[-1]), tgt[1:, :].reshape(-1))

                running_loss += loss.item() * bs
                total_samples += bs
                avg_loss = running_loss / total_samples

        return avg_loss

    def _save_model(self, checkpoint_name):
        """
        Save the model checkpoint.

        Args:
            checkpoint_name (str): Name of the checkpoint file.
        """
        state_dict = self.model.state_dict()
        torch.save({
            "epoch": self.current_epoch + 1,
            "state_dict": state_dict,
            'optimizer': self.optimizer.state_dict(),
            "train_loss_list": self.train_loss_list,
            "valid_loss_list": self.valid_loss_list,
        }, os.path.join(self.root_dir, checkpoint_name))

    def _test_seq_acc(self,load_best=True,epochs=None):
        """
        Test sequence accuracy and save results to a file.
        """
        self.device = 'cuda'
        self.load_model()
        test_accuracy_seq = sequence_accuracy(load_best,epochs)
        print(f"Test Accuracy: {round(test_accuracy_seq, 4)}")

    def fit(self):
        """
        Train the model.
        """
        if self.resume_best:
            self.load_model(resume=True)
        for self.current_epoch in range(self.current_epoch, epochs):
            training_loss = self._train_epoch()
            valid_loss = self.evaluate("valid")

            self.train_loss_list.append(round(training_loss, 4))
            self.valid_loss_list.append(round(valid_loss, 4))
            if self.save_freq:
                if self.current_epoch % self.save_freq == 0:
                    self._save_model(f"model_ep{self.current_epoch + 1 }.pth")
            
            if valid_loss <= self.best_val_loss:
                self.best_val_loss = valid_loss
                self._save_model(f"model_best_tr.pth")
                self._test_seq_acc()

            print(f"Epoch {self.current_epoch + 1}/{epochs}, "
                  f"Training Loss: {training_loss:.4f}, "
                  f"Validation Loss: {valid_loss:.4f}, ")

        if self.save_last:
            self._save_model(f"model_ep{self.current_epoch + 1 }.pth")
        self._test_seq_acc(load_best=False,epochs=self.current_epoch)

In [None]:
trainer = Trainer()

In [None]:
trainer.fit()