# NLP Selection Camp - Inter IIT
# Neel B. Rambhia (22B1298)

## Install and import necessary libraries

In [None]:
!pip install -q torch torchvision torchaudio
!pip install -q transformers datasets tqdm matplotlib nltk

import os
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"

  You can safely remove it manually.
  You can safely remove it manually.
  You can safely remove it manually.
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gensim 4.3.0 requires FuzzyTM>=0.4.0, which is not installed.
opencv-python 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= "3.9", but you have numpy 1.26.4 which is incompatible.
qiskit-transpiler-service 0.4.3 requires qiskit~=1.0, but you have qiskit 2.1.1 which is incompatible.
streamlit 1.30.0 requires packaging<24,>=16.8, but you have packaging 25.0 which is incompatible.
streamlit 1.30.0 requires pillow<11,>=7.1.0, but you have pillow 11.3.0 which is incompatible.
streamlit 1.30.0 requires protobuf<5,>=3.20, but you have protobuf 6.32.1 which is incompatible.
streamlit 1.30.0 requires rich<14,>=10.14.0, but you have rich 14.2.0 which is incompatible.


In [1]:
import os, random, math, time
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer
from tqdm.auto import tqdm
import nltk
nltk.download('punkt')
nltk.download('punkt_tab')
from nltk.tokenize import sent_tokenize
import matplotlib.pyplot as plt
# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


Device: cuda


In [2]:
def set_seed(seed=42):
  random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
  if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
set_seed(42)

## Importing data from Wikitext Dataset

In [6]:
raw = load_dataset("wikitext", "wikitext-2-v1")
train_raw = raw["train"]["text"]
valid_raw = raw["validation"]["text"]

def lines_to_sentences(lines):
  sents = []
  for line in lines:
      if not line or line.strip()=="":
        continue
      for s in sent_tokenize(line):
        s = s.strip()
        sents.append(s)
  return sents

train_sents = lines_to_sentences(train_raw)
valid_sents = lines_to_sentences(valid_raw)
print("Train sentences:", len(train_sents), "\n Valid sentences:", len(valid_sents))

Train sentences: 87073 
 Valid sentences: 9041


Example sentences from Dataset

In [7]:
train_sents[:8]

