# [Sentence-BERT](https://arxiv.org/pdf/1908.10084.pdf)

[Reference Code](https://www.pinecone.io/learn/series/nlp/train-sentence-transformers-softmax/)

In [1]:
import os
import math
import re
from   random import *
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from Bert import BERT
# Set GPU device
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

os.environ['http_proxy']  = 'http://192.41.170.23:3128'
os.environ['https_proxy'] = 'http://192.41.170.23:3128'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

## 1. Data

### Train, Test, Validation 

In [2]:
import datasets
snli = datasets.load_dataset('snli')
mnli = datasets.load_dataset('glue', 'mnli')
mnli['train'].features, snli['train'].features

({'premise': Value(dtype='string', id=None),
  'hypothesis': Value(dtype='string', id=None),
  'label': ClassLabel(names=['entailment', 'neutral', 'contradiction'], id=None),
  'idx': Value(dtype='int32', id=None)},
 {'premise': Value(dtype='string', id=None),
  'hypothesis': Value(dtype='string', id=None),
  'label': ClassLabel(names=['entailment', 'neutral', 'contradiction'], id=None)})

In [3]:
# List of datasets to remove 'idx' column from
mnli.column_names.keys()

dict_keys(['train', 'validation_matched', 'validation_mismatched', 'test_matched', 'test_mismatched'])

In [4]:
# Remove 'idx' column from each dataset
for column_names in mnli.column_names.keys():
    mnli[column_names] = mnli[column_names].remove_columns('idx')

In [5]:
mnli.column_names.keys()

dict_keys(['train', 'validation_matched', 'validation_mismatched', 'test_matched', 'test_mismatched'])

In [6]:
import numpy as np
np.unique(mnli['train']['label']), np.unique(snli['train']['label'])
#snli also have -1

(array([0, 1, 2]), array([-1,  0,  1,  2]))

In [7]:
# there are -1 values in the label feature, these are where no class could be decided so we remove
snli = snli.filter(
    lambda x: 0 if x['label'] == -1 else 1
)

In [8]:
import numpy as np
np.unique(mnli['train']['label']), np.unique(snli['train']['label'])
#snli also have -1

(array([0, 1, 2]), array([0, 1, 2]))

In [9]:
# Assuming you have your two DatasetDict objects named snli and mnli
from datasets import DatasetDict
# Merge the two DatasetDict objects
raw_dataset = DatasetDict({
    'train': datasets.concatenate_datasets([snli['train'], mnli['train']]).shuffle(seed=55).select(list(range(1000))),
    'test': datasets.concatenate_datasets([snli['test'], mnli['test_mismatched']]).shuffle(seed=55).select(list(range(100))),
    'validation': datasets.concatenate_datasets([snli['validation'], mnli['validation_mismatched']]).shuffle(seed=55).select(list(range(1000)))
})
#remove .select(list(range(1000))) in order to use full dataset
# Now, merged_dataset_dict contains the combined datasets from snli and mnli
raw_dataset

DatasetDict({
    train: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 100
    })
    validation: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 1000
    })
})

## 2. Preprocessing

In [10]:
import json

with open("./model/word2id.json", "r") as f:
    word2id = json.load(f)

In [11]:
def custom_tokenizer(text, max_seq_length=128):
    tokens = text.lower().split()
    token_ids = [word2id.get(word, word2id["[MASK]"]) for word in tokens]   
    token_ids = token_ids[:max_seq_length]
    attention_mask = [1] * len(token_ids)
    padding_length = max_seq_length - len(token_ids)
    token_ids.extend([word2id["[PAD]"]] * padding_length)
    attention_mask.extend([0] * padding_length) 

    return {
        "input_ids": token_ids,
        "attention_mask": attention_mask
    }


In [12]:
def preprocess_function(examples):
    max_seq_length = 128

    premise_result = [custom_tokenizer(text, max_seq_length) for text in examples["premise"]]
    hypothesis_result = [custom_tokenizer(text, max_seq_length) for text in examples["hypothesis"]]

    labels = examples["label"]

    return {
        "premise_input_ids": [res["input_ids"] for res in premise_result],
        "premise_attention_mask": [res["attention_mask"] for res in premise_result],
        "hypothesis_input_ids": [res["input_ids"] for res in hypothesis_result],
        "hypothesis_attention_mask": [res["attention_mask"] for res in hypothesis_result],
        "labels": labels
    }

