# [Sentence-BERT](https://arxiv.org/pdf/1908.10084.pdf)

[Reference Code](https://www.pinecone.io/learn/series/nlp/train-sentence-transformers-softmax/)

In [1]:
import os
import math
import re
from   random import *
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Set GPU device
# os.environ["CUDA_VISIBLE_DEVICES"] = "2"
# # os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
# os.environ['http_proxy']  = 'http://192.41.170.23:3128'
# os.environ['https_proxy'] = 'http://192.41.170.23:3128'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

## 1. Data

### Train, Test, Validation 

In [2]:
import datasets
snli = datasets.load_dataset('snli')
mnli = datasets.load_dataset('glue', 'mnli')
mnli['train'].features, snli['train'].features

({'premise': Value(dtype='string', id=None),
  'hypothesis': Value(dtype='string', id=None),
  'label': ClassLabel(names=['entailment', 'neutral', 'contradiction'], id=None),
  'idx': Value(dtype='int32', id=None)},
 {'premise': Value(dtype='string', id=None),
  'hypothesis': Value(dtype='string', id=None),
  'label': ClassLabel(names=['entailment', 'neutral', 'contradiction'], id=None)})

In [3]:
# List of datasets to remove 'idx' column from
mnli.column_names.keys()

dict_keys(['train', 'validation_matched', 'validation_mismatched', 'test_matched', 'test_mismatched'])

In [4]:
# Remove 'idx' column from each dataset
for column_names in mnli.column_names.keys():
    mnli[column_names] = mnli[column_names].remove_columns('idx')

In [5]:
mnli.column_names.keys()

dict_keys(['train', 'validation_matched', 'validation_mismatched', 'test_matched', 'test_mismatched'])

In [6]:
import numpy as np
np.unique(mnli['train']['label']), np.unique(snli['train']['label'])
#snli also have -1

(array([0, 1, 2]), array([-1,  0,  1,  2]))

In [7]:
# there are -1 values in the label feature, these are where no class could be decided so we remove
snli = snli.filter(
    lambda x: 0 if x['label'] == -1 else 1
)

In [8]:
import numpy as np
np.unique(mnli['train']['label']), np.unique(snli['train']['label'])
#snli also have -1

(array([0, 1, 2]), array([0, 1, 2]))

In [9]:
# Assuming you have your two DatasetDict objects named snli and mnli
from datasets import DatasetDict
# Merge the two DatasetDict objects
raw_dataset = DatasetDict({
    'train': datasets.concatenate_datasets([snli['train'], mnli['train']]).shuffle(seed=55).select(list(range(1000))),
    'test': datasets.concatenate_datasets([snli['test'], mnli['test_mismatched']]).shuffle(seed=55).select(list(range(100))),
    'validation': datasets.concatenate_datasets([snli['validation'], mnli['validation_mismatched']]).shuffle(seed=55).select(list(range(1000)))
})
#remove .select(list(range(1000))) in order to use full dataset
# Now, merged_dataset_dict contains the combined datasets from snli and mnli
raw_dataset

DatasetDict({
    train: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 100
    })
    validation: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 1000
    })
})

## 2. Preprocessing

In [10]:
# pip install transformers

In [11]:
# from transformers import BertTokenizer

# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
import pickle

# Load word2id
with open("models/word2id.pkl", "rb") as word2id_file:
    word2id = pickle.load(word2id_file)

# Load id2word
with open("models/id2word.pkl", "rb") as id2word_file:
    id2word = pickle.load(id2word_file)

# Load token_list
with open("models/token_list.pkl", "rb") as token_list_file:
    token_list = pickle.load(token_list_file)


In [13]:
# word2id

In [14]:
max_len = 128
max_mask = 5

In [15]:
class Tokenizer:
    def __init__(self, word2id):
        if not isinstance(word2id, dict):
            raise ValueError("word2id must be a dictionary")
        self.word2id = word2id
        self.id2word = {v: k for k, v in self.word2id.items()}
        self.vocab_size = len(self.word2id)
        self.max_len = max_len

    def encode(self, sentences):
        output = {}
        output['input_ids'] = []
        output['attention_mask'] = []
        for sentence in sentences:
            input_ids = [self.word2id.get(word, self.word2id['[UNK]']) for word in sentence.split()]
            n_pad = self.max_len - len(input_ids)
            input_ids.extend([0] * n_pad)
            att_mask = [1 if idx != 0 else 0 for idx in input_ids]  # Create attention mask
            output['input_ids'].append(torch.tensor(input_ids))  # Convert to tensor
            output['attention_mask'].append(torch.tensor(att_mask))  # Convert to tensor
        return output

    def decode(self, ids):
        return ' '.join([self.id2word.get(idx.item(), '[UNK]') for idx in ids])

