# Training prompt tuned T5 model for Neural Machine translation

## Importing modules

In [34]:
!pip install transformers
!pip install evaluate

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [35]:
import torch
from torch import nn
import torch.nn.functional as F
import random
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from transformers import BartModel
import torch
from torch import nn
from torch import optim
from torch.utils.data import random_split
import random
import numpy as np
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import evaluate
from transformers import T5Model, BartModel, T5ForConditionalGeneration
import torch
from torch import nn
import torch.nn.functional as F
from transformers import T5Model
from transformers.utils import logging
logger = logging.get_logger(__name__)

## Dowloading the dataset

In [36]:
!wget http://www.manythings.org/anki/ita-eng.zip
!unzip ita-eng.zip
!rm ita-eng.zip
!mkdir dataset
!mv ita.txt dataset

--2023-03-26 13:12:35--  http://www.manythings.org/anki/ita-eng.zip
Resolving www.manythings.org (www.manythings.org)... 173.254.30.110
Connecting to www.manythings.org (www.manythings.org)|173.254.30.110|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 7981351 (7.6M) [application/zip]
Saving to: ‘ita-eng.zip’


2023-03-26 13:12:36 (11.6 MB/s) - ‘ita-eng.zip’ saved [7981351/7981351]

Archive:  ita-eng.zip
  inflating: ita.txt                 
replace _about.txt? [y]es, [n]o, [A]ll, [N]one, [r]ename: A
  inflating: _about.txt              
mkdir: cannot create directory ‘dataset’: File exists


## Defining some settings

In [37]:
!mkdir images
!mkdir checkpoints

DIR_PATH= "."
DATASET_PATH = os.path.join(DIR_PATH, "./dataset")
IMAGE_PATH = os.path.join(DIR_PATH, "./images")
CHECKPOINT_DIR = os.path.join(DIR_PATH, "./checkpoints")

mkdir: cannot create directory ‘images’: File exists
mkdir: cannot create directory ‘checkpoints’: File exists


## Defining utilities

In [38]:
def count_parameters(model):
    n_params =  sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'The model has {n_params} trainable parameters')


def plot_curves(curve_1, label_1, curve_2=None, label_2=None, fig_name="figure", show=False):

    plt.plot(curve_1, label = label_1)
    if curve_2 is not None:
        plt.plot(curve_2, label = label_2)
    plt.legend()
    plt.savefig(f"{fig_name}")

    if show:
        plt.show()

    plt.clf()

    
def plot_attention_mask(attention_mask, source_tokens, target_tokens):

    skip_tokens = len(source_tokens) if "[PAD]" not in source_tokens else source_tokens.index("[PAD]")
    source_tokens = source_tokens[:skip_tokens]

    attention_mask = attention_mask.squeeze(1)

    attention_mask = attention_mask[:, :skip_tokens]

    plt.xticks(ticks=[x for x in range(len(source_tokens))], labels=source_tokens, rotation=45)
    plt.tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
    plt.yticks(ticks=[x for x in range(len(target_tokens))], labels=target_tokens)
    plt.imshow(attention_mask, cmap='gray', vmin=0, vmax=1)
    plt.show()

## Definition of the dataset class

In [39]:
class AnkiDataset(Dataset):

    def __init__(self,
                 data_path,
                 tokenizer_src,
                 tokenizer_dst,
                 src_max_length,
                 dst_max_length,
                 subsample=False,
                 frac=1.0,
                 seed=42
                ) -> None:
        super().__init__()
        self.tokenizer_src = tokenizer_src
        self.tokenizer_dst = tokenizer_dst
        self.src_max_length = src_max_length
        self.dst_max_length = dst_max_length
        self.seed = seed
        self.frac = frac
        self.subsample = subsample
        random.seed(self.seed)
        self.data = self.get_data(data_path)


    def __len__(self):
        return len(self.data)
    

    def __getitem__(self, index):
        
        src, dst = self.data[index]

        src = self.tokenizer_src(src, max_length=self.src_max_length, pad_to_max_length=True, truncation=True, padding="max_length", return_tensors='pt')
        dst = self.tokenizer_dst(dst, max_length=self.dst_max_length, pad_to_max_length=True, truncation=True, padding="max_length", return_tensors='pt')
            
        for key in src.keys():
            src[key] = src[key][0]
            dst[key] = dst[key][0]

        return (src, dst)
        


    '''
    Takes in input the path of the datasets and it returnes a list where each element of
    the list is a list of the elment containing the english and italian sentence
    '''
    def get_data(self, data_path="./../dataset/ita.txt"):

        with open(data_path, "r") as dataset:
            sentences = [tuple(sentence.split("\t")[:2]) for sentence in dataset.readlines()]
            
        if self.subsample == True:
            k = int(len(sentences)*self.frac)
            sentences = random.sample(sentences, k)

        return sentences