['= Valkyria Chronicles III =',
 'Senjō no Valkyria 3 : <unk> Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit .',
 'Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable .',
 'Released in January 2011 in Japan , it is the third game in the Valkyria series .',
 '<unk> the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the " Nameless " , a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " <unk> Raven " .',
 'The game began development in 2010 , carrying over a large portion of the work done on Valkyria Chronicles II .',
 'While it retained the standard features of the series , it also underwent multiple adjustments , such as making the game more <unk> for series

## Importing tokenizer

In [8]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
max_seq_len = 128
print("Vocab size:", tokenizer.vocab_size)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Vocab size: 30522


## Tokenizing the dataset and building pairs of positive and negative samples for NSP

In [9]:
class WikiSentencePairDataset(Dataset):
  def __init__(self, sents, tokenizer, max_seq_len=128, mask_prob=0.15):
    self.tokenizer = tokenizer
    self.max_seq_len = max_seq_len
    self.mask_prob = mask_prob

    # tokenize sentences (no special tokens)
    self.tokens = [tokenizer.encode(s, add_special_tokens=False) for s in sents]
    self.vocab_size = tokenizer.vocab_size

    # build pairs: positive and a negative sample
    self.pairs = []
    n = len(self.tokens)
    for i in range(n-1):
      self.pairs.append((i, i+1, 1))
      j = random.randrange(0, n)
      while j == i+1:
        j = random.randrange(0, n)
      self.pairs.append((i, j, 0))

  def __len__(self):
    return len(self.pairs)

  def mask_tokens(self, input_ids):
    labels = [-100]*len(input_ids)
    for i, tok in enumerate(input_ids):
      if tok in (self.tokenizer.cls_token_id, self.tokenizer.sep_token_id, self.tokenizer.pad_token_id):
          continue
      if random.random() < self.mask_prob:
          prob = random.random()
          labels[i] = input_ids[i]
          if prob < 0.8:
              input_ids[i] = self.tokenizer.mask_token_id
          elif prob < 0.9:
              input_ids[i] = random.randrange(0, self.vocab_size)
    return input_ids, labels

  def __getitem__(self, idx):
    a_idx, b_idx, is_next = self.pairs[idx]
    a = self.tokens[a_idx].copy()
    b = self.tokens[b_idx].copy()
    max_tokens = self.max_seq_len - 3
    while len(a) + len(b) > max_tokens:
        if len(a) > len(b): a.pop()
        else: b.pop()
    input_ids = [self.tokenizer.cls_token_id] + a + [self.tokenizer.sep_token_id] + b + [self.tokenizer.sep_token_id]
    token_type_ids = [0]*(1+len(a)+1) + [1]*(len(b)+1)
    orig_input_ids = input_ids.copy()
    input_ids_masked, mlm_labels = self.mask_tokens(input_ids)
    return {
        "input_ids": torch.tensor(input_ids_masked, dtype=torch.long),
        "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
        "mlm_labels": torch.tensor(mlm_labels, dtype=torch.long),
        "is_next": torch.tensor(is_next, dtype=torch.long),
        "orig_input_ids": orig_input_ids,
        "length": len(input_ids_masked)
    }

In [10]:
def collate_fn(batch, pad_id=0):
    max_len = max(x["length"] for x in batch)
    input_ids = []
    token_type_ids = []
    mlm_labels = []
    is_next = []
    attention_mask = []
    orig_input_ids = []
    for x in batch:
        pad = max_len - x["length"]
        input_ids.append(F.pad(x["input_ids"], (0, pad), value=pad_id))
        token_type_ids.append(F.pad(x["token_type_ids"], (0, pad), value=0))
        mlm_labels.append(F.pad(x["mlm_labels"], (0, pad), value=-100))
        attention_mask.append(torch.cat([torch.ones(x["length"], dtype=torch.long), torch.zeros(pad, dtype=torch.long)]))
        is_next.append(x["is_next"])
        orig_input_ids.append(x["orig_input_ids"] + [pad_id]*pad)
    return {
        "input_ids": torch.stack(input_ids),
        "token_type_ids": torch.stack(token_type_ids),
        "attention_mask": torch.stack(attention_mask),
        "mlm_labels": torch.stack(mlm_labels),
        "is_next": torch.stack(is_next),
        "orig_input_ids": orig_input_ids
    }

# Example:
train_ds = WikiSentencePairDataset(train_sents, tokenizer, max_seq_len=max_seq_len)
valid_ds = WikiSentencePairDataset(valid_sents, tokenizer, max_seq_len=max_seq_len)

train_loader = DataLoader(train_ds, batch_size=8, shuffle=True, collate_fn=lambda b: collate_fn(b, tokenizer.pad_token_id))
valid_loader = DataLoader(valid_ds, batch_size=8, shuffle=False, collate_fn=lambda b: collate_fn(b, tokenizer.pad_token_id))

## Building the encoder only Multi layer transformer with self attention

In [11]:
class TransformerBlock(nn.Module):
    def __init__(self, hidden_size, num_heads, ff_size, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, dropout=dropout)
        self.ln1 = nn.LayerNorm(hidden_size)
        self.ff = nn.Sequential(nn.Linear(hidden_size, ff_size), nn.GELU(), nn.Linear(ff_size, hidden_size))
        self.ln2 = nn.LayerNorm(hidden_size)
        self.drop = nn.Dropout(dropout)
    def forward(self, x, key_padding_mask=None):
        attn_out, _ = self.attn(x, x, x, key_padding_mask=key_padding_mask)
        x = x + self.drop(attn_out)
        x = self.ln1(x)
        ff_out = self.ff(x)
        x = x + self.drop(ff_out)
        x = self.ln2(x)
        return x

class MiniBert(nn.Module):
    def __init__(self, vocab_size, max_pos, type_vocab_size, hidden_size=256, num_layers=4, num_heads=4, ff_size=1024, dropout=0.1):
        super().__init__()
        self.token_embeddings = nn.Embedding(vocab_size, hidden_size)
        self.position_embeddings = nn.Embedding(max_pos, hidden_size)
        self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size)
        self.layernorm = nn.LayerNorm(hidden_size)
        self.dropout = nn.Dropout(dropout)

        self.blocks = nn.ModuleList([TransformerBlock(hidden_size, num_heads, ff_size, dropout) for _ in range(num_layers)])

        # MLM head
        self.mlm_dense = nn.Linear(hidden_size, hidden_size)
        self.mlm_act = nn.GELU()
        self.mlm_layernorm = nn.LayerNorm(hidden_size)

        # decoder/projection for MLM
        self.mlm_decoder = nn.Linear(hidden_size, vocab_size, bias=False)
        self.mlm_bias = nn.Parameter(torch.zeros(vocab_size))

        # NSP head
        self.nsp_classifier = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Tanh(), nn.Linear(hidden_size, 2))
        self.mlm_decoder.weight = self.token_embeddings.weight

    def forward(self, input_ids, token_type_ids=None, attention_mask=None):
        b, seq = input_ids.shape
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)
        if attention_mask is None:
            attention_mask = (input_ids != tokenizer.pad_token_id).long()
        pos_ids = torch.arange(seq, device=input_ids.device).unsqueeze(0).expand(b, -1)
        x = self.token_embeddings(input_ids) + self.position_embeddings(pos_ids) + self.token_type_embeddings(token_type_ids)
        x = self.layernorm(x)
        x = self.dropout(x)
        key_padding_mask = (attention_mask == 0)  # True for pad tokens
        for blk in self.blocks:
            x = blk(x, key_padding_mask=key_padding_mask)
        # MLM
        mlm_h = self.mlm_dense(x)
        mlm_h = self.mlm_act(mlm_h)
        mlm_h = self.mlm_layernorm(mlm_h)
        logits = self.mlm_decoder(mlm_h) + self.mlm_bias
        cls = x[:, 0, :]
        nsp_logits = self.nsp_classifier(cls)
        return logits, nsp_logits