In [16]:
tokenizer = Tokenizer(word2id)

In [17]:
def preprocess_function(examples):
    padding = max_len
    # Tokenize the premise
    premise_result = tokenizer.encode(
        examples['premise'])
    #num_rows, max_seq_length
    # Tokenize the hypothesis
    hypothesis_result = tokenizer.encode(
        examples['hypothesis'])
    #num_rows, max_seq_length
    # Extract labels
    labels = examples["label"]
    #num_rows
    return {
        "premise_input_ids": premise_result["input_ids"],
        "premise_attention_mask": premise_result["attention_mask"],
        "hypothesis_input_ids": hypothesis_result["input_ids"],
        "hypothesis_attention_mask": hypothesis_result["attention_mask"],
        "labels" : labels
    }

tokenized_datasets = raw_dataset.map(
    preprocess_function,
    batched=True,
)

tokenized_datasets = tokenized_datasets.remove_columns(['premise','hypothesis','label'])
tokenized_datasets.set_format("torch")

In [18]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['premise_input_ids', 'premise_attention_mask', 'hypothesis_input_ids', 'hypothesis_attention_mask', 'labels'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['premise_input_ids', 'premise_attention_mask', 'hypothesis_input_ids', 'hypothesis_attention_mask', 'labels'],
        num_rows: 100
    })
    validation: Dataset({
        features: ['premise_input_ids', 'premise_attention_mask', 'hypothesis_input_ids', 'hypothesis_attention_mask', 'labels'],
        num_rows: 1000
    })
})

In [19]:
# tokenized_datasets['train'][0]

In [20]:
# sentences = raw_dataset['train']['premise']
# text = [x.lower() for x in sentences] #lower case
# text = [re.sub("[.,!?\\-]", '', x) for x in text] #clean all symbols
# # text

In [21]:
# for sentence in text:
#     print(sentence, "_____")
#     words = sentence.split()
#     print(words)
#     break

In [22]:
# from tqdm.auto import tqdm

# # Combine everything into one to make vocab
# word_list = list(set(" ".join(text).split()))
# word2id = {'[PAD]': 0, '[CLS]': 1, '[SEP]': 2, '[MASK]': 3}  # special tokens

# # Limit the word list to fit vocab_size
# # word_list = word_list[:23068 - 4]  # Reserving space for the 4 special tokens

# # Create the word2id in a single pass
# for i, w in tqdm(enumerate(word_list), desc="Creating word2id"):
#     word2id[w] = i + 4  # because 0-3 are already occupied

# # Precompute the id2word mapping (this can be done once after word2id is fully populated)
# id2word = {v: k for k, v in word2id.items()}
# vocab_size = len(word2id)
# vocab_size

In [23]:
# # vocab_size = len(word2id)
# vocab_size = 23068

# # List of all tokens for the whole text
# token_list = []

# # Process sentences more efficiently
# for sentence in tqdm(text, desc="Processing sentences"):
#     token_list.append([word2id[word] for word in sentence.split()])

# # Now token_list contains the tokenized sentences

In [24]:
# sentences[:2]

In [25]:
# token_list[:2]

In [26]:
# for tokens in token_list[0]:
#     print(id2word[tokens])

In [27]:
# tokenized_datasets = token_list

## 3. Data loader

In [28]:
from torch.utils.data import DataLoader

# initialize the dataloader
batch_size = 6
train_dataloader = DataLoader(
    tokenized_datasets['train'], 
    batch_size=batch_size, 
    shuffle=True
)
eval_dataloader = DataLoader(
    tokenized_datasets['validation'], 
    batch_size=batch_size
)
test_dataloader = DataLoader(
    tokenized_datasets['test'], 
    batch_size=batch_size
)

In [29]:
for batch in train_dataloader:
    print(batch['premise_input_ids'].shape)
    print(batch['premise_attention_mask'].shape)
    print(batch['hypothesis_input_ids'].shape)
    print(batch['hypothesis_attention_mask'].shape)
    print(batch['labels'].shape)
    break

