# Training GRU based Sequence-2-Sequence model for Neural Machine translation

## Importing modules

In [1]:
!pip install transformers
!pip install evaluate

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.27.1-py3-none-any.whl (6.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.7/6.7 MB[0m [31m53.0 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.11.0
  Downloading huggingface_hub-0.13.2-py3-none-any.whl (199 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.2/199.2 KB[0m [31m14.8 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m76.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.13.2 tokenizers-0.13.2 transformers-4.27.1
Looking in indexes: https://pypi.org/simple, https://us

In [2]:
import torch
from torch import nn
import torch.nn.functional as F
import random
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from transformers import BartModel
import torch
from torch import nn
from torch import optim
from torch.utils.data import random_split
import random
import numpy as np
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import evaluate

## Dowloading the dataset

In [3]:
!wget http://www.manythings.org/anki/ita-eng.zip
!unzip ita-eng.zip
!rm ita-eng.zip
!mkdir dataset
!mv ita.txt dataset

--2023-03-16 09:50:51--  http://www.manythings.org/anki/ita-eng.zip
Resolving www.manythings.org (www.manythings.org)... 173.254.30.110
Connecting to www.manythings.org (www.manythings.org)|173.254.30.110|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 7981351 (7.6M) [application/zip]
Saving to: ‘ita-eng.zip’


2023-03-16 09:50:52 (17.4 MB/s) - ‘ita-eng.zip’ saved [7981351/7981351]

Archive:  ita-eng.zip
  inflating: ita.txt                 
  inflating: _about.txt              


# Defining some settings

In [4]:
!mkdir images
!mkdir checkpoints

DIR_PATH= "."
DATASET_PATH = os.path.join(DIR_PATH, "./dataset")
IMAGE_PATH = os.path.join(DIR_PATH, "./images")
CHECKPOINT_DIR = os.path.join(DIR_PATH, "./checkpoints")

# Defining utilities

In [5]:
def count_parameters(model):
    n_params =  sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'The model has {n_params} trainable parameters')


def plot_curves(curve_1, label_1, curve_2=None, label_2=None, fig_name="figure", show=False):

    plt.plot(curve_1, label = label_1)
    if curve_2 is not None:
        plt.plot(curve_2, label = label_2)
    plt.legend()
    plt.savefig(f"{fig_name}")

    if show:
        plt.show()

    plt.clf()

    
def plot_attention_mask(attention_mask, source_tokens, target_tokens):

    skip_tokens = len(source_tokens) if "[PAD]" not in source_tokens else source_tokens.index("[PAD]")
    source_tokens = source_tokens[:skip_tokens]

    attention_mask = attention_mask.squeeze(1)

    attention_mask = attention_mask[:, :skip_tokens]

    plt.xticks(ticks=[x for x in range(len(source_tokens))], labels=source_tokens, rotation=45)
    plt.tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
    plt.yticks(ticks=[x for x in range(len(target_tokens))], labels=target_tokens)
    plt.imshow(attention_mask, cmap='gray', vmin=0, vmax=1)
    plt.show()

## Definition of the dataset class

In [6]:
class AnkiDataset(Dataset):

    def __init__(self, data_path, tokenizer_src, tokenizer_dst, src_max_length, dst_max_length) -> None:
        super().__init__()
        self.tokenizer_src = tokenizer_src
        self.tokenizer_dst = tokenizer_dst
        self.src_max_length = src_max_length
        self.dst_max_length = dst_max_length
        self.data = self.get_data(data_path)

    def __len__(self):
        return len(self.data)
    

    def __getitem__(self, index):
        
        src, dst = self.data[index]

        src = self.tokenizer_src(src, max_length=self.src_max_length, pad_to_max_length=True, truncation=True, padding="max_length", return_tensors='pt')
        dst = self.tokenizer_dst(dst, max_length=self.dst_max_length, pad_to_max_length=True, truncation=True, padding="max_length", return_tensors='pt')
            
        for key in src.keys():
            src[key] = src[key][0]
            dst[key] = dst[key][0]

        return (src, dst)
        


    '''
    Takes in input the path of the datasets and it returnes a list where each element of
    the list is a list of the elment containing the english and italian sentence
    '''
    def get_data(self, data_path="./ita.txt"):
        with open(data_path, "r") as dataset:
            sentences = [tuple(sentence.split("\t")[:2]) for sentence in dataset.readlines()]

        return sentences

## Code of the model

In [7]:
'''
Implementation of the sequence2sequence model as descriped in the paper:
Neural Machine Translation by Jointly Learning to Align and Translate
(https://arxiv.org/abs/1409.0473)
'''

'''
Encoder of the sequence to sequence model implemented using a bidirectional GRU
'''
class Encoder(nn.Module):

    '''
    - input_dim: size of the vocabulary
    - hidden_dim: size of the embedding
    - n_layers: number of layers of the GRU
    '''
    def __init__(self, vocab_dim, enc_hidden_dim, dec_hidden_dim, n_layers) -> None:
        super(Encoder, self).__init__()
        self.input_dim = vocab_dim
        self.enc_hidden_dim = enc_hidden_dim
        self.dec_hidden_dim = dec_hidden_dim
        self.n_layers = n_layers
        self.embedder = nn.Embedding(vocab_dim, enc_hidden_dim)
        self.encoder = nn.GRU(enc_hidden_dim, enc_hidden_dim, n_layers, bidirectional=True)
        self.linear = nn.Linear(enc_hidden_dim*2, dec_hidden_dim)


    '''
    Forward pass of the encoder

    Input:
    - x: a tensor of size (length, batch_size) which contains the tokenized sentences

    Output:
    - out: the output of the GRU of size (length, batch_size, dec_hidden_dim)
    - hidden: the internal state of the last GRU layer of size (length, 2*enc_hidden_dim)
    '''
    def forward(self, x):
        
        '''
        Embedder receives as input a tensor (length, batch_size) and returns
        a tensor of size (length, batch_size, enc_hidden_dim)
        '''
        x = self.embedder(x)

        '''
        The GRU receives as input a tensor of size (length, batch_size, enc_hidden_dim)
        and returns two tensors:
        - The output for each input of size (length, batch_size, 2*enc_hidden_dim)
        - The hidden state of each layer of size (2*num_layers, batch_size, enc_hidden_dim)
        '''
        out, hidden = self.encoder(x)

        '''
        We concatenate the last hidden sates of the left-to-right and the right-to-left
        layers of the GRU to get a single hidden state tensor of size (length, 2*enc_hidden_dim)
        '''
        hidden_cat = torch.cat((hidden[-2, : ,:], hidden[-1, :, :]), dim=1)

        '''
        We use a linear layer with a tanh activation function to map the last hidden state of
        size (length, 2*enc_hidden_dim) to a tensor of size (length, dec_hidden_dim)
        '''
        hidden = torch.tanh(self.linear(hidden_cat))

        '''
        We return:
        - The output of the GRU of size (length, batch_size, hidden_dim)
        - The mapping of the internal hidden state of the GRU of size (length, dec_hidden_dim)
        '''
        return out, hidden


class AttentionLayer(nn.Module):
    
    def __init__(self, enc_hidden_dim, dec_hidden_dim) -> None:
        super(AttentionLayer, self).__init__()
        self.enc_hidden_dim = enc_hidden_dim
        self.dec_hidden_dim = dec_hidden_dim
        
        self.score = nn.Linear(2*enc_hidden_dim + dec_hidden_dim, dec_hidden_dim)
        self.v = nn.Linear(dec_hidden_dim, 1, bias = False)


    '''
    Forward pass of the attention layer. The attention layers computes the energy scores e_{i,j}:
    e_{i,j} = V^T*tanh(W*[s_{i-1}|h_j])
    where s_{i-1} is the output of the encoder at time i-1 and h_j is the state of the encoder at time j
    Then it normalizes them them using the softmax to return the attention coefficients:
    alpha_{i,j} = exp(e_{i,j}) / sum_k(exp(e_{i,k}))

    Input:
    - dec_hidden: the hidden state of the decoder at time t-1
                  which is a tensor of size (batch_size, dec_hidden_dim) 
    - enc_output: the hidden states of the encoder at each time step
                  which is a tensor of size (length, batch_size, 2*enc_hidden_dim)

    Output:
    - attention: the attention scores of size (batch_size, length)
    '''
    def forward(self, dec_hidden, enc_output, mask):

        # sequence length of the source
        src_len = enc_output.shape[0]

        '''
        dec_hidden state is repeated src_len times and it 
        becomes a tensor of shape (batch_size, length, dec_hidden_dim)
        '''
        dec_hidden = dec_hidden.unsqueeze(1).repeat(1, src_len, 1)

        '''
        enc_ouput is reshaped from (length, batch_size, 2*enc_hidden_dim)
        to (batch_size, length, 2*enc_hidden_dim)
        '''
        enc_output = enc_output.permute(1, 0, 2)

        '''
        The output of the encoder and the repeated hidden state of the decoder
        to obtain a tensor of size (batch_size, length, 2*enc_hidden_dim + dec_hidden_dim)
        '''
        hidden_cat = torch.cat((dec_hidden, enc_output), dim = 2)
        
        '''
        First part of the energy computation: tanh(W*[s_{i-1}|h_j])
        which returns a tensor of size (batch_size, length, dec_hidden_dim) 
        '''
        energy = torch.tanh(self.score(hidden_cat))

        '''
        Final part of the energy computation: V^T * tanh(W*[s_{i-1}|h_j])
        It returns a tensor of size (batch_size, length, 1), therefore we use
        the squeeze to change te size to (batch_size, length)
        '''
        energy = self.v(energy).squeeze(2)

        '''

        '''
        energy = energy.masked_fill(mask == 0, -1e10)
    

        '''
        We compute the alphas applying the softmax to the energy scores
        as a tensor of size (batch_size, length). Each alpha[i, :] is 
        a tensor of size (length) which are the attention coefficients
        for the i-th sentence of the batch with respect to the current
        state of the decoder
        '''
        alpha = F.softmax(energy, dim=1)

        '''
        We reshape alpha from (batch_size, length) to (batch_size, 1, length)
        '''
        alpha = alpha.unsqueeze(1)

        '''
        We compute the attentions scores which are a tensor of size
        (batch_size, 1, 2*enc_hidden_dim)
        '''
        attention = torch.bmm(alpha, enc_output)

        '''
        We reshape attention from (batch_size, 1, 2*enc_hidden_dim) to 
        (batch_size, 2*enc_hidden_dim)
        '''
        attention = attention.squeeze(1)

        return attention, alpha.squeeze(1)



'''
Decoder of the sequence to sequence model implemented using a GRU
'''
class Decoder(nn.Module):

    def __init__(self, vocab_dim, enc_hidden_dim, dec_hidden_dim, n_layers) -> None:
        super(Decoder, self).__init__()
        self.input_dim = vocab_dim
        self.enc_hidden_dim = enc_hidden_dim
        self.dec_hidden_dim = dec_hidden_dim
        self.n_layers = n_layers

        self.embedder = nn.Embedding(vocab_dim, dec_hidden_dim)
        self.decoder = nn.GRU(2*enc_hidden_dim  + dec_hidden_dim, dec_hidden_dim, n_layers, bidirectional=False)
        self.attention = AttentionLayer(enc_hidden_dim, dec_hidden_dim)
        self.fc_out = nn.Linear(2*(enc_hidden_dim  + dec_hidden_dim), vocab_dim)


    '''
    Forward pass of the decoder of the sequence to sequence model

    Input:
    - input: the ground-truth token that should be predicted by the decoder,
      which a tensor of shape (batch_size)

    - dec_hidden: the hidden state of the decoder at time t-1
      which is a tensor of size (batch_size, dec_hidden_dim) 

    - enc_output: the hidden states of the encoder at each time step
      which is a tensor of size (length, batch_size, 2*enc_hidden_dim)

    Output:
    - logits: the logits produced by the decoder which is a tensor of size (batch_size, voc_dim)
    - hidden: the next hidden state of the decoder of size (batch_size, dec_hidden_dim) 
    '''
    def forward(self, input, dec_hidden, enc_output, mask):

        '''
        We reshape input from (batch_size) to (1, batch_size)
        '''

        '''
        We compute the word embedding of the input, which is a tensor of shape
        (batch_size, dec_hidden_dim)
        '''
        embedded = self.embedder(input)

        '''
        We compute the attention scores which are a tensor of shape (batch_size, 2*enc_hidden_dim)
        '''
        attention, alpha = self.attention(dec_hidden, enc_output, mask)
        
        '''
        We concatenate the attention tensor with the  embedded tensor to get
        a tensor of size (batch_size, 2*enc_hidden_dim  + dec_hidden_dim)
        '''
        gru_input = torch.cat((embedded, attention), dim=1)

        '''
        We reshape gru_input from (batch_size, 2*enc_hidden_dim  + dec_hidden_dim)
        to (1, batch_size, 2*enc_hidden_dim  + dec_hidden_dim)
        '''
        gru_input = gru_input.unsqueeze(0)

        '''
        We reshape dec_hidden from (batch_size, dec_hidden_dim) to (1, batch_size, dec_hidden_dim) 
        '''
        dec_hidden = dec_hidden.unsqueeze(0)


        '''
        We pass to the GRU the tensor obtained concatenating the attention score and the
        encoder output, and the previous decoder hidden state, and we get to output tensors
        both of size (1, batch_size, dec_hidden_dim)
        '''
        output, hidden = self.decoder(gru_input, dec_hidden)


        '''
        We reshape output and hidden from (1, batch_size, dec_hidden_dim)
        to (batch_size, dec_hidden_dim)
        '''
        output = output.squeeze(0)
        hidden = hidden.squeeze(0)


        '''
        We concatenate the attention scores, the output of the GRU and the embedded of the
        input sequence to get a tensor of size (batch_size, 2*enc_hidden_dim  + 2*dec_hidden_dim)
        '''
        fc_input = torch.cat((output, attention, embedded), dim=1)

        '''
        We compute the logits which is a tensor of size (batch_size, voc_dim)
        '''
        logits = self.fc_out(fc_input)

        return logits, hidden, alpha



class Seq2Seq(nn.Module):


    def __init__(self,
                 enc_vocab_dim,
                 dec_vocab_dim,
                 enc_hidden_dim,
                 dec_hidden_dim,
                 enc_n_layers,
                 dec_n_layers,
                 pad_idx,
                 start_idx,
                 end_idx,
                 teacher_forcing_ratio,
                 device
                ) -> None:
        super(Seq2Seq, self).__init__()
        self.enc_vocab_dim = enc_vocab_dim
        self.dec_vocab_dim = dec_vocab_dim
        self.enc_hidden_dim = enc_hidden_dim
        self.dec_hidden_dim = dec_hidden_dim
        self.enc_n_layers = enc_n_layers
        self.dec_n_layers = dec_n_layers
        self.teacher_forcing_ratio = teacher_forcing_ratio
        self.device = device
        self.pad_idx = pad_idx
        self.start_idx = start_idx
        self.end_idx = end_idx

        self.encoder = Encoder(enc_vocab_dim, enc_hidden_dim, dec_hidden_dim, enc_n_layers)
        self.decoder = Decoder(dec_vocab_dim, enc_hidden_dim, dec_hidden_dim, dec_n_layers)



    def create_mask(self, source):
        mask = (source != self.pad_idx).permute(1, 0)
        return mask


    '''
    Forward pass of the seq2seq model

    Input:
    - source: the input sentences, a tensor of size (src_len, batch_size)
    - target: the target sentences, a tensor of size (dst_len, batch_size)

    Output:
    - outputs: the logits for each token of each sentence,
               a tensor of size (dst_len, batch_size, dec_vocab_dim)
    '''
    def forward(self, source, target):
        
        # max sequence length of the target sentences
        target_len = target.shape[0]

        # batch size
        batch_size = target.shape[1]

        '''
        The source sentences are processed by the encoder that returns:
        - enc_output: a tensor of size (length, batch_size, 2*enc_hidden_dim)
        - hidden: a tensor of size (batch_size, dec_hidden_dim)
        '''
        enc_output, hidden = self.encoder(source)

        '''
        We prepare a tensor of size (dst_len, batch_size, dec_vocab_dim)
        that will store the output of the model
        '''
        outputs = torch.zeros(target_len, batch_size, self.dec_vocab_dim)

        '''
        We take the first token of each target sentence of the batch,
        a tensor of size (batch_size)
        '''
        target_token = target[0, :]

        mask = self.create_mask(source)
        
        # iterate over all the tokens of the target sentences
        for t in range(1, target_len):
            
            '''
            The decoder returns:
            - logits: the logits for the current target token, size (batch_size, dec_voc_dim)
            - hidden: the next hidden state of the decoder, size (batch_size, dec_hidden_dim) 
            '''
            logits, hidden, _ = self.decoder(target_token, hidden, enc_output, mask)

            outputs[t] = logits

            # decide whether to use teacher forcing or not
            teacher_force = random.uniform(0, 1) < self.teacher_forcing_ratio

            '''
            for each sentence we compute the most likely next token
            '''
            top1 = logits.argmax(1) 

            target_token = target[t, :] if teacher_force else top1

        return outputs
    

    '''
    Greedy text generation
    '''
    def generate(self, source, max_len=50):

        source = source.permute(1, 0)

        self.eval()

        src_len = source.shape[0]

        with torch.no_grad():
            enc_output, hidden = self.encoder(source)

        mask = self.create_mask(source)

        target_token = torch.LongTensor([self.start_idx])

        predicted = []

        attention_matrix = torch.zeros(max_len, 1, src_len)
        
        for i in range(max_len):
            
            with torch.no_grad():
                logits, hidden, attention = self.decoder(target_token, hidden, enc_output, mask)

            pred_token = logits.argmax(1).item()
            predicted.append(pred_token)
            attention_matrix[i] = attention

            if pred_token == self.end_idx:
                break

        return predicted, attention_matrix[:i+1]

## Defining the trainer code

In [8]:
class Trainer:

    def __init__(self, model, src_tokenizer, dst_tokenizer, config) -> None:
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = model
        self.src_tokenizer = src_tokenizer
        self.dst_tokenizer = dst_tokenizer
        self.config = config

        pad_token = dst_tokenizer.pad_token
        pad_token_idx = dst_tokenizer.convert_tokens_to_ids([pad_token])[0]
        self.criterion = nn.CrossEntropyLoss(ignore_index=pad_token_idx)
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.1)

        self.metric = evaluate.load("bleu")

        if "model_name" in config:
            self.model_name = config["model_name"]
        else:
            self.model_name = self.model.__class__.__name__.lower()


    
    def set_seeds(self, seed):
        torch.manual_seed(seed)
        random.seed(seed)
        np.random.seed(seed)


    def get_data_loader(self, batch_size, val_split=0.2, test_split=0.1):
        
        data_set = AnkiDataset(
            f"{DATASET_PATH}/ita.txt",
            self.src_tokenizer,
            self.dst_tokenizer,
            self.config["src_max_length"],
            self.config["dst_max_length"]
        )


        n = len(data_set)

        val_size = int(n*val_split)
        test_size = int(n*test_split)
        train_size = n - val_size - test_size


        train_set, val_set, test_set = random_split(data_set, [train_size, val_size, test_size])

        train_loader = DataLoader(
                    train_set,
                    batch_size = batch_size
                )
        
        val_loader = DataLoader(
                    val_set,
                    batch_size=batch_size
                )
        
        test_loader = DataLoader(
                    test_set,
                    batch_size = batch_size
                )
        
        return train_loader, val_loader, test_loader


    def generate_learning_curvers(self, train_losses, val_losses):

        plot_curves(
            curve_1=train_losses,
            curve_2=val_losses,
            label_1="Train loss",
            label_2="Validation loss",
            fig_name=f"{IMAGE_PATH}/loss_model_{self.model_name}"
        )

        plot_curves(
            curve_1=train_losses[:self.best_epoch],
            curve_2=val_losses[:self.best_epoch],
            label_1="Train loss",
            label_2="Validation loss",
            fig_name=f"{IMAGE_PATH}/best_loss_model_{self.model_name}"
        )

        plot_curves(
            curve_1=train_losses,
            label_1="Train loss",
            fig_name=f"{IMAGE_PATH}/train_loss_model_{self.model_name}"
        )

        plot_curves(
            curve_1=train_losses[:self.best_epoch],
            label_1="Train loss",
            fig_name=f"{IMAGE_PATH}/best_train_loss_model_{self.model_name}"
        )

        plot_curves(
            curve_1=val_losses,
            label_1="Val loss",
            fig_name=f"{IMAGE_PATH}/val_loss_model_{self.model_name}"
        )

        plot_curves(
            curve_1=val_losses[:self.best_epoch],
            label_1="Val loss",
            fig_name=f"{IMAGE_PATH}/best_val_loss_model_{self.model_name}"
        )


    def train(self, generate_fun):
        
        seed = self.config["seed"]
        self.set_seeds(seed)

        batch_size = self.config["batch_size"]
        # self.model.to(self.device)

        train_loader, val_loader, test_loader = self.get_data_loader(batch_size, 0.2, 0.1)

        self.train_loop(train_loader, val_loader)
        self.model.eval()
        test_loss = print("Evaluating model on the test set")
        print(f"Test loss: {test_loss}")
        self.test_step(test_loader)

        # evaluate bleu score
        train_score = self.metric_evaluation(train_loader, generate_fun)
        val_score = self.metric_evaluation(val_loader, generate_fun)
        test_score = self.metric_evaluation(test_loader, generate_fun)

        print(f"Average train set BLEU score: {train_score}")
        print(f"Average validation set BLEU score: {val_score}")
        print(f"Average test set BLEU score: {test_score}")
        




    def train_loop(self, train_loader, val_loader):

        epochs = self.config["max_epochs"]
        batch_size = self.config["batch_size"]

        train_losses = []
        val_losses = []

        best_val_loss = float("inf")
        best_loss_epoch = None

        for epoch in range(1, epochs+1):
            self.model.train()
            print(f"Training epoch {epoch}/{epochs}")
            train_loss = self.train_step(train_loader, epoch)
            self.model.eval()
            print(f"Validation epoch {epoch}/{epochs}")
            val_loss = self.val_step(val_loader, epoch)

            if val_loss < best_val_loss:
                if best_loss_epoch != None:
                    os.system(f"rm {CHECKPOINT_DIR}/model_{self.model_name}_{best_loss_epoch}_checkpoint.pt")
                best_val_loss = val_loss
                best_loss_epoch = epoch
                torch.save(self.model.state_dict(), f"{CHECKPOINT_DIR}/model_{self.model_name}_{epoch}_checkpoint.pt")

            train_losses.append(train_loss)
            val_losses.append(val_loss)

            print(f"Epoch {epoch} train loss: {train_loss}, val_loss: {val_loss}")

        self.best_epoch = best_loss_epoch

        self.generate_learning_curvers(train_losses, val_losses)
        


    def train_step(self, train_loader):

        for step, batch in enumerate(train_loader):
            self.optimizer.zero_grad()
            inputs, targets = batch
            output = self.model(inputs, targets)
            logits = output.logits

    
    def val_step(self, val_loader):

        for step, batch in enumerate(val_loader):
            inputs, targets = batch
            output = self.model(inputs, targets)
            logits = output.logits


    
    def test_step(self, test_loader):

        self.model.load_state_dict(torch.load(f"{CHECKPOINT_DIR}/model_{self.model_name}_{self.best_epoch}_checkpoint.pt"))
        
        for step, batch in enumerate(test_loader):
            inputs, targets = batch
            pred_ids = self.model.generate(inputs.input_ids)
            pred_sentences = self.dst_tokenizer.decode(pred_ids)
            target_sentences = self.dst_tokenizer.decode(targets.input_ids)
            

    

    def metric_evaluation(self, data_loader, generate_fun):
        
        self.model.load_state_dict(torch.load(f"{CHECKPOINT_DIR}/model_{self.model_name}_{self.best_epoch}_checkpoint.pt"))
        self.model.eval()


        score = 0

        for step, batch in enumerate(data_loader):

            self.optimizer.zero_grad()

            inputs, targets = batch

            for i in range(len(inputs.input_ids)):
                print(i)

                print(inputs.input_ids[i].shape)
                input_ids = inputs.input_ids[i]
                target_ids = targets.input_ids[i]

                pred_ids, attention = generate_fun(input_ids.unsqueeze(0))

                # source_tokens = self.src_tokenizer.convert_ids_to_tokens(pred_ids)
                # target_tokens = self.dst_tokenizer.convert_ids_to_tokens(target_ids)

                pred_sentence = self.src_tokenizer.decode(pred_ids, skip_special_tokens=True)
                target_sentence = self.dst_tokenizer.decode(target_ids, skip_special_tokens=True)

                result = self.metric.compute(predictions=[pred_sentence], references=[target_sentence])
                score += result["bleu"]

                break

            score /= len(data_loader)

            return score


class Seq2SeqTrainer(Trainer):

    def __init__(self, model, src_tokenizer, dst_tokenizer, config) -> None:
        super(Seq2SeqTrainer, self).__init__(model, src_tokenizer, dst_tokenizer, config)

        pad_token = dst_tokenizer.pad_token
        pad_token_idx = dst_tokenizer.convert_tokens_to_ids([pad_token])[0]
        self.criterion = nn.CrossEntropyLoss(ignore_index=pad_token_idx)
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.1)



    def train_step(self, train_loader, epoch):

        total_loss = 0
        n = len(train_loader)

        with tqdm(total=n) as pbar:
            for step, batch in enumerate(train_loader):

                self.optimizer.zero_grad()

                inputs, targets = batch

                '''
                reshape input tensors from (batch_size, length) to (length, batch_size)
                '''
                input_ids = inputs.input_ids.permute(1, 0)
                target_ids = targets.input_ids.permute(1, 0)

                output = self.model(input_ids, target_ids)

                output_dim = output.shape[-1]

                output = output[1:].view(-1, output_dim)
                target_ids = target_ids[1:].reshape(-1)

                loss = self.criterion(output, target_ids)
                
                loss.backward()

                self.optimizer.step()

                total_loss += loss.item()

                if (step+1) % 10 == 0:
                    print(f"Epoch {epoch}, samples {step+1}/{n} train loss: {total_loss/(step+1)}")

                pbar.update(1)


        avg_loss = total_loss / n

        return avg_loss
    


    def val_step(self, val_loader, epoch):
        
        total_loss = 0
        n = len(val_loader)

        with tqdm(total=n) as pbar:
            for step, batch in enumerate(val_loader):

                self.optimizer.zero_grad()

                inputs, targets = batch

                '''
                reshape input tensors from (batch_size, length) to (length, batch_size)
                '''
                input_ids = inputs.input_ids.permute(1, 0)
                target_ids = targets.input_ids.permute(1, 0)

                output = self.model(input_ids, target_ids)

                output_dim = output.shape[-1]

                output = output[1:].view(-1, output_dim)
                target_ids = target_ids[1:].reshape(-1)

                loss = self.criterion(output, target_ids)
                
                loss.backward()

                self.optimizer.step()

                total_loss += loss.item()

                if (step+1) % 10 == 0:
                    print(f"Epoch {epoch}, samples {step+1}/{n} validation loss: {total_loss/(step+1)}")

                pbar.update(1)

        avg_loss = total_loss / n

        return avg_loss
    


    def test_step(self, test_loader):
        
        total_loss = 0

        n = len(test_loader)

        with tqdm(total=n) as pbar:

            for step, batch in enumerate(test_loader):

                self.optimizer.zero_grad()

                inputs, targets = batch

                '''
                reshape input tensors from (batch_size, length) to (length, batch_size)
                '''
                input_ids = inputs.input_ids.permute(1, 0)
                target_ids = targets.input_ids.permute(1, 0)

                output = self.model(input_ids, target_ids)

                output_dim = output.shape[-1]

                output = output[1:].view(-1, output_dim)
                target_ids = target_ids[1:].reshape(-1)

                loss = self.criterion(output, target_ids)
                
                loss.backward()

                self.optimizer.step()

                total_loss += loss.item()

                pbar.update(1)

        avg_loss = total_loss / len(test_loader)

        return avg_loss

# Define and train the model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = {
    "src_max_length": 183,
    "dst_max_length": 208,
    "src_vocab_size": 31102,
    "dst_vocab_size": 28996,
    "enc_hidden_dim": 256,
    "dec_hidden_dim": 256,
    "max_epochs": 1,
    "batch_size": 8,
    "seed": 7
}

src_tokenizer = AutoTokenizer.from_pretrained("dbmdz/bert-base-italian-cased")
dst_tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')


model = Seq2Seq(
    enc_vocab_dim=config["src_vocab_size"],
    dec_vocab_dim=config["dst_vocab_size"],
    enc_hidden_dim=config["enc_hidden_dim"],
    dec_hidden_dim=config["dec_hidden_dim"],
    enc_n_layers=2,
    dec_n_layers=1,
    pad_idx=src_tokenizer.pad_token_id,
    start_idx=dst_tokenizer.sep_token_id,
    end_idx=dst_tokenizer.mask_token_id,
    teacher_forcing_ratio=0.5,
    device=device
)

trainer = Seq2SeqTrainer(model, src_tokenizer, dst_tokenizer, config)

trainer.train(lambda x: model.generate(x, max_len=100))

Downloading (…)okenizer_config.json:   0%|          | 0.00/59.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/433 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/235k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/5.94k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/3.34k [00:00<?, ?B/s]

Training epoch 1/1


  0%|          | 0/31751 [00:00<?, ?it/s]