# NLP A4


## Set up the dataset

We implement a BERT-style encoder with masked language modeling (MLM) 
and next sentence prediction (NSP), and pretrain it on a subset of a public corpus 
(BookCorpus / English Wikipedia).

**References**

[1] BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding (Devlin et al., 2018).  
[2] BookCorpus dataset: https://huggingface.co/datasets/bookcorpus  
[3] English Wikipedia dataset: https://huggingface.co/datasets/legacy-datasets/wikipedia

In [1]:
# Import modules
import os
import math
import re
from random import random, shuffle, randint
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from datasets import load_dataset
from tqdm.auto import tqdm

In [2]:
# Check if cuda is ready
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

### Code: โหลด corpus (เอาตัวอย่าง BookCorpus subset 100k)

In [3]:
# Load bookcorpus as plain text
raw_dataset = load_dataset("bookcorpus", "plain_text", trust_remote_code=True)
raw_dataset

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 74004228
    })
})

In [4]:
# Extract text list from split train
texts = raw_dataset["train"]["text"]

# Filter for short sentense/null sentenses and fliter lower case
clean_texts = []
for t in texts:
    if t is None:
        continue
    t = t.strip()
    if len(t) < 10:
        continue
    t = t.lower()
    t = re.sub(r"[^a-z0-9 ,.!?'-]+", " ", t)
    t = re.sub(r"\s+", " ", t).strip()
    if t:
        clean_texts.append(t)

len(clean_texts)


72044730

In [5]:
# Ramdomly pick 100k subset (Lower if GPU is not strong enough)
np.random.seed(42)
max_samples = min(100_000, len(clean_texts))
indices = np.random.choice(len(clean_texts), size=max_samples, replace=False)
corpus = [clean_texts[i] for i in indices]
len(corpus)


100000

### Code: สร้าง vocab แบบ word-level

In [6]:
# Build vocab from corpus
special_tokens = ["[PAD]", "[CLS]", "[SEP]", "[MASK]", "[UNK]"]
word2id = {tok: i for i, tok in enumerate(special_tokens)}
start_idx = len(special_tokens)

all_words = set(" ".join(corpus).split())
for i, w in enumerate(all_words):
    word2id[w] = i + start_idx

id2word = {i: w for w, i in word2id.items()}
vocab_size = len(word2id)

vocab_size, list(word2id.items())[:10]


(44623,
 [('[PAD]', 0),
  ('[CLS]', 1),
  ('[SEP]', 2),
  ('[MASK]', 3),
  ('[UNK]', 4),
  ('knuckle', 5),
  ('whimper', 6),
  ('alrighty', 7),
  ('exhalations', 8),
  ('married', 9)])

In [51]:
# Save word2id as json format
import json
os.makedirs("../dataset", exist_ok=True)
with open("../dataset/word2id.json", "w", encoding="utf-8") as f:
    json.dump(word2id, f)

In [7]:
# Convert each sentenses to list of ids (Not included CLS/SEP)
tokenized_sentences = []
for sent in corpus:
    ids = [word2id[w] for w in sent.split() if w in word2id]
    if len(ids) > 0:
        tokenized_sentences.append(ids)

len(tokenized_sentences)

100000

### Code: ฟังก์ชันสร้าง batch แบบ MLM+NSP

In [8]:
# Set up hyperparameters 
batch_size = 16          # Adjust depends on GPU capacity
max_mask = 20            # Number of masked tokens ต่per sequence
max_len = 128            # Max sequence length in Transformer
n_segments = 2

pad_id = word2id["[PAD]"]
cls_id = word2id["[CLS]"]
sep_id = word2id["[SEP]"]
mask_id = word2id["[MASK]"]

