# Packages and variables

In [86]:
# Load packages
from sentence_transformers import models, losses, datasets, SentencesDataset
from sentence_transformers import SentenceTransformer, util, InputExample
import pandas as pd
from torch.utils.data import DataLoader, TensorDataset
from datasets import load_dataset
from transformers import BertTokenizer
from collections import OrderedDict
import torch.optim as optim
import torch
from torch import nn, Tensor
from tqdm import tqdm
import transformers
from typing import Iterable, Dict

from metrics import *
from utils import *

In [None]:
# Specify variables
model_name = "bert-base-uncased"
train_batch_size = 20
max_seq_length = 250
num_epochs = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load tokenizer and model
tokenizer = BertTokenizer.from_pretrained(model_name)
model = SentenceTransformer(model_name).to(device)

# get embedding dimension
num_features = model[1].word_embedding_dimension

# Load and prepare dataset

In [None]:
# Load dataset from huggingface
dataset = load_dataset("multi_nli")

# and make dataset as dataframe for easier usage
df = pd.DataFrame()
df["premise"] = dataset["train"]["premise"]
df["hypothesis"] = dataset["train"]["hypothesis"]
df["genre"] = dataset["train"]["genre"]
df["label"] = dataset["train"]["label"]

In [None]:
# the tokenizer sends some warnings about the truncation strategy but all sentences are shorter than max_seq_length
transformers.logging.set_verbosity_error()

# Each different hierarchy needs a different label for AdaCos
labels = []
input_ids = []
attention_masks = []

for i in tqdm(df.iterrows()):
    if i[1]["genre"] == "telephone":
        tmp = 0
    elif i[1]["genre"] == "government":
        tmp = 3
    elif i[1]["genre"] == "travel":
        tmp = 6
    elif i[1]["genre"] == "fiction":
        tmp = 9
    elif i[1]["genre"] == "slate":
        tmp = 12
            
    lab = int(i[1]["label"]) + tmp
    
    encoded_data = tokenizer(
        i[1]["premise"],
        i[1]["hypothesis"],
        add_special_tokens = True, 
        return_attention_mask = True, 
        padding = "max_length", 
        max_length = max_seq_length, 
        return_tensors = 'pt',
        truncation = True
    )

    input_ids.append(encoded_data['input_ids'])
    attention_masks.append(encoded_data['attention_mask'])
    labels.append(lab)
    
input_ids = torch.cat(input_ids)
attention_masks = torch.cat(attention_masks)
labels = torch.LongTensor(labels)

train_dataset = TensorDataset(input_ids, attention_masks, labels)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)

# Train AdaCos model

In [None]:
# Define training loop
def train(train_loader, model, metrics_fc, criterion, optimizer, optimizer2):
    losses = AverageMeter()
    acc1s = AverageMeter()
    
    model.train()
    metrics_fc.train()
    
    for i,batch in tqdm(enumerate(train_loader),total=len(train_loader)):
        optimizer.zero_grad()
        
        inputs = {
            'input_ids': batch[0].to(device),
            'attention_mask': batch[1].to(device),
        }

        feature = model(inputs)["sentence_embedding"]
        target = torch.LongTensor(batch[2]).to(device)
        
        output = metrics_fc(feature, target)
        loss = criterion(output, target)
        acc1, = accuracy(output, target, topk=(1,))
        
        losses.update(loss.item(), len(batch[0]))
        acc1s.update(acc1.item(), len(batch[0]))
 
        # compute gradient and do optimizing step
        loss.backward()
        optimizer.step()
        optimizer2.step()

    log = OrderedDict([
        ('loss', losses.avg),
        ('acc1', acc1s.avg),
    ])
            
    return log

In [None]:
# Define the losses needed for adacos
adacos = metrics.AdaCos(num_features, num_classes=15).to(device)
criterion = nn.CrossEntropyLoss().to(device)

# Prepare optimizers for model and adacos layers
optimizer_m = optim.AdamW(model.parameters(), lr=1e-5)
optimizer = optim.SGD(adacos.parameters(), lr=1e-5, momentum=0.9, weight_decay=1e-4)