# the shape is [32, 128] due to batch_size is 32 and max_seq_length is 128 from setting

torch.Size([6, 128])
torch.Size([6, 128])
torch.Size([6, 128])
torch.Size([6, 128])
torch.Size([6])


In [30]:
# batch_size = 6
# max_mask   = 5  # max masked tokens when 15% exceed, it will only be max_pred
# max_len    = 1000 # maximum of length to be padded; 

In [31]:
# def make_batch():
#     batch = []
#     positive = negative = 0  #count of batch size;  we want to have half batch that are positive pairs (i.e., next sentence pairs)
#     while positive != batch_size/2 or negative != batch_size/2:
        
#         #randomly choose two sentence so we can put [SEP]
#         tokens_a_index, tokens_b_index = randrange(len(sentences)), randrange(len(sentences))
#         #retrieve the two sentences
#         tokens_a, tokens_b = token_list[tokens_a_index], token_list[tokens_b_index]

#         #1. token embedding - append CLS and SEP
#         input_ids = [word2id['[CLS]']] + tokens_a + [word2id['[SEP]']] + tokens_b + [word2id['[SEP]']]

#         #2. segment embedding - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
#         segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)

#         #3. mask language modeling
#         #masked 15%, but should be at least 1 but does not exceed max_mask
#         n_pred =  min(max_mask, max(1, int(round(len(input_ids) * 0.15))))
#         #get the pos that excludes CLS and SEP and shuffle them
#         cand_maked_pos = [i for i, token in enumerate(input_ids) if token != word2id['[CLS]'] and token != word2id['[SEP]']]
#         shuffle(cand_maked_pos)
#         masked_tokens, masked_pos = [], []
#         #simply loop and change the input_ids to [MASK]
#         for pos in cand_maked_pos[:n_pred]:
#             masked_pos.append(pos)  #remember the position
#             masked_tokens.append(input_ids[pos]) #remember the tokens
#             #80% replace with a [MASK], but 10% will replace with a random token
#             if random() < 0.1:  # 10%
#                 index = randint(0, vocab_size - 1) # random index in vocabulary
#                 input_ids[pos] = word2id[id2word[index]] # replace
#             elif random() < 0.9:  # 80%
#                 input_ids[pos] = word2id['[MASK]'] # make mask
#             else:  #10% do nothing
#                 pass

#         # pad the input_ids and segment ids until the max len
#         n_pad = max_len - len(input_ids)
#         input_ids.extend([0] * n_pad)
#         segment_ids.extend([0] * n_pad)

#         # pad the masked_tokens and masked_pos to make sure the lenth is max_mask
#         if max_mask > n_pred:
#             n_pad = max_mask - n_pred
#             masked_tokens.extend([0] * n_pad)
#             masked_pos.extend([0] * n_pad)

#         #check if first sentence is really comes before the second sentence
#         #also make sure positive is exactly half the batch size
#         if tokens_a_index + 1 == tokens_b_index and positive < batch_size / 2:
#             batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True]) # IsNext
#             positive += 1
#         elif tokens_a_index + 1 != tokens_b_index and negative < batch_size/2:
#             batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False]) # NotNext
#             negative += 1
            
#     return batch

In [32]:
# batchs = make_batch()

In [34]:
# # #we can deconstruct using map and zip
# input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(*batchs))
# input_ids.shape, segment_ids.shape, masked_tokens.shape, masked_pos.shape, isNext.shape

In [35]:
# train_dataloader

## 4. Model

In [36]:
# # start from a pretrained bert-base-uncased model
# from transformers import BertTokenizer, BertModel
# model = BertModel.from_pretrained('bert-base-uncased')
# model.to(device)

In [37]:
# from transformers import BertTokenizer, BertModel
# model = BertModel.from_pretrained('bert-base-uncased')
# model.load_state_dict(torch.load("bert_model.pth"),strict=False)
# model.to(device)

In [38]:
class Embedding(nn.Module):
    def __init__(self, vocab_size, max_len, n_segments, d_model, device):
        super(Embedding, self).__init__()
        self.tok_embed = nn.Embedding(vocab_size, d_model)  # token embedding
        self.pos_embed = nn.Embedding(max_len, d_model)      # position embedding
        self.seg_embed = nn.Embedding(n_segments, d_model)  # segment(token type) embedding
        self.norm = nn.LayerNorm(d_model)
        self.device = device

    def forward(self, x, seg):
        #x, seg: (bs, len)
        seq_len = x.size(1)
        pos = torch.arange(seq_len, dtype=torch.long).to(self.device)
        pos = pos.unsqueeze(0).expand_as(x)  # (len,) -> (bs, len)
        embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)
        return self.norm(embedding)