In [9]:
def make_batch(sentences, batch_size, max_len, max_mask):
    """
    sentences: list[list[int]] (แต่ละ element คือ sentence token ids)
    return:
        input_ids: [B, L]
        segment_ids: [B, L]
        masked_tokens: [B, max_mask]
        masked_pos: [B, max_mask]
        is_next: [B]
    """
    batch = []
    positive = 0
    negative = 0

    num_sent = len(sentences)

    while positive < batch_size // 2 or negative < batch_size // 2:
        # random เลือกสองประโยค
        idx_a = randint(0, num_sent - 2)
        idx_b = randint(0, num_sent - 1)

        tokens_a = sentences[idx_a]
        # ถ้าเป็น positive pair ให้ B = A+1
        if random() < 0.5:
            tokens_b = sentences[idx_a + 1]
            is_next = True
        else:
            tokens_b = sentences[idx_b]
            is_next = False

        # ตัดความยาว A/B ให้รวมกันไม่เกิน max_len-3 (เผื่อ CLS และ SEP สองตัว)
        max_total = max_len - 3
        if len(tokens_a) + len(tokens_b) > max_total:
            # แบ่งสัดส่วนจากความยาวเดิม
            len_a = max(1, int(max_total * len(tokens_a) / (len(tokens_a) + len(tokens_b))))
            len_b = max_total - len_a
            tokens_a = tokens_a[:len_a]
            tokens_b = tokens_b[:len_b]

        # token embedding: [CLS] A [SEP] B [SEP]
        input_ids = [cls_id] + tokens_a + [sep_id] + tokens_b + [sep_id]

        # segment embedding
        seg_a_len = 1 + len(tokens_a) + 1
        seg_b_len = len(input_ids) - seg_a_len
        segment_ids = [0] * seg_a_len + [1] * seg_b_len

        # ตรวจซ้ำเพื่อ safety ว่าตอนนี้ยาวไม่เกิน max_len
        assert len(input_ids) <= max_len
        assert len(segment_ids) == len(input_ids)

        # สร้าง masked LM
        n_pred = max(1, int(round(len(input_ids) * 0.15)))
        n_pred = min(n_pred, max_mask)

        cand_pos = [i for i, tok in enumerate(input_ids)
                    if tok not in (cls_id, sep_id)]
        shuffle(cand_pos)
        masked_tokens, masked_pos = [], []

        for pos in cand_pos[:n_pred]:
            masked_pos.append(pos)
            masked_tokens.append(input_ids[pos])

            r = random()
            if r < 0.8:
                input_ids[pos] = mask_id       # 80% → [MASK]
            elif r < 0.9:
                rand_tok = randint(0, vocab_size - 1)  # 10% → random token
                input_ids[pos] = rand_tok
            else:
                pass                           # 10% → unchanged

        # padding sequence → ยาวเท่ากันเป๊ะ
        if len(input_ids) < max_len:
            pad_n = max_len - len(input_ids)
            input_ids += [pad_id] * pad_n
            segment_ids += [0] * pad_n

        # padding masked tokens/pos
        if len(masked_tokens) < max_mask:
            pad_n = max_mask - len(masked_tokens)
            masked_tokens += [0] * pad_n
            masked_pos += [0] * pad_n

        if is_next and positive < batch_size // 2:
            batch.append((input_ids, segment_ids, masked_tokens, masked_pos, 1))
            positive += 1
        elif (not is_next) and negative < batch_size // 2:
            batch.append((input_ids, segment_ids, masked_tokens, masked_pos, 0))
            negative += 1

    input_ids, segment_ids, masked_tokens, masked_pos, is_next = map(
        torch.LongTensor, zip(*batch)
    )
    return input_ids, segment_ids, masked_tokens, masked_pos, is_next


In [10]:
# Test building one batch to check shape 
test_input_ids, test_segment_ids, test_masked_tokens, test_masked_pos, test_is_next = \
    make_batch(tokenized_sentences, batch_size=8, max_len=max_len, max_mask=max_mask)

test_input_ids.shape, test_segment_ids.shape, test_masked_tokens.shape, test_masked_pos.shape, test_is_next.shape

(torch.Size([8, 128]),
 torch.Size([8, 128]),
 torch.Size([8, 20]),
 torch.Size([8, 20]),
 torch.Size([8]))

### Code: โมเดล BERT (copy/ปรับจาก BERT.ipynb ให้เล็กลงนิดหน่อย)

In [11]:
# Set up hyperparameters
n_layers = 2    # encoder layers
n_heads = 4
d_model = 256
d_ff = d_model * 4
d_k = 64
d_v = 64

In [12]:
class Embedding(nn.Module):
    def __init__(self, vocab_size, d_model, max_len, n_segments):
        super().__init__()
        self.tok_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(max_len, d_model)
        self.seg_embed = nn.Embedding(n_segments, d_model)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, seg):
        bsz, seq_len = x.size()
        pos = torch.arange(seq_len, dtype=torch.long, device=x.device)
        pos = pos.unsqueeze(0).expand_as(x)
        emb = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)
        return self.norm(emb)

In [13]:
def get_attn_pad_mask(seq_q, seq_k, pad_id=0):
    """
    seq_q: [B, Lq], seq_k: [B, Lk]
    return: [B, Lq, Lk]
    """
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    pad_attn_mask = seq_k.eq(pad_id).unsqueeze(1)  # [B, 1, Lk]
    return pad_attn_mask.expand(batch_size, len_q, len_k)  # [B, Lq, Lk]


In [14]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, Q, K, V, attn_mask):
        # Q: [B, H, Lq, d_k], K: [B, H, Lk, d_k], V: [B, H, Lk, d_v]
        scores = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(Q.size(-1))
        scores.masked_fill_(attn_mask, -1e9)
        attn = F.softmax(scores, dim=-1)
        context = torch.matmul(attn, V)
        return context, attn


In [15]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, d_k, d_v, n_heads):
        super().__init__()
        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v
        self.n_heads = n_heads

        self.W_Q = nn.Linear(d_model, d_k * n_heads)
        self.W_K = nn.Linear(d_model, d_k * n_heads)
        self.W_V = nn.Linear(d_model, d_v * n_heads)
        self.fc = nn.Linear(n_heads * d_v, d_model)
        self.layer_norm = nn.LayerNorm(d_model)
        self.attn = None
        self.sdpa = ScaledDotProductAttention()

    def forward(self, Q, K, V, attn_mask):
        residual = Q
        batch_size = Q.size(0)

        # [B, L, H*d_k] -> [B, H, L, d_k]
        q_s = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        k_s = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        v_s = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2)

        # attn_mask: [B, Lq, Lk] -> [B, H, Lq, Lk]
        attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1)

        context, attn = self.sdpa(q_s, k_s, v_s, attn_mask)
        self.attn = attn

        # [B, H, L, d_v] -> [B, L, H*d_v]
        context = context.transpose(1, 2).contiguous().view(
            batch_size, -1, self.n_heads * self.d_v
        )
        output = self.fc(context)
        return self.layer_norm(output + residual)


