In [None]:
# bert_encoder_from_scratch_with_pooling_multitype.py
# Modified version that supports two input types:
# - 'C' = Chunk (uses MLM + NSP)
# - 'Q' = Query (uses MLM only)

import random
import math
import os
from typing import List, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from gensim.models import Word2Vec

# -------------------------
# Config
# -------------------------
VOCAB_MIN_FREQ = 1
MAX_SEQ_LEN = 1024
HIDDEN_SIZE = 768
NUM_LAYERS = 12
NUM_HEADS = 12
FFN_DIM = 3072
DROPOUT = 0.1
WORD2VEC_SIZE = HIDDEN_SIZE
WORD2VEC_WINDOW = 5
WORD2VEC_MIN_COUNT = 1
MLM_MASK_PROB = 0.15
BATCH_SIZE = 8
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 2e-5

# -------------------------
# Special tokens
# -------------------------
PAD_TOKEN = "[PAD]"
CLS_TOKEN = "[CLS]"
SEP_TOKEN = "[SEP]"
MASK_TOKEN = "[MASK]"
UNK_TOKEN = "[UNK]"
SPECIAL_TOKENS = [PAD_TOKEN, CLS_TOKEN, SEP_TOKEN, MASK_TOKEN, UNK_TOKEN]

# -------------------------
# Utility: Vocab builder
# -------------------------
def build_vocab(sentences: List[str], min_freq: int = VOCAB_MIN_FREQ):
    from collections import Counter
    token_counts = Counter()
    for s in sentences:
        tokens = s.strip().split()
        token_counts.update(tokens)
    stoi, itos = {}, []
    for t in SPECIAL_TOKENS:
        stoi[t] = len(itos)
        itos.append(t)
    for token, cnt in token_counts.items():
        if cnt >= min_freq and token not in stoi:
            stoi[token] = len(itos)
            itos.append(token)
    return stoi, itos

# -------------------------
# Train or load Word2Vec
# -------------------------
def train_word2vec(sentences: List[str], vector_size=WORD2VEC_SIZE, window=WORD2VEC_WINDOW, min_count=WORD2VEC_MIN_COUNT, epochs=5):
    tokenized = [s.strip().split() for s in sentences]
    w2v = Word2Vec(sentences=tokenized, vector_size=vector_size, window=window, min_count=min_count, epochs=epochs, sg=0)
    return w2v

def build_embedding_matrix(w2v: Word2Vec, itos: List[str], hidden_size: int):
    vocab_size = len(itos)
    embeddings = np.random.normal(scale=0.02, size=(vocab_size, hidden_size)).astype(np.float32)
    for idx, tok in enumerate(itos):
        if tok in w2v.wv:
            vec = w2v.wv[tok]
            if vec.shape[0] != hidden_size:
                vec = vec[:hidden_size] if vec.shape[0] >= hidden_size else np.pad(vec, (0, hidden_size - vec.shape[0]))
            embeddings[idx] = vec
    pad_idx = itos.index(PAD_TOKEN)
    embeddings[pad_idx] = np.zeros(hidden_size, dtype=np.float32)
    return torch.tensor(embeddings)