# Train the model
for epoch in range(num_epochs):
    print(f"epoch{epoch+1}")
    train(train_dataloader, model, adacos, criterion, optimizer, optimizer_m)
    
    tmp = pd.Series([
        epoch,
        train_log['loss'],
        train_log['acc1'],
    ], index=['epoch', 'loss', 'acc1'])
    print(tmp)

model.save(f'model_adacos_mnli')

# Prepare data for pairwise training
For clarity we separate the dataset preprocessing, but it can be done at once

In [None]:
# Pairwise cosine loss expects labels to be opposite for opposite classes (positive, negative)
# for example for telephone the labels will be 1 for positive, -1 for negative
# for government the labels will be 2 for positive and -2 for negative and so on
# neutral classes are mapped to values whose absolute value is not shared by any other class
train_examples = []
for i in tqdm(df.iterrows()):
    if i[1]["genre"] == "telephone":
        if int(i[1]["label"]) != 1:
            tmp = -(int(i[1]["label"]) - 1)
        else:
            tmp = int(i[1]["label"]) + 10
    elif i[1]["genre"] == "government":
        if int(i[1]["label"]) != 1:
            tmp = -(int(i[1]["label"]) - 1) * 2
        else:
            tmp = int(i[1]["label"]) + 11
    elif i[1]["genre"] == "travel":
        if int(i[1]["label"]) != 1:
            tmp = -(int(i[1]["label"]) - 1) * 3
        else:
            tmp = int(i[1]["label"]) + 12
    elif i[1]["genre"] == "fiction":
        if int(i[1]["label"]) != 1:
            tmp = -(int(i[1]["label"]) - 1) * 4
        else:
            tmp = int(i[1]["label"]) + 13
    elif i[1]["genre"] == "slate":
        if int(i[1]["label"]) != 1:
            tmp = -(int(i[1]["label"]) - 1) * 5
        else:
            tmp = int(i[1]["label"]) + 14
    lab = tmp

    input = i[1]["premise"] + " " + tokenizer.sep_token + " " + i[1]["hypothesis"]
    train_examples.append(InputExample(texts=[input], label=lab))  

train_dataset = SentencesDataset(train_examples, model)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)  

# Train Pairwise

In [88]:
# define pairwise loss function
class Pairwise_Cosine_Loss(nn.Module):
    def __init__(self, model: SentenceTransformer, t=0.3):
        super(Pairwise_Cosine_Loss, self).__init__()
        self.sentence_embedder = model
        self.t = t

    def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
        rep = self.sentence_embedder(sentence_features[0])['sentence_embedding']
        return self.pairwise_loss(labels, rep)

    def pairwise_loss(self, labels, embeddings):
        """Build the pairwise loss over a batch of embeddings.
        We generate all the pairs and average the loss.
        Args:
            labels: labels of the batch, of size (batch_size,)
            embeddings: tensor of shape (batch_size, embed_dim)
        Returns:
            Pairwise loss: scalar tensor containing the pairwise loss
        """
        # Get the pairwise distance matrix
        pairwise_dist = util.pytorch_cos_sim(embeddings, embeddings)
        
        # mask for positive/neutral/negative split
        mask_positive = labels.unsqueeze(0) == labels.unsqueeze(1)
        mask_negative = labels.unsqueeze(0) == -labels.unsqueeze(1)
        mask_neutral = ~mask_negative & ~mask_positive
        
        # create distance objective matrix
        objective = mask_positive.float() + (mask_negative.float() * -1)
        # get errors
        pairwise_loss = pairwise_dist - objective
        pairwise_loss = torch.abs(pairwise_loss)
        # make losses null for neutral classes
        pairwise_loss[(pairwise_loss < self.t) * mask_neutral] = 0
        
        # Get mean pairwise loss
        pairwise_loss = pairwise_loss.sum() / (len(labels)**2)
        
        return pairwise_loss

In [None]:
# load the adacos model if needed
# model = SentenceTransformer(f'model_adacos_mnli')

# define the loss
train_loss = Pairwise_Cosine_Loss(model)

# Train the model
model.fit(train_objectives=[(train_dataloader, train_loss)],
          epochs=num_epochs,
          show_progress_bar=True,
          optimizer_params={'lr': 1e-05}
          )

model.save("model_pairwise_mnli")