In [16]:
class PoswiseFeedForwardNet(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self, x):
        residual = x
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return self.layer_norm(x + residual)


In [17]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, d_k, d_v, d_ff, n_heads):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, d_k, d_v, n_heads)
        self.pos_ffn = PoswiseFeedForwardNet(d_model, d_ff)

    def forward(self, enc_inputs, enc_self_attn_mask):
        out = self.self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask)
        out = self.pos_ffn(out)
        return out


In [18]:
class BERT(nn.Module):
    def __init__(self, vocab_size, d_model, d_k, d_v, d_ff, n_layers, n_heads, max_len, n_segments):
        super().__init__()
        self.embedding = Embedding(vocab_size, d_model, max_len, n_segments)
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, d_k, d_v, d_ff, n_heads)
            for _ in range(n_layers)
        ])

        # NSP head
        self.fc = nn.Linear(d_model, d_model)
        self.activ = nn.Tanh()
        self.classifier = nn.Linear(d_model, 2)

        # MLM head
        self.linear = nn.Linear(d_model, d_model)
        self.norm = nn.LayerNorm(d_model)

        # decoder share weight with token embedding
        embed_weight = self.embedding.tok_embed.weight
        n_vocab, n_dim = embed_weight.size()
        self.decoder = nn.Linear(n_dim, n_vocab, bias=False)
        self.decoder.weight = embed_weight
        self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))

    # NEW: encoder-only method
    def encode(self, input_ids, segment_ids):
        # Compute encoder hidden states given token ids and segment ids.
        # This is used by SentenceEncoder in Task 2.
        x = self.embedding(input_ids, segment_ids)
        enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids, pad_id=pad_id)
        for layer in self.layers:
            x = layer(x, enc_self_attn_mask)
        return x  # [B, L, d_model]
    
    def forward(self, input_ids, segment_ids, masked_pos):
        # Use encode() to get contextual token representations
        x = self.encode(input_ids, segment_ids)

        # NSP head using [CLS] token at position 0
        pooled = self.activ(self.fc(x[:, 0]))
        logits_nsp = self.classifier(pooled)

        # MLM head using masked positions
        masked_pos = masked_pos.unsqueeze(-1).expand(-1, -1, x.size(-1))
        h_masked = torch.gather(x, 1, masked_pos)
        h_masked = self.norm(F.gelu(self.linear(h_masked)))
        logits_lm = self.decoder(h_masked) + self.decoder_bias

        return logits_lm, logits_nsp

    def forward(self, input_ids, segment_ids, masked_pos):
        x = self.embedding(input_ids, segment_ids)
        enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids, pad_id=pad_id)

        for layer in self.layers:
            x = layer(x, enc_self_attn_mask)

        # NSP: ใช้ [CLS] token (ตำแหน่ง 0)
        pooled = self.activ(self.fc(x[:, 0]))
        logits_nsp = self.classifier(pooled)

        # MLM: ดึง hidden states ที่ตำแหน่ง masked_pos
        masked_pos = masked_pos.unsqueeze(-1).expand(-1, -1, x.size(-1))
        h_masked = torch.gather(x, 1, masked_pos)
        h_masked = self.norm(F.gelu(self.linear(h_masked)))
        logits_lm = self.decoder(h_masked) + self.decoder_bias

        return logits_lm, logits_nsp

In [19]:
model = BERT(
    vocab_size=vocab_size,
    d_model=d_model,
    d_k=d_k,
    d_v=d_v,
    d_ff=d_ff,
    n_layers=n_layers,
    n_heads=n_heads,
    max_len=max_len,
    n_segments=n_segments
).to(device)

sum(p.numel() for p in model.parameters()) / 1e6  # ดูประมาณ parameter (ล้าน)


13.214033

###  Code: training loop + save weights

In [20]:
num_epochs = 3         # This affect amount of time for training
lr = 1e-4

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

In [21]:
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    num_steps = 1000   # limit steps for each epoch to not make the training goes too long

    for step in tqdm(range(num_steps)):
        input_ids, segment_ids, masked_tokens, masked_pos, is_next = make_batch(
            tokenized_sentences,
            batch_size=batch_size,
            max_len=max_len,
            max_mask=max_mask
        )

        input_ids = input_ids.to(device)
        segment_ids = segment_ids.to(device)
        masked_tokens = masked_tokens.to(device)
        masked_pos = masked_pos.to(device)
        is_next = is_next.to(device)

        optimizer.zero_grad()
        logits_lm, logits_nsp = model(input_ids, segment_ids, masked_pos)

        # MLM loss
        # logits_lm: [B, max_mask, vocab_size]
        loss_lm = criterion(
            logits_lm.view(-1, vocab_size),
            masked_tokens.view(-1)
        )

        # NSP loss
        loss_nsp = criterion(logits_nsp, is_next)

        loss = loss_lm + loss_nsp
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / num_steps
    print(f"Epoch {epoch+1:02d} | loss = {avg_loss:.4f}")


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 01 | loss = 6.3761


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 02 | loss = 3.5728


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 03 | loss = 2.9685