# -------------------------
# Dataset (supports queries and chunks)
# -------------------------
class BertPretrainingDataset(Dataset):
    def __init__(self, data: List[Tuple[str, str]], stoi: dict, max_seq_len=MAX_SEQ_LEN):
        """
        data: list of tuples [(text, discriminator)], where discriminator ∈ {'Q', 'C'}
        """
        self.stoi = stoi
        self.max_seq_len = max_seq_len
        self.data = data

    def __len__(self):
        return len(self.data)

    def _tokenize_to_ids(self, text: str) -> List[int]:
        toks = text.strip().split()
        return [self.stoi.get(t, self.stoi[UNK_TOKEN]) for t in toks]

    def __getitem__(self, idx):
        text, dtype = self.data[idx]
        # -------------------------------
        # Case 1: Query (MLM only)
        # -------------------------------
        if dtype == 'Q':
            ids = self._tokenize_to_ids(text)
            ids = ids[:self.max_seq_len - 2]
            input_ids = [self.stoi[CLS_TOKEN]] + ids + [self.stoi[SEP_TOKEN]]
            token_type_ids = [0] * len(input_ids)
            nsp_label = -100  # dummy
            return {
                "input_ids": torch.tensor(input_ids, dtype=torch.long),
                "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
                "nsp_label": torch.tensor(nsp_label, dtype=torch.long),
                "batch_type": "Q"
            }

        # -------------------------------
        # Case 2: Chunk (MLM + NSP)
        # -------------------------------
        elif dtype == 'C':
            # simple sentence split
            sents = [s.strip() for s in text.strip().split('.') if s.strip()]
            if len(sents) < 2:
                sents = sents + sents  # duplicate if only one sentence
            idx_s = random.randint(0, len(sents) - 2)
            sent_a = sents[idx_s]
            is_next = random.random() < 0.5
            if is_next:
                sent_b = sents[idx_s + 1]
                nsp_label = 1
            else:
                rand_idx = random.randint(0, len(sents) - 1)
                while rand_idx == idx_s + 1:
                    rand_idx = random.randint(0, len(sents) - 1)
                sent_b = sents[rand_idx]
                nsp_label = 0
            ids_a = self._tokenize_to_ids(sent_a)
            ids_b = self._tokenize_to_ids(sent_b)
            while len(ids_a) + len(ids_b) > self.max_seq_len - 3:
                if len(ids_a) > len(ids_b):
                    ids_a.pop()
                else:
                    ids_b.pop()
            input_ids = [self.stoi[CLS_TOKEN]] + ids_a + [self.stoi[SEP_TOKEN]] + ids_b + [self.stoi[SEP_TOKEN]]
            token_type_ids = [0] * (len(ids_a) + 2) + [1] * (len(ids_b) + 1)
            return {
                "input_ids": torch.tensor(input_ids, dtype=torch.long),
                "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
                "nsp_label": torch.tensor(nsp_label, dtype=torch.long),
                "batch_type": "C"
            }

def collate_fn(batch, pad_id):
    input_ids_list = [b["input_ids"] for b in batch]
    token_type_list = [b["token_type_ids"] for b in batch]
    nsp_labels = torch.stack([b["nsp_label"] for b in batch]).long()
    batch_types = [b["batch_type"] for b in batch]

    max_len = max([x.size(0) for x in input_ids_list])
    padded_input_ids, padded_token_types, attention_masks = [], [], []
    for ids, tt in zip(input_ids_list, token_type_list):
        pad_len = max_len - ids.size(0)
        padded_input_ids.append(F.pad(ids, (0, pad_len), value=pad_id))
        padded_token_types.append(F.pad(tt, (0, pad_len), value=0))
        attention_masks.append((F.pad(ids, (0, pad_len), value=pad_id) != pad_id).long())

    return {
        "input_ids": torch.stack(padded_input_ids),
        "token_type_ids": torch.stack(padded_token_types),
        "attention_mask": torch.stack(attention_masks),
        "nsp_labels": nsp_labels,
        "batch_type": batch_types
    }

# -------------------------
# MLM Masking
# -------------------------
def create_mlm_labels_and_masked_input(input_ids, pad_id, mask_token_id, vocab_size, mask_prob=MLM_MASK_PROB):
    batch_size, seq_len = input_ids.shape
    mlm_labels = torch.full_like(input_ids, -100)
    prob_matrix = torch.full((batch_size, seq_len), mask_prob)
    prob_matrix[input_ids == pad_id] = 0.0
    special_upper = len(SPECIAL_TOKENS)
    prob_matrix[input_ids < special_upper] = 0.0
    masked_positions = torch.bernoulli(prob_matrix).bool()
    mlm_labels[masked_positions] = input_ids[masked_positions]
    input_ids_masked = input_ids.clone()
    rand_for_replace = torch.rand_like(input_ids, dtype=torch.float)
    mask_replace = masked_positions & (rand_for_replace < 0.8)
    random_replace = masked_positions & (rand_for_replace >= 0.8) & (rand_for_replace < 0.9)
    input_ids_masked[mask_replace] = mask_token_id
    if random_replace.any():
        count = int(random_replace.sum().item())
        rand_tokens = torch.randint(len(SPECIAL_TOKENS), vocab_size, (count,), device=input_ids.device)
        input_ids_masked[random_replace] = rand_tokens
    return input_ids_masked, mlm_labels