## Code of the model

In [40]:
'''
super class that defines the behavior of the T5 model with the soft-prompts
'''
class T5PromptTuningMixin:

    '''
    wrapper of the from_pretrained class method to include the loading of the soft-prompts
    '''
    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path,
        encoder_soft_prompt_path = None,
        decoder_soft_prompt_path = None,
        encoder_n_tokens = None,
        decoder_n_tokens = None,
        encoder_hidden_dim = None,
        decoder_hidden_dim = None,
        initialize_from_vocab = True,
        random_range = 0.5,
        device=None,
        **kwargs,
    ):
        
        # getting the T5 model
        model = super().from_pretrained(pretrained_model_name_or_path, **kwargs)

        # freezing all the parameters of the pretrained T5 model
        for param in model.parameters():
            param.requires_grad = False

        '''
        load the encoder soft prompts if the path is provided otheriwise they
        are randomly initialized
        '''
        if encoder_soft_prompt_path is not None:
            model.set_encoder_soft_prompts(encoder_soft_prompt_path)
        else:
            model.initialize_encoder_soft_prompts(encoder_n_tokens, encoder_hidden_dim, random_range)

        '''
        load the encoder soft prompts if the path is provided otheriwise they
        are randomly initialized
        '''
        if decoder_soft_prompt_path is not None:
            model.set_decoder_soft_prompts(decoder_soft_prompt_path)
        else:
            model.initialize_decoder_soft_prompts(decoder_n_tokens, decoder_hidden_dim, random_range)

        model.encoder_n_tokens = encoder_n_tokens
        model.decoder_n_tokens = decoder_n_tokens

        enc_emb_size = model.encoder.get_input_embeddings().weight.shape[1]
        dec_emb_size = model.decoder.get_input_embeddings().weight.shape[1]

        encoder_emb_generator = nn.Sequential(
            nn.Linear(encoder_hidden_dim, encoder_hidden_dim),
            nn.ReLU(),
            nn.Linear(encoder_hidden_dim, enc_emb_size),
            nn.Tanh()
        )

        model.encoder_emb_generator = encoder_emb_generator

        decoder_emb_generator = nn.Sequential(
            nn.Linear(decoder_hidden_dim, decoder_hidden_dim),
            nn.ReLU(),
            nn.Linear(decoder_hidden_dim, dec_emb_size),
            nn.Tanh()
        ) 

        model.decoder_emb_generator = decoder_emb_generator

        model.encoder_input_tokens = torch.arange(encoder_n_tokens).long().to(device)
        model.decoder_input_tokens = torch.arange(decoder_n_tokens).long().to(device)

        return model
    

    def initialize_encoder_soft_prompts(self, n_tokens, hidden_dim, random_range=0.5):
        self.n_tokens = n_tokens
        self.encoder_soft_prompt = nn.Embedding(n_tokens, hidden_dim)
        # init_prompt_value = torch.FloatTensor(2, 10).uniform_(-random_range, random_range)
        # self.encoder_soft_prompt.weight = nn.parameter.Parameter(init_prompt_value)


    def set_encoder_soft_prompts(self, soft_prompt_path):
        self.encoder_soft_prompt = torch.load(soft_prompt_path, map_location=torch.device("cpu"))
        self.n_tokens = self.encoder_soft_prompt.num_embeddings


    def initialize_decoder_soft_prompts(self, n_tokens, hidden_dim, random_range=0.5):
        self.n_tokens = n_tokens
        self.decoder_soft_prompt = nn.Embedding(n_tokens, hidden_dim)
        # init_prompt_value = torch.FloatTensor(2, 10).uniform_(-random_range, random_range)
        # self.decoder_soft_prompt.weight = nn.parameter.Parameter(init_prompt_value)


    def set_decoder_soft_prompts(self, soft_prompt_path):
        self.decoder_soft_prompt = torch.load(soft_prompt_path, map_location=torch.device("cpu"))
        self.n_tokens = self.decoder_soft_prompt.num_embeddings


    def concatenate_encoder_soft_prompts(self, input_ids):
        inputs_emb = self.encoder_soft_prompt(self.encoder_input_tokens)
        soft_prompts = self.encoder_emb_generator(inputs_emb)

        embeddings = self.encoder.embed_tokens(input_ids)

        soft_prompts = soft_prompts.repeat(embeddings.size(0), 1, 1)

        inputs_concat = torch.cat([soft_prompts, embeddings], dim=1)
        return inputs_concat
    

    def concatenate_decoder_soft_prompts(self, input_ids):
        inputs_emb = self.decoder_soft_prompt(self.decoder_input_tokens)
        soft_prompts = self.decoder_emb_generator(inputs_emb)
        
        embeddings = self.decoder.embed_tokens(input_ids)

        soft_prompts = soft_prompts.repeat(embeddings.size(0), 1, 1)

        inputs_concat = torch.cat([soft_prompts, embeddings], dim=1)
        return inputs_concat


    def extend_attention_mask(self, attention_mask):
        batch_size = attention_mask.shape[0]
        soft_prompts_mask = torch.full((batch_size, self.n_tokens), 1, dtype=torch.long)
        extended_mask = torch.concat([soft_prompts_mask, attention_mask], dim=1)
        return extended_mask
    

    def extend_labels(self, labels, ignore_index=-100):
        batch_size = labels.shape[0]
        soft_prompts_indices = torch.full((batch_size, self.decoder_n_tokens), ignore_index).to(self.device)
        extended_labels = torch.concat([soft_prompts_indices, labels], dim=1)
        return extended_labels


    '''
    forward pass of the T5 prompt tuning model

    Input (only the relevants):
    - input_ids: the inputs tokens of the encoder (batch_size, src_len)
    - attention_mask: the attention mask of the encoder (batch_size, src_len)
    - decoder_input_ids: the inputs tokens of the decoder (batch_size, dst_len)
    - decoder_attention_mask: the attention mask of the decoder (batch_size, dst_len)

    Output:
    - logits: 
    - encoder_last_hidden_state: 
    '''
    def forward(
        self,
        input_ids=None,
        past_key_values=None,
        attention_mask=None,
        decoder_attention_mask=None,
        inputs_embeds=None,
        decoder_input_ids=None,
        decoder_inputs_embeds=None,
        encoder_outputs=None,
        use_cache=None,
        labels=None,
        return_dict=None,
        *args,
        **kwargs
    ):
        
        if input_ids is not None:
            '''
            if input_ids are passed their embedding is concatenated to the
            encoder soft prompts to generate input_embeds, a tensor
            of size (batch_size, enc_n_tokens + seq_len, enc_hidden_dim)
            '''
            inputs_embeds = self.concatenate_encoder_soft_prompts(input_ids)
            input_ids = None

        if decoder_input_ids is not None:
            '''
            if decoder_input_ids are passed thier embedding is concatenated to the
            decoder soft prompts to generate decoder_input_embeds, a tensor
            of size (batch_size, dec_n_tokens + dst_len, dec_hidden_dim)
            '''
            decoder_inputs_embeds = self.concatenate_decoder_soft_prompts(decoder_input_ids)
            decoder_input_ids = None

        if attention_mask is not None and inputs_embeds is not None:
            '''
            if attention_mask is passed it is extended to include also the encoder
            soft prompts, generating a tensor of size (batch_size, enc_n_tokens + seq_len)
            '''
            attention_mask = self.extend_attention_mask(attention_mask)

        if decoder_attention_mask is not None:
            '''
            if decoder_attention_mask is passed it is extended to include also the decoder
            soft prompts, generating a tensor of size (batch_size, dec_n_tokens + dst_len)
            '''
            decoder_attention_mask = self.extend_attention_mask(decoder_attention_mask)


        if labels is not None:
            '''
            if labels is passed then it is extended to include the also the embeddings
            '''
            labels = self.extend_labels(labels)
            
        '''
        we pass the encoder and decoder embeddings to the forward layer of T5
        '''
        return super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            decoder_input_ids=decoder_input_ids,
            decoder_inputs_embeds=decoder_inputs_embeds,
            decoder_attention_mask=decoder_attention_mask,
            labels=labels,
            encoder_outputs=encoder_outputs,
            past_key_values=past_key_values,
            use_cache=use_cache,
            return_dict=return_dict,
            *args,
            **kwargs
        )


    def generate(self, *args, **kwargs):

        kwargs['inputs_embeds'] = self.concatenate_encoder_soft_prompts(kwargs['input_ids']).to(self.device)
        kwargs['attention_mask']=self.extend_attention_mask(torch.ones([1,kwargs['inputs_embeds'].shape[1]-self.n_tokens]).long()).to(self.device)

        del kwargs['input_ids']

        return super().generate(*args, **kwargs)

