In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn.functional as F

import pandas as pd
from sentence_transformers import SentenceTransformer, InputExample, models, util

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import seed_everything

seed_everything(42, workers=True)

import numpy as np

from transformers import get_linear_schedule_with_warmup


import wandb

In [None]:
project = "Bi Encoder Fine Tuning with Pytorch Lightning"
run_rame = "run_01"
wandb.init(project= project,
                 config={
                    "csv_path": '~/train_for_fine_tune BI.csv',
                    "validation_size": 2000,
                    "shuffle_train": True,
                    "optimizer": AdamW,
                    "learning_rate": 1.5e-5,
                    "weight_decay": 0.01,
                    "warmup_function": 'Linear',
                    "warmup_steps": 0.1,
                    "batch_size": 512,
                    "max_epochs": 100,
                    "early_stop": False,
                    "patience": 10,
                    "run_name": run_rame,
                    "model_name": '~/models/BERT after MLM+NSP',
                    "model_save_path": '~/models/' + run_rame
                 })

config = wandb.config
wandb_logger = WandbLogger(name=config.run_name, project=project)

<p>The code defines a class named <code>MatchDataset</code> which is a subclass of <code>torch.utils.data.Dataset</code>.<br>
The constructor initializes an instance variable <code>data</code> which is a <code>DataFrame</code> of training data.<br>
 The <code>__len__()</code> method returns the length of the dataset, and the <code>__getitem__()</code> method takes an index and returns a single training example as an instance of the <code>InputExample</code> class.<br>
 This example contains two texts, a normalized address and an unnormalized address, along with a binary label indicating whether the two addresses match.</p>


In [None]:
class MatchDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sa = self.data.iloc[idx]['NormalizedAddress']
        sb = self.data.iloc[idx]['UnnormalizedAddress']
        match = 1
        example = InputExample(texts=[sa, sb], label=match)
        return example

The given code defines a PyTorch Lightning module called <code>MatchModel</code>. The constructor of the module takes in several arguments including a <code>bi_encoder</code> object, <code>csv_path</code>, <code>batch_size</code>, <code>num_epochs</code>, and <code>num_warmup_steps</code>.<br>

The <code>setup</code> method reads a CSV file from <code>csv_path</code> and creates training and validation datasets from it using <code>MatchDataset</code>.<br>

The module defines two data loaders, one for training data and another for validation data using <code>DataLoader</code> and the collate function <code>self.bi_encoder.smart_batching_collate</code>.<br>

The <code>forward</code> method passes the input through the <code>bi_encoder</code>.<br>

The <code>training_step</code> and <code>validation_step</code> methods perform a forward pass of the input through the <code>bi_encoder</code>, calculate the cosine similarity matrix between all the embeddings (unnormalized with normalized ones) and compute the multiple negatives ranking loss.<br>

The <code>validation_epoch_end</code> method calculates the average validation loss across all the batches in the validation set and logs it. It also checks whether the current validation loss is better than the best validation loss so far and saves the model if it is.<br>

The <code>configure_optimizers</code> method defines an optimizer <code>(AdamW)</code> and a learning rate scheduler (<code>get_linear_schedule_with_warmup</code> or <code>CosineAnnealingLR</code>). The learning rate scheduler is used to adjust the learning rate during training.<br>

In [None]:
class MatchModel(LightningModule):
    def __init__(self, bi_encoder , csv_path, batch_size, num_epochs, num_warmup_steps):
        super().__init__()

        self.bi_encoder = bi_encoder
        self.csv_path = csv_path
        self.batch_size = batch_size
        self.epochs = num_epochs
        self.warmup_steps = num_warmup_steps

        self.act_function = torch.nn.Sigmoid()
        self.loss_function = torch.nn.CrossEntropyLoss()
    
        self.best_val_loss = 999 # Init with big value so that 1st epoch is always saved


    def setup(self, stage = None):
        df = pd.read_csv(self.csv_path)

        train_data = df.iloc[:-config.validation_size]
        val_data = df.iloc[-config.validation_size:]

        self.train_dataset = MatchDataset(train_data)
        self.val_dataset = MatchDataset(val_data)


    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=config.batch_size, shuffle=config.shuffle_train, collate_fn=self.bi_encoder.smart_batching_collate, num_workers=40)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=config.batch_size, shuffle=False, collate_fn=self.bi_encoder.smart_batching_collate, num_workers=40)

    def forward(self, input):
        return self.bi_encoder(input)

    
    def training_step(self, batch, batch_idx):
        
        (addresses, target) = batch
        addresses_embedding_unpack = [self.forward(address)['sentence_embedding'] for address in addresses]
        address_1_embedding, address_2_embedding = addresses_embedding_unpack

        scores = util.cos_sim(address_1_embedding, address_2_embedding)*20
        target = torch.tensor(range(address_1_embedding.shape[0]), dtype=torch.long, device=self.bi_encoder.device) 

        loss = self.loss_function(scores, target)
        
        self.log('train_loss', loss, prog_bar=True)
        return {'loss': loss}

    def validation_step(self, batch, batch_idx):

        (addresses, target) = batch
        addresses_embedding_unpack = [self.forward(address)['sentence_embedding'] for address in addresses]
        address_1_embedding, address_2_embedding = addresses_embedding_unpack

        scores = util.cos_sim(address_1_embedding, address_2_embedding)*20
        target = torch.tensor(range(address_1_embedding.shape[0]), dtype=torch.long, device=self.bi_encoder.device) 

        loss = self.loss_function(scores, target)

        return {'val_loss': loss} #Isto na realidade não é 1 valor, vão ser tantos valores quantos batches tiver o validation_step

    def validation_epoch_end(self, outputs):

        # Obter as médias dos valores de métricas calculadas em todos os batches de validation.
        current_epoch_val_loss = torch.stack([x['val_loss'] for x in outputs]).mean() #calculate the average validation loss across all the batches in the validation set.   


        # Logar as métricas
        self.log('val_loss', current_epoch_val_loss, prog_bar=True, sync_dist=True)


        # Verificar se quero dar save ao modelo. Comparo a loss obtida com a melhor loss do momento
        print('\nVerifico se quero savar o modelo: ')
        print('Loss deste epoch = {}, e melhor loss do momento = {}'.format(current_epoch_val_loss, self.best_val_loss))
        if current_epoch_val_loss < self.best_val_loss:
            self.best_val_loss = current_epoch_val_loss
            print('Savei o modelo')
            self.bi_encoder.save(config.model_save_path)

    def configure_optimizers(self):
        
        optimizer_config = {
            'params' : self.bi_encoder.parameters(),
            'lr' : config.learning_rate,
            'weight_decay' : config.weight_decay,
        }
        
        optimizer = AdamW(**optimizer_config)
        total_steps=self.trainer.estimated_stepping_batches #Esta função tem em conta o valor de quantos batches são accumulated

        if config.warmup_function == "Linear":
            scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps= int(config.warmup_steps * total_steps), num_training_steps= total_steps)
        else:
            scheduler = CosineAnnealingLR(optimizer, T_max=total_steps , eta_min = 5e-7)
        return dict(optimizer=optimizer, lr_scheduler=dict(scheduler=scheduler, interval='step'))