# -------------------------
# Transformer encoder
# -------------------------
class TransformerEncoderLayer(nn.Module):
    def __init__(self, hidden_size, num_heads, ffn_dim, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(hidden_size, num_heads, dropout=dropout, batch_first=True)
        self.ln1 = nn.LayerNorm(hidden_size)
        self.ln2 = nn.LayerNorm(hidden_size)
        self.ffn = nn.Sequential(nn.Linear(hidden_size, ffn_dim), nn.GELU(), nn.Linear(ffn_dim, hidden_size))
        self.dropout = nn.Dropout(dropout)
    def forward(self, x, mask):
        key_padding_mask = (mask == 0)
        attn_out, _ = self.self_attn(x, x, x, key_padding_mask=key_padding_mask)
        x = self.ln1(x + self.dropout(attn_out))
        ffn_out = self.ffn(x)
        x = self.ln2(x + self.dropout(ffn_out))
        return x

class BertEncoderModel(nn.Module):
    def __init__(self, vocab_size, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS, num_heads=NUM_HEADS, ffn_dim=FFN_DIM, max_position_embeddings=512, pad_token_id=0, embedding_weights=None):
        super().__init__()
        self.pad_token_id = pad_token_id
        self.token_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=pad_token_id)
        if embedding_weights is not None:
            self.token_embeddings.weight.data.copy_(embedding_weights)
        self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
        self.segment_embeddings = nn.Embedding(2, hidden_size)
        self.emb_ln = nn.LayerNorm(hidden_size)
        self.emb_dropout = nn.Dropout(0.1)
        self.layers = nn.ModuleList([TransformerEncoderLayer(hidden_size, num_heads, ffn_dim) for _ in range(num_layers)])
        self.nsp_classifier = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Tanh(), nn.Linear(hidden_size, 2))
        self.mlm_bias = nn.Parameter(torch.zeros(vocab_size))
    def encode(self, ids, tt=None, mask=None):
        if tt is None:
            tt = torch.zeros_like(ids)
        if mask is None:
            mask = (ids != self.pad_token_id).long()
        pos = torch.arange(ids.size(1), device=ids.device).unsqueeze(0)
        x = self.token_embeddings(ids) + self.position_embeddings(pos) + self.segment_embeddings(tt)
        x = self.emb_dropout(self.emb_ln(x))
        for layer in self.layers:
            x = layer(x, mask)
        return x
    def forward(self, ids, tt=None, mask=None):
        seq_out = self.encode(ids, tt, mask)
        pooled = seq_out[:, 0]
        nsp_logits = self.nsp_classifier(pooled)
        mlm_logits = F.linear(seq_out, self.token_embeddings.weight, self.mlm_bias)
        return mlm_logits, nsp_logits