'''
Defining the T5 model with prompt tuning superclassing T5PromptTuningUtils and 
T5ForConditionalGeneration (which adds the head for producing the logits)
'''
class T5PromptTuning(T5PromptTuningMixin, T5ForConditionalGeneration):

    def __init__(self, config) -> None:
        super(T5PromptTuning, self).__init__(config)



'''
super class that defines the behavior of the T5 model with the soft-prompts
'''
class T5PromptTuningMixinSimple:

    '''
    wrapper of the from_pretrained class method to include the loading of the soft-prompts
    '''
    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path,
        encoder_soft_prompt_path = None,
        decoder_soft_prompt_path = None,
        encoder_n_tokens = None,
        decoder_n_tokens = None,
        initialize_from_vocab = True,
        random_range = 0.5,
        device=None,
        **kwargs,
    ):
        
        # getting the T5 model
        model = super().from_pretrained(pretrained_model_name_or_path, **kwargs)

        # freezing all the parameters of the pretrained T5 model
        for param in model.parameters():
            param.requires_grad = False

        '''
        load the encoder soft prompts if the path is provided otheriwise they
        are randomly initialized
        '''
        if encoder_soft_prompt_path is not None:
            model.set_encoder_soft_prompts(encoder_soft_prompt_path)
        else:
            model.initialize_encoder_soft_prompts(encoder_n_tokens, random_range)

        '''
        load the encoder soft prompts if the path is provided otheriwise they
        are randomly initialized
        '''
        if decoder_soft_prompt_path is not None:
            model.set_decoder_soft_prompts(decoder_soft_prompt_path)
        else:
            model.initialize_decoder_soft_prompts(decoder_n_tokens, random_range)

        model.encoder_n_tokens = encoder_n_tokens
        model.decoder_n_tokens = decoder_n_tokens 

        model.encoder_input_tokens = torch.arange(encoder_n_tokens).long().to(device)
        model.decoder_input_tokens = torch.arange(decoder_n_tokens).long().to(device)

        return model
    

    def initialize_encoder_soft_prompts(self, n_tokens, random_range=0.5):
        self.n_tokens = n_tokens
        self.encoder_soft_prompt = nn.Embedding(n_tokens, self.config.d_model)
        # init_prompt_value = torch.FloatTensor(2, 10).uniform_(-random_range, random_range)
        # self.encoder_soft_prompt.weight = nn.parameter.Parameter(init_prompt_value)


    def set_encoder_soft_prompts(self, soft_prompt_path):
        self.encoder_soft_prompt = torch.load(soft_prompt_path, map_location=torch.device("cpu"))
        self.n_tokens = self.encoder_soft_prompt.num_embeddings


    def initialize_decoder_soft_prompts(self, n_tokens, random_range=0.5):
        self.n_tokens = n_tokens
        self.decoder_soft_prompt = nn.Embedding(n_tokens, self.config.d_model)
        # init_prompt_value = torch.FloatTensor(2, 10).uniform_(-random_range, random_range)
        # self.decoder_soft_prompt.weight = nn.parameter.Parameter(init_prompt_value)


    def set_decoder_soft_prompts(self, soft_prompt_path):
        self.decoder_soft_prompt = torch.load(soft_prompt_path, map_location=torch.device("cpu"))
        self.n_tokens = self.decoder_soft_prompt.num_embeddings


    def concatenate_encoder_soft_prompts(self, input_ids):
        soft_prompts = self.encoder_soft_prompt(self.encoder_input_tokens)
        embeddings = self.encoder.embed_tokens(input_ids)
        soft_prompts = soft_prompts.repeat(embeddings.size(0), 1, 1)

        inputs_concat = torch.cat([soft_prompts, embeddings], dim=1)
        return inputs_concat
    

    def concatenate_decoder_soft_prompts(self, input_ids):
        soft_prompts = self.decoder_soft_prompt(self.decoder_input_tokens)
        embeddings = self.decoder.embed_tokens(input_ids)

        soft_prompts = soft_prompts.repeat(embeddings.size(0), 1, 1)

        inputs_concat = torch.cat([soft_prompts, embeddings], dim=1)
        return inputs_concat


    def extend_attention_mask(self, attention_mask):
        batch_size = attention_mask.shape[0]
        soft_prompts_mask = torch.full((batch_size, self.n_tokens), 1, dtype=torch.long)
        extended_mask = torch.concat([soft_prompts_mask, attention_mask], dim=1)
        return extended_mask
    

    def extend_labels(self, labels, ignore_index=-100):
        batch_size = labels.shape[0]
        soft_prompts_indices = torch.full((batch_size, self.decoder_n_tokens), ignore_index).to(self.device)
        extended_labels = torch.concat([soft_prompts_indices, labels], dim=1)
        return extended_labels


    '''
    forward pass of the T5 prompt tuning model

    Input (only the relevants):
    - input_ids: the inputs tokens of the encoder (batch_size, src_len)
    - attention_mask: the attention mask of the encoder (batch_size, src_len)
    - decoder_input_ids: the inputs tokens of the decoder (batch_size, dst_len)
    - decoder_attention_mask: the attention mask of the decoder (batch_size, dst_len)

    Output:
    - logits: 
    - encoder_last_hidden_state: 
    '''
    def forward(
        self,
        input_ids=None,
        past_key_values=None,
        attention_mask=None,
        decoder_attention_mask=None,
        inputs_embeds=None,
        decoder_input_ids=None,
        decoder_inputs_embeds=None,
        encoder_outputs=None,
        use_cache=None,
        labels=None,
        return_dict=None,
        *args,
        **kwargs
    ):
        
        if input_ids is not None:
            '''
            if input_ids are passed their embedding is concatenated to the
            encoder soft prompts to generate input_embeds, a tensor
            of size (batch_size, enc_n_tokens + seq_len, enc_hidden_dim)
            '''
            inputs_embeds = self.concatenate_encoder_soft_prompts(input_ids)
            input_ids = None

        if decoder_input_ids is not None:
            '''
            if decoder_input_ids are passed thier embedding is concatenated to the
            decoder soft prompts to generate decoder_input_embeds, a tensor
            of size (batch_size, dec_n_tokens + dst_len, dec_hidden_dim)
            '''
            decoder_inputs_embeds = self.concatenate_decoder_soft_prompts(decoder_input_ids)
            decoder_input_ids = None

        if attention_mask is not None and inputs_embeds is not None:
            '''
            if attention_mask is passed it is extended to include also the encoder
            soft prompts, generating a tensor of size (batch_size, enc_n_tokens + seq_len)
            '''
            attention_mask = self.extend_attention_mask(attention_mask)

        if decoder_attention_mask is not None:
            '''
            if decoder_attention_mask is passed it is extended to include also the decoder
            soft prompts, generating a tensor of size (batch_size, dec_n_tokens + dst_len)
            '''
            decoder_attention_mask = self.extend_attention_mask(decoder_attention_mask)


        if labels is not None:
            '''
            if labels is passed then it is extended to include the also the embeddings
            '''
            labels = self.extend_labels(labels)
            
        '''
        we pass the encoder and decoder embeddings to the forward layer of T5
        '''
        return super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            decoder_input_ids=decoder_input_ids,
            decoder_inputs_embeds=decoder_inputs_embeds,
            decoder_attention_mask=decoder_attention_mask,
            labels=labels,
            encoder_outputs=encoder_outputs,
            past_key_values=past_key_values,
            use_cache=use_cache,
            return_dict=return_dict,
            *args,
            **kwargs
        )


    def generate(self, *args, **kwargs):

        kwargs['inputs_embeds'] = self.concatenate_encoder_soft_prompts(kwargs['input_ids']).to(self.device)
        kwargs['attention_mask']=self.extend_attention_mask(torch.ones([1,kwargs['inputs_embeds'].shape[1]-self.n_tokens]).long()).to(self.device)

        del kwargs['input_ids']

        return super().generate(*args, **kwargs)

