In [None]:
# bert_encoder_from_scratch_with_pooling.py
# Full PyTorch implementation of BERT-base style encoder-only Transformer with:
# - Word2Vec token embeddings
# - Learned positional & segment embeddings
# - MLM (15% mask; 80/10/10 replacement rule) -- PAD excluded from random replacements
# - NSP head
# - Mask-aware mean pooling for retrieval (added)
#
# Requirements:
#   pip install torch torchvision torchaudio
#   pip install gensim
#
# Run: python bert_encoder_from_scratch_with_pooling.py

import random
import math
import os
from typing import List, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# If you need to install gensim: uncomment below
# import sys
# !{sys.executable} -m pip install gensim

from gensim.models import Word2Vec

# -------------------------
# Config (change for experiments)
# -------------------------
VOCAB_MIN_FREQ = 1
MAX_SEQ_LEN = 128
HIDDEN_SIZE = 768       # BERT-base hidden size
NUM_LAYERS = 12         # BERT-base number of encoder layers
NUM_HEADS = 12          # BERT-base attention heads
FFN_DIM = 3072          # BERT-base intermediate size
DROPOUT = 0.1
WORD2VEC_SIZE = HIDDEN_SIZE  # Use same dimension for direct weight tie
WORD2VEC_WINDOW = 5
WORD2VEC_MIN_COUNT = 1
MLM_MASK_PROB = 0.15
BATCH_SIZE = 8
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 2e-5

# -------------------------
# Special tokens
# -------------------------
PAD_TOKEN = "[PAD]"
CLS_TOKEN = "[CLS]"
SEP_TOKEN = "[SEP]"
MASK_TOKEN = "[MASK]"
UNK_TOKEN = "[UNK]"

SPECIAL_TOKENS = [PAD_TOKEN, CLS_TOKEN, SEP_TOKEN, MASK_TOKEN, UNK_TOKEN]

# -------------------------
# Simple whitespace tokenizer & vocab build (replace if desired)
# -------------------------
def build_vocab(sentences: List[str], min_freq: int = VOCAB_MIN_FREQ):
    from collections import Counter
    token_counts = Counter()
    for s in sentences:
        tokens = s.strip().split()
        token_counts.update(tokens)
    # build vocab
    stoi = {}
    itos = []
    # add special tokens first
    for t in SPECIAL_TOKENS:
        stoi[t] = len(itos)
        itos.append(t)
    for token, cnt in token_counts.items():
        if cnt >= min_freq and token not in stoi:
            stoi[token] = len(itos)
            itos.append(token)
    return stoi, itos

# -------------------------
# Train or load Word2Vec
# -------------------------
def train_word2vec(sentences: List[str], vector_size=WORD2VEC_SIZE, window=WORD2VEC_WINDOW, min_count=WORD2VEC_MIN_COUNT, epochs=5):
    tokenized = [s.strip().split() for s in sentences]
    w2v = Word2Vec(sentences=tokenized, vector_size=vector_size, window=window, min_count=min_count, epochs=epochs, sg=0)
    return w2v

def build_embedding_matrix(w2v: Word2Vec, itos: List[str], hidden_size: int):
    vocab_size = len(itos)
    embeddings = np.random.normal(scale=0.02, size=(vocab_size, hidden_size)).astype(np.float32)
    for idx, tok in enumerate(itos):
        if tok in w2v.wv:
            vec = w2v.wv[tok]
            if vec.shape[0] != hidden_size:
                # project or pad/truncate (here we project via linear mapping if sizes differ)
                vec = vec[:hidden_size] if vec.shape[0] >= hidden_size else np.pad(vec, (0, hidden_size - vec.shape[0]))
            embeddings[idx] = vec
        else:
            # keep random init for unknown tokens (special tokens will remain random except we can set PAD to zeros)
            pass
    # set PAD embedding to zeros
    pad_idx = itos.index(PAD_TOKEN)
    embeddings[pad_idx] = np.zeros(hidden_size, dtype=np.float32)
    return torch.tensor(embeddings)

