# Task 3: Evaluation of Fine-Tuned Sentence-BERT

In [1]:

import torch
import torch.nn as nn
import re
import numpy as np
from datasets import load_dataset
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from tqdm.auto import tqdm

# Detect device (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


  from .autonotebook import tqdm as notebook_tqdm


Using device: cpu


In [2]:
# load model artefacts (vocabulary + model configuration)

ckpt = torch.load("artefacts/bert_mlm.pt", map_location="cpu")

config = ckpt["config"]      # contains architecture details
word2id = ckpt["word2id"]    # word -> index mapping
id2word = ckpt["id2word"]    # index -> word mapping

PAD_ID = word2id["[PAD]"]
UNK_ID = word2id["[UNK]"]
MAX_LEN = config["max_len"]     # maximum sequence length
H = config["d_model"]           # hidden size of BERT encoder

print("Loaded Task 1 config and vocabulary")
print("Max length:", MAX_LEN)
print("Hidden size:", H)



# load encoder-only model (exported from Task 1)

# this is the TorchScript encoder saved as bert_encoder.pt
encoder = torch.jit.load("artefacts/bert_encoder.pt", map_location=device)
encoder.eval()

print("Encoder loaded successfully.")

Loaded Task 1 config and vocabulary
Max length: 1000
Hidden size: 256
Encoder loaded successfully.


### Tokenization

In [3]:
def clean_text(s: str) -> str:
    """
    Lowercase and remove punctuation exactly like Task 1.
    This ensures vocabulary consistency.
    """
    s = s.lower()
    s = re.sub(r"[.,!\-]", "", s)
    s = re.sub(r"\s+", " ", s).strip()
    return s


def encode_sentence(sentence: str, max_len: int):
    """
    Converts a sentence into:
    - input_ids: padded token IDs
    - attention_mask: 1 for real tokens, 0 for PAD
    """
    sentence = clean_text(sentence)
    tokens = sentence.split()

    # Convert words to IDs (use UNK if word not found)
    ids = [word2id.get(w, UNK_ID) for w in tokens][:max_len]
    attn = [1] * len(ids)

    # Pad up to MAX_LEN
    while len(ids) < max_len:
        ids.append(PAD_ID)
        attn.append(0)

    return ids, attn


In [4]:
def mean_pooling(token_embeddings, attention_mask):
    """
    Apply mean pooling over valid (non-PAD) tokens.
    """
    mask = attention_mask.unsqueeze(-1).float()   # [B,S,1]
    summed = (token_embeddings * mask).sum(dim=1) # Sum over sequence
    count = mask.sum(dim=1).clamp(min=1e-9)       # Avoid divide-by-zero
    return summed / count                         # Mean


In [5]:
class SBERTSoftmax(nn.Module):
    def __init__(self, encoder, hidden_size):
        super().__init__()
        self.encoder = encoder
        self.classifier = nn.Linear(hidden_size * 3, 3)  # 3 NLI classes

    def encode(self, input_ids, attention_mask):
        """
        Encode a sentence into its embedding.
        """
        segment_ids = torch.zeros_like(input_ids)   # Single segment
        hidden = self.encoder(input_ids, segment_ids)   # [B,S,H]
        return mean_pooling(hidden, attention_mask)     # [B,H]

    def forward(self, prem_ids, prem_attn, hyp_ids, hyp_attn):
        """
        Forward pass for NLI pair.
        """
        u = self.encode(prem_ids, prem_attn)
        v = self.encode(hyp_ids, hyp_attn)

        # SBERT feature construction
        features = torch.cat([u, v, torch.abs(u - v)], dim=1)

        return self.classifier(features)


# Instantiate model
sbert = SBERTSoftmax(encoder, H).to(device)

# Load fine-tuned weights from Task 2
sbert_ckpt = torch.load("artefacts/sbert_softmax_snli.pt", map_location="cpu")
sbert.load_state_dict(sbert_ckpt["sbert_state_dict"])

sbert.eval()

print("Fine-tuned SBERT model loaded.")

Fine-tuned SBERT model loaded.


In [None]:

snli = load_dataset("snli")
snli = snli.filter(lambda x: x["label"] != -1)

# use subsets for faster evaluation
val_ds = snli["validation"].shuffle(seed=42).select(range(3000))
test_ds = snli["test"].shuffle(seed=42).select(range(3000))