In [22]:
# Save weights for next Task
os.makedirs("../model", exist_ok=True)
save_path = "../model/bert_pretrained_from_scratch.pt"
torch.save(model.state_dict(), save_path)
save_path

'../model/bert_pretrained_from_scratch.pt'

## Sentence Embedding with Sentence-BERT

In this section, we reuse the BERT encoder trained in Task 1 as a sentence encoder, 
build a Siamese architecture like Sentence-BERT, and train it on NLI data 
(SNLI + MNLI) with the softmax classification objective as described in the assignment [1][2].

[1] Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks (Reimers & Gurevych, EMNLP 2019).  
[2] SNLI / MNLI datasets from HuggingFace Hub.


In [23]:
# Import extra libraries for Task 2 (NLI datasets, dataloaders)
from torch.utils.data import DataLoader
from datasets import load_dataset, DatasetDict

### Load SNLI and MNLI and prepare DatasetDict

In [24]:
# Load SNLI and MNLI datasets from HuggingFace
# SNLI: natural language inference with three labels (entailment, neutral, contradiction)
# MNLI: multi-genre NLI with the same three labels plus some -1 labels that we will filter out

snli = load_dataset("snli")
mnli = load_dataset("glue", "mnli")

In [25]:
# Remove invalid labels (-1) from SNLI (no gold label could be decided)
def filter_valid_snli(example):
    # Keep only rows where label is 0, 1, or 2
    return example["label"] != -1

snli = snli.filter(filter_valid_snli)

In [26]:
# Remove 'idx' column from each MNLI split and keep only label + text pairs
for split_name in list(mnli.keys()):
    if "idx" in mnli[split_name].column_names:
        mnli[split_name] = mnli[split_name].remove_columns("idx")

In [27]:
# Build a merged DatasetDict: concatenate SNLI and MNLI train/validation/test sets
# We shuffle and optionally subsample for faster training.
from datasets import concatenate_datasets

max_train = 10000   # You can increase this if you have more compute
max_val = 2000
max_test = 2000

rawdataset = DatasetDict({
    "train": concatenate_datasets([
        snli["train"],
        mnli["train"]
    ]).shuffle(seed=55).select(range(max_train)),
    "validation": concatenate_datasets([
        snli["validation"],
        mnli["validation_matched"]
    ]).shuffle(seed=55).select(range(max_val)),
    "test": concatenate_datasets([
        snli["test"],
        mnli["test_matched"]
    ]).shuffle(seed=55).select(range(max_test)),
})

rawdataset

DatasetDict({
    train: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 10000
    })
    validation: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 2000
    })
    test: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 2000
    })
})

### Simple whitespace tokenizer

In [28]:
# Simple whitespace tokenizer reusing the vocabulary from Task 1
# We will map unknown words to a special [UNK] id if needed.

# Create UNK token if it does not exist
if "[UNK]" not in word2id:
    unk_id = len(word2id)
    word2id["[UNK]"] = unk_id
    id2word[unk_id] = "[UNK]"
else:
    unk_id = word2id["[UNK]"]

vocab_size = len(word2id)

# Define a helper to tokenize and numericalize a sentence using the existing vocab
def encode_sentence(text, max_len):
    # Lowercase and basic cleaning to be consistent with Task 1 preprocessing
    text = text.lower()
    text = re.sub(r"[^a-z0-9 ,.!?'-]+", " ", text)
    text = re.sub(r"\s+", " ", text).strip()
    tokens = []
    for w in text.split():
        if w in word2id:
            tokens.append(word2id[w])
        else:
            tokens.append(unk_id)
    # Truncate to max_len (we will still add CLS/SEP at sentence encoder level)
    if len(tokens) > max_len - 2:  # leave room for CLS/SEP if needed
        tokens = tokens[:max_len - 2]
    return tokens

### Preprocess NLI dataset ids + label

In [29]:
# Preprocess function to convert premise/hypothesis strings to token id sequences
# using the encoder from Task 1. We keep the raw sequences (without CLS/SEP)
# because we will let the sentence encoder add those later if needed.

def preprocess_nli(example):
    premise = example["premise"] if "premise" in example else example["sentence1"]
    hypothesis = example["hypothesis"] if "hypothesis" in example else example["sentence2"]
    label = example["label"]
    
    input_ids_a = encode_sentence(premise, max_len)
    input_ids_b = encode_sentence(hypothesis, max_len)
    
    return {
        "input_ids_a": input_ids_a,
        "input_ids_b": input_ids_b,
        "label": label,
    }

tokenized_nli = rawdataset.map(preprocess_nli, remove_columns=rawdataset["train"].column_names)
tokenized_nli

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['label', 'input_ids_a', 'input_ids_b'],
        num_rows: 10000
    })
    validation: Dataset({
        features: ['label', 'input_ids_a', 'input_ids_b'],
        num_rows: 2000
    })
    test: Dataset({
        features: ['label', 'input_ids_a', 'input_ids_b'],
        num_rows: 2000
    })
})