# -------------------------
# Training
# -------------------------
def main():
    # Dummy sample with types
    corpus = [
        ("the quick brown fox jumps over the lazy dog. the dog did not mind.", "C"),
        ("i love machine learning and transformers.", "Q"),
        ("deep learning enables summarization and translation. it is powerful.", "C"),
        ("best restaurants near me", "Q")
    ]
    stoi, itos = build_vocab([x[0] for x in corpus])
    vocab_size = len(itos)
    w2v = train_word2vec([x[0] for x in corpus])
    emb = build_embedding_matrix(w2v, itos, HIDDEN_SIZE)
    pad_id = stoi[PAD_TOKEN]; mask_id = stoi[MASK_TOKEN]
    ds = BertPretrainingDataset(corpus, stoi)
    dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=lambda b: collate_fn(b, pad_id))
    model = BertEncoderModel(vocab_size, embedding_weights=emb).to(DEVICE)
    opt = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
    nsp_loss_fct = nn.CrossEntropyLoss()
    model.train()
    for epoch in range(1):
        for batch in dl:
            ids = batch["input_ids"].to(DEVICE)
            tts = batch["token_type_ids"].to(DEVICE)
            mask = batch["attention_mask"].to(DEVICE)
            nsp_labels = batch["nsp_labels"].to(DEVICE)
            btypes = batch["batch_type"]
            ids_masked, mlm_labels = create_mlm_labels_and_masked_input(ids, pad_id, mask_id, vocab_size)
            ids_masked, mlm_labels = ids_masked.to(DEVICE), mlm_labels.to(DEVICE)
            mlm_logits, nsp_logits = model(ids_masked, tts, mask)
            mlm_loss = mlm_loss_fct(mlm_logits.view(-1, vocab_size), mlm_labels.view(-1))
            # Only compute NSP for chunk batches
            if all(bt == "C" for bt in btypes):
                nsp_loss = nsp_loss_fct(nsp_logits.view(-1, 2), nsp_labels.view(-1))
            else:
                nsp_loss = torch.tensor(0.0, device=DEVICE)
            loss = mlm_loss + nsp_loss
            opt.zero_grad()
            loss.backward()
            opt.step()
            print(f"Loss {loss.item():.4f} (MLM {mlm_loss.item():.4f}, NSP {nsp_loss.item():.4f})")
    print("Training done.")

if __name__ == "__main__":
    main()

The code above does not consider all positive and negative sentence pairs in a chunk for NSP. The code below uses all possible positive and negative pairs for every chunk.

In [None]:
# bert_encoder_from_scratch_with_pooling_multitype_allpairs.py
# Modified version that supports:
# - 'C' = Chunk (uses MLM + NSP) → uses ALL possible positive & negative pairs
# - 'Q' = Query (uses MLM only)

import random
import math
import os
from typing import List, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from gensim.models import Word2Vec

# -------------------------
# Config
# -------------------------
VOCAB_MIN_FREQ = 1
MAX_SEQ_LEN = 1024
HIDDEN_SIZE = 768
NUM_LAYERS = 12
NUM_HEADS = 12
FFN_DIM = 3072
DROPOUT = 0.1
WORD2VEC_SIZE = HIDDEN_SIZE
WORD2VEC_WINDOW = 5
WORD2VEC_MIN_COUNT = 1
MLM_MASK_PROB = 0.15
BATCH_SIZE = 8
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 2e-5

# -------------------------
# Special tokens
# -------------------------
PAD_TOKEN = "[PAD]"
CLS_TOKEN = "[CLS]"
SEP_TOKEN = "[SEP]"
MASK_TOKEN = "[MASK]"
UNK_TOKEN = "[UNK]"
SPECIAL_TOKENS = [PAD_TOKEN, CLS_TOKEN, SEP_TOKEN, MASK_TOKEN, UNK_TOKEN]

# -------------------------
# Utility: Vocab builder
# -------------------------
def build_vocab(sentences: List[str], min_freq: int = VOCAB_MIN_FREQ):
    from collections import Counter
    token_counts = Counter()
    for s in sentences:
        tokens = s.strip().split()
        token_counts.update(tokens)
    stoi, itos = {}, []
    for t in SPECIAL_TOKENS:
        stoi[t] = len(itos)
        itos.append(t)
    for token, cnt in token_counts.items():
        if cnt >= min_freq and token not in stoi:
            stoi[token] = len(itos)
            itos.append(token)
    return stoi, itos

