### Sentence Embedding with Sentence BERT

In [1]:
# load the trained model from task 1
import re
import torch 
import torch.nn as nn
from datasets import load_dataset
from torch.utils.data import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [3]:
# Load checkpoint (vocab + config)
pretrained_model = torch.load("artefacts/bert_mlm.pt", map_location="cpu")
config = pretrained_model["config"]
word2id = pretrained_model["word2id"]
id2word = pretrained_model["id2word"]

PAD_ID = word2id["[PAD]"]
UNK_ID = word2id["[UNK]"]
MAX_LEN = config["max_len"]
H = config["d_model"]

print("Loaded bert_mlm.pt")
print("MAX_LEN:", MAX_LEN, "H:", H, "vocab_size:", config["vocab_size"])


Loaded bert_mlm.pt
MAX_LEN: 1000 H: 256 vocab_size: 20889


In [4]:
encoder_path = "artefacts/bert_encoder.pt"

encoder = torch.jit.load(encoder_path, map_location="cpu")
encoder.eval()

print("Loaded bert_encoder.pt")

Loaded bert_encoder.pt


## Tokenization

In [5]:
# preprocessing the same as task 1 
def clean_text(s: str) -> str:
    # lower case the character
    s = s.lower()
    # remove punctuation
    s = re.sub(r"[.,!\-]", "", s)   
    # remove white spaces 
    s = re.sub(r"\s+", " ", s).strip()
    return s

def encode_sentence(sentence: str, max_len: int):
    """
    sentence -> (input_ids, attention_mask)
    input_ids: padded to max_len
    attention_mask: 1 for real tokens, 0 for PAD
    """
    sentence = clean_text(sentence)
    tokens = sentence.split()

    ids = [word2id.get(w, UNK_ID) for w in tokens][:max_len]
    attn = [1] * len(ids)

    while len(ids) < max_len:
        ids.append(PAD_ID)
        attn.append(0)  

    return ids, attn

In [6]:
# quick check on the corpus
test_sent = "A man is playing basketball outdoors."
ids, attn = encode_sentence(test_sent, MAX_LEN)

unk_count = sum(1 for i in ids if i == UNK_ID)
real_tokens = sum(attn)

print("Real tokens:", real_tokens, "UNK tokens:", unk_count)
print("UNK ratio:", unk_count / max(real_tokens, 1))


Real tokens: 6 UNK tokens: 0
UNK ratio: 0.0


In [7]:
snli = load_dataset("snli")
snli = snli.filter(lambda x: x["label"] != -1)

# Start smaller for speed; increase later
train_ds = snli["train"].shuffle(seed=42).select(range(20000))
val_ds   = snli["validation"].shuffle(seed=42).select(range(3000))

print("SNLI:", len(train_ds), len(val_ds))


SNLI: 20000 3000


In [8]:
def collate_fn(batch):
    prem_ids, prem_attn = [], []
    hyp_ids, hyp_attn = [], []
    labels = []

    for x in batch:
        p_ids, p_att = encode_sentence(x["premise"], MAX_LEN)
        h_ids, h_att = encode_sentence(x["hypothesis"], MAX_LEN)

        prem_ids.append(p_ids); prem_attn.append(p_att)
        hyp_ids.append(h_ids);  hyp_attn.append(h_att)
        labels.append(x["label"])  # 0 entailment, 1 neutral, 2 contradiction

    return (
        torch.tensor(prem_ids, dtype=torch.long),
        torch.tensor(prem_attn, dtype=torch.long),
        torch.tensor(hyp_ids, dtype=torch.long),
        torch.tensor(hyp_attn, dtype=torch.long),
        torch.tensor(labels, dtype=torch.long),
    )

train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_ds, batch_size=16, shuffle=False, collate_fn=collate_fn)


In [9]:
def mean_pooling(token_embeddings, attention_mask):
    # token_embeddings: [B,S,H]
    # attention_mask: [B,S]
    mask = attention_mask.unsqueeze(-1).float()         # [B,S,1]
    summed = (token_embeddings * mask).sum(dim=1)       # [B,H]
    count = mask.sum(dim=1).clamp(min=1e-9)             # [B,1]
    return summed / count                               # [B,H]

class SBERTSoftmax(nn.Module):
    """
    Features = [u, v, |u-v|], then softmax classifier for NLI (3 classes)
    """
    def __init__(self, encoder, hidden_size):
        super().__init__()
        self.encoder = encoder
        self.classifier = nn.Linear(hidden_size * 3, 3)

    def encode(self, input_ids, attention_mask):
        segment_ids = torch.zeros_like(input_ids)  # [B,S]
        hidden = self.encoder(input_ids, segment_ids)         # [B,S,H]
        sent_emb = mean_pooling(hidden, attention_mask)       # [B,H]
        return sent_emb

    def forward(self, prem_ids, prem_attn, hyp_ids, hyp_attn):
        u = self.encode(prem_ids, prem_attn)
        v = self.encode(hyp_ids, hyp_attn)
        feats = torch.cat([u, v, torch.abs(u - v)], dim=1)
        logits = self.classifier(feats)
        return logits

sbert = SBERTSoftmax(encoder, H).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(sbert.parameters(), lr=2e-5)  # paper-style LR


In [10]:
from tqdm.auto import tqdm

epochs = 1

for ep in range(epochs):
    sbert.train()
    total_loss = 0.0

    for prem_ids, prem_attn, hyp_ids, hyp_attn, labels in tqdm(train_loader):
        prem_ids = prem_ids.to(device)
        prem_attn = prem_attn.to(device)
        hyp_ids = hyp_ids.to(device)
        hyp_attn = hyp_attn.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        logits = sbert(prem_ids, prem_attn, hyp_ids, hyp_attn)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {ep+1} avg loss: {total_loss/len(train_loader):.4f}")


100%|██████████| 1250/1250 [2:20:45<00:00,  6.76s/it]   

Epoch 1 avg loss: 1.0328





In [12]:
torch.save({
    "sbert_state_dict": sbert.state_dict(),
    "config": config,
    "word2id": word2id,
    "id2word": id2word,
    },
    "artefacts/sbert_softmax_snli.pt")

print("Sbert-softmax_snli model is saved")

Sbert-softmax_snli model is saved