In [7]:
def collate_fn(batch):
    prem_ids, prem_attn = [], []
    hyp_ids, hyp_attn = [], []
    labels = []

    for x in batch:
        p_ids, p_att = encode_sentence(x["premise"], MAX_LEN)
        h_ids, h_att = encode_sentence(x["hypothesis"], MAX_LEN)

        prem_ids.append(p_ids)
        prem_attn.append(p_att)
        hyp_ids.append(h_ids)
        hyp_attn.append(h_att)
        labels.append(x["label"])

    return (
        torch.tensor(prem_ids, dtype=torch.long),
        torch.tensor(prem_attn, dtype=torch.long),
        torch.tensor(hyp_ids, dtype=torch.long),
        torch.tensor(hyp_attn, dtype=torch.long),
        torch.tensor(labels, dtype=torch.long),
    )

val_loader = DataLoader(val_ds, batch_size=32, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_ds, batch_size=32, shuffle=False, collate_fn=collate_fn)

In [8]:
label_names = ["entailment", "neutral", "contradiction"]

def evaluate(loader, split_name="VAL"):

    all_preds = []
    all_labels = []

    sbert.eval()

    with torch.no_grad():
        for prem_ids, prem_attn, hyp_ids, hyp_attn, labels in tqdm(loader, desc=f"Evaluating {split_name}"):

            prem_ids = prem_ids.to(device)
            prem_attn = prem_attn.to(device)
            hyp_ids = hyp_ids.to(device)
            hyp_attn = hyp_attn.to(device)
            labels = labels.to(device)

            logits = sbert(prem_ids, prem_attn, hyp_ids, hyp_attn)

            preds = torch.argmax(logits, dim=1)

            all_preds.append(preds.cpu().numpy())
            all_labels.append(labels.cpu().numpy())

    y_pred = np.concatenate(all_preds)
    y_true = np.concatenate(all_labels)

    # Accuracy
    acc = accuracy_score(y_true, y_pred)
    print(f"\n{split_name} Accuracy: {acc:.4f}\n")

    # Precision / Recall / F1
    print(classification_report(y_true, y_pred,
                                target_names=label_names,
                                digits=4))

    # Confusion Matrix
    print("Confusion Matrix:\n", confusion_matrix(y_true, y_pred))


# Run evaluation
evaluate(val_loader, "Validation")
evaluate(test_loader, "Test")

Evaluating Validation: 100%|██████████| 94/94 [03:11<00:00,  2.04s/it]



Validation Accuracy: 0.4817

               precision    recall  f1-score   support

   entailment     0.6076    0.3391    0.4353      1032
      neutral     0.4944    0.4596    0.4763       953
contradiction     0.4272    0.6473    0.5147      1015

     accuracy                         0.4817      3000
    macro avg     0.5097    0.4820    0.4755      3000
 weighted avg     0.5106    0.4817    0.4752      3000

Confusion Matrix:
 [[350 215 467]
 [101 438 414]
 [125 233 657]]


Evaluating Test: 100%|██████████| 94/94 [03:25<00:00,  2.19s/it]


Test Accuracy: 0.4877

               precision    recall  f1-score   support

   entailment     0.5939    0.3424    0.4344      1025
      neutral     0.5188    0.4605    0.4879       988
contradiction     0.4289    0.6657    0.5216       987

     accuracy                         0.4877      3000
    macro avg     0.5139    0.4895    0.4813      3000
 weighted avg     0.5149    0.4877    0.4807      3000

Confusion Matrix:
 [[351 210 464]
 [122 455 411]
 [118 212 657]]





### Analysis

The fine-tuned Sentence-BERT model achieved an accuracy of 48.17% on the validation set and 48.77% on the test set. This performance is significantly above the random baseline of approximately 33% for a three-class classification task (entailment, neutral, contradiction), indicating that the model successfully learned meaningful sentence representations.

The model performed best on the contradiction class, achieving recall values above 0.64 on both validation and test sets. However, recall for entailment was relatively low (≈0.34), suggesting that the model struggles to detect semantic entailment relationships accurately.

The moderate overall performance can be attributed to several factors. First, the BERT model in Task 1 was pretrained on a relatively small subset of BookCorpus, limiting vocabulary coverage. Consequently, many SNLI tokens were mapped to the [UNK] token, reducing semantic richness in the embeddings. Second, the fine-tuning phase used a limited number of epochs and training samples compared to the original SBERT paper.

Despite these limitations, the experimental results demonstrate that the custom BERT encoder successfully learned transferable sentence representations, which were effectively adapted using the SoftmaxLoss objective for natural language inference