# -------------------------
# Dataset for NSP + MLM
# -------------------------
class BertPretrainingDataset(Dataset):
    def __init__(self, paragraphs: List[str], stoi: dict, max_seq_len=MAX_SEQ_LEN, short_seq_prob=0.1):
        """
        paragraphs: list of paragraphs (strings). We'll split paragraphs into sentences by '.' naive split for demo.
        Real pipelines use sentence-splitting more robustly.
        """
        self.stoi = stoi
        self.itos = {v:k for k,v in stoi.items()}
        self.max_seq_len = max_seq_len
        self.short_seq_prob = short_seq_prob

        # build sentences list by naive splitting on periods (improve for real corpora)
        sents = []
        for p in paragraphs:
            pieces = [s.strip() for s in p.strip().split('.') if s.strip()]
            sents.extend(pieces)
        # Now we have a list of sentences
        self.sentences = sents

    def __len__(self):
        return max(1, len(self.sentences) - 1)

    def _tokenize_to_ids(self, text: str) -> List[int]:
        toks = text.strip().split()
        ids = [self.stoi.get(t, self.stoi[UNK_TOKEN]) for t in toks]
        return ids

    def __getitem__(self, idx):
        # create a training example for NSP:
        # 50% of time next sentence is actual next sentence (label 1), else random sentence (label 0)
        is_next = random.random() < 0.5
        sent_a = self.sentences[idx]
        if is_next:
            sent_b = self.sentences[idx+1]
            label_nsp = 1
        else:
            # random sentence
            rand_idx = random.randint(0, len(self.sentences)-1)
            # avoid accidentally picking the actual next
            if rand_idx == idx+1 and len(self.sentences) > 2:
                rand_idx = (rand_idx + 2) % len(self.sentences)
            sent_b = self.sentences[rand_idx]
            label_nsp = 0
        ids_a = self._tokenize_to_ids(sent_a)
        ids_b = self._tokenize_to_ids(sent_b)

        # Truncate if too long; simple truncation here (not optimal)
        # Reserve 3 tokens for [CLS], [SEP], [SEP]
        max_len_for_pair = self.max_seq_len - 3
        # simple truncation from the end
        while len(ids_a) + len(ids_b) > max_len_for_pair:
            if len(ids_a) > len(ids_b):
                ids_a.pop()
            else:
                ids_b.pop()

        # Build input: [CLS] A [SEP] B [SEP]
        input_ids = [self.stoi[CLS_TOKEN]] + ids_a + [self.stoi[SEP_TOKEN]] + ids_b + [self.stoi[SEP_TOKEN]]
        # token type ids: 0 for A (including CLS), 1 for B
        token_type_ids = [0] * (len(ids_a) + 2) + [1] * (len(ids_b) + 1)

        return {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
            "nsp_label": torch.tensor(label_nsp, dtype=torch.long)
        }

def collate_fn(batch, pad_id):
    # Pads sequences to max length in batch
    input_ids_list = [b["input_ids"] for b in batch]
    token_type_list = [b["token_type_ids"] for b in batch]
    nsp_labels = torch.stack([b["nsp_label"] for b in batch]).long()

    max_len = max([x.size(0) for x in input_ids_list])
    padded_input_ids = []
    padded_token_types = []
    attention_masks = []
    for ids, tt in zip(input_ids_list, token_type_list):
        pad_len = max_len - ids.size(0)
        padded_input_ids.append(F.pad(ids, (0, pad_len), value=pad_id))
        padded_token_types.append(F.pad(tt, (0, pad_len), value=0))
        attention_masks.append((F.pad(ids, (0, pad_len), value=pad_id) != pad_id).long())
    padded_input_ids = torch.stack(padded_input_ids)
    padded_token_types = torch.stack(padded_token_types)
    attention_masks = torch.stack(attention_masks)
    return {
        "input_ids": padded_input_ids,
        "token_type_ids": padded_token_types,
        "attention_mask": attention_masks,
        "nsp_labels": nsp_labels
    }

