# [Sentence-BERT](https://arxiv.org/pdf/1908.10084.pdf)

[Reference Code](https://www.pinecone.io/learn/series/nlp/train-sentence-transformers-softmax/)

In [1]:
import os
import math
import re
from   random import *
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Set GPU device
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

os.environ['http_proxy']  = 'http://192.41.170.23:3128'
os.environ['https_proxy'] = 'http://192.41.170.23:3128'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

## 1. Data

### Train, Test, Validation 

In [2]:
import datasets

# Load SNLI and MNLI datasets
snli = datasets.load_dataset('snli')
mnli = datasets.load_dataset('glue', 'mnli')

# View features of the train split for both datasets
mnli['train'].features, snli['train'].features

({'premise': Value(dtype='string', id=None),
  'hypothesis': Value(dtype='string', id=None),
  'label': ClassLabel(names=['entailment', 'neutral', 'contradiction'], id=None),
  'idx': Value(dtype='int32', id=None)},
 {'premise': Value(dtype='string', id=None),
  'hypothesis': Value(dtype='string', id=None),
  'label': ClassLabel(names=['entailment', 'neutral', 'contradiction'], id=None)})

In [3]:
# List of datasets to remove 'idx' column from
mnli.column_names.keys()

dict_keys(['train', 'validation_matched', 'validation_mismatched', 'test_matched', 'test_mismatched'])

In [4]:
# Remove 'idx' column from each dataset
for column_names in mnli.column_names.keys():
    mnli[column_names] = mnli[column_names].remove_columns('idx')

In [5]:
mnli.column_names.keys()

dict_keys(['train', 'validation_matched', 'validation_mismatched', 'test_matched', 'test_mismatched'])

In [6]:
import numpy as np
np.unique(mnli['train']['label']), np.unique(snli['train']['label'])
#snli also have -1

(array([0, 1, 2]), array([-1,  0,  1,  2]))

In [7]:
# there are -1 values in the label feature, these are where no class could be decided so we remove
snli = snli.filter(
    lambda x: 0 if x['label'] == -1 else 1
)

In [8]:
import numpy as np
np.unique(mnli['train']['label']), np.unique(snli['train']['label'])
#snli also have -1

(array([0, 1, 2]), array([0, 1, 2]))

In [9]:
# Assuming you have your two DatasetDict objects named snli and mnli
from datasets import DatasetDict
# Merge the two DatasetDict objects
raw_dataset = DatasetDict({
    'train': datasets.concatenate_datasets([snli['train'], mnli['train']]).shuffle(seed=55).select(list(range(1000))),
    'test': datasets.concatenate_datasets([snli['test'], mnli['test_mismatched']]).shuffle(seed=55).select(list(range(100))),
    'validation': datasets.concatenate_datasets([snli['validation'], mnli['validation_mismatched']]).shuffle(seed=55).select(list(range(1000)))
})
# Remove .select(list(range(1000))) in order to use full dataset
# Now, merged_dataset_dict contains the combined datasets from snli and mnli
raw_dataset

DatasetDict({
    train: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 100
    })
    validation: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 1000
    })
})

## 2. Preprocessing

In [10]:
from transformers import BertTokenizer

# Load pre-trained BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [11]:
def preprocess_function(examples):
    max_seq_length = 128
    padding = 'max_length'
    # Tokenize the premise
    premise_result = tokenizer(
        examples['premise'], padding=padding, max_length=max_seq_length, truncation=True)
    #num_rows, max_seq_length
    # Tokenize the hypothesis
    hypothesis_result = tokenizer(
        examples['hypothesis'], padding=padding, max_length=max_seq_length, truncation=True)
    #num_rows, max_seq_length
    # Extract labels
    labels = examples["label"]
    #num_rows
    return {
        "premise_input_ids": premise_result["input_ids"],
        "premise_attention_mask": premise_result["attention_mask"],
        "hypothesis_input_ids": hypothesis_result["input_ids"],
        "hypothesis_attention_mask": hypothesis_result["attention_mask"],
        "labels" : labels
    }

tokenized_datasets = raw_dataset.map(
    preprocess_function,
    batched=True,
)