<p>The first line creates an instance of the <code>LearningRateMonitor</code> class and assigns it to the variable <code>lr_monitor</code>. This callback monitors the learning rate of the optimizer and logs it at a specified interval during training.</p>
<p>The second line creates an instance of the <code>EarlyStopping</code> class and assigns it to the variable <code>early_stopping_callback</code>. This callback monitors the validation loss and stops training early if the loss does not improve for a certain number of epochs specified by the <code>patience</code> argument. The <code>monitor</code> argument specifies which metric to monitor, in this case, the validation loss. The <code>mode</code> argument specifies whether to minimize or maximize the monitored metric, in this case, we want to minimize the validation loss.</p>

In [None]:
lr_monitor = LearningRateMonitor(logging_interval='step')
early_stopping_callback = EarlyStopping(monitor='val_loss', patience=config.patience, mode='min')

The code creates an instance of the `SentenceTransformer` class using a pre-trained transformer-based language model specified in `config.model_name`.<br>
It first creates a `Transformer` model using `config.model_name` with a maximum sequence length of 128.<br>
It then creates a `Pooling` model to obtain a fixed-size sentence embedding from the transformer output.<br>
Finally, a fully connected `Dense` model with `nn.Tanh()` activation function is used to further transform the sentence embeddings. These three models are passed as a list to the `SentenceTransformer` class constructor, which creates the final model.<br>


The code also sets the device for running the model to "cuda" if a GPU is available, otherwise "cpu".


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

word_embedding_model = models.Transformer(config.model_name, max_seq_length=128)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
dense_model = models.Dense(in_features=pooling_model.get_sentence_embedding_dimension(), out_features=512, activation_function=nn.Tanh())
model = SentenceTransformer(modules=[word_embedding_model, pooling_model, dense_model], device=device)


<div>
<p>The code initializes a <code>MatchModel</code> with the <code>SentenceTransformer</code> model created above and a <code>Trainer</code> with various arguments:</p>
<ul>
<li><code>accelerator</code> is set to 'gpu'</li>
<li><code>devices</code> is set to 1</li>
<li><code>callbacks</code> is set to a list containing <code>lr_monitor</code> and <code>early_stopping_callback</code> if <code>config_wandb.early_stop</code> is <code>True</code>, otherwise only <code>lr_monitor</code> is included</li>
<li><code>check_val_every_n_epoch</code> is set to 1</li>
<li><code>max_epochs</code> is set to <code>config_wandb.max_epochs</code></li>
<li><code>enable_checkpointing</code> is set to <code>False</code></li>
<li><code>accumulate_grad_batches</code> is set to 1</li>
<li><code>logger</code> is set to <code>wandb_logger</code></li>
<li><code>log_every_n_steps</code> is set to 1</li>
<li><code>deterministic</code> is set to <code>True</code></li>
<li><code>precision</code> is set to 16</li>
</ul>
<p>Finally, the <code>trainer</code> is used to fit the <code>model</code>.</p>
</div>

In [None]:
my_model = MatchModel(bi_encoder = model, 
                    csv_path = config.csv_path,
                    batch_size = config.batch_size, 
                    num_epochs = config.max_epochs,
                    num_warmup_steps = config.warmup_steps)

trainer = Trainer(
                    accelerator='gpu',
                    devices=1,
                    callbacks = [lr_monitor, early_stopping_callback] if config.early_stop == True else [lr_monitor],
                    check_val_every_n_epoch=1,
                    max_epochs=config.max_epochs,
                    logger=wandb_logger,
                    log_every_n_steps=1,
                    enable_checkpointing=False,
                    deterministic=True,
                    precision = 16,
                    )
trainer.fit(my_model)

In [None]:
wandb.alert(
    title=f"Finish {config.run_name}", 
    text = f"Run is over.",
)

wandb.finish()