# -------------------------
# Train or load Word2Vec
# -------------------------
def train_word2vec(sentences: List[str], vector_size=WORD2VEC_SIZE, window=WORD2VEC_WINDOW, min_count=WORD2VEC_MIN_COUNT, epochs=5):
    tokenized = [s.strip().split() for s in sentences]
    w2v = Word2Vec(sentences=tokenized, vector_size=vector_size, window=window, min_count=min_count, epochs=epochs, sg=0)
    return w2v

def build_embedding_matrix(w2v: Word2Vec, itos: List[str], hidden_size: int):
    vocab_size = len(itos)
    embeddings = np.random.normal(scale=0.02, size=(vocab_size, hidden_size)).astype(np.float32)
    for idx, tok in enumerate(itos):
        if tok in w2v.wv:
            vec = w2v.wv[tok]
            if vec.shape[0] != hidden_size:
                vec = vec[:hidden_size] if vec.shape[0] >= hidden_size else np.pad(vec, (0, hidden_size - vec.shape[0]))
            embeddings[idx] = vec
    pad_idx = itos.index(PAD_TOKEN)
    embeddings[pad_idx] = np.zeros(hidden_size, dtype=np.float32)
    return torch.tensor(embeddings)

# -------------------------
# Dataset (supports queries and chunks)
# -------------------------
class BertPretrainingDataset(Dataset):
    def __init__(self, data: List[Tuple[str, str]], stoi: dict, max_seq_len=MAX_SEQ_LEN):
        """
        data: list of tuples [(text, discriminator)], where discriminator ∈ {'Q', 'C'}
        """
        self.stoi = stoi
        self.max_seq_len = max_seq_len
        self.samples = []

        for text, dtype in data:
            if dtype == "Q":
                # Single-sentence query (MLM only)
                self.samples.append((text, dtype, None, None))
            elif dtype == "C":
                # Split chunk into sentences
                sents = [s.strip() for s in text.strip().split('.') if s.strip()]
                if len(sents) < 2:
                    sents = sents + sents  # duplicate if only one sentence
                # Positive pairs: consecutive sentences
                for i in range(len(sents) - 1):
                    self.samples.append((sents[i], "C", sents[i + 1], 1))
                # Negative pairs: non-consecutive
                for i in range(len(sents)):
                    for j in range(len(sents)):
                        if abs(i - j) > 1:  # skip consecutive
                            self.samples.append((sents[i], "C", sents[j], 0))

    def __len__(self):
        return len(self.samples)

    def _tokenize_to_ids(self, text: str) -> List[int]:
        toks = text.strip().split()
        return [self.stoi.get(t, self.stoi[UNK_TOKEN]) for t in toks]

    def __getitem__(self, idx):
        sent_a, dtype, sent_b, nsp_label = self.samples[idx]

        # -------------------------------
        # Case 1: Query (MLM only)
        # -------------------------------
        if dtype == 'Q':
            ids = self._tokenize_to_ids(sent_a)
            ids = ids[:self.max_seq_len - 2]
            input_ids = [self.stoi[CLS_TOKEN]] + ids + [self.stoi[SEP_TOKEN]]
            token_type_ids = [0] * len(input_ids)
            nsp_label = -100  # dummy
            return {
                "input_ids": torch.tensor(input_ids, dtype=torch.long),
                "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
                "nsp_label": torch.tensor(nsp_label, dtype=torch.long),
                "batch_type": "Q"
            }

        # -------------------------------
        # Case 2: Chunk (MLM + NSP) — uses precomputed pairs
        # -------------------------------
        elif dtype == 'C':
            ids_a = self._tokenize_to_ids(sent_a)
            ids_b = self._tokenize_to_ids(sent_b)
            while len(ids_a) + len(ids_b) > self.max_seq_len - 3:
                if len(ids_a) > len(ids_b):
                    ids_a.pop()
                else:
                    ids_b.pop()
            input_ids = [self.stoi[CLS_TOKEN]] + ids_a + [self.stoi[SEP_TOKEN]] + ids_b + [self.stoi[SEP_TOKEN]]
            token_type_ids = [0] * (len(ids_a) + 2) + [1] * (len(ids_b) + 1)
            return {
                "input_ids": torch.tensor(input_ids, dtype=torch.long),
                "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
                "nsp_label": torch.tensor(nsp_label, dtype=torch.long),
                "batch_type": "C"
            }