tokenized_datasets = tokenized_datasets.remove_columns(['premise','hypothesis','label'])
tokenized_datasets.set_format("torch")

In [12]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['premise_input_ids', 'premise_attention_mask', 'hypothesis_input_ids', 'hypothesis_attention_mask', 'labels'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['premise_input_ids', 'premise_attention_mask', 'hypothesis_input_ids', 'hypothesis_attention_mask', 'labels'],
        num_rows: 100
    })
    validation: Dataset({
        features: ['premise_input_ids', 'premise_attention_mask', 'hypothesis_input_ids', 'hypothesis_attention_mask', 'labels'],
        num_rows: 1000
    })
})

## 3. Data loader

In [13]:
from torch.utils.data import DataLoader

# initialize the dataloader
batch_size = 16
train_dataloader = DataLoader(
    tokenized_datasets['train'], 
    batch_size=batch_size, 
    shuffle=True
)
eval_dataloader = DataLoader(
    tokenized_datasets['validation'], 
    batch_size=batch_size
)
test_dataloader = DataLoader(
    tokenized_datasets['test'], 
    batch_size=batch_size
)

In [14]:
for batch in train_dataloader:
    print(batch['premise_input_ids'].shape)
    print(batch['premise_attention_mask'].shape)
    print(batch['hypothesis_input_ids'].shape)
    print(batch['hypothesis_attention_mask'].shape)
    print(batch['labels'].shape)
    break

torch.Size([16, 128])
torch.Size([16, 128])
torch.Size([16, 128])
torch.Size([16, 128])
torch.Size([16])


## 4. Model

In [16]:
class Embedding(nn.Module):
    def __init__(self, vocab_size, max_len, n_segments, d_model, device):
        super(Embedding, self).__init__()
        self.tok_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(max_len, d_model)
        self.seg_embed = nn.Embedding(n_segments, d_model)
        self.norm = nn.LayerNorm(d_model)
        self.device = device

    def forward(self, x, seg):
        
        # Check for invalid indices in x (input_ids)
        if (x < 0).any() or (x >= self.tok_embed.num_embeddings).any():
            raise ValueError("input_ids contains invalid indices.")
        
        # Check for invalid indices in seg (segment_ids)
        if (seg < 0).any() or (seg >= self.seg_embed.num_embeddings).any():
            raise ValueError("segment_ids contains invalid indices.")

        seq_len = x.size(1)
        pos = torch.arange(seq_len, dtype=torch.long).to(self.device)
        pos = pos.unsqueeze(0).expand_as(x)
        embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)
        return self.norm(embedding)

In [17]:
def get_attn_pad_mask(seq_q, seq_k, device, attention_mask=None):
    if attention_mask is not None:
        # Use the provided attention_mask and convert to boolean
        pad_attn_mask = attention_mask.unsqueeze(1).expand(-1, seq_q.size(1), -1).to(device)
        pad_attn_mask = ~pad_attn_mask.bool()  # Invert to match PyTorch's convention
    else:
        # Fallback to the original behavior: mask where seq_k is 0 (padding)
        pad_attn_mask = seq_k.data.eq(0).unsqueeze(1).to(device)
        pad_attn_mask = pad_attn_mask.expand(seq_q.size(0), seq_q.size(1), seq_k.size(1))
    return pad_attn_mask

In [18]:
class EncoderLayer(nn.Module):
    def __init__(self, n_heads, d_model, d_ff, d_k, device):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention(n_heads, d_model, d_k, device)
        self.pos_ffn = PoswiseFeedForwardNet(d_model, d_ff)

    def forward(self, enc_inputs, enc_self_attn_mask):
        enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask)
        enc_outputs = self.pos_ffn(enc_outputs)
        return enc_outputs, attn

In [19]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k, device):
        super(ScaledDotProductAttention, self).__init__()
        self.scale = torch.sqrt(torch.FloatTensor([d_k])).to(device)

    def forward(self, Q, K, V, attn_mask):
        scores = torch.matmul(Q, K.transpose(-1, -2)) / self.scale
        scores.masked_fill_(attn_mask, -1e9)  # attn_mask must be boolean
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V)
        return context, attn