'''
Defining the T5 model with prompt tuning superclassing T5PromptTuningUtils and 
T5ForConditionalGeneration (which adds the head for producing the logits)
'''
class T5PromptTuningSimple(T5PromptTuningMixinSimple, T5ForConditionalGeneration):

    def __init__(self, config) -> None:
        super(T5PromptTuningSimple, self).__init__(config)

## Defining the trainer superclass

In [41]:
class Trainer:

    def __init__(self, model, src_tokenizer, dst_tokenizer, config) -> None:

        self.device = config["device"]
        self.model = model.to(self.device)
        self.src_tokenizer = src_tokenizer
        self.dst_tokenizer = dst_tokenizer
        self.config = config

        pad_token = dst_tokenizer.pad_token
        pad_token_idx = dst_tokenizer.convert_tokens_to_ids([pad_token])[0]
        self.criterion = nn.CrossEntropyLoss(ignore_index=pad_token_idx)
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.1)
        self.pad_token_idx = pad_token_idx

        self.metric = evaluate.load("bleu")

        if "model_name" in config:
            self.model_name = config["model_name"]
        else:
            self.model_name = self.model.__class__.__name__.lower()


    
    def set_seeds(self, seed):
        torch.manual_seed(seed)
        random.seed(seed)
        np.random.seed(seed)


    def get_data_loader(self, batch_size, val_split=0.2, test_split=0.1):
        
        data_set = AnkiDataset(
            f"{DATASET_PATH}/ita.txt",
            self.src_tokenizer,
            self.dst_tokenizer,
            self.config["src_max_length"],
            self.config["dst_max_length"]
        )


        n = len(data_set)

        val_size = int(n*val_split)
        test_size = int(n*test_split)
        train_size = n - val_size - test_size


        train_set, val_set, test_set = random_split(data_set, [train_size, val_size, test_size])

        train_loader = DataLoader(
                    train_set,
                    batch_size = batch_size
                )
        
        val_loader = DataLoader(
                    val_set,
                    batch_size=batch_size
                )
        
        test_loader = DataLoader(
                    test_set,
                    batch_size = batch_size
                )
        
        return train_loader, val_loader, test_loader


    def generate_learning_curvers(self, train_losses, val_losses):

        plot_curves(
            curve_1=train_losses,
            curve_2=val_losses,
            label_1="Train loss",
            label_2="Validation loss",
            fig_name=f"{IMAGE_PATH}/loss_model_{self.model_name}"
        )

        plot_curves(
            curve_1=train_losses[:self.best_epoch],
            curve_2=val_losses[:self.best_epoch],
            label_1="Train loss",
            label_2="Validation loss",
            fig_name=f"{IMAGE_PATH}/best_loss_model_{self.model_name}"
        )

        plot_curves(
            curve_1=train_losses,
            label_1="Train loss",
            fig_name=f"{IMAGE_PATH}/train_loss_model_{self.model_name}"
        )

        plot_curves(
            curve_1=train_losses[:self.best_epoch],
            label_1="Train loss",
            fig_name=f"{IMAGE_PATH}/best_train_loss_model_{self.model_name}"
        )

        plot_curves(
            curve_1=val_losses,
            label_1="Val loss",
            fig_name=f"{IMAGE_PATH}/val_loss_model_{self.model_name}"
        )

        plot_curves(
            curve_1=val_losses[:self.best_epoch],
            label_1="Val loss",
            fig_name=f"{IMAGE_PATH}/best_val_loss_model_{self.model_name}"
        )


    def train(self, generate_fun):
        
        seed = self.config["seed"]
        self.set_seeds(seed)

        batch_size = self.config["batch_size"]
        # self.model.to(self.device)

        train_loader, val_loader, test_loader = self.get_data_loader(batch_size, 0.2, 0.1)

        self.train_loop(train_loader, val_loader)
        self.model.eval()
        test_loss = self.test_step(test_loader)
        print("Evaluating model on the test set")
        print(f"Test loss: {test_loss}")

        # evaluate bleu score
        train_score = self.metric_evaluation(train_loader, generate_fun)
        val_score = self.metric_evaluation(val_loader, generate_fun)
        test_score = self.metric_evaluation(test_loader, generate_fun)

        print(f"Average train set BLEU score: {train_score}")
        print(f"Average validation set BLEU score: {val_score}")
        print(f"Average test set BLEU score: {test_score}")


    def train_loop(self, train_loader, val_loader):

        epochs = self.config["max_epochs"]
        batch_size = self.config["batch_size"]

        train_losses = []
        val_losses = []

        best_val_loss = float("inf")
        best_loss_epoch = None

        for epoch in range(1, epochs+1):
            self.model.train()
            print(f"Training epoch {epoch}/{epochs}")
            train_loss = self.train_step(train_loader, epoch)
            self.model.eval()
            print(f"Validation epoch {epoch}/{epochs}")
            val_loss = self.val_step(val_loader, epoch)

            if val_loss < best_val_loss:
                if best_loss_epoch != None:
                    os.system(f"rm {CHECKPOINT_DIR}/model_{self.model_name}_{best_loss_epoch}_checkpoint.pt")
                best_val_loss = val_loss
                best_loss_epoch = epoch
                torch.save(self.model.state_dict(), f"{CHECKPOINT_DIR}/model_{self.model_name}_{epoch}_checkpoint.pt")

            train_losses.append(train_loss)
            val_losses.append(val_loss)

            print(f"Epoch {epoch} train loss: {train_loss}, val_loss: {val_loss}")

        self.best_epoch = best_loss_epoch

        self.generate_learning_curvers(train_losses, val_losses)
        


    def train_step(self, train_loader, epoch):

        total_loss = 0
        n = len(train_loader)

        with tqdm(total=n) as pbar:
            for step, batch in enumerate(train_loader):

                self.optimizer.zero_grad()
                inputs, targets = batch

                inputs = inputs.to(self.device)
                targets = targets.to(self.device)

                input_ids = inputs.input_ids
                target_ids = targets.input_ids

                output = self.model(input_ids=input_ids, decoder_input_ids=target_ids)

                logits = output.logits

                logits_dim = logits.shape[-1]

                logits = logits[1:].view(-1, logits_dim)
                target_ids = target_ids[1:].reshape(-1)

                loss = self.criterion(logits, target_ids)
                
                loss.backward()

                self.optimizer.step()

                total_loss += loss.item()

                if (step+1) % 50 == 0:
                    print(f"\nEpoch {epoch}, samples {step+1}/{n} train loss: {total_loss/(step+1)}")

                pbar.update(1)

                
        avg_loss = total_loss / n

        return avg_loss
            
    
    def val_step(self, val_loader, epoch):

        total_loss = 0
        n = len(val_loader)

        with tqdm(total=n) as pbar:
            for step, batch in enumerate(val_loader):

                inputs, targets = batch

                inputs = inputs.to(self.device)
                targets = targets.to(self.device)

                input_ids = inputs.input_ids
                target_ids = targets.input_ids

                output = self.model(input_ids=input_ids, decoder_input_ids=target_ids)
                logits = output.logits

                logits_dim = logits.shape[-1]

                logits = logits[1:].view(-1, logits_dim)
                target_ids = target_ids[1:].reshape(-1)

                loss = self.criterion(logits, target_ids)

                total_loss += loss.item()

                if (step+1) % 50 == 0:
                    print(f"\nEpoch {epoch}, samples {step+1}/{n} train loss: {total_loss/(step+1)}")

                pbar.update(1)

        avg_loss = total_loss / n

        return avg_loss


    
    def test_step(self, test_loader):

        self.model.load_state_dict(torch.load(f"{CHECKPOINT_DIR}/model_{self.model_name}_{self.best_epoch}_checkpoint.pt"))
        
        total_loss = 0
        n = len(test_loader)

        with tqdm(total=n) as pbar:
            for step, batch in enumerate(test_loader):
                
                inputs, targets = batch

                inputs = inputs.to(self.device)
                targets = targets.to(self.device)

                input_ids = inputs.input_ids
                target_ids = targets.input_ids

                output = self.model(input_ids=input_ids, decoder_input_ids=target_ids)
                logits = output.logits

                logits_dim = logits.shape[-1]

                logits = logits[1:].view(-1, logits_dim)
                target_ids = target_ids[1:].reshape(-1)

                loss = self.criterion(logits, target_ids)

                total_loss += loss.item()

                pbar.update(1)

        avg_loss = total_loss / n

        return avg_loss
            

    

    def metric_evaluation(self, data_loader, generate_fun):
        
        self.model.load_state_dict(torch.load(f"{CHECKPOINT_DIR}/model_{self.model_name}_{self.best_epoch}_checkpoint.pt"))
        self.model.eval()


        score = 0

        for step, batch in enumerate(data_loader):

            self.optimizer.zero_grad()

            inputs, targets = batch

            inputs = inputs.to(device)
            targets = targets.to(device)

            for i in range(len(inputs.input_ids)):

                input_ids = inputs.input_ids[i]
                target_ids = targets.input_ids[i]

                output = generate_fun(input_ids.unsqueeze(0))

                if type(output) == tuple:
                    pred_ids, attention = output
                else:
                    pred_ids = output[0]

                pred_sentence = self.src_tokenizer.decode(pred_ids, skip_special_tokens=True)
                target_sentence = self.dst_tokenizer.decode(target_ids, skip_special_tokens=True)

                result = self.metric.compute(predictions=[pred_sentence], references=[target_sentence])
                score += result["bleu"]

            score /= len(data_loader)

            return score