In [30]:
# # Set PyTorch format for DataLoader and let HuggingFace handle variable-length lists.
# tokenized_nli.set_format(
#     type="torch",
#     columns=["input_ids_a", "input_ids_b", "label"],
# )

### collate_fn

In [31]:
# Collate function to pad variable-length sequences in a batch
# We will pad both premises and hypotheses to the maximum length in the batch.

def collate_nli(batch):
    # Extract sequences and labels from the list of examples.
    # Each example has Python lists for 'input_ids_a' and 'input_ids_b',
    # and an integer label.

    # Make sure we are always working on Python lists.
    input_ids_a = [list(item["input_ids_a"]) for item in batch]
    input_ids_b = [list(item["input_ids_b"]) for item in batch]
    labels = torch.tensor([int(item["label"]) for item in batch], dtype=torch.long)

    # Compute max lengths for A and B in this batch
    max_len_a = max(len(x) for x in input_ids_a)
    max_len_b = max(len(x) for x in input_ids_b)

    # Pad sequences with [PAD] id
    padded_a = []
    padded_b = []
    for a, b in zip(input_ids_a, input_ids_b):
        # Pad premise
        pad_len_a = max_len_a - len(a)
        padded_a.append(a + [pad_id] * pad_len_a)
        # Pad hypothesis
        pad_len_b = max_len_b - len(b)
        padded_b.append(b + [pad_id] * pad_len_b)

    # Convert to tensors
    input_ids_a = torch.tensor(padded_a, dtype=torch.long)
    input_ids_b = torch.tensor(padded_b, dtype=torch.long)

    return {
        "input_ids_a": input_ids_a,
        "input_ids_b": input_ids_b,
        "labels": labels,
    }

In [32]:
# Build DataLoaders for train/validation/test
batch_size_sbert = 32

train_loader = DataLoader(
    tokenized_nli["train"],
    batch_size=batch_size_sbert,
    shuffle=True,
    collate_fn=collate_nli,
)

val_loader = DataLoader(
    tokenized_nli["validation"],
    batch_size=batch_size_sbert,
    shuffle=False,
    collate_fn=collate_nli,
)

test_loader = DataLoader(
    tokenized_nli["test"],
    batch_size=batch_size_sbert,
    shuffle=False,
    collate_fn=collate_nli,
)

### Build sentence encoder 

In [33]:
# We reuse the BERT encoder implementation from Task 1.
# Here we define a helper that takes token ids of a single sentence,
# wraps them with [CLS] and [SEP], feeds them into BERT, and returns
# the mean-pooled sentence embedding (excluding padding).

def build_sentence_batch(input_ids_batch):
    # input_ids_batch: [B, L] (token ids without CLS/SEP)
    batch_size, seq_len = input_ids_batch.size()
    
    # Add CLS and SEP to each sentence: [CLS] sentence [SEP]
    cls = torch.full((batch_size, 1), cls_id, dtype=torch.long, device=input_ids_batch.device)
    sep = torch.full((batch_size, 1), sep_id, dtype=torch.long, device=input_ids_batch.device)

    # Concatenate CLS + sentence + SEP along sequence dimension
    # Resulting shape: [B, L+2]
    input_ids = torch.cat([cls, input_ids_batch, sep], dim=1)

    # Build segment ids (all zeros because we only have one sentence)
    segment_ids = torch.zeros_like(input_ids, dtype=torch.long)

    # Compute positions in the sentence where token is not PAD
    # We treat PAD as [PAD] id and do not include it in attention.
    return input_ids, segment_ids

In [34]:
# Mean pooling over non-PAD token positions. This is similar in spirit
# to the pooling used in the SBERT reference code, but operates on our custom BERT.

def mean_pool(hidden_states, input_ids):
    # hidden_states: [B, L, D]
    # input_ids: [B, L] (with PAD tokens)
    # Build mask where input_ids != PAD
    mask = (input_ids != pad_id).unsqueeze(-1)  # [B, L, 1]
    masked_hidden = hidden_states * mask
    # Sum and then divide by the number of non-PAD tokens
    sum_hidden = masked_hidden.sum(dim=1)                # [B, D]
    lengths = mask.sum(dim=1).clamp(min=1)               # [B, 1]
    mean_hidden = sum_hidden / lengths                   # [B, D]
    return mean_hidden

In [35]:
# Wrapper around the BERT model to get sentence embeddings directly.
# We reuse the same BERT class and only use the encoder outputs.

class SentenceEncoder(nn.Module):
    def __init__(self, bert_model):
        super().__init__()
        self.bert = bert_model

    def forward(self, input_ids_batch):
        # Build full input_ids and segment_ids with CLS/SEP
        input_ids, segment_ids = build_sentence_batch(input_ids_batch)
        
        # Pass through the BERT encoder to obtain hidden states
        logits_lm, logits_nsp = self.bert(input_ids, segment_ids, masked_pos=torch.zeros(input_ids.size(0), 1, dtype=torch.long, device=input_ids.device))
        
        # The BERT forward returns MLM and NSP logits, not hidden states.
        # To get hidden states, we need to slightly modify BERT or define a method that exposes them.
        # For simplicity, we will define a new function in BERT that returns hidden states.

        raise NotImplementedError("Please modify BERT class to expose hidden states.")