# -------------------------
# MLM masking function (15% mask with 80/10/10)
# -------------------------
def create_mlm_labels_and_masked_input(input_ids: torch.Tensor, pad_id: int, mask_token_id: int, vocab_size: int, mask_prob=MLM_MASK_PROB):
    """
    input_ids: (batch, seq_len)
    Returns:
        input_ids_masked: masked inputs with replacements applied
        mlm_labels: same shape, original token ids where masked else -100 (for ignore_index)
    """
    batch_size, seq_len = input_ids.shape
    mlm_labels = torch.full_like(input_ids, fill_value=-100)

    # don't consider special tokens for masking: [CLS], [SEP], [PAD], token_type maybe?
    # We'll consider tokens where input_ids != pad_id and not CLS/SEP/MASK (we shouldn't mask [CLS]/[SEP])
    special_ids = set([pad_id, mask_token_id])  # pad and mask token are special; also exclude CLS/SEP
    # We'll expect caller to pass special token ids; for now assume CLS and SEP indices are small (we can pass them in if needed)
    # To be safer: we'll treat indices 0..len(SPECIAL_TOKENS)-1 as special because we added them first in vocab
    # BUT we need pad_id and mask_token_id already.

    # Create mask positions
    prob_matrix = torch.full((batch_size, seq_len), mask_prob)
    # do not mask PAD
    prob_matrix[input_ids == pad_id] = 0.0
    # do not mask CLS or SEP (we assume they are among small indices; instead pass a mask)
    # Let's mask out positions where token id is CLS (index of CLS_TOKEN) or SEP
    # To be safe, caller should ensure CLS/SEP not masked; but let's set zeros for tokens with ids 0..4 (special tokens)
    special_upper = len(SPECIAL_TOKENS)
    prob_matrix[input_ids < special_upper] = 0.0

    masked_positions = torch.bernoulli(prob_matrix).bool()  # shape batch x seq_len

    # For masked positions, create labels and replace according to 80/10/10
    mlm_labels[masked_positions] = input_ids[masked_positions]

    input_ids_masked = input_ids.clone()

    # For each masked position, decide replacement
    rand_for_replace = torch.rand_like(input_ids, dtype=torch.float)
    # 80% -> [MASK]
    mask_replace = masked_positions & (rand_for_replace < 0.8)
    # 10% -> random token
    random_replace = masked_positions & (rand_for_replace >= 0.8) & (rand_for_replace < 0.9)
    # 10% -> keep original (do nothing) -> masked_positions & (rand >= 0.9)

    # Apply [MASK]
    input_ids_masked[mask_replace] = mask_token_id

    # Apply random tokens (excluding PAD token by sampling from non-special region if possible)
    if random_replace.any():
        count = int(random_replace.sum().item())
        special_upper = len(SPECIAL_TOKENS)  # exclude special tokens including PAD
        if special_upper < vocab_size:
            # sample from [special_upper, vocab_size) to avoid sampling special tokens (including PAD)
            rand_tokens = torch.randint(low=special_upper, high=vocab_size, size=(count,), dtype=torch.long, device=input_ids.device)
        else:
            # fallback: if there are no non-special tokens (very small vocab), sample across vocab but avoid PAD if possible
            rand_tokens = torch.randint(low=0, high=vocab_size, size=(count,), dtype=torch.long, device=input_ids.device)
            if vocab_size > 1:
                # replace any accidental PAD selections with another token (pad_id+1 mod vocab_size)
                rand_tokens[rand_tokens == pad_id] = (pad_id + 1) % vocab_size
        input_ids_masked[random_replace] = rand_tokens

    return input_ids_masked, mlm_labels