## Prompt Tuning Trainer

In [42]:
class PromptTuningTrainer(Trainer):


    def __init__(self, model, src_tokenizer, dst_tokenizer, config) -> None:
        super(PromptTuningTrainer, self).__init__(model, src_tokenizer, dst_tokenizer, config)

    
    def train_step(self, train_loader, epoch):

        total_loss = 0
        n = len(train_loader)

        with tqdm(total=n) as pbar:
            for step, batch in enumerate(train_loader):

                self.optimizer.zero_grad()
                inputs, targets = batch

                inputs = inputs.to(self.device)
                targets = targets.to(self.device)

                input_ids = inputs.input_ids
                target_ids = targets.input_ids

                output = self.model(input_ids=input_ids, decoder_input_ids=target_ids)

                target_ids = self.model.extend_labels(target_ids, self.pad_token_idx)

                logits = output.logits

                logits_dim = logits.shape[-1]

                logits = logits[1:].view(-1, logits_dim)
                target_ids = target_ids[1:].reshape(-1)

                loss = self.criterion(logits, target_ids)
                
                loss.backward()

                self.optimizer.step()

                total_loss += loss.item()

                if (step+1) % 100 == 0:
                    print(f"\nEpoch {epoch}, samples {step+1}/{n} train loss: {total_loss/(step+1)}")

                pbar.update(1)

                
        avg_loss = total_loss / n

        return avg_loss
            
    
    def val_step(self, val_loader, epoch):

        total_loss = 0
        n = len(val_loader)

        with tqdm(total=n) as pbar:
            for step, batch in enumerate(val_loader):

                inputs, targets = batch

                inputs = inputs.to(self.device)
                targets = targets.to(self.device)

                input_ids = inputs.input_ids
                target_ids = targets.input_ids

                output = self.model(input_ids=input_ids, decoder_input_ids=target_ids)

                target_ids = self.model.extend_labels(target_ids, self.pad_token_idx)

                logits = output.logits

                logits_dim = logits.shape[-1]

                logits = logits[1:].view(-1, logits_dim)
                target_ids = target_ids[1:].reshape(-1)

                loss = self.criterion(logits, target_ids)

                total_loss += loss.item()

                if (step+1) % 100 == 0:
                    print(f"\nEpoch {epoch}, samples {step+1}/{n} train loss: {total_loss/(step+1)}")

                pbar.update(1)

        avg_loss = total_loss / n

        return avg_loss


    
    def test_step(self, test_loader):

        self.model.load_state_dict(torch.load(f"{CHECKPOINT_DIR}/model_{self.model_name}_{self.best_epoch}_checkpoint.pt"))
        
        total_loss = 0
        n = len(test_loader)

        with tqdm(total=n) as pbar:
            for step, batch in enumerate(test_loader):
                
                inputs, targets = batch

                inputs = inputs.to(self.device)
                targets = targets.to(self.device)

                input_ids = inputs.input_ids
                target_ids = targets.input_ids

                output = self.model(input_ids=input_ids, decoder_input_ids=target_ids)

                target_ids = self.model.extend_labels(target_ids, self.pad_token_idx)

                logits = output.logits

                logits_dim = logits.shape[-1]

                logits = logits[1:].view(-1, logits_dim)
                target_ids = target_ids[1:].reshape(-1)

                loss = self.criterion(logits, target_ids)

                total_loss += loss.item()

                pbar.update(1)


        avg_loss = total_loss / n

        return avg_loss

