In [2]:
# Cell 1: Install compatible PyTorch + Transformers + Datasets for Colab (CUDA 12)

!pip install --quiet torch==2.8.0 torchvision==0.23.0 torchaudio==2.8.0 \
    transformers==4.45.0 datasets==3.0.1 tqdm --upgrade

# Check everything is fine
import torch
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.9/9.9 MB[0m [31m82.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m471.6/471.6 kB[0m [31m21.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m177.6/177.6 kB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.0/3.0 MB[0m [31m65.4 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gcsfs 2025.3.0 requires fsspec==2025.3.0, but you have fsspec 2024.6.1 which is incompatible.[0m[31m
[0mPyTorch version: 2.8.0+cu126
CUDA available: True


In [3]:
# Cell 2: imports + device + deterministic seed
import os
import random
import math
from typing import List, Tuple, Dict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from transformers import BertTokenizerFast
from datasets import load_dataset

from tqdm.auto import tqdm

# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# reproducibility
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)


Device: cuda


In [4]:
import torch
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("CUDA device name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "None")


PyTorch version: 2.8.0+cu126
CUDA available: True
CUDA device name: Tesla T4


In [5]:
# Cell 3: Mini-Plus (≈30 min training)
config = {
    "pretrained_tokenizer": "bert-base-uncased",
    "max_len": 128,
    "batch_size": 12,
    "num_epochs": 5,
    "lr": 1.5e-4,
    "hidden_size": 384,
    "num_heads": 6,
    "ffn_dim": 1536,
    "num_layers": 4,
    "dropout": 0.1,
    "mask_prob": 0.15,
    "print_every_n_steps": 300
}
print("Config:", config)


Config: {'pretrained_tokenizer': 'bert-base-uncased', 'max_len': 128, 'batch_size': 12, 'num_epochs': 5, 'lr': 0.00015, 'hidden_size': 384, 'num_heads': 6, 'ffn_dim': 1536, 'num_layers': 4, 'dropout': 0.1, 'mask_prob': 0.15, 'print_every_n_steps': 300}


In [6]:
# Cell 4: load wikitext-2 and extract sentences
dataset = load_dataset("wikitext", "wikitext-2-v1")

# quick function to split text into sentences (simple)
def split_paragraph_to_sentences(paragraph: str) -> List[str]:
    # keep simple: split on newline and periods; filter tiny sentences
    if paragraph is None:
        return []
    parts = []
    for line in paragraph.split("\n"):
        line = line.strip()
        if not line:
            continue
        # split by ". " but preserve abbreviations poorly - good enough for this exercise
        for s in line.split(". "):
            s = s.strip()
            if len(s) > 3:
                # ensure punctuation ends with period for clarity
                if not s.endswith("."):
                    s = s + "."
                parts.append(s)
    return parts

# build a list of sentences from train split
sentences = []
for item in tqdm(dataset["train"], desc="Extracting sentences"):
    txt = item["text"]
    sents = split_paragraph_to_sentences(txt)
    sentences.extend(sents)

print("Total sentences extracted (train):", len(sentences))
# keep some cap so preprocessing and memory are reasonable
MAX_SENTENCES = 200000  # cap to avoid huge memory usage
if len(sentences) > MAX_SENTENCES:
    sentences = sentences[:MAX_SENTENCES]
    print("Capped sentences to:", len(sentences))


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

wikitext-2-v1/test-00000-of-00001.parque(…):   0%|          | 0.00/685k [00:00<?, ?B/s]

wikitext-2-v1/train-00000-of-00001.parqu(…):   0%|          | 0.00/6.07M [00:00<?, ?B/s]

wikitext-2-v1/validation-00000-of-00001.(…):   0%|          | 0.00/618k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

Extracting sentences:   0%|          | 0/36718 [00:00<?, ?it/s]

Total sentences extracted (train): 85764


In [7]:
# Cell 5: tokenizer and utilities
tokenizer = BertTokenizerFast.from_pretrained(config["pretrained_tokenizer"])
CLS_ID = tokenizer.cls_token_id
SEP_ID = tokenizer.sep_token_id
MASK_ID = tokenizer.mask_token_id
VOCAB_SIZE = tokenizer.vocab_size

print("Tokenizer vocab size:", VOCAB_SIZE)
print("[CLS],[SEP],[MASK] ids:", CLS_ID, SEP_ID, MASK_ID)

# Create positive and negative sentence pairs
# For each consecutive sentence pair (i, i+1) -> label 1 (is_next)
# Create negative by pairing sentence i with random sentence from corpus -> label 0
def build_sentence_pairs(sent_list: List[str]) -> List[Tuple[str, str, int]]:
    pairs = []
    n = len(sent_list)
    for i in range(n - 1):
        a = sent_list[i]
        b = sent_list[i + 1]
        pairs.append((a, b, 1))  # positive
        # negative sample: choose a random sentence not equal to b
        rand_idx = random.randint(0, n - 1)
        # ensure not same as true next
        if rand_idx == i + 1:
            rand_idx = (rand_idx + 7) % n
        pairs.append((a, sent_list[rand_idx], 0))
    return pairs

pairs = build_sentence_pairs(sentences)
print("Total sentence pairs:", len(pairs))
pairs = pairs[:80000]
print("Using pairs:", len(pairs))



tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Tokenizer vocab size: 30522
[CLS],[SEP],[MASK] ids: 101 102 103
Total sentence pairs: 171526
Using pairs: 80000




In [8]:
# Cell 6: PyTorch Dataset creating input_ids, token_type_ids, attention_mask, mlm_labels, nsp_label
class BertMiniDataset(Dataset):
    def __init__(self, pairs, tokenizer, max_len=128, mask_prob=0.15):
        self.pairs = pairs
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.mask_prob = mask_prob

    def __len__(self):
        return len(self.pairs)

    def _encode_pair(self, sent_a: str, sent_b: str):
        # tokenizer will add special tokens if requested; we will add manually to control token_type_ids
        # Use tokenizer.encode_plus for convenience
        enc = self.tokenizer.encode_plus(
            sent_a,
            sent_b,
            add_special_tokens=True,
            truncation=True,
            max_length=self.max_len,
            padding="max_length",
            return_attention_mask=True,
            return_token_type_ids=True,
        )
        return enc

    def _mask_tokens(self, input_ids):
        """
        Implement BERT's masking strategy:
         - 15% tokens selected for possible masking
         - For selected tokens: 80% -> [MASK], 10% -> random token, 10% -> unchanged
        Return masked_input_ids, mlm_labels (with -100 for tokens not to predict)
        """
        input_ids = input_ids.clone()
        labels = torch.full(input_ids.shape, -100, dtype=torch.long)  # -100 ignored in loss
        # we don't want to mask special tokens (CLS, SEP, PAD)
        special_tokens_mask = [
            self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True)
            for val in input_ids.tolist()
        ]
        special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        for i in range(input_ids.size(0)):
            # choose maskable positions
            maskable = ~special_tokens_mask[i]
            probs = torch.rand_like(input_ids[i].float())
            mask_pos = (probs < self.mask_prob) & maskable
            # ensure at least one token masked? not necessary
            labels[i][mask_pos] = input_ids[i][mask_pos]

            # now apply replacement rules
            for idx in torch.nonzero(mask_pos, as_tuple=False).view(-1).tolist():
                rand = random.random()
                if rand < 0.8:
                    # 80% replace with [MASK]
                    input_ids[i, idx] = MASK_ID
                elif rand < 0.9:
                    # 10% replace with random token (not special)
                    input_ids[i, idx] = random.randint(0, VOCAB_SIZE - 1)
                else:
                    # 10% keep original
                    pass
        return input_ids, labels

    def __getitem__(self, idx):
        a, b, nsp_label = self.pairs[idx]
        enc = self._encode_pair(a, b)
        input_ids = torch.tensor(enc["input_ids"], dtype=torch.long)
        token_type_ids = torch.tensor(enc["token_type_ids"], dtype=torch.long)
        attention_mask = torch.tensor(enc["attention_mask"], dtype=torch.long)

        # make mlm labels and masked input_ids
        masked_input_ids, mlm_labels = self._mask_tokens(input_ids.unsqueeze(0))
        masked_input_ids = masked_input_ids.squeeze(0)
        mlm_labels = mlm_labels.squeeze(0)

        return {
            "input_ids": masked_input_ids,
            "token_type_ids": token_type_ids,
            "attention_mask": attention_mask,
            "mlm_labels": mlm_labels,
            "nsp_label": torch.tensor(nsp_label, dtype=torch.long),
            # also keep original (for pretty printing) - not used in training
            "orig_input_ids": torch.tensor(enc["input_ids"], dtype=torch.long)
        }

# create dataset & dataloader
full_dataset = BertMiniDataset(pairs, tokenizer, max_len=config["max_len"], mask_prob=config["mask_prob"])
train_loader = DataLoader(full_dataset, batch_size=config["batch_size"], shuffle=True, drop_last=True)
print("Created DataLoader with batches:", len(train_loader))


Created DataLoader with batches: 6666


In [9]:
# Cell 7: attention, feed-forward, encoder layer
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, hidden_size, num_heads, dropout=0.1):
        super().__init__()
        assert hidden_size % num_heads == 0, "hidden_size must be divisible by num_heads"
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(hidden_size, hidden_size * 3)
        self.out = nn.Linear(hidden_size, hidden_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, attn_mask=None):
        # x: (B, T, hidden)
        B, T, H = x.size()
        qkv = self.qkv(x)  # (B, T, 3H)
        q, k, v = qkv.chunk(3, dim=-1)
        # reshape for multi-head: (B, num_heads, T, head_dim)
        def split_heads(tensor):
            return tensor.view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        q = split_heads(q)
        k = split_heads(k)
        v = split_heads(v)

        # scaled dot-product
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale  # (B, nh, T, T)
        if attn_mask is not None:
            # attn_mask: (B, T) where 1 = keep, 0 = pad; convert to (B, 1, 1, T) additive mask
            mask = attn_mask.unsqueeze(1).unsqueeze(2)  # (B,1,1,T)
            attn_scores = attn_scores.masked_fill(mask == 0, -1e4)  # use smaller negative for fp16

        attn_probs = torch.softmax(attn_scores, dim=-1)
        attn_probs = self.dropout(attn_probs)

        context = torch.matmul(attn_probs, v)  # (B, nh, T, head_dim)
        context = context.permute(0, 2, 1, 3).contiguous().view(B, T, H)  # (B, T, H)
        out = self.out(context)
        return out


class FeedForward(nn.Module):
    def __init__(self, hidden_size, ffn_dim, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(hidden_size, ffn_dim)
        self.linear2 = nn.Linear(ffn_dim, hidden_size)
        self.act = nn.GELU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.linear1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.linear2(x)
        x = self.dropout(x)
        return x


class EncoderLayer(nn.Module):
    def __init__(self, hidden_size, num_heads, ffn_dim, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadSelfAttention(hidden_size, num_heads, dropout)
        self.norm1 = nn.LayerNorm(hidden_size, eps=1e-6)
        self.ff = FeedForward(hidden_size, ffn_dim, dropout)
        self.norm2 = nn.LayerNorm(hidden_size, eps=1e-6)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, attn_mask=None):
        # Self-attention + residual
        attn_out = self.attn(x, attn_mask=attn_mask)
        x = x + self.dropout(attn_out)
        x = self.norm1(x)

        # Feed-forward + residual
        ff_out = self.ff(x)
        x = x + self.dropout(ff_out)
        x = self.norm2(x)
        return x


In [10]:
# Cell 8: Mini-BERT model
class MiniBert(nn.Module):
    def __init__(self, vocab_size, hidden_size=256, num_heads=4, ffn_dim=512,
                 num_layers=4, max_len=128, type_vocab_size=2, dropout=0.1):
        super().__init__()
        self.hidden_size = hidden_size
        self.token_embeddings = nn.Embedding(vocab_size, hidden_size)
        self.position_embeddings = nn.Embedding(max_len, hidden_size)
        self.segment_embeddings = nn.Embedding(type_vocab_size, hidden_size)
        self.layernorm = nn.LayerNorm(hidden_size, eps=1e-6)
        self.dropout = nn.Dropout(dropout)

        self.layers = nn.ModuleList([
            EncoderLayer(hidden_size, num_heads, ffn_dim, dropout)
            for _ in range(num_layers)
        ])

        # MLM head: project hidden states back to vocab logits
        # use an intermediate dense + activation as BERT does (tied weights optional)
        self.mlm_dense = nn.Linear(hidden_size, hidden_size)
        self.mlm_act = nn.GELU()
        self.mlm_layernorm = nn.LayerNorm(hidden_size, eps=1e-6)
        self.mlm_decoder = nn.Linear(hidden_size, vocab_size, bias=True)

        # NSP head: simple classifier on [CLS] embedding
        self.nsp_classifier = nn.Linear(hidden_size, 2)

        # tie mlm_decoder weight with token_embeddings if wanted
        self.mlm_decoder.weight = self.token_embeddings.weight

        self.max_len = max_len

    def forward(self, input_ids, token_type_ids=None, attention_mask=None):
        B, T = input_ids.size()
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)
        if attention_mask is None:
            attention_mask = (input_ids != tokenizer.pad_token_id).long()

        # Embeddings
        pos_ids = torch.arange(0, T, device=input_ids.device).unsqueeze(0).expand(B, T)
        x = self.token_embeddings(input_ids) + self.position_embeddings(pos_ids) + self.segment_embeddings(token_type_ids)
        x = self.layernorm(x)
        x = self.dropout(x)

        # encoder stack
        for layer in self.layers:
            x = layer(x, attn_mask=attention_mask)

        # MLM head (apply only to sequence output)
        mlm_hidden = self.mlm_dense(x)
        mlm_hidden = self.mlm_act(mlm_hidden)
        mlm_hidden = self.mlm_layernorm(mlm_hidden)
        mlm_logits = self.mlm_decoder(mlm_hidden)  # (B, T, vocab)

        # NSP head uses [CLS] token (index 0)
        cls_hidden = x[:, 0, :]  # (B, hidden)
        nsp_logits = self.nsp_classifier(cls_hidden)  # (B, 2)

        return mlm_logits, nsp_logits


In [11]:
# Cell 9: instantiate model and training components
model = MiniBert(
    vocab_size=VOCAB_SIZE,
    hidden_size=config["hidden_size"],
    num_heads=config["num_heads"],
    ffn_dim=config["ffn_dim"],
    num_layers=config["num_layers"],
    max_len=config["max_len"],
    dropout=config["dropout"]
).to(device)

# losses
mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)  # ignore positions with -100
nsp_loss_fct = nn.CrossEntropyLoss()

# optimizer & scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"])
total_steps = len(train_loader) * config["num_epochs"]
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.0, total_iters=total_steps)

print(model)


MiniBert(
  (token_embeddings): Embedding(30522, 384)
  (position_embeddings): Embedding(128, 384)
  (segment_embeddings): Embedding(2, 384)
  (layernorm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (layers): ModuleList(
    (0-3): 4 x EncoderLayer(
      (attn): MultiHeadSelfAttention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (out): Linear(in_features=384, out_features=384, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (ff): FeedForward(
        (linear1): Linear(in_features=384, out_features=1536, bias=True)
        (linear2): Linear(in_features=1536, out_features=384, bias=True)
        (act): GELU(approximate='none')
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )

In [12]:
# Cell 10: prepare a small validation pairs dataset and loader
# Build sentence list for validation split
val_sentences = []
for item in dataset["validation"]:
    txt = item["text"]
    val_sentences.extend(split_paragraph_to_sentences(txt))
val_sentences = [s for s in val_sentences if len(s) > 3]
print("Validation sentences:", len(val_sentences))

val_pairs = build_sentence_pairs(val_sentences)
val_pairs = val_pairs[:5000]  # small validation subset

val_dataset = BertMiniDataset(val_pairs, tokenizer, max_len=config["max_len"], mask_prob=config["mask_prob"])
val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False)
print("Validation batches:", len(val_loader))


Validation sentences: 9080
Validation batches: 417


In [13]:
# Cell 11: full training loop with mixed precision + periodic validation
from torch.cuda.amp import GradScaler, autocast

scaler = GradScaler()
global_step = 0

def evaluate(model, loader):
    """Compute MLM and NSP accuracy on validation loader."""
    model.eval()
    total_nsp, correct_nsp = 0, 0
    total_mlm, correct_mlm = 0, 0
    with torch.no_grad():
        for batch in loader:
            input_ids = batch["input_ids"].to(device)
            token_type_ids = batch["token_type_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            mlm_labels = batch["mlm_labels"].to(device)
            nsp_label = batch["nsp_label"].to(device)

            mlm_logits, nsp_logits = model(input_ids, token_type_ids, attention_mask)

            # NSP accuracy
            preds_nsp = torch.argmax(nsp_logits, dim=-1)
            correct_nsp += (preds_nsp == nsp_label).sum().item()
            total_nsp += nsp_label.size(0)

            # MLM accuracy (only masked positions)
            mlm_preds = torch.argmax(mlm_logits, dim=-1)
            mask_positions = mlm_labels != -100
            total_mlm += mask_positions.sum().item()
            correct_mlm += (mlm_preds[mask_positions] == mlm_labels[mask_positions]).sum().item()

    model.train()
    mlm_acc = correct_mlm / total_mlm if total_mlm else 0.0
    nsp_acc = correct_nsp / total_nsp if total_nsp else 0.0
    return mlm_acc, nsp_acc


for epoch in range(config["num_epochs"]):
    model.train()
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['num_epochs']}")
    running_loss = 0.0

    for step, batch in enumerate(pbar):
        input_ids = batch["input_ids"].to(device)
        token_type_ids = batch["token_type_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        mlm_labels = batch["mlm_labels"].to(device)
        nsp_label = batch["nsp_label"].to(device)

        optimizer.zero_grad()

        # --- forward + loss under mixed precision ---
        with autocast():
            mlm_logits, nsp_logits = model(input_ids, token_type_ids, attention_mask)
            mlm_loss = mlm_loss_fct(mlm_logits.view(-1, VOCAB_SIZE), mlm_labels.view(-1))
            nsp_loss = nsp_loss_fct(nsp_logits, nsp_label)
            loss = mlm_loss + nsp_loss

        # --- backward & optimization ---
        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        running_loss += loss.item()
        global_step += 1

        # --- periodic validation ---
        if global_step % config["print_every_n_steps"] == 0:
            mlm_acc, nsp_acc = evaluate(model, val_loader)
            pbar.set_postfix({
                "loss": f"{running_loss/(step+1):.4f}",
                "val_mlm_acc": f"{mlm_acc:.4f}",
                "val_nsp_acc": f"{nsp_acc:.4f}"
            })

    # --- end of epoch summary ---
    val_mlm_acc, val_nsp_acc = evaluate(model, val_loader)
    print(f"\nEpoch {epoch+1} finished.")
    print(f"Validation MLM acc: {val_mlm_acc:.4f}, NSP acc: {val_nsp_acc:.4f}\n")


  scaler = GradScaler()


Epoch 1/5:   0%|          | 0/6666 [00:00<?, ?it/s]

  with autocast():



Epoch 1 finished.
Validation MLM acc: 0.2335, NSP acc: 0.5318



Epoch 2/5:   0%|          | 0/6666 [00:00<?, ?it/s]


Epoch 2 finished.
Validation MLM acc: 0.3638, NSP acc: 0.5284



Epoch 3/5:   0%|          | 0/6666 [00:00<?, ?it/s]


Epoch 3 finished.
Validation MLM acc: 0.3974, NSP acc: 0.5528



Epoch 4/5:   0%|          | 0/6666 [00:00<?, ?it/s]


Epoch 4 finished.
Validation MLM acc: 0.4091, NSP acc: 0.5604



Epoch 5/5:   0%|          | 0/6666 [00:00<?, ?it/s]


Epoch 5 finished.
Validation MLM acc: 0.4219, NSP acc: 0.5642



In [16]:
# Cell 12: Show sample masked predictions (pretty print)
def show_predictions(model, dataset, n=5):
    model.eval()
    for i in range(n):
        item = dataset[i]
        input_ids = item["input_ids"].unsqueeze(0).to(device)
        orig_ids = item["orig_input_ids"].tolist()
        mlm_labels = item["mlm_labels"]
        token_type_ids = item["token_type_ids"].unsqueeze(0).to(device)
        attention_mask = item["attention_mask"].unsqueeze(0).to(device)

        with torch.no_grad():
            mlm_logits, nsp_logits = model(input_ids, token_type_ids, attention_mask)
        preds = torch.argmax(mlm_logits, dim=-1).squeeze(0).cpu().tolist()

        # decode tokens for display
        tokens_masked = tokenizer.convert_ids_to_tokens(input_ids.squeeze(0).cpu().tolist())
        tokens_pred = tokenizer.convert_ids_to_tokens(preds)
        tokens_orig = tokenizer.convert_ids_to_tokens(orig_ids)

        # build strings highlighting masked positions
        display_lines = []
        for idx, (tk_masked, tk_pred, tk_orig, label_id) in enumerate(zip(tokens_masked, tokens_pred, tokens_orig, mlm_labels.tolist())):
            if label_id != -100:
                display_lines.append(f"pos {idx} | masked: {tk_masked:>12} | pred: {tk_pred:>12} | orig: {tk_orig:>12}")
        print("\nSample", i+1)
        print("\n".join(display_lines))
        # show the input text approx (decode full orig sequence)
        print("Original text (approx):", tokenizer.decode(orig_ids, skip_special_tokens=True))
        print("Masked input (approx):", tokenizer.decode(input_ids.squeeze(0).cpu().tolist(), skip_special_tokens=True))
        # NSP prediction
        nsp_pred = torch.argmax(nsp_logits, dim=-1).item()
        print("NSP label:", item["nsp_label"].item(), "NSP pred:", nsp_pred)
    model.train()

# Show 5 samples from validation dataset
show_predictions(model, val_dataset, n=5)



Sample 1
pos 10 | masked:       [MASK] | pred:          the | orig:           ho
pos 19 | masked:     european | pred:     european | orig:     european
pos 27 | masked:      species | pred:      species | orig:      species
Original text (approx): = homarus gammarus =. homarus gammarus, known as the european lobster or common lobster, is a species of < unk > lobster from the eastern atlantic ocean, mediterranean sea and parts of the black sea.
Masked input (approx): = homarus gammarus =.marus gammarus, known as the european lobster or common lobster, is a species of < unk > lobster from the eastern atlantic ocean, mediterranean sea and parts of the black sea.
NSP label: 1 NSP pred: 0

Sample 2
pos 4 | masked:       [MASK] | pred:          and | orig:         ##us
pos 5 | masked:        gamma | pred:            = | orig:        gamma
pos 10 | masked:       [MASK] | pred:          the | orig:          the
pos 14 | masked:       [MASK] | pred:          the | orig:     expected
pos 15 | 

In [17]:
# Cell 13: save & final evaluation
ckpt_path = "mini_bert_checkpoint.pt"
torch.save({
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "config": config
}, ckpt_path)
print("Saved checkpoint to", ckpt_path)

final_mlm, final_nsp = evaluate(model, val_loader)
print(f"Final Validation MLM acc: {final_mlm:.4f}, NSP acc: {final_nsp:.4f}")


Saved checkpoint to mini_bert_checkpoint.pt
Final Validation MLM acc: 0.4119, NSP acc: 0.5648


# README — Mini BERT (from-scratch) — Submission

## Overview
This notebook implements a mini BERT encoder (Transformer encoder stack) from scratch in PyTorch and trains it jointly on Masked Language Modeling (MLM) and Next Sentence Prediction (NSP) using WikiText-2 dataset.

## Files to submit
- `mini_bert_notebook.ipynb` (this Colab notebook)
- `mini_bert_checkpoint.pt` (optional small checkpoint)
- `README.md` (explain model hyperparameters and results)

## Model details
- Layers: 4 encoder layers
- Hidden size: 256
- Heads: 4
- FFN dim: 512
- Max seq len: 128

## Training
- Dataset: WikiText-2 (wikitext-2-v1)
- Objectives: MLM (15% mask) + NSP (binary)
- Tokenizer: `bert-base-uncased` (HuggingFace; tokenizer only)

## How to run
1. Open in Google Colab.
2. Make sure GPU runtime is selected.
3. Run cells top-to-bottom.

## Results (example)
- Final validation MLM accuracy and NSP accuracy printed at the end of the run.
- Sample masked predictions printed for qualitative check.

## Notes and improvements
- Increase `num_epochs`, `batch_size`, and `hidden_size` for better performance (requires more GPU memory/time).
- Could add more sophisticated sentence splitting and negative sampling strategies for NSP.
- Consider longer pretraining or using larger corpus for stronger representations.