In [39]:
def get_attn_pad_mask(seq_q, seq_k, device):
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    # eq(zero) is PAD token
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1).to(device)  # batch_size x 1 x len_k(=len_q), one is masking
    return pad_attn_mask.expand(batch_size, len_q, len_k)  # batch_size x len_q x len_k

In [40]:
# print(get_attn_pad_mask(input_ids, input_ids, device).shape)

In [41]:
class EncoderLayer(nn.Module):
    def __init__(self, n_heads, d_model, d_ff, d_k, device):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention(n_heads, d_model, d_k, device)
        self.pos_ffn       = PoswiseFeedForwardNet(d_model, d_ff)

    def forward(self, enc_inputs, enc_self_attn_mask):
        enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V
        enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size x len_q x d_model]
        return enc_outputs, attn

In [42]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k, device):
        super(ScaledDotProductAttention, self).__init__()
        self.scale = torch.sqrt(torch.FloatTensor([d_k])).to(device)

    def forward(self, Q, K, V, attn_mask):
        scores = torch.matmul(Q, K.transpose(-1, -2)) / self.scale # scores : [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
        scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is one.
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V)
        return context, attn 

In [43]:
n_layers = 6    # number of Encoder of Encoder Layer
n_heads  = 8    # number of heads in Multi-Head Attention
d_model  = 768  # Embedding Size
d_ff = 768 * 4  # 4*d_model, FeedForward dimension
d_k = d_v = 64  # dimension of K(=Q), V
n_segments = 2

In [44]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, d_model, d_k, device):
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads
        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_k
        self.W_Q = nn.Linear(d_model, d_k * n_heads)
        self.W_K = nn.Linear(d_model, d_k * n_heads)
        self.W_V = nn.Linear(d_model, self.d_v * n_heads)
        self.device = device
    def forward(self, Q, K, V, attn_mask):
        # q: [batch_size x len_q x d_model], k: [batch_size x len_k x d_model], v: [batch_size x len_k x d_model]
        residual, batch_size = Q, Q.size(0)
        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        q_s = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)  # q_s: [batch_size x n_heads x len_q x d_k]
        k_s = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)  # k_s: [batch_size x n_heads x len_k x d_k]
        v_s = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1,2)  # v_s: [batch_size x n_heads x len_k x d_v]

        attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1) # attn_mask : [batch_size x n_heads x len_q x len_k]

        # context: [batch_size x n_heads x len_q x d_v], attn: [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
        context, attn = ScaledDotProductAttention(self.d_k, self.device)(q_s, k_s, v_s, attn_mask)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v) # context: [batch_size x len_q x n_heads * d_v]
        output = nn.Linear(self.n_heads * self.d_v, self.d_model, device=self.device)(context)
        return nn.LayerNorm(self.d_model, device=self.device)(output + residual), attn # output: [batch_size x len_q x d_model]