In [20]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, d_model, d_k, device):
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads
        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_k
        self.W_Q = nn.Linear(d_model, d_k * n_heads)
        self.W_K = nn.Linear(d_model, d_k * n_heads)
        self.W_V = nn.Linear(d_model, self.d_v * n_heads)
        self.device = device

    def forward(self, Q, K, V, attn_mask):
        residual, batch_size = Q, Q.size(0)
        q_s = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)
        k_s = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)
        v_s = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1,2)
        attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1)
        context, attn = ScaledDotProductAttention(self.d_k, self.device)(q_s, k_s, v_s, attn_mask)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v)
        output = nn.Linear(self.n_heads * self.d_v, self.d_model, device=self.device)(context)
        return nn.LayerNorm(self.d_model, device=self.device)(output + residual), attn

In [21]:
class PoswiseFeedForwardNet(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.fc2(F.gelu(self.fc1(x)))

In [22]:
class BERT(nn.Module):
    def __init__(self, n_layers, n_heads, d_model, d_ff, d_k, n_segments, vocab_size, max_len, device):
        super(BERT, self).__init__()
        self.params = {
            'n_layers': n_layers, 'n_heads': n_heads, 'd_model': d_model,
            'd_ff': d_ff, 'd_k': d_k, 'n_segments': n_segments,
            'vocab_size': vocab_size, 'max_len': max_len
        }
        self.embedding = Embedding(vocab_size, max_len, n_segments, d_model, device)
        self.layers = nn.ModuleList([EncoderLayer(n_heads, d_model, d_ff, d_k, device) for _ in range(n_layers)])
        self.fc = nn.Linear(d_model, d_model)
        self.activ = nn.Tanh()
        self.linear = nn.Linear(d_model, d_model)
        self.norm = nn.LayerNorm(d_model)
        self.classifier = nn.Linear(d_model, 2)
        embed_weight = self.embedding.tok_embed.weight
        n_vocab, n_dim = embed_weight.size()
        self.decoder = nn.Linear(n_dim, n_vocab, bias=False)
        self.decoder.weight = embed_weight
        self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))
        self.device = device

    def forward(self, input_ids, segment_ids=None, masked_pos=None, attention_mask=None):
        # If segment_ids is not provided, create a default one (all zeros)
        if segment_ids is None:
            segment_ids = torch.zeros_like(input_ids).to(self.device)
        
        # If attention_mask is not provided, create a default one
        if attention_mask is None:
            attention_mask = (input_ids != 0).float().to(self.device)
        
        # Get embeddings
        output = self.embedding(input_ids, segment_ids)
        
        # Create attention mask using get_attn_pad_mask
        enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids, self.device, attention_mask)
        
        # Pass through encoder layers
        for layer in self.layers:
            output, enc_self_attn = layer(output, enc_self_attn_mask)
        
        # Next Sentence Prediction (NSP)
        h_pooled = self.activ(self.fc(output[:, 0]))  # Use the [CLS] token
        logits_nsp = self.classifier(h_pooled)
        
        # Masked Language Modeling (MLM)
        if masked_pos is not None:
            masked_pos = masked_pos[:, :, None].expand(-1, -1, output.size(-1))
            h_masked = torch.gather(output, 1, masked_pos)
            h_masked = self.norm(F.gelu(self.linear(h_masked)))
            logits_lm = self.decoder(h_masked) + self.decoder_bias
        else:
            logits_lm = None
        
        # Return last hidden state along with logits
        return output, logits_lm, logits_nsp

    def get_last_hidden_state(self, input_ids, segment_ids=None, attention_mask=None):
        # If attention_mask is not provided, create a default one
        if attention_mask is None:
            attention_mask = (input_ids != 0).float().to(self.device)
        
        # Get embeddings
        output = self.embedding(input_ids, segment_ids)
        
        # Create attention mask using get_attn_pad_mask
        enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids, self.device, attention_mask)
        
        # Pass through encoder layers
        for layer in self.layers:
            output, enc_self_attn = layer(output, enc_self_attn_mask)
        
        return output