### Modify BERT class to return hidden state

In [36]:
# Add this method inside the BERT class definition from Task 1

def encode(self, input_ids, segment_ids):
    # Compute encoder hidden states given token ids and segment ids.
    # This method does not compute MLM or NSP logits; it only returns the final hidden states.
    x = self.embedding(input_ids, segment_ids)
    enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids, pad_id=pad_id)
    for layer in self.layers:
        x = layer(x, enc_self_attn_mask)
    return x  # [B, L, d_model]

In [37]:
def forward(self, input_ids, segment_ids, masked_pos):
    # Encode input sequence to obtain contextual token representations
    x = self.encode(input_ids, segment_ids)

    # NSP head using [CLS] token at position 0
    pooled = self.activ(self.fc(x[:, 0]))
    logits_nsp = self.classifier(pooled)

    # MLM head using masked positions
    masked_pos = masked_pos.unsqueeze(-1).expand(-1, -1, x.size(-1))
    h_masked = torch.gather(x, 1, masked_pos)
    h_masked = self.norm(F.gelu(self.linear(h_masked)))
    logits_lm = self.decoder(h_masked) + self.decoder_bias

    return logits_lm, logits_nsp

### SentenceEncoder encode()

In [38]:
# SentenceEncoder now uses the encode() method to get hidden states

class SentenceEncoder(nn.Module):
    def __init__(self, bert_model):
        super().__init__()
        self.bert = bert_model

    def forward(self, input_ids_batch):
        # Build full sentence batch with CLS and SEP tokens
        input_ids, segment_ids = build_sentence_batch(input_ids_batch)
        input_ids = input_ids.to(device)
        segment_ids = segment_ids.to(device)

        # Use BERT encoder to get contextualized token representations
        hidden_states = self.bert.encode(input_ids, segment_ids)  # [B, L, D]

        # Compute mean-pooled sentence embedding
        sentence_embeddings = mean_pool(hidden_states, input_ids)  # [B, D]

        return sentence_embeddings

In [39]:
# Create a fresh BERT model with the same architecture as in Task 1
bert_for_sbert = BERT(
    vocab_size=vocab_size,
    d_model=d_model,
    d_k=d_k,
    d_v=d_v,
    d_ff=d_ff,
    n_layers=n_layers,
    n_heads=n_heads,
    max_len=max_len,
    n_segments=n_segments
)

# Load pretrained weights from Task 1
task1_ckpt_path = "../model/bert_pretrained_from_scratch.pt"
state_dict = torch.load(task1_ckpt_path, map_location="cpu")
bert_for_sbert.load_state_dict(state_dict)

bert_for_sbert = bert_for_sbert.to(device)

In [40]:
# Build sentence encoder wrapper
sentence_encoder = SentenceEncoder(bert_for_sbert).to(device)

### Softmax classification head (u, v, |u−v| concat)

In [41]:
# According to the assignment and the SBERT paper, we build a classifier
# on top of the concatenation [u; v; |u - v|], where u and v are the sentence embeddings.
# If the embedding dimension is d_model, the input to the classifier has size 3 * d_model.

class SBERTClassifier(nn.Module):
    def __init__(self, encoder, hidden_dim):
        super().__init__()
        self.encoder = encoder
        self.classifier = nn.Linear(3 * hidden_dim, 3)  # three NLI labels

    def forward(self, input_ids_a, input_ids_b):
        # Compute sentence embeddings u and v for premise and hypothesis
        u = self.encoder(input_ids_a)  # [B, D]
        v = self.encoder(input_ids_b)  # [B, D]

        # Compute absolute difference |u - v|
        uv_abs = torch.abs(u - v)

        # Concatenate u, v, and |u - v| along the last dimension
        x = torch.cat([u, v, uv_abs], dim=-1)  # [B, 3D]

        # Compute logits for three classes (entailment, neutral, contradiction)
        logits = self.classifier(x)

        return logits

In [42]:
# Initialize SBERT classifier with the sentence encoder and dimension d_model
sbert_model = SBERTClassifier(sentence_encoder, hidden_dim=d_model).to(device)

### Training loop

In [43]:
# Set up optimizer and loss function for NLI classification
# We use cross-entropy loss as in the SBERT classification objective.

lr_sbert = 2e-5
num_epochs_sbert = 3  # You can change this depending on time and resources

criterion_nli = nn.CrossEntropyLoss()
optimizer_sbert = optim.Adam(sbert_model.parameters(), lr=lr_sbert)

In [44]:
# Training loop for Sentence-BERT classifier on NLI data
# For each batch, we compute sentence embeddings for premise and hypothesis,
# build the [u; v; |u - v|] representation, and optimize the softmax loss.