def collate_fn(batch, pad_id):
    input_ids_list = [b["input_ids"] for b in batch]
    token_type_list = [b["token_type_ids"] for b in batch]
    nsp_labels = torch.stack([b["nsp_label"] for b in batch]).long()
    batch_types = [b["batch_type"] for b in batch]

    max_len = max([x.size(0) for x in input_ids_list])
    padded_input_ids, padded_token_types, attention_masks = [], [], []
    for ids, tt in zip(input_ids_list, token_type_list):
        pad_len = max_len - ids.size(0)
        padded_input_ids.append(F.pad(ids, (0, pad_len), value=pad_id))
        padded_token_types.append(F.pad(tt, (0, pad_len), value=0))
        attention_masks.append((F.pad(ids, (0, pad_len), value=pad_id) != pad_id).long())

    return {
        "input_ids": torch.stack(padded_input_ids),
        "token_type_ids": torch.stack(padded_token_types),
        "attention_mask": torch.stack(attention_masks),
        "nsp_labels": nsp_labels,
        "batch_type": batch_types
    }

# -------------------------
# MLM Masking
# -------------------------
def create_mlm_labels_and_masked_input(input_ids, pad_id, mask_token_id, vocab_size, mask_prob=MLM_MASK_PROB):
    batch_size, seq_len = input_ids.shape
    mlm_labels = torch.full_like(input_ids, -100)
    prob_matrix = torch.full((batch_size, seq_len), mask_prob)
    prob_matrix[input_ids == pad_id] = 0.0
    special_upper = len(SPECIAL_TOKENS)
    prob_matrix[input_ids < special_upper] = 0.0
    masked_positions = torch.bernoulli(prob_matrix).bool()
    mlm_labels[masked_positions] = input_ids[masked_positions]
    input_ids_masked = input_ids.clone()
    rand_for_replace = torch.rand_like(input_ids, dtype=torch.float)
    mask_replace = masked_positions & (rand_for_replace < 0.8)
    random_replace = masked_positions & (rand_for_replace >= 0.8) & (rand_for_replace < 0.9)
    input_ids_masked[mask_replace] = mask_token_id
    if random_replace.any():
        count = int(random_replace.sum().item())
        rand_tokens = torch.randint(len(SPECIAL_TOKENS), vocab_size, (count,), device=input_ids.device)
        input_ids_masked[random_replace] = rand_tokens
    return input_ids_masked, mlm_labels