tokenized_datasets = raw_dataset.map(preprocess_function, batched=True)

tokenized_datasets = tokenized_datasets.remove_columns(['premise','hypothesis','label'])
tokenized_datasets.set_format("torch")

In [13]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['premise_input_ids', 'premise_attention_mask', 'hypothesis_input_ids', 'hypothesis_attention_mask', 'labels'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['premise_input_ids', 'premise_attention_mask', 'hypothesis_input_ids', 'hypothesis_attention_mask', 'labels'],
        num_rows: 100
    })
    validation: Dataset({
        features: ['premise_input_ids', 'premise_attention_mask', 'hypothesis_input_ids', 'hypothesis_attention_mask', 'labels'],
        num_rows: 1000
    })
})

## 3. Data loader

In [14]:
from torch.utils.data import DataLoader

# initialize the dataloader
batch_size = 32
train_dataloader = DataLoader(
    tokenized_datasets['train'], 
    batch_size=batch_size, 
    shuffle=True
)
eval_dataloader = DataLoader(
    tokenized_datasets['validation'], 
    batch_size=batch_size
)
test_dataloader = DataLoader(
    tokenized_datasets['test'], 
    batch_size=batch_size
)

In [15]:
for batch in train_dataloader:
    print(batch['premise_input_ids'].shape)
    print(batch['premise_attention_mask'].shape)
    print(batch['hypothesis_input_ids'].shape)
    print(batch['hypothesis_attention_mask'].shape)
    print(batch['labels'].shape)
    break

torch.Size([32, 128])
torch.Size([32, 128])
torch.Size([32, 128])
torch.Size([32, 128])
torch.Size([32])


## 4. Model

In [16]:
from tqdm.auto import tqdm

n_layers = 12    # number of Encoder of Encoder Layer
n_heads  = 12    # number of heads in Multi-Head Attention
d_model  = 768  # Embedding Size
d_ff = d_model * 4  # 4*d_model, FeedForward dimension
d_k = d_v = 64  # dimension of K(=Q), V
n_segments = 2
num_epoch = 4
max_len    = 1000
vocab_size = len(word2id)
model = BERT(
    n_layers, 
    n_heads, 
    d_model, 
    d_ff, 
    d_k, 
    n_segments, 
    vocab_size, 
    max_len, 
    device
).to(device)  # Move model to GPU

model.load_state_dict(torch.load("./model/bert_model.pth", map_location=device))

model.eval()