In [None]:
# Initialize and load the BERT model from Task 1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Inspect the saved model's vocabulary size
state_dict = torch.load('bert_model.pth', map_location=device)
vocab_size_saved = state_dict['embedding.tok_embed.weight'].shape[0]
print("Vocabulary size of saved model:", vocab_size_saved)

# Use the saved model's vocabulary size
vocab_size = vocab_size_saved  # Match the saved model's vocabulary size
model = BERT(
    n_layers=12,
    n_heads=12,
    d_model=768,
    d_ff=3072,
    d_k=64,
    n_segments=2,
    vocab_size=vocab_size,
    max_len=1000,
    device=device
).to(device)

# Print embedding vocabulary size
print("Embedding vocab size:", model.embedding.tok_embed.num_embeddings)

# Load the model weights
model.load_state_dict(state_dict)

# Debug: Check tokenizer vocabulary size
print("Tokenizer vocab size:", tokenizer.vocab_size)

Vocabulary size of saved model: 23068
Embedding vocab size: 23068
Tokenizer vocab size: 30522


### Pooling
SBERT adds a pooling operation to the output of BERT / RoBERTa to derive a fixed sized sentence embedding

In [24]:
# define mean pooling function
def mean_pool(token_embeds, attention_mask):
    # reshape attention_mask to cover 768-dimension embeddings
    in_mask = attention_mask.unsqueeze(-1).expand(
        token_embeds.size()
    ).float()
    # perform mean-pooling but exclude padding tokens (specified by in_mask)
    pool = torch.sum(token_embeds * in_mask, 1) / torch.clamp(
        in_mask.sum(1), min=1e-9
    )
    return pool

## 5. Loss Function

## Classification Objective Function 
We concatenate the sentence embeddings $u$ and $v$ with the element-wise difference  $\lvert u - v \rvert $ and multiply the result with the trainable weight  $ W_t ∈  \mathbb{R}^{3n \times k}  $:

$ o = \text{softmax}\left(W^T \cdot \left(u, v, \lvert u - v \rvert\right)\right) $

where $n$ is the dimension of the sentence embeddings and k the number of labels. We optimize cross-entropy loss. This structure is depicted in Figure 1.

## Regression Objective Function. 
The cosine similarity between the two sentence embeddings $u$ and $v$ is computed (Figure 2). We use means quared-error loss as the objective function.

(Manhatten / Euclidean distance, semantically  similar sentences can be found.)

In [None]:
def configurations(u,v):
    # build the |u-v| tensor
    uv = torch.sub(u, v)   # batch_size,hidden_dim
    uv_abs = torch.abs(uv) # batch_size,hidden_dim
    
    # concatenate u, v, |u-v|
    x = torch.cat([u, v, uv_abs], dim=-1) # batch_size, 3*hidden_dim
    return x

def cosine_similarity(u, v):
    dot_product = torch.sum(u * v, dim=-1)
    norm_u = torch.norm(u, dim=-1)
    norm_v = torch.norm(v, dim=-1)
    similarity = dot_product / (norm_u * norm_v + 1e-9)
    return similarity