# -------------------------
# Model components
# -------------------------
class TransformerEncoderLayer(nn.Module):
    def __init__(self, hidden_size, num_heads, ffn_dim, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim=hidden_size, num_heads=num_heads, dropout=dropout, batch_first=True)
        self.attn_layernorm = nn.LayerNorm(hidden_size, eps=1e-12)
        self.ffn = nn.Sequential(
            nn.Linear(hidden_size, ffn_dim),
            nn.GELU(),
            nn.Linear(ffn_dim, hidden_size),
        )
        self.ffn_layernorm = nn.LayerNorm(hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, attention_mask):
        # x: (batch, seq_len, hidden)
        # attention_mask: (batch, seq_len) with 1 for tokens to keep, 0 for pad
        # MultiheadAttention with batch_first=True expects src_key_padding_mask shape (batch, seq_len) with True for tokens to be masked
        key_padding_mask = (attention_mask == 0)  # True where pad
        attn_out, _ = self.self_attn(x, x, x, key_padding_mask=key_padding_mask)
        x = x + self.dropout(attn_out)
        x = self.attn_layernorm(x)
        ffn_out = self.ffn(x)
        x = x + self.dropout(ffn_out)
        x = self.ffn_layernorm(x)
        return x

class BertEncoderModel(nn.Module):
    def __init__(self, vocab_size, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS, num_heads=NUM_HEADS, ffn_dim=FFN_DIM, max_position_embeddings=512, pad_token_id=0, embedding_weights=None, dropout=DROPOUT):
        super().__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.pad_token_id = pad_token_id

        # Token embeddings (initialized from Word2Vec weights if provided)
        self.token_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=pad_token_id)
        if embedding_weights is not None:
            with torch.no_grad():
                self.token_embeddings.weight.copy_(embedding_weights)

        # Position & segment embeddings
        self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
        self.segment_embeddings = nn.Embedding(2, hidden_size)

        self.emb_layernorm = nn.LayerNorm(hidden_size, eps=1e-12)
        self.emb_dropout = nn.Dropout(dropout)

        # Encoder stack
        self.layers = nn.ModuleList([
            TransformerEncoderLayer(hidden_size, num_heads, ffn_dim, dropout=dropout)
            for _ in range(num_layers)
        ])

        # NSP head: take hidden state of first token ([CLS]) and classify
        self.nsp_classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, 2)
        )

        # MLM head (language modeling) — weight tied to token_embeddings
        # We'll implement output logits via F.linear using token_embeddings.weight and a bias
        self.mlm_bias = nn.Parameter(torch.zeros(vocab_size))

    def encode(self, input_ids, token_type_ids=None, attention_mask=None):
        """
        Run embeddings + encoder stack and return sequence_output (batch, seq_len, hidden).
        This method is useful when you only need representations (e.g., for retrieval).
        """
        batch_size, seq_len = input_ids.size()
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)
        if attention_mask is None:
            attention_mask = (input_ids != self.pad_token_id).long()

        # Embeddings
        token_emb = self.token_embeddings(input_ids)  # (batch, seq_len, hidden)
        # positions
        position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device).unsqueeze(0).expand(batch_size, -1)
        pos_emb = self.position_embeddings(position_ids)
        seg_emb = self.segment_embeddings(token_type_ids)
        x = token_emb + pos_emb + seg_emb
        x = self.emb_layernorm(x)
        x = self.emb_dropout(x)

        # Encoder layers
        for layer in self.layers:
            x = layer(x, attention_mask)

        sequence_output = x  # (batch, seq_len, hidden)
        return sequence_output

    def get_pooled_embeddings(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, token_type_ids: torch.LongTensor = None, exclude_special: bool = True, normalize: bool = True):
        """
        Compute mask-aware mean pooling over encoder outputs.

        - attention_mask: (batch, seq_len) with 1 for real tokens and 0 for padding.
        - exclude_special: if True, positions whose token id is in the first `len(SPECIAL_TOKENS)` indices are excluded
                           from the mean (commonly excludes [CLS],[SEP],[PAD],[MASK],[UNK]).
        - normalize: if True, L2-normalize the resulting pooled vectors (good for cosine-similarity retrieval).

        Returns:
            pooled: (batch, hidden)
        """
        # Run encoder to get token-level representations
        sequence_output = self.encode(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)  # [B, L, H]

        # Build mask float (1.0 for tokens to include, 0.0 for pad/special if excluded)
        mask = attention_mask.to(sequence_output.dtype)  # [B, L], 1.0 for valid tokens

        if exclude_special:
            # Exclude tokens whose ids are in the first len(SPECIAL_TOKENS) (we added these at the front of vocab)
            # This removes [CLS], [SEP], etc. from the pooling by setting mask=0 at those positions.
            special_upper = len(SPECIAL_TOKENS)
            special_flags = (input_ids < special_upper).to(sequence_output.dtype)  # 1.0 at special positions
            mask = mask * (1.0 - special_flags)  # zero out special positions

        # Avoid division by zero: at least keep denom >= eps
        denom = mask.sum(dim=1, keepdim=True).clamp(min=1e-9)  # [B, 1]

        # Weighted sum and mean
        masked_sum = torch.einsum('bld,bl->bd', sequence_output, mask)  # (B, H)
        pooled = masked_sum / denom  # (B, H)

        if normalize:
            pooled = F.normalize(pooled, p=2, dim=1)

        return pooled

    def forward(self, input_ids, token_type_ids=None, attention_mask=None):
        """
        input_ids: (batch, seq_len)
        token_type_ids: (batch, seq_len)
        attention_mask: (batch, seq_len) 1 for tokens, 0 for pad
        Returns:
            mlm_logits: (batch, seq_len, vocab)
            nsp_logits: (batch, 2)
            sequence_output: (batch, seq_len, hidden)
        Note: For retrieval embeddings, call model.get_pooled_embeddings(...)
        """
        # reuse encode for sequence output
        sequence_output = self.encode(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
        # pooled output = representation of [CLS] (first token)
        pooled = sequence_output[:, 0]  # (batch, hidden)
        nsp_logits = self.nsp_classifier(pooled)  # (batch, 2)

        # MLM logits via weight tying: use F.linear
        mlm_logits = F.linear(sequence_output, self.token_embeddings.weight, self.mlm_bias)  # (batch, seq_len, vocab)

        return mlm_logits, nsp_logits, sequence_output

# -------------------------
# Example usage: Putting it all together
# -------------------------
def main():
    # Sample (toy) corpus; replace with your real corpus list of paragraphs.
    sample_corpus = [
        "the quick brown fox jumps over the lazy dog. the dog did not seem to mind.",
        "i love machine learning and natural language processing. transformers are powerful models.",
        "this is an example sentence for training word2vec. another example sentence is here.",
        "today is a sunny day. we will go to the park and enjoy the weather.",
        "deep learning enables many tasks such as translation, summarization and question answering."
    ]
    # Build vocabulary from corpus
    stoi, itos = build_vocab([s for p in sample_corpus for s in p.strip().split('.') if s.strip()])
    vocab_size = len(itos)
    pad_id = stoi[PAD_TOKEN]
    mask_id = stoi[MASK_TOKEN]
    cls_id = stoi[CLS_TOKEN]
    sep_id = stoi[SEP_TOKEN]
    unk_id = stoi[UNK_TOKEN]

    print("Vocab size:", vocab_size)

    # Train Word2Vec on sentences (tokenized)
    # WARNING: training vector_size=768 is heavy; for quick test, you can set smaller WORD2VEC_SIZE in config.
    print("Training Word2Vec (this may take a while for large vector sizes)...")
    # Prepare sentences list for word2vec (list of token lists)
    tokenized_sents = []
    for p in sample_corpus:
        pieces = [s.strip() for s in p.strip().split('.') if s.strip()]
        for sent in pieces:
            tokenized_sents.append(sent.split())
    w2v = Word2Vec(sentences=tokenized_sents, vector_size=WORD2VEC_SIZE, window=WORD2VEC_WINDOW, min_count=WORD2VEC_MIN_COUNT, epochs=10)

    # Build embedding matrix aligned with itos
    embedding_weights = build_embedding_matrix(w2v, itos, HIDDEN_SIZE)

    # Create dataset and dataloader
    dataset = BertPretrainingDataset(paragraphs=sample_corpus, stoi=stoi, max_seq_len=MAX_SEQ_LEN)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=lambda b: collate_fn(b, pad_id))

    # Initialize model
    model = BertEncoderModel(vocab_size=vocab_size, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS, num_heads=NUM_HEADS, ffn_dim=FFN_DIM, max_position_embeddings=MAX_SEQ_LEN, pad_token_id=pad_id, embedding_weights=embedding_weights, dropout=DROPOUT)
    model.to(DEVICE)

    # Optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

    # Loss functions
    mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
    nsp_loss_fct = nn.CrossEntropyLoss()

    num_epochs = 1  # change as needed
    global_step = 0
    model.train()
    for epoch in range(num_epochs):
        for batch in dataloader:
            input_ids = batch["input_ids"].to(DEVICE)
            token_type_ids = batch["token_type_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            nsp_labels = batch["nsp_labels"].to(DEVICE)

            # Create MLM masked inputs & labels according to scheme
            input_ids_masked, mlm_labels = create_mlm_labels_and_masked_input(input_ids, pad_id, mask_id, vocab_size, mask_prob=MLM_MASK_PROB)
            input_ids_masked = input_ids_masked.to(DEVICE)
            mlm_labels = mlm_labels.to(DEVICE)

            # Forward
            mlm_logits, nsp_logits, _ = model(input_ids_masked, token_type_ids=token_type_ids, attention_mask=attention_mask)

            # Compute MLM loss: reshape to (batch * seq_len, vocab)
            mlm_loss = mlm_loss_fct(mlm_logits.view(-1, vocab_size), mlm_labels.view(-1))
            nsp_loss = nsp_loss_fct(nsp_logits.view(-1, 2), nsp_labels.view(-1))
            loss = mlm_loss + nsp_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            global_step += 1
            if global_step % 10 == 0:
                print(f"Epoch {epoch} step {global_step}: loss {loss.item():.4f} (mlm {mlm_loss.item():.4f} nsp {nsp_loss.item():.4f})")

    print("Training loop finished (toy example). Save model if desired.")
    # Example saving
    torch.save({
        "model_state_dict": model.state_dict(),
        "vocab": itos,
        "stoi": stoi
    }, "bert_encoder_toy_with_pooling.pth")
    print("Saved to bert_encoder_toy_with_pooling.pth")

    # -------------------------
    # Example: get pooled retrieval embeddings for some sentences
    # -------------------------
    model.eval()
    with torch.no_grad():
        # Take a small batch from dataset and produce embeddings
        sample_batch = next(iter(DataLoader(dataset, batch_size=2, collate_fn=lambda b: collate_fn(b, pad_id))))
        ids = sample_batch["input_ids"].to(DEVICE)
        masks = sample_batch["attention_mask"].to(DEVICE)
        types = sample_batch["token_type_ids"].to(DEVICE)
        embeddings = model.get_pooled_embeddings(ids, masks, token_type_ids=types, exclude_special=True, normalize=True)
        print("Pooled embeddings shape:", embeddings.shape)  # (batch, hidden)
        # First 2 dims
        print("First pooled vector (first 8 values):", embeddings[0, :8].cpu().numpy())

if __name__ == "__main__":
    main()