for epoch in range(num_epochs_sbert):
    sbert_model.train()
    total_loss = 0.0
    total_steps = 0

    for batch in tqdm(train_loader, desc=f"SBERT Training Epoch {epoch+1}"):
        # Move all batch tensors to the active device
        input_ids_a = batch["input_ids_a"].to(device)
        input_ids_b = batch["input_ids_b"].to(device)
        labels = batch["labels"].to(device)

        # Reset gradients at the start of each step
        optimizer_sbert.zero_grad()

        # Forward pass: compute logits for NLI labels
        logits = sbert_model(input_ids_a, input_ids_b)

        # Compute cross-entropy loss between predicted logits and true labels
        loss = criterion_nli(logits, labels)

        # Backpropagate and update model parameters
        loss.backward()
        optimizer_sbert.step()

        total_loss += loss.item()
        total_steps += 1

    avg_loss = total_loss / max(1, total_steps)
    print(f"Epoch {epoch+1:02d} | SBERT training loss = {avg_loss:.4f}")

SBERT Training Epoch 1:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 01 | SBERT training loss = 1.1026


SBERT Training Epoch 2:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 02 | SBERT training loss = 1.0822


SBERT Training Epoch 3:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 03 | SBERT training loss = 1.0711


### Validation accuracy

In [45]:
# Evaluate SBERT classifier on the validation set to monitor performance.
# We compute simple accuracy over the three NLI labels.