In [45]:
class PoswiseFeedForwardNet(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        # (batch_size, len_seq, d_model) -> (batch_size, len_seq, d_ff) -> (batch_size, len_seq, d_model)
        return self.fc2(F.gelu(self.fc1(x)))

In [46]:
class BERT(nn.Module):
    def __init__(self, n_layers, n_heads, d_model, d_ff, d_k, n_segments, vocab_size, max_len, device):
        super(BERT, self).__init__()
        self.params = {'n_layers': n_layers, 'n_heads': n_heads, 'd_model': d_model,
                       'd_ff': d_ff, 'd_k': d_k, 'n_segments': n_segments,
                       'vocab_size': vocab_size, 'max_len': max_len}
        self.embedding = Embedding(vocab_size, max_len, n_segments, d_model, device)
        self.layers = nn.ModuleList([EncoderLayer(n_heads, d_model, d_ff, d_k, device) for _ in range(n_layers)])
        self.fc = nn.Linear(d_model, d_model)
        self.activ = nn.Tanh()
        self.linear = nn.Linear(d_model, d_model)
        self.norm = nn.LayerNorm(d_model)
        self.classifier = nn.Linear(d_model, 2)
        # decoder is shared with embedding layer
        embed_weight = self.embedding.tok_embed.weight
        n_vocab, n_dim = embed_weight.size()
        self.decoder = nn.Linear(n_dim, n_vocab, bias=False)
        self.decoder.weight = embed_weight
        self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))
        self.device = device

    def forward(self, input_ids, segment_ids, masked_pos):
        output = self.embedding(input_ids, segment_ids)
        enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids, self.device)
        for layer in self.layers:
            output, enc_self_attn = layer(output, enc_self_attn_mask)
        # output : [batch_size, len, d_model], attn : [batch_size, n_heads, d_mode, d_model]
        
        # 1. predict next sentence
        # it will be decided by first token(CLS)
        h_pooled   = self.activ(self.fc(output[:, 0])) # [batch_size, d_model]
        logits_nsp = self.classifier(h_pooled) # [batch_size, 2]

        # 2. predict the masked token
        masked_pos = masked_pos[:, :, None].expand(-1, -1, output.size(-1)) # [batch_size, max_pred, d_model]
        h_masked = torch.gather(output, 1, masked_pos) # masking position [batch_size, max_pred, d_model]
        h_masked  = self.norm(F.gelu(self.linear(h_masked)))
        logits_lm = self.decoder(h_masked) + self.decoder_bias # [batch_size, max_pred, n_vocab]

        return logits_lm, logits_nsp, output
    
    def get_last_hidden_state(self, input_ids, segment_ids):
        output = self.embedding(input_ids, segment_ids)
        enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids, self.device)
        for layer in self.layers:
            output, enc_self_attn = layer(output, enc_self_attn_mask)

        return output

Training

In [47]:
from tqdm.auto import tqdm

n_layers = 6    # number of Encoder of Encoder Layer
n_heads  = 12    # number of heads in Multi-Head Attention
d_model  = 768  # Embedding Size 
d_ff = d_model * 4  # 4*d_model, FeedForward dimension
d_k = d_v = 64  # dimension of K(=Q), V
n_segments = 2
vocab_size = 6965
max_len = 128

num_epoch = 100
model = BERT(
    n_layers, 
    n_heads, 
    d_model, 
    d_ff, 
    d_k, 
    n_segments, 
    vocab_size, 
    max_len, 
    device
).to(device)  # Move model to GPU

In [48]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [49]:
# checkpoint = torch.load("bert_model.pth")
# for name, param in checkpoint.items():
#     print(f"{name}: {param.shape}")

# # Print the shapes of parameters in the current model
# for name, param in model.named_parameters():
#     print(f"{name}: {param.shape}")

In [None]:
# Load the model's state dict (weights)
pretrained_weights = torch.load('models/bert_model.pth')

# Load pre-trained weights into your custom model
model_dict = model.state_dict()

# Filter out the weights that do not match (if any)
pretrained_weights = {k: v for k, v in pretrained_weights.items() if k in model_dict and v.size() == model_dict[k].size()}

# Update the model's state_dict
model_dict.update(pretrained_weights)
model.load_state_dict(model_dict)

# Move the model to the appropriate device (CPU or GPU)
model.to(device)