In [None]:
classifier_head = torch.nn.Sequential(
    torch.nn.Linear(768*3, 512),
    torch.nn.ReLU(),
    torch.nn.Linear(512, 3)
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
optimizer_classifier = torch.optim.Adam(classifier_head.parameters(), lr=5e-5)

criterion = nn.CrossEntropyLoss()

In [None]:
from transformers import get_linear_schedule_with_warmup

# Setup a warmup for the first ~10% steps
total_steps = int(len(raw_dataset) / batch_size)
warmup_steps = int(0.1 * total_steps)

# Scheduler for the main model optimizer with warmup
scheduler = get_linear_schedule_with_warmup(
		optimizer, num_warmup_steps=warmup_steps,
  	num_training_steps=total_steps - warmup_steps
)

# Scheduler for the classifier head optimizer with warmup
scheduler_classifier = get_linear_schedule_with_warmup(
		optimizer_classifier, num_warmup_steps=warmup_steps,
  	num_training_steps=total_steps - warmup_steps
)

## 6. Training

In [40]:
from tqdm.auto import tqdm

num_epoch = 5
for epoch in range(num_epoch):
    model.train()  
    classifier_head.train()
    for step, batch in enumerate(tqdm(train_dataloader, leave=True)):
        optimizer.zero_grad()
        optimizer_classifier.zero_grad()
        
        inputs_ids_a = batch['premise_input_ids'].to(device)
        inputs_ids_b = batch['hypothesis_input_ids'].to(device)
        attention_a = batch['premise_attention_mask'].to(device)
        attention_b = batch['hypothesis_attention_mask'].to(device)
        label = batch['labels'].to(device)
        
        inputs_ids_a = torch.clamp(inputs_ids_a, max=vocab_size - 1)
        inputs_ids_b = torch.clamp(inputs_ids_b, max=vocab_size - 1)
        
        u_last_hidden_state, _, _ = model(inputs_ids_a, attention_mask=attention_a)  
        v_last_hidden_state, _, _ = model(inputs_ids_b, attention_mask=attention_b)  

        u_mean_pool = mean_pool(u_last_hidden_state, attention_a) 
        v_mean_pool = mean_pool(v_last_hidden_state, attention_b) 
        
        x = torch.cat([u_mean_pool, v_mean_pool, torch.abs(u_mean_pool - v_mean_pool)], dim=-1) 
        
        x = classifier_head(x) 
        
        loss = criterion(x, label)
        
        loss.backward()
        optimizer.step()
        optimizer_classifier.step()

        scheduler.step() 
        scheduler_classifier.step()
        
    print(f'Epoch: {epoch + 1} | Loss: {loss.item():.6f}')

  0%|          | 0/63 [00:00<?, ?it/s]

Epoch: 1 | Loss: 1.489428


  0%|          | 0/63 [00:00<?, ?it/s]

Epoch: 2 | Loss: 1.436978


  0%|          | 0/63 [00:00<?, ?it/s]

Epoch: 3 | Loss: 1.223172


  0%|          | 0/63 [00:00<?, ?it/s]

Epoch: 4 | Loss: 1.393602


  0%|          | 0/63 [00:00<?, ?it/s]

Epoch: 5 | Loss: 1.720800


In [41]:
model.eval()
classifier_head.eval()
total_correct = 0
total_samples = 0
total_similarity = 0

with torch.no_grad():
    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        inputs_ids_a = batch['premise_input_ids'].to(device)
        inputs_ids_b = batch['hypothesis_input_ids'].to(device)
        attention_a = batch['premise_attention_mask'].to(device)
        attention_b = batch['hypothesis_attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        inputs_ids_a = torch.clamp(inputs_ids_a, min=0, max=vocab_size - 1)
        inputs_ids_b = torch.clamp(inputs_ids_b, min=0, max=vocab_size - 1)
        
        u_output, _, _ = model(inputs_ids_a, attention_mask=attention_a)
        v_output, _, _ = model(inputs_ids_b, attention_mask=attention_b)
        
        u_pool = mean_pool(u_output, attention_a)
        v_pool = mean_pool(v_output, attention_b)
        
        x = torch.cat([u_pool, v_pool, torch.abs(u_pool - v_pool)], dim=-1)
        logits = classifier_head(x)
        preds = torch.argmax(logits, dim=-1)
        
        total_correct += (preds == labels).sum().item()
        total_samples += labels.size(0)
        
        # Calculate cosine similarity
        similarity_scores = cosine_similarity(u_pool, v_pool)
        total_similarity += similarity_scores.sum().item()

# Compute metrics
accuracy = total_correct / total_samples
average_similarity = total_similarity / total_samples
print(f"Validation Accuracy: {accuracy:.4f}")
print(f"Average Cosine Similarity: {average_similarity:.4f}")

Evaluating:   0%|          | 0/63 [00:00<?, ?it/s]

Validation Accuracy: 0.3420
Average Cosine Similarity: 0.9989


In [42]:
import os

# Create the 'models' directory if it doesn't exist
os.makedirs('models', exist_ok=True)

# Save the BERT model and classifier head
torch.save(model.state_dict(), 'models/best_model.pth')
torch.save(classifier_head.state_dict(), 'models/classifier_head.pth')

# Save the tokenizer
tokenizer.save_pretrained('models/tokenizer')

('models/tokenizer/tokenizer_config.json',
 'models/tokenizer/special_tokens_map.json',
 'models/tokenizer/vocab.txt',
 'models/tokenizer/added_tokens.json')

## 7. Inference

In [None]:
import torch
from sklearn.metrics.pairwise import cosine_similarity

def predict_label(model, classifier_head, tokenizer, premise, hypothesis, device):
    # Tokenize and convert sentences to input IDs and attention masks
    inputs_a = tokenizer(premise, return_tensors='pt', truncation=True, padding=True).to(device)
    inputs_b = tokenizer(hypothesis, return_tensors='pt', truncation=True, padding=True).to(device)

    # Move input IDs and attention masks to the active device
    inputs_ids_a = inputs_a['input_ids']
    attention_a = inputs_a['attention_mask']
    inputs_ids_b = inputs_b['input_ids']
    attention_b = inputs_b['attention_mask']

    # Extract token embeddings from BERT
    u = model(inputs_ids_a, attention_mask=attention_a)[0]  # all token embeddings A = batch_size, seq_len, hidden_dim
    v = model(inputs_ids_b, attention_mask=attention_b)[0]  # all token embeddings B = batch_size, seq_len, hidden_dim

    # Get the mean-pooled vectors
    u_pool = mean_pool(u, attention_a)  # batch_size, hidden_dim
    v_pool = mean_pool(v, attention_b)  # batch_size, hidden_dim

    # Concatenate embeddings for classification
    x = torch.cat([u_pool, v_pool, torch.abs(u_pool - v_pool)], dim=-1)

    # Predict label using the classifier head
    logits = classifier_head(x)
    pred_label = torch.argmax(logits, dim=-1).item()  # Get the predicted label index

    # Map label index to text
    label_map = {0: "entailment", 1: "neutral", 2: "contradiction"}
    predicted_label_text = label_map[pred_label]

    return predicted_label_text

def calculate_similarity(model, tokenizer, premise, hypothesis, device):
    # Tokenize and convert sentences to input IDs and attention masks
    inputs_a = tokenizer(premise, return_tensors='pt', truncation=True, padding=True).to(device)
    inputs_b = tokenizer(hypothesis, return_tensors='pt', truncation=True, padding=True).to(device)

    # Move input IDs and attention masks to the active device
    inputs_ids_a = inputs_a['input_ids']
    attention_a = inputs_a['attention_mask']
    inputs_ids_b = inputs_b['input_ids']
    attention_b = inputs_b['attention_mask']

    # Extract token embeddings from BERT
    u = model(inputs_ids_a, attention_mask=attention_a)[0]  # all token embeddings A = batch_size, seq_len, hidden_dim
    v = model(inputs_ids_b, attention_mask=attention_b)[0]  # all token embeddings B = batch_size, seq_len, hidden_dim

    # Get the mean-pooled vectors
    u_pool = mean_pool(u, attention_a).detach().cpu().numpy().reshape(-1)  # batch_size, hidden_dim
    v_pool = mean_pool(v, attention_b).detach().cpu().numpy().reshape(-1)  # batch_size, hidden_dim

    # Calculate cosine similarity
    similarity_score = cosine_similarity(u_pool.reshape(1, -1), v_pool.reshape(1, -1))[0, 0]

    return similarity_score

def inference(model, classifier_head, tokenizer, premise, hypothesis, device):
    # Predict the label
    predicted_label = predict_label(model, classifier_head, tokenizer, premise, hypothesis, device)

    # Calculate cosine similarity
    similarity = calculate_similarity(model, tokenizer, premise, hypothesis, device)

    print(f"• Premise: {premise}")
    print(f"• Hypothesis: {hypothesis}")
    print(f"• Label: {predicted_label.capitalize()}")
    print(f"• Cosine Similarity: {similarity:.4f}")

# Example usage:
premise = 'A man is playing a guitar on stage.'
hypothesis = 'The man is performing music.'
inference(model, classifier_head, tokenizer, premise, hypothesis, device)

• Premise: A man is playing a guitar on stage.
• Hypothesis: The man is performing music.
• Label: Entailment
• Cosine Similarity: 0.9994