# -------------------------
# Transformer encoder
# -------------------------
class TransformerEncoderLayer(nn.Module):
    def __init__(self, hidden_size, num_heads, ffn_dim, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(hidden_size, num_heads, dropout=dropout, batch_first=True)
        self.ln1 = nn.LayerNorm(hidden_size)
        self.ln2 = nn.LayerNorm(hidden_size)
        self.ffn = nn.Sequential(nn.Linear(hidden_size, ffn_dim), nn.GELU(), nn.Linear(ffn_dim, hidden_size))
        self.dropout = nn.Dropout(dropout)
    def forward(self, x, mask):
        key_padding_mask = (mask == 0)
        attn_out, _ = self.self_attn(x, x, x, key_padding_mask=key_padding_mask)
        x = self.ln1(x + self.dropout(attn_out))
        ffn_out = self.ffn(x)
        x = self.ln2(x + self.dropout(ffn_out))
        return x

class BertEncoderModel(nn.Module):
    def __init__(self, vocab_size, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS, num_heads=NUM_HEADS, ffn_dim=FFN_DIM, max_position_embeddings=512, pad_token_id=0, embedding_weights=None):
        super().__init__()
        self.pad_token_id = pad_token_id
        self.token_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=pad_token_id)
        if embedding_weights is not None:
            self.token_embeddings.weight.data.copy_(embedding_weights)
        self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
        self.segment_embeddings = nn.Embedding(2, hidden_size)
        self.emb_ln = nn.LayerNorm(hidden_size)
        self.emb_dropout = nn.Dropout(0.1)
        self.layers = nn.ModuleList([TransformerEncoderLayer(hidden_size, num_heads, ffn_dim) for _ in range(num_layers)])
        self.nsp_classifier = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Tanh(), nn.Linear(hidden_size, 2))
        self.mlm_bias = nn.Parameter(torch.zeros(vocab_size))
    def encode(self, ids, tt=None, mask=None):
        if tt is None:
            tt = torch.zeros_like(ids)
        if mask is None:
            mask = (ids != self.pad_token_id).long()
        pos = torch.arange(ids.size(1), device=ids.device).unsqueeze(0)
        x = self.token_embeddings(ids) + self.position_embeddings(pos) + self.segment_embeddings(tt)
        x = self.emb_dropout(self.emb_ln(x))
        for layer in self.layers:
            x = layer(x, mask)
        return x
    def forward(self, ids, tt=None, mask=None):
        seq_out = self.encode(ids, tt, mask)
        pooled = seq_out[:, 0]
        nsp_logits = self.nsp_classifier(pooled)
        mlm_logits = F.linear(seq_out, self.token_embeddings.weight, self.mlm_bias)
        return mlm_logits, nsp_logits

# -------------------------
# Training
# -------------------------
def main():
    corpus = [
        ("the quick brown fox jumps over the lazy dog. the dog did not mind.", "C"),
        ("i love machine learning and transformers.", "Q"),
        ("deep learning enables summarization and translation. it is powerful.", "C"),
        ("best restaurants near me", "Q")
    ]
    stoi, itos = build_vocab([x[0] for x in corpus])
    vocab_size = len(itos)
    w2v = train_word2vec([x[0] for x in corpus])
    emb = build_embedding_matrix(w2v, itos, HIDDEN_SIZE)
    pad_id = stoi[PAD_TOKEN]; mask_id = stoi[MASK_TOKEN]
    ds = BertPretrainingDataset(corpus, stoi)
    dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=lambda b: collate_fn(b, pad_id))
    model = BertEncoderModel(vocab_size, embedding_weights=emb).to(DEVICE)
    opt = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
    nsp_loss_fct = nn.CrossEntropyLoss()
    model.train()
    for epoch in range(1):
        for batch in dl:
            ids = batch["input_ids"].to(DEVICE)
            tts = batch["token_type_ids"].to(DEVICE)
            mask = batch["attention_mask"].to(DEVICE)
            nsp_labels = batch["nsp_labels"].to(DEVICE)
            btypes = batch["batch_type"]
            ids_masked, mlm_labels = create_mlm_labels_and_masked_input(ids, pad_id, mask_id, vocab_size)
            ids_masked, mlm_labels = ids_masked.to(DEVICE), mlm_labels.to(DEVICE)
            mlm_logits, nsp_logits = model(ids_masked, tts, mask)
            mlm_loss = mlm_loss_fct(mlm_logits.view(-1, vocab_size), mlm_labels.view(-1))
            if all(bt == "C" for bt in btypes):
                nsp_loss = nsp_loss_fct(nsp_logits.view(-1, 2), nsp_labels.view(-1))
            else:
                nsp_loss = torch.tensor(0.0, device=DEVICE)
            loss = mlm_loss + nsp_loss
            opt.zero_grad()
            loss.backward()
            opt.step()
            print(f"Loss {loss.item():.4f} (MLM {mlm_loss.item():.4f}, NSP {nsp_loss.item():.4f})")
    print("Training done.")

if __name__ == "__main__":
    main()