BERT(
  (embedding): Embedding(
    (tok_embed): Embedding(6965, 768)
    (pos_embed): Embedding(128, 768)
    (seg_embed): Embedding(2, 768)
    (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (layers): ModuleList(
    (0-5): 6 x EncoderLayer(
      (enc_self_attn): MultiHeadAttention(
        (W_Q): Linear(in_features=768, out_features=768, bias=True)
        (W_K): Linear(in_features=768, out_features=768, bias=True)
        (W_V): Linear(in_features=768, out_features=768, bias=True)
      )
      (pos_ffn): PoswiseFeedForwardNet(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
      )
    )
  )
  (fc): Linear(in_features=768, out_features=768, bias=True)
  (activ): Tanh()
  (linear): Linear(in_features=768, out_features=768, bias=True)
  (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (classifier): Linear(in_features=768, out_features=2, bias=True)
  (decode

### Pooling
SBERT adds a pooling operation to the output of BERT / RoBERTa to derive a fixed sized sentence embedding

In [53]:
# define mean pooling function
def mean_pool(token_embeds, attention_mask):
    # reshape attention_mask to cover 768-dimension embeddings
    in_mask = attention_mask.unsqueeze(-1).expand(
        token_embeds.size()
    ).float()
    # perform mean-pooling but exclude padding tokens (specified by in_mask)
    pool = torch.sum(token_embeds * in_mask, 1) / torch.clamp(
        in_mask.sum(1), min=1e-9
    )
    return pool

## 5. Loss Function

## Classification Objective Function 
We concatenate the sentence embeddings $u$ and $v$ with the element-wise difference  $\lvert u - v \rvert $ and multiply the result with the trainable weight  $ W_t ∈  \mathbb{R}^{3n \times k}  $:

$ o = \text{softmax}\left(W^T \cdot \left(u, v, \lvert u - v \rvert\right)\right) $

where $n$ is the dimension of the sentence embeddings and k the number of labels. We optimize cross-entropy loss. This structure is depicted in Figure 1.

## Regression Objective Function. 
The cosine similarity between the two sentence embeddings $u$ and $v$ is computed (Figure 2). We use means quared-error loss as the objective function.

(Manhatten / Euclidean distance, semantically  similar sentences can be found.)

<img src="./figures/sbert-architecture.png" >

In [55]:
def configurations(u,v):
    # build the |u-v| tensor
    uv = torch.sub(u, v)   # batch_size,hidden_dim
    uv_abs = torch.abs(uv) # batch_size,hidden_dim
    
    # concatenate u, v, |u-v|
    x = torch.cat([u, v, uv_abs], dim=-1) # batch_size, 3*hidden_dim
    return x

def cosine_similarity(u, v):
    dot_product = np.dot(u, v)
    norm_u = np.linalg.norm(u)
    norm_v = np.linalg.norm(v)
    similarity = dot_product / (norm_u * norm_v)
    return similarity

<img src="./figures/sbert-ablation.png" width="350" height="300">

In [56]:
classifier_head = torch.nn.Linear(768*3, 3).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
optimizer_classifier = torch.optim.Adam(classifier_head.parameters(), lr=2e-5)

criterion = nn.CrossEntropyLoss()

In [57]:
from transformers import get_linear_schedule_with_warmup

# and setup a warmup for the first ~10% steps
total_steps = int(len(raw_dataset) / batch_size)
warmup_steps = int(0.1 * total_steps)
scheduler = get_linear_schedule_with_warmup(
		optimizer, num_warmup_steps=warmup_steps,
  	num_training_steps=total_steps - warmup_steps
)

# then during the training loop we update the scheduler per step
scheduler.step()

scheduler_classifier = get_linear_schedule_with_warmup(
		optimizer_classifier, num_warmup_steps=warmup_steps,
  	num_training_steps=total_steps - warmup_steps
)

# then during the training loop we update the scheduler per step
scheduler_classifier.step()



## 6. Training

In [None]:
from tqdm.auto import tqdm

num_epoch = 5
max_mask   = 5
# need segment and masked for model input but not used in SBERT

accuracy = 0
count = 0
# 1 epoch should be enough, increase if wanted
for epoch in range(num_epoch):
    model.train()
    classifier_head.train()
    # initialize the dataloader loop with tqdm (tqdm == progress bar)
    for step, batch in enumerate(tqdm(train_dataloader, leave=True)):

        # zero all gradients on each new step
        optimizer.zero_grad()
        optimizer_classifier.zero_grad()

        # prepare batches and more all to the active device
        inputs_ids_a = batch['premise_input_ids'].to(device)
        inputs_ids_b = batch['hypothesis_input_ids'].to(device)
        attention_a = batch['premise_attention_mask'].to(device)
        attention_b = batch['hypothesis_attention_mask'].to(device)
        label = batch['labels'].to(device)

        segment_ids = torch.tensor([0] * max_len).unsqueeze(0).repeat(inputs_ids_a.shape[0], 1).to(device)
        masked_pos  = torch.tensor([0] * max_mask).unsqueeze(0).repeat(inputs_ids_a.shape[0], 1).to(device)

        # extract token embeddings from BERT at last_hidden_state
        _,_,u = model(inputs_ids_a, segment_ids, masked_pos)
        _,_,v = model(inputs_ids_b, segment_ids, masked_pos)


        u_last_hidden_state = u # all token embeddings A = batch_size, seq_len, hidden_dim
        v_last_hidden_state = v # all token embeddings B = batch_size, seq_len, hidden_dim

        # print("u_last_hidden_state shape:", u_last_hidden_state.shape)
        # print("v_last_hidden_state shape:", v_last_hidden_state.shape)

        # print("attention_a shape:", attention_a.shape)
        # print("attention_b shape:", attention_b.shape)

        # u_last_hidden_state = u.get_last_hidden_state(inputs_ids_a, segment_ids) # all token embeddings A = batch_size, seq_len, hidden_dim
        # v_last_hidden_state = v.get_last_hidden_state(inputs_ids_a, segment_ids) # all token embeddings B = batch_size, seq_len, hidden_dim
        
        # Extract token embeddings directly from the model output (without using get_last_hidden_state).
        # u_last_hidden_state, u_logits_nsp = model(inputs_ids_a, segment_ids, masked_pos)
        # v_last_hidden_state, v_logits_nsp = model(inputs_ids_b, segment_ids, masked_pos)
        
         # get the mean pooled vectors
        u_mean_pool = mean_pool(u_last_hidden_state, attention_a) # batch_size, hidden_dim
        v_mean_pool = mean_pool(v_last_hidden_state, attention_b) # batch_size, hidden_dim

        # build the |u-v| tensor
        uv = torch.sub(u_mean_pool, v_mean_pool)   # batch_size,hidden_dim
        uv_abs = torch.abs(uv) # batch_size,hidden_dim

        # concatenate u, v, |u-v|
        x = torch.cat([u_mean_pool, v_mean_pool, uv_abs], dim=-1) # batch_size, 3*hidden_dim

        # process concatenated tensor through classifier_head
        x = classifier_head(x) #batch_size, classifer
        for out, lab in zip(x, label):
            count = count + 1
            if torch.argmax(out).item() == lab.item():
                accuracy = accuracy + 1
        # calculate the 'softmax-loss' between predicted and true label
        loss = criterion(x, label)

        # using loss, calculate gradients and then optimizerize
        loss.backward()
        optimizer.step()
        optimizer_classifier.step()

        scheduler.step() # update learning rate scheduler
        scheduler_classifier.step()

    print(f'Epoch: {epoch + 1} | loss = {loss.item():.6f} | Accuracy = {(accuracy / count) * 100}%')


  0%|          | 0/167 [00:00<?, ?it/s]

Epoch: 1 | loss = 1.457031 | Accuracy = 35.3%


  0%|          | 0/167 [00:00<?, ?it/s]

Epoch: 2 | loss = 0.915018 | Accuracy = 35.3%


  0%|          | 0/167 [00:00<?, ?it/s]

Epoch: 3 | loss = 1.451059 | Accuracy = 35.3%


  0%|          | 0/167 [00:00<?, ?it/s]

Epoch: 4 | loss = 2.268982 | Accuracy = 35.3%


  0%|          | 0/167 [00:00<?, ?it/s]

Epoch: 5 | loss = 1.503545 | Accuracy = 35.3%


In [None]:
# Save the model after training
torch.save(model.state_dict(), 's-bert_model.pt')
print("Model saved to s-bert_model.pt")

Model saved to s-bert_model.pt


In [87]:
# model.eval()
# classifier_head.eval()
# total_similarity = 0
# with torch.no_grad():
#     for step, batch in enumerate(eval_dataloader):
#         # prepare batches and more all to the active device
#         inputs_ids_a = batch['premise_input_ids'].to(device)
#         inputs_ids_b = batch['hypothesis_input_ids'].to(device)
#         attention_a = batch['premise_attention_mask'].to(device)
#         attention_b = batch['hypothesis_attention_mask'].to(device)
#         label = batch['labels'].to(device)

#         segment_ids_a = torch.zeros(inputs_ids_a.size(0), inputs_ids_a.size(1), dtype=torch.long).to(device)  # All 0s for sentence A
#         segment_ids_b = torch.ones(inputs_ids_b.size(0), inputs_ids_b.size(1), dtype=torch.long).to(device)  # All 1s for sentence B

#         # extract token embeddings from BERT at last_hidden_state
#         _,_,u = model(inputs_ids_a, segment_ids_a, masked_pos)
#         _,_,v = model(inputs_ids_b, segment_ids_b, masked_pos)

#         # get the mean pooled vectors
#         u_mean_pool = mean_pool(u, attention_a).detach().cpu().numpy().reshape(u.size(0), -1)  # Reshape to (batch_size, hidden_dim)
#         v_mean_pool = mean_pool(v, attention_b).detach().cpu().numpy().reshape(v.size(0), -1)  # Reshape to (batch_size, hidden_dim)


#         similarity_score = cosine_similarity(u_mean_pool, v_mean_pool)
#         total_similarity += similarity_score
    
# average_similarity = total_similarity / len(eval_dataloader)
# print(f"Average Cosine Similarity: {average_similarity:.4f}")

In [86]:
model.eval()
classifier_head.eval()
total_similarity = 0
total_count = 0  # Track the number of comparisons

with torch.no_grad():
    for step, batch in enumerate(eval_dataloader):
        # prepare batches and move all to the active device
        inputs_ids_a = batch['premise_input_ids'].to(device)
        inputs_ids_b = batch['hypothesis_input_ids'].to(device)
        attention_a = batch['premise_attention_mask'].to(device)
        attention_b = batch['hypothesis_attention_mask'].to(device)
        label = batch['labels'].to(device)

        # Create segment ids for sentence A and sentence B
        segment_ids_a = torch.zeros(inputs_ids_a.size(0), inputs_ids_a.size(1), dtype=torch.long).to(device)  # All 0s for sentence A
        segment_ids_b = torch.ones(inputs_ids_b.size(0), inputs_ids_b.size(1), dtype=torch.long).to(device)  # All 1s for sentence B

        # Extract token embeddings from BERT at last_hidden_state
        _, _, u = model(inputs_ids_a, segment_ids_a, masked_pos)
        _, _, v = model(inputs_ids_b, segment_ids_b, masked_pos)

        # Get the mean pooled vectors and reshape them to 2D arrays
        u_mean_pool = mean_pool(u, attention_a).detach().cpu().numpy().reshape(u.size(0), -1)  # (batch_size, hidden_dim)
        v_mean_pool = mean_pool(v, attention_b).detach().cpu().numpy().reshape(v.size(0), -1)  # (batch_size, hidden_dim)

        # Check if the batch sizes match
        if u_mean_pool.shape[0] != v_mean_pool.shape[0]:
            print(f"Batch size mismatch: {u_mean_pool.shape[0]} vs {v_mean_pool.shape[0]}")
            continue  # Skip this batch if the batch sizes do not match

        # Compute the cosine similarity for each pair in the batch
        similarity_score = cosine_similarity(u_mean_pool, v_mean_pool).mean()  # Mean similarity for the batch
        total_similarity += similarity_score
        total_count += 1

# Compute the average similarity across all batches
average_similarity = total_similarity / total_count if total_count > 0 else 0
print(f"Average Cosine Similarity: {average_similarity:.4f}")


Average Cosine Similarity: 0.9933


## 7. Inference

In [80]:
import torch
from sklearn.metrics.pairwise import cosine_similarity

def calculate_similarity(model, tokenizer, sentence_a, sentence_b, device):
    # Tokenize and convert sentences to input IDs and attention masks
    inputs_a = tokenizer.encode([sentence_a])
    inputs_b = tokenizer.encode([sentence_b])

    # Move input IDs and attention masks to the active device
    inputs_ids_a = inputs_a['input_ids'][0].unsqueeze(0).to(device)
    attention_a = inputs_a['attention_mask'][0].unsqueeze(0).to(device)
    inputs_ids_b = inputs_b['input_ids'][0].unsqueeze(0).to(device)
    attention_b = inputs_b['attention_mask'][0].unsqueeze(0).to(device)

    # Extract token embeddings from BERT
    u = model(inputs_ids_a, segment_ids, masked_pos)[2]
    v = model(inputs_ids_b, segment_ids, masked_pos)[2]

    # Get the mean-pooled vectors
    u = mean_pool(u, attention_a).detach().cpu().numpy().reshape(-1)  # batch_size, hidden_dim
    v = mean_pool(v, attention_b).detach().cpu().numpy().reshape(-1)  # batch_size, hidden_dim

    # Calculate cosine similarity
    similarity_score = cosine_similarity(u.reshape(1, -1), v.reshape(1, -1))[0, 0]

    return similarity_score

# Example usage:
sentence_a = 'Your contribution helped make it possible for us to provide our students with a quality education.'
sentence_b = "Your contributions were of no help with our students' education."
similarity = calculate_similarity(model, tokenizer, sentence_a, sentence_b, device)
print(f"Cosine Similarity: {similarity:.4f}")

Cosine Similarity: 0.9914