BERT(
  (embedding): Embedding(
    (tok_embed): Embedding(60305, 768)
    (pos_embed): Embedding(1000, 768)
    (seg_embed): Embedding(2, 768)
    (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (layers): ModuleList(
    (0-11): 12 x EncoderLayer(
      (enc_self_attn): MultiHeadAttention(
        (W_Q): Linear(in_features=768, out_features=768, bias=True)
        (W_K): Linear(in_features=768, out_features=768, bias=True)
        (W_V): Linear(in_features=768, out_features=768, bias=True)
      )
      (pos_ffn): PoswiseFeedForwardNet(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
      )
    )
  )
  (fc): Linear(in_features=768, out_features=768, bias=True)
  (activ): Tanh()
  (linear): Linear(in_features=768, out_features=768, bias=True)
  (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (classifier): Linear(in_features=768, out_features=2, bias=True)
  (de

### Pooling
SBERT adds a pooling operation to the output of BERT / RoBERTa to derive a fixed sized sentence embedding

In [17]:
# define mean pooling function
def mean_pool(token_embeds, attention_mask):
    # reshape attention_mask to cover 768-dimension embeddings
    in_mask = attention_mask.unsqueeze(-1).expand(
        token_embeds.size()
    ).float()
    # perform mean-pooling but exclude padding tokens (specified by in_mask)
    pool = torch.sum(token_embeds * in_mask, 1) / torch.clamp(
        in_mask.sum(1), min=1e-9
    )
    return pool

## 5. Loss Function

## Classification Objective Function 
We concatenate the sentence embeddings $u$ and $v$ with the element-wise difference  $\lvert u - v \rvert $ and multiply the result with the trainable weight  $ W_t ∈  \mathbb{R}^{3n \times k}  $:

$ o = \text{softmax}\left(W^T \cdot \left(u, v, \lvert u - v \rvert\right)\right) $

where $n$ is the dimension of the sentence embeddings and k the number of labels. We optimize cross-entropy loss. This structure is depicted in Figure 1.

## Regression Objective Function. 
The cosine similarity between the two sentence embeddings $u$ and $v$ is computed (Figure 2). We use means quared-error loss as the objective function.

(Manhatten / Euclidean distance, semantically  similar sentences can be found.)


In [18]:
def configurations(u,v):
    # build the |u-v| tensor
    uv = torch.sub(u, v)   # batch_size,hidden_dim
    uv_abs = torch.abs(uv) # batch_size,hidden_dim
    
    # concatenate u, v, |u-v|
    x = torch.cat([u, v, uv_abs], dim=-1) # batch_size, 3*hidden_dim
    return x

def cosine_similarity(u, v):
    dot_product = np.dot(u, v)
    norm_u = np.linalg.norm(u)
    norm_v = np.linalg.norm(v)
    similarity = dot_product / (norm_u * norm_v)
    return similarity

In [19]:
classifier_head = torch.nn.Linear(768*3, 3).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
optimizer_classifier = torch.optim.Adam(classifier_head.parameters(), lr=2e-5)

criterion = nn.CrossEntropyLoss()

In [20]:
from transformers import get_linear_schedule_with_warmup

# and setup a warmup for the first ~10% steps
total_steps = int(len(raw_dataset) / batch_size)
warmup_steps = int(0.1 * total_steps)
scheduler = get_linear_schedule_with_warmup(
		optimizer, num_warmup_steps=warmup_steps,
  	num_training_steps=total_steps - warmup_steps
)

# then during the training loop we update the scheduler per step
scheduler.step()

scheduler_classifier = get_linear_schedule_with_warmup(
		optimizer_classifier, num_warmup_steps=warmup_steps,
  	num_training_steps=total_steps - warmup_steps
)

# then during the training loop we update the scheduler per step
scheduler_classifier.step()



## 6. Training

In [21]:
from tqdm.auto import tqdm
import torch

num_epoch = 4
batch_size = 8
accuracy = 0
count = 0

for epoch in range(num_epoch):
    model.train()  
    classifier_head.train()

    accuracy = 0  # Reset accuracy for each epoch
    count = 0  # Reset count for each epoch

    for step, batch in enumerate(tqdm(train_dataloader, leave=True)):
        # zero all gradients on each new step
        optimizer.zero_grad()
        optimizer_classifier.zero_grad()
        
        # prepare batches and move all to the active device
        inputs_ids_a = batch['premise_input_ids'].to(device)
        inputs_ids_b = batch['hypothesis_input_ids'].to(device)
        attention_a = batch['premise_attention_mask'].to(device)
        attention_b = batch['hypothesis_attention_mask'].to(device)
        label = batch['labels'].to(device)

        segment_ids = torch.zeros_like(inputs_ids_a).to(device)
        masked_pos = torch.zeros_like(inputs_ids_a).to(device)

        # Extract token embeddings from BERT at last_hidden_state
        u = model.get_last_hidden_state(inputs_ids_a, segment_ids)
        v = model.get_last_hidden_state(inputs_ids_b, segment_ids)

        # Mean Pool
        u_mean_pool = mean_pool(u, attention_a)  # batch_size, hidden_dim
        v_mean_pool = mean_pool(v, attention_b)  # batch_size, hidden_dim
        
        # Build |u-v| tensor
        uv = torch.sub(u_mean_pool, v_mean_pool)   # batch_size, hidden_dim
        uv_abs = torch.abs(uv)  # batch_size, hidden_dim
        
        # Concatenate u, v, |u-v|
        x = torch.cat([u_mean_pool, v_mean_pool, uv_abs], dim=-1)  # batch_size, 3*hidden_dim
        
        # Process concatenated tensor through classifier_head
        x = classifier_head(x)  # batch_size, num_classes

        # Calculate accuracy using torch.argmax and comparison with labels
        correct = (torch.argmax(x, dim=1) == label).float()  # Compare predictions with labels
        accuracy += correct.sum().item()  # Add correct predictions to total accuracy
        count += label.size(0)  # Increment count by the number of samples in the batch

        # Calculate 'softmax-loss' between predicted and true label
        loss = criterion(x, label)
        
        # Using loss, calculate gradients and then optimize
        loss.backward()
        optimizer.step()
        optimizer_classifier.step()

        scheduler.step()  # Update learning rate scheduler
        scheduler_classifier.step()

    print(f'Epoch: {epoch + 1} | loss = {loss.item():.6f} | Accuracy = {(accuracy / count) * 100:.2f}%')


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 1 | loss = 1.150960 | Accuracy = 35.30%


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 2 | loss = 1.411711 | Accuracy = 35.30%


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 3 | loss = 0.958149 | Accuracy = 35.30%


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 4 | loss = 1.543412 | Accuracy = 35.30%


In [22]:
model.eval()
classifier_head.eval()
total_similarity = 0
with torch.no_grad():
    for step, batch in enumerate(eval_dataloader):
        # prepare batches and move all to the active device
        inputs_ids_a = batch['premise_input_ids'].to(device)
        inputs_ids_b = batch['hypothesis_input_ids'].to(device)
        attention_a = batch['premise_attention_mask'].to(device)
        attention_b = batch['hypothesis_attention_mask'].to(device)

        segment_ids = torch.zeros_like(inputs_ids_a).to(device)

        u = model.get_last_hidden_state(inputs_ids_a, segment_ids)  # batch_size, seq_len, hidden_dim
        v = model.get_last_hidden_state(inputs_ids_b, segment_ids)  # batch_size, seq_len, hidden_dim

        # get the mean pooled vectors
        u_mean_pool = mean_pool(u, attention_a).detach().cpu().numpy().reshape(-1)  # batch_size, hidden_dim
        v_mean_pool = mean_pool(v, attention_b).detach().cpu().numpy().reshape(-1)  # batch_size, hidden_dim

        # Compute cosine similarity
        similarity_score = np.dot(u_mean_pool, v_mean_pool) / (np.linalg.norm(u_mean_pool) * np.linalg.norm(v_mean_pool))
        total_similarity += similarity_score
    
average_similarity = total_similarity / len(eval_dataloader)
print(f"Average Cosine Similarity: {average_similarity:.4f}")


Average Cosine Similarity: 0.9995


In [23]:
# Save the model after training
torch.save(model.state_dict(), './model/s_bert_model.pth')
print("Model saved to s_bert_model.pth")

Model saved to s_bert_model.pth


## 7. Inference

In [24]:
import torch
from sklearn.metrics.pairwise import cosine_similarity

def calculate_similarity(model, custom_tokenizer, sentence_a, sentence_b, device):
    # Tokenize sentences using custom tokenizer
    inputs_a = custom_tokenizer(sentence_a)  
    inputs_b = custom_tokenizer(sentence_b) 

    inputs_ids_a = torch.tensor(inputs_a['input_ids']).unsqueeze(0).to(device)  # Shape: (1, seq_len)
    attention_a = torch.tensor(inputs_a['attention_mask']).unsqueeze(0).to(device)  # Shape: (1, seq_len)

    inputs_ids_b = torch.tensor(inputs_b['input_ids']).unsqueeze(0).to(device)  # Shape: (1, seq_len)
    attention_b = torch.tensor(inputs_b['attention_mask']).unsqueeze(0).to(device)  # Shape: (1, seq_len)

    # Extract the last hidden states from the model
    u = model.get_last_hidden_state(inputs_ids_a, segment_ids=torch.zeros_like(inputs_ids_a).to(device))  # batch_size, seq_len, hidden_dim
    v = model.get_last_hidden_state(inputs_ids_b, segment_ids=torch.zeros_like(inputs_ids_b).to(device))  # batch_size, seq_len, hidden_dim

    # Compute the mean-pooled vectors (using the attention mask for pooling)
    u_mean_pool = mean_pool(u, attention_a).detach().cpu().numpy().reshape(-1)  # batch_size, hidden_dim
    v_mean_pool = mean_pool(v, attention_b).detach().cpu().numpy().reshape(-1)  # batch_size, hidden_dim

    # Compute cosine similarity
    similarity_score = cosine_similarity(u_mean_pool.reshape(1, -1), v_mean_pool.reshape(1, -1))[0, 0]

    # Determine the label based on similarity score
    if similarity_score > 0.8:
        label = "Entailment"
    elif similarity_score < 0.4:
        label = "Contradiction"
    else:
        label = "Neutral"

    return similarity_score, label

# Example usage
sentence_a = 'Your contribution helped make it possible for us to provide our students with a quality education.'
sentence_b = "Your contributions were of no help with our students' education."
similarity, label = calculate_similarity(model, custom_tokenizer, sentence_a, sentence_b, device)
print(f"Cosine Similarity: {similarity:.4f}, Label: {label}")


Cosine Similarity: 0.9991, Label: Entailment