In [43]:
model = T5PromptTuning.from_pretrained(
    "t5-small",
    encoder_soft_prompt_path = None,
    decoder_soft_prompt_path = None,
    encoder_n_tokens = 20,
    decoder_n_tokens = 40,
    encoder_hidden_dim=32,
    decoder_hidden_dim=32
)
count_parameters(model)

The model has 37824 trainable parameters


## Define and train the model

In [44]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = {
    "src_max_length": 183,
    "dst_max_length": 208,
    "src_vocab_size": 31102,
    "dst_vocab_size": 28996,
    "enc_hidden_dim": 8,
    "dec_hidden_dim": 8,
    "max_epochs": 1,
    "batch_size": 8,
    "seed": 7,
    "device": device
}


src_tokenizer = AutoTokenizer.from_pretrained("dbmdz/bert-base-italian-cased")
dst_tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')

model = T5PromptTuning.from_pretrained(
    "t5-small",
    encoder_soft_prompt_path = None,
    decoder_soft_prompt_path = None,
    encoder_n_tokens = 40,
    decoder_n_tokens = 40,
    encoder_hidden_dim=64,
    decoder_hidden_dim=64,
    device=device
)

trainer = PromptTuningTrainer(model, src_tokenizer, dst_tokenizer, config)

generate_fun = lambda x: model.generate(
    input_ids=x, 
    decoder_input_ids=torch.zeros([1,1]).long().to(config["device"]), 
    max_length=200,
    num_beams=5,
    early_stopping=True,
)

trainer.train(generate_fun)

Training epoch 1/1


  0%|          | 100/31751 [00:20<1:45:02,  5.02it/s]


Epoch 1, samples 100/31751 train loss: 7.0314547729492185


  1%|          | 201/31751 [00:40<1:39:08,  5.30it/s]


Epoch 1, samples 200/31751 train loss: 6.892015600204468


  1%|          | 275/31751 [00:54<1:43:34,  5.06it/s]


KeyboardInterrupt: ignored