def evaluate_sbert(model, data_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in data_loader:
            input_ids_a = batch["input_ids_a"].to(device)
            input_ids_b = batch["input_ids_b"].to(device)
            labels = batch["labels"].to(device)

            logits = model(input_ids_a, input_ids_b)
            preds = logits.argmax(dim=-1)

            correct += (preds == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / max(1, total)
    return accuracy

val_acc = evaluate_sbert(sbert_model, val_loader)
print(f"Validation accuracy after SBERT training: {val_acc:.4f}")

Validation accuracy after SBERT training: 0.4125


### Save sentence encoder

In [46]:
# Save the trained SBERT sentence encoder weights for later use (Task 3 and Task 4).
# We save the full SBERT classifier, but you can also save only the encoder if preferred.

os.makedirs("../model", exist_ok=True)
sbert_ckpt_path = "../model/sbert_sentence_encoder.pt"
torch.save(sbert_model.state_dict(), sbert_ckpt_path)
sbert_ckpt_path

'../model/sbert_sentence_encoder.pt'

## Evaluation and Analysis

In this section, we evaluate the Sentence-BERT model from Task 2 on the NLI task 
and report a classification report (precision, recall, F1, support) on a held-out 
test set. We then discuss limitations and potential improvements [1].

In [None]:
# Quick check of unique labels in the collected test set
sbert_model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for batch in test_loader:
        labels = batch["labels"]
        all_labels.extend(labels.tolist())

import numpy as np
print("Unique labels in test set:", np.unique(all_labels))

Unique labels in test set: [-1  0  1  2]


In [49]:
# Import the sklearn metrics that we will use for the detailed classification report.
from sklearn.metrics import classification_report

# Evaluate the SBERT classifier on the test set and collect all predictions and labels.
# We then compute a classification report only over the valid NLI labels (0, 1, 2).
def get_classification_report(model, data_loader):
    # Switch the model to evaluation mode to disable dropout and gradient computation.
    model.eval()
    all_preds = []
    all_labels = []

    # Disable gradient tracking since we are only doing inference.
    with torch.no_grad():
        # Iterate over all batches in the test DataLoader.
        for batch in data_loader:
            # Move the batch tensors to the active device (CPU/GPU).
            input_ids_a = batch["input_ids_a"].to(device)
            input_ids_b = batch["input_ids_b"].to(device)
            labels = batch["labels"].to(device)

            # Forward pass through the SBERT classifier to obtain logits.
            logits = model(input_ids_a, input_ids_b)

            # Take the argmax over the class dimension to get predicted labels.
            preds = logits.argmax(dim=-1)

            # Store predictions and true labels on CPU for later metrics calculation.
            all_preds.extend(preds.cpu().tolist())
            all_labels.extend(labels.cpu().tolist())

    # Convert to numpy arrays for easier masking and inspection.
    all_preds_arr = np.array(all_preds)
    all_labels_arr = np.array(all_labels)

    # Keep only examples with valid labels {0, 1, 2} to match the three NLI classes.
    valid_mask = np.isin(all_labels_arr, [0, 1, 2])
    all_preds_arr = all_preds_arr[valid_mask]
    all_labels_arr = all_labels_arr[valid_mask]

    # Define human-readable names for the three NLI classes.
    target_names = ["entailment", "neutral", "contradiction"]
    labels = [0, 1, 2]

    # Compute the classification report string using sklearn, specifying labels explicitly.
    report_str = classification_report(
        all_labels_arr,
        all_preds_arr,
        labels=labels,
        target_names=target_names,
        digits=2
    )
    return report_str

In [50]:
test_report = get_classification_report(sbert_model, test_loader)
print(test_report)

               precision    recall  f1-score   support

   entailment       0.44      0.79      0.57       365
      neutral       0.46      0.39      0.42       309
contradiction       0.42      0.11      0.17       329

     accuracy                           0.44      1003
    macro avg       0.44      0.43      0.39      1003
 weighted avg       0.44      0.44      0.39      1003



### Classification report on NLI test set

```text
precision    recall  f1-score   support

   entailment       0.44      0.79      0.57       365
      neutral       0.46      0.39      0.42       309
contradiction       0.42      0.11      0.17       329

     accuracy                           0.44      1003
    macro avg       0.44      0.43      0.39      1003
 weighted avg       0.44      0.44      0.39      1003
```

- **Entailment** is detected with relatively high recall (0.79), meaning the model correctly finds most entailment pairs, but the precision is moderate (0.44), so there are many false positives.  
- **Neutral** has balanced precision and recall around 0.4, indicating that the model struggles to separate neutral from the other two classes.  
- **Contradiction** shows low recall (0.11), meaning the model misses most contradiction cases even though the precision is similar to the other classes, suggesting the classifier rarely predicts this label.

These results are reasonable for a small custom BERT (2 layers, 256 hidden size) pretrained on only ~100k sentences and fine-tuned on a reduced subset of SNLI+MNLI rather than the full datasets. 

## Datasets and preprocessing

- Pretraining corpus (Task 1):  
  We sampled 100k sentences from the BookCorpus dataset via the HuggingFace Hub 
  (`bookcorpus`, `plain_text`, `trust_remote_code=True`). Sentences were lowercased, 
  cleaned with a simple regex, and tokenized by whitespace into a custom word-level 
  vocabulary [1][3].

- NLI datasets (Task 2 and 3):  
  We combined the SNLI dataset and the MNLI dataset from the GLUE benchmark using 
  the HuggingFace `datasets` library. For SNLI, examples with label `-1` were 
  removed because no gold label was available. For MNLI, the `idx` column was 
  removed and only the text and label fields were kept [1][3].

- Train/validation/test splits:  
  From the concatenated SNLI + MNLI DatasetDict, we sampled a subset for training, 
  validation, and testing (e.g., 10k train, 2k validation, 2k test) using 
  `shuffle(seed=55).select(range(N))` for efficiency in this assignment setting [3].

## Hyperparameters

- BERT pretraining (Task 1):  
  - Number of encoder layers: 2  
  - Hidden size (`d_model`): 256  
  - Number of attention heads: 4  
  - Feed-forward size (`d_ff`): 4 × 256  
  - Maximum sequence length: 128  
  - Batch size: (e.g., 16)  
  - Learning rate: 1e-4 using Adam  
  - Pretraining steps per epoch: 1000 batches of randomly sampled sentence pairs  
  - Pretraining objective: combined masked language modeling (MLM) and 
    next sentence prediction (NSP).

- Sentence-BERT fine-tuning (Task 2):  
  - Base encoder: BERT model from Task 1 (weights loaded from 
    `bert_pretrained_from_scratch.pt`)  
  - Batch size: 32  
  - Optimizer: Adam with learning rate 2e-5  
  - Number of epochs: 3  
  - Sentence pooling: mean pooling over non-PAD tokens after the encoder  
  - Classification head: single linear layer on top of the concatenation 
    `[u; v; |u − v|]` with output dimension 3 (entailment, neutral, contradiction) [1][3].

## Modifications to the original models

- BERT architecture:  
  The original classroom BERT implementation (from `BERT.ipynb`) was adapted to a 
  smaller configuration (2 layers, 256-dimensional hidden states, 4 attention heads) 
  to reduce computational cost while preserving the key architectural components 
  (token/position/segment embeddings, multi-head self-attention, feed-forward layers, 
  masked LM head, NSP head) [2][5].

- Additional encoder interface:  
  We added an `encode(input_ids, segment_ids)` method to the BERT class that returns 
  the final encoder hidden states without computing MLM or NSP logits, enabling reuse 
  of the encoder as a sentence encoder in Sentence-BERT [2][5].

- Sentence-BERT wrapper:  
  We wrapped the BERT encoder in a `SentenceEncoder` module that adds `[CLS]` and 
  `[SEP]` tokens, runs the encoder, and applies mean pooling over non-PAD positions 
  to produce fixed-size sentence embeddings, following the design of Sentence-BERT [3][5].

## Limitations and potential improvements

- Limited pretraining data and capacity:  
  The BERT model was pretrained on only a 100k-sentence subset of BookCorpus with 
  a small 2-layer encoder, which is significantly smaller than standard BERT-base 
  pretraining (800M+ words). This limits the quality of the learned representations 
  and likely caps the final NLI accuracy [1][5].

- Simple tokenization:  
  We used a custom whitespace-based word-level tokenizer rather than a subword 
  tokenizer (WordPiece/BPE), which leads to a large vocabulary, no sharing of rare 
  subword units, and more out-of-vocabulary issues compared to standard BERT tokenizers [4][5].

- Training objective and data size:  
  For NLI, we trained only the simple `[u; v; |u − v|]` classifier with cross-entropy 
  loss on a relatively small subset of SNLI and MNLI. Scaling up the number of training 
  examples, adding learning-rate schedules or regularization, and exploring alternative 
  objectives (e.g., contrastive losses from SimCSE) could further improve sentence 
  embeddings and NLI performance [1][3][5].

- Future improvements:  
  Potential extensions include using a subword tokenizer from HuggingFace, increasing 
  the depth and width of the encoder, using the full SNLI/MNLI datasets, and adding 
  validation-based early stopping or more advanced optimization strategies.