## Training:

In [12]:
model = MiniBert(vocab_size=tokenizer.vocab_size, max_pos=max_seq_len, type_vocab_size=2,
                 hidden_size=256, num_layers=4, num_heads=4, ff_size=1024, dropout=0.1).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
nsp_loss_fct = nn.CrossEntropyLoss()

def train_epoch(model, loader, optimizer):
    model.train()
    running = []
    for batch in tqdm(loader, desc="train"):
        input_ids = batch["input_ids"].to(device)
        token_type_ids = batch["token_type_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        mlm_labels = batch["mlm_labels"].to(device)
        is_next = batch["is_next"].to(device)

        logits, nsp_logits = model(input_ids, token_type_ids, attention_mask)
        B, S, V = logits.shape
        mlm_loss = mlm_loss_fct(logits.view(-1, V), mlm_labels.view(-1))
        nsp_loss = nsp_loss_fct(nsp_logits.view(-1, 2), is_next.view(-1))
        loss = mlm_loss + nsp_loss

        optimizer.zero_grad(); loss.backward(); optimizer.step()
        running.append(loss.item())
    return sum(running)/len(running)

def eval_model(model, loader):
    model.eval()
    total = 0; correct = 0; mlm_losses=[]
    with torch.no_grad():
        for batch in tqdm(loader, desc="eval"):
            input_ids = batch["input_ids"].to(device)
            token_type_ids = batch["token_type_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            mlm_labels = batch["mlm_labels"].to(device)
            is_next = batch["is_next"].to(device)

            logits, nsp_logits = model(input_ids, token_type_ids, attention_mask)
            preds = torch.argmax(nsp_logits, dim=-1)
            correct += (preds == is_next).sum().item()
            total += is_next.numel()
            B,S,V = logits.shape
            loss = mlm_loss_fct(logits.view(-1, V), mlm_labels.view(-1))
            mlm_losses.append(loss.item())
    return correct/total, sum(mlm_losses)/len(mlm_losses)

# Training loop
epochs = 3 #I have kept 3 epochs due to lack of computational resources
for epoch in range(1, epochs+1):
    train_loss = train_epoch(model, train_loader, optimizer)
    nsp_acc, mlm_eval_loss = eval_model(model, valid_loader)
    print(f"Epoch {epoch} train_loss {train_loss:.4f} | val NSP acc {nsp_acc:.4f} | val MLM loss {mlm_eval_loss:.4f}")
    torch.save({"model_state": model.state_dict(), "optimizer": optimizer.state_dict()}, f"mini_bert_epoch{epoch}.pt")


train:   0%|          | 0/21768 [00:00<?, ?it/s]

eval:   0%|          | 0/2260 [00:00<?, ?it/s]

Epoch 1 train_loss 10.5647 | val NSP acc 0.5106 | val MLM loss 5.4686


train:   0%|          | 0/21768 [00:00<?, ?it/s]

eval:   0%|          | 0/2260 [00:00<?, ?it/s]

Epoch 2 train_loss 6.5666 | val NSP acc 0.5181 | val MLM loss 4.7738


train:   0%|          | 0/21768 [00:00<?, ?it/s]

eval:   0%|          | 0/2260 [00:00<?, ?it/s]

Epoch 3 train_loss 6.1392 | val NSP acc 0.5418 | val MLM loss 4.5647


## Testing:

In [16]:
# Top-k predictions for masked positions from one batch of validation
model.eval()
batch = next(iter(valid_loader))
input_ids = batch["input_ids"].to(device)
token_type_ids = batch["token_type_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
orig_input_ids = batch["orig_input_ids"]

with torch.no_grad():
    logits, nsp_logits = model(input_ids, token_type_ids, attention_mask)

probs = F.softmax(logits, dim=-1)

def func(tokens, masked_positions, topk_preds, labels, k=3):
    print("------------------------------------------------------------")
    print("INPUT:")
    print(" ".join(tokens))
    print()

    for pos, preds, lab in zip(masked_positions, topk_preds, labels):
        tok = tokens[pos]
        print(f"Masked index {pos}:")
        print(f"  original token: {lab}")
        print(f"  top-{k}: {preds}")
        print(f"  raw token (wordpiece): {tok}")
        print()

    recon = []
    for t in tokens:
        if t.startswith("##"):
            recon[-1] = recon[-1] + t[2:]
        else:
            recon.append(t)
    print("RECONSTRUCTED SENTENCE:")
    print(" ".join(recon))
    print("------------------------------------------------------------\n")


for i in range(min(4, input_ids.size(0))):
    labels = batch["mlm_labels"][i]
    mask_positions = (labels != -100).nonzero(as_tuple=True)[0].tolist()

    tokens = tokenizer.convert_ids_to_tokens(orig_input_ids[i])
    topk_preds = []
    true_lbls = []
    for pos in mask_positions:
        topk = torch.topk(probs[i, pos], k=5).indices.tolist()
        topk_tokens = tokenizer.convert_ids_to_tokens(topk)
        lbl = tokenizer.convert_ids_to_tokens([labels[pos].item()])[0]
        topk_preds.append(topk_tokens)
        true_lbls.append(lbl)

    print(f"\n Example {i}")
    func(tokens, mask_positions, topk_preds, true_lbls, k=3)
    pred = torch.argmax(nsp_logits[i]).item()
    truth = batch["is_next"][i].item()
    print(f"NSP prediction: {'IS_NEXT' if pred==1 else 'NOT_NEXT'} | True: {'IS_NEXT' if truth==1 else 'NOT_NEXT'}")

from sklearn.metrics import confusion_matrix
nsp_preds = torch.argmax(nsp_logits, dim=-1).cpu().numpy()
cm = confusion_matrix(batch["is_next"].numpy(), nsp_preds)
print("\nGlobal tiny-batch NSP Confusion (rows true, columns predicted):\n", cm)



 Example 0
------------------------------------------------------------
INPUT:
[CLS] = ho ##mar ##us gamma ##rus = [SEP] ho ##mar ##us gamma ##rus , known as the european lobster or common lobster , is a species of < un ##k > lobster from the eastern atlantic ocean , mediterranean sea and parts of the black sea . [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]

Masked index 2:
  original token: ho
  top-3: ['=', 'the', 'and', 'is', 'an']
  raw token (wordpiece): ho

Masked index 6:
  original token: ##rus
  top-3: ['=', '.', ')', ':', 'and']
  raw token (wordpiece): ##rus

Masked index 17:
  original token: the
  top-3: ['a', 'the', 'an', 'his', '"']
  raw token (wordpiece): the

Masked index 31:
  original token: >
  top-3: ['>', ',', 'christian', 'seems', 'he']
  raw token (wordpiece): >

Masked index 40:
  original token: sea
  top-3: ['sea', ',', 'production', 'season', 'believed']
  ra

## Saving the model for future use.
This can be loaded durectly for testing/ use.

In [14]:
torch.save(model.state_dict(), "mini_bert_weights.pt")
# load:
model2 = MiniBert(
    vocab_size=tokenizer.vocab_size,
    max_pos=max_seq_len,
    type_vocab_size=2,
    hidden_size=256,
    num_layers=4,
    num_heads=4,
    ff_size=1024,
    dropout=0.1
).to(device)

model2.load_state_dict(torch.load("mini_bert_weights.pt"))
model2.eval()

MiniBert(
  (token_embeddings): Embedding(30522, 256)
  (position_embeddings): Embedding(128, 256)
  (token_type_embeddings): Embedding(2, 256)
  (layernorm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (blocks): ModuleList(
    (0-3): 4 x TransformerBlock(
      (attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
      )
      (ln1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (ff): Sequential(
        (0): Linear(in_features=256, out_features=1024, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=1024, out_features=256, bias=True)
      )
      (ln2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (drop): Dropout(p=0.1, inplace=False)
    )
  )
  (mlm_dense): Linear(in_features=256, out_features=256, bias=True)
  (mlm_act): GELU(approximate='none')
  (mlm_layernorm): LayerNorm((256,), eps=1e-05, e