# 02 - BERT MLM from Scratch

Deep learning paper implementation from scratch using PyTorch.
1. **Simple Word-Level Tokenizer** - With special tokens [PAD], [UNK], [CLS], [SEP], [MASK]
- Mask probability: 0.15 vs 0.20
- Number of layers: 2 vs 4

In [None]:
import math
import random
import time
from typing import Optional, Tuple, List, Dict, Any
from dataclasses import dataclass
from collections import Counter

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm

print(f"PyTorch version: {torch.__version__}")

In [None]:
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
set_seed(42)

# Deterministic settings
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 4. Configuration

In [None]:
@dataclass
class BERTConfig:
    # Model architecture
    d_model: int = 256           # Hidden size
    n_heads: int = 8             # Number of attention heads
    n_layers: int = 4            # Number of encoder layers
    d_ff: int = 512              # Feed-forward intermediate size
    max_seq_len: int = 128       # Maximum sequence length
    vocab_size: int = 10000      # Will be updated after building tokenizer
    n_segments: int = 2          # Number of segment types (sentence A/B)
    dropout: float = 0.1         # Dropout rate
    
    # MLM settings
    mask_prob: float = 0.15      # Probability of masking a token
    
    # Training
    batch_size: int = 32
    learning_rate: float = 1e-4
    n_epochs: int = 10
    warmup_steps: int = 100
    
config = BERTConfig()


## 5. Tokenizer Implementation

We implement a simple word-level tokenizer. While BPE (Byte Pair Encoding) is more common in practice, word-level tokenization is simpler to implement and sufficient for demonstration.

### Trade-offs:
- **Word-level** (our choice):
  - ✅ Simple to implement
  - ✅ Words are semantically meaningful units
  - ❌ Large vocabulary for good coverage
  - ❌ Cannot handle OOV (out-of-vocabulary) words well
  
- **BPE / WordPiece**:
  - ✅ Smaller vocabulary
  - ✅ Better OOV handling (breaks into subwords)
  - ✅ Language-agnostic
  - ❌ More complex to implement
  - ❌ Subwords may not be semantically meaningful

In [None]:
class BERTTokenizer:
    
    PAD_TOKEN = "[PAD]"
    UNK_TOKEN = "[UNK]"
    CLS_TOKEN = "[CLS]"
    SEP_TOKEN = "[SEP]"
    MASK_TOKEN = "[MASK]"
    
    def __init__(self, min_freq: int = 2, max_vocab_size: Optional[int] = None):
        self.min_freq = min_freq
        self.max_vocab_size = max_vocab_size
        
        self.word2idx: Dict[str, int] = {}
        self.idx2word: Dict[int, str] = {}
        
        # Special token indices
        self.special_tokens = [
            self.PAD_TOKEN, self.UNK_TOKEN, self.CLS_TOKEN,
            self.SEP_TOKEN, self.MASK_TOKEN
        ]
        
    @property
    def pad_idx(self) -> int:
        return self.word2idx[self.PAD_TOKEN]
    
    @property
    def unk_idx(self) -> int:
        return self.word2idx[self.UNK_TOKEN]
    
    @property
    def cls_idx(self) -> int:
        return self.word2idx[self.CLS_TOKEN]
    
    @property
    def sep_idx(self) -> int:
        return self.word2idx[self.SEP_TOKEN]
    
    @property
    def mask_idx(self) -> int:
        return self.word2idx[self.MASK_TOKEN]
    
    @property
    def vocab_size(self) -> int:
        return len(self.word2idx)
    
    def fit(self, texts: List[str]) -> None:
        # Count word frequencies
        word_freq = Counter()
        for text in texts:
            words = self._tokenize(text)
            word_freq.update(words)
        
        # Add special tokens first
        for i, token in enumerate(self.special_tokens):
            self.word2idx[token] = i
            self.idx2word[i] = token
        
        # Add words that meet frequency threshold
        idx = len(self.special_tokens)
        sorted_words = sorted(word_freq.items(), key=lambda x: (-x[1], x[0]))
        
        for word, freq in sorted_words:
            if freq < self.min_freq:
                continue
            if self.max_vocab_size and idx >= self.max_vocab_size:
                break
            
            self.word2idx[word] = idx
            self.idx2word[idx] = word
            idx += 1
            
        print(f"Vocabulary built: {self.vocab_size} tokens")
        print(f"  - Special tokens: {len(self.special_tokens)}")
        print(f"  - Regular tokens: {self.vocab_size - len(self.special_tokens)}")
        
    def _tokenize(self, text: str) -> List[str]:
        # Remove punctuation and lowercase
        text = text.lower()
        # Keep only alphanumeric and spaces
        text = ''.join(c if c.isalnum() or c.isspace() else ' ' for c in text)
        return text.split()
    
    def encode(
        self,
        text: str,
        max_length: Optional[int] = None,
        add_special_tokens: bool = True,
        padding: bool = False
    ) -> Dict[str, List[int]]:
        words = self._tokenize(text)
        
        # Convert to indices
        token_ids = [self.word2idx.get(w, self.unk_idx) for w in words]
        
        # Add special tokens
        if add_special_tokens:
            token_ids = [self.cls_idx] + token_ids + [self.sep_idx]
        
        # Truncate if necessary
        if max_length and len(token_ids) > max_length:
            token_ids = token_ids[:max_length]
            if add_special_tokens:
                token_ids[-1] = self.sep_idx  # Ensure [SEP] at end
        
        attention_mask = [1] * len(token_ids)
        
        # Segment IDs (all 0 for single sentence)
        token_type_ids = [0] * len(token_ids)
        
        # Pad if necessary
        if padding and max_length:
            pad_len = max_length - len(token_ids)
            token_ids = token_ids + [self.pad_idx] * pad_len
            attention_mask = attention_mask + [0] * pad_len
            token_type_ids = token_type_ids + [0] * pad_len
        
        return {
            'input_ids': token_ids,
            'attention_mask': attention_mask,
            'token_type_ids': token_type_ids
        }
    
    def decode(self, token_ids: List[int], skip_special_tokens: bool = True) -> str:
        words = []
        for idx in token_ids:
            word = self.idx2word.get(idx, self.UNK_TOKEN)
            if skip_special_tokens and word in self.special_tokens:
                continue
            words.append(word)
        return ' '.join(words)
    
    def get_random_token(self) -> int:
        return random.randint(len(self.special_tokens), self.vocab_size - 1)

In [None]:
# Sample corpus for training
SAMPLE_TEXTS = [
    "The quick brown fox jumps over the lazy dog.",
    "A journey of a thousand miles begins with a single step.",
    "To be or not to be that is the question.",
    "All that glitters is not gold.",
    "The only thing we have to fear is fear itself.",
    "In the beginning was the word and the word was with god.",
    "It was the best of times it was the worst of times.",
    "Call me ishmael some years ago never mind how long precisely.",
    "It is a truth universally acknowledged that a single man in possession of a good fortune must be in want of a wife.",
    "Happy families are all alike every unhappy family is unhappy in its own way.",
    "The sun rose slowly over the mountains casting long shadows across the valley below.",
    "She walked through the forest listening to the birds singing their morning songs.",
    "The old man sat by the fire remembering the days of his youth.",
    "The city streets were busy with people hurrying to their destinations.",
    "A gentle breeze blew through the open window bringing the scent of flowers.",
    "The children played in the garden while their parents watched from the porch.",
    "He picked up the book and began to read losing himself in the story.",
    "The stars twinkled in the night sky like diamonds scattered across velvet.",
    "She smiled at the memory of their first meeting so many years ago.",
    "The waves crashed against the shore creating a soothing rhythm.",
    "Deep in the forest there lived a wise old owl who knew many secrets.",
    "The train departed from the station carrying passengers to distant lands.",
    "Music filled the air as the orchestra began their evening performance.",
    "The scientist worked late into the night trying to solve the puzzle.",
    "Rain began to fall gently at first then harder until it became a downpour.",
    "The ancient temple stood silent witness to centuries of human history.",
    "Birds migrated south as the leaves began to change colors in autumn.",
    "The chef prepared an exquisite meal using fresh ingredients from the garden.",
    "Lightning flashed across the sky followed by a thunderous roar.",
    "The artist captured the essence of beauty in every brushstroke.",
]

# Expand dataset by repetition
TEXTS = SAMPLE_TEXTS * 100
print(f"Total texts: {len(TEXTS)}")

# Build tokenizer
tokenizer = BERTTokenizer(min_freq=2)
tokenizer.fit(TEXTS)

# Update config
config.vocab_size = tokenizer.vocab_size

In [None]:
# Test tokenizer
test_text = "The quick brown fox jumps over the lazy dog."
encoded = tokenizer.encode(test_text, max_length=20, padding=True)

print(f"Original: {test_text}")
print(f"Token IDs: {encoded['input_ids']}")
print(f"Attention mask: {encoded['attention_mask']}")
print(f"Decoded: {tokenizer.decode(encoded['input_ids'])}")

## 7. MLM Dataset with Dynamic Masking

In [None]:
class MLMDataset(Dataset):
    
    def __init__(
        self,
        texts: List[str],
        tokenizer: BERTTokenizer,
        max_length: int,
        mask_prob: float = 0.15
    ):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.mask_prob = mask_prob
        
        # Pre-tokenize all texts
        self.examples = []
        for text in texts:
            encoded = tokenizer.encode(
                text, 
                max_length=max_length, 
                add_special_tokens=True,
                padding=True
            )
            self.examples.append(encoded)
            
        print(f"Created {len(self.examples)} examples")
        
    def __len__(self) -> int:
        return len(self.examples)
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        example = self.examples[idx]
        
        input_ids = example['input_ids'].copy()
        attention_mask = example['attention_mask'].copy()
        token_type_ids = example['token_type_ids'].copy()
        
        # Labels: -100 means "ignore" in cross-entropy loss
        labels = [-100] * len(input_ids)
        masked_positions = [False] * len(input_ids)
        
        # Apply masking
        for i, token_id in enumerate(input_ids):
            # Don't mask special tokens or padding
            if token_id in [self.tokenizer.pad_idx, self.tokenizer.cls_idx, 
                           self.tokenizer.sep_idx, self.tokenizer.mask_idx]:
                continue
            
            # Randomly decide to mask this token
            if random.random() < self.mask_prob:
                # Store original label
                labels[i] = token_id
                masked_positions[i] = True
                
                # Decide how to corrupt
                rand = random.random()
                if rand < 0.8:
                    # 80%: Replace with [MASK]
                    input_ids[i] = self.tokenizer.mask_idx
                elif rand < 0.9:
                    # 10%: Replace with random token
                    input_ids[i] = self.tokenizer.get_random_token()
                # else: 10%: Keep original (already set)
        
        return {
            'input_ids': torch.tensor(input_ids, dtype=torch.long),
            'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
            'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
            'labels': torch.tensor(labels, dtype=torch.long),
            'masked_positions': torch.tensor(masked_positions, dtype=torch.bool)
        }

# Create dataset
mlm_dataset = MLMDataset(TEXTS, tokenizer, config.max_seq_len, config.mask_prob)
mlm_dataloader = DataLoader(mlm_dataset, batch_size=config.batch_size, shuffle=True)

In [None]:
# Sanity check: Verify masking ratio
print("Verifying masking ratio...")

total_maskable = 0
total_masked = 0

for _ in range(100):
    sample = mlm_dataset[random.randint(0, len(mlm_dataset)-1)]
    # Count maskable tokens (non-special, non-padding)
    maskable = (sample['labels'] != -100).sum().item() + \
               ((sample['attention_mask'] == 1) & (sample['labels'] == -100) & 
                (sample['input_ids'] != tokenizer.cls_idx) & 
                (sample['input_ids'] != tokenizer.sep_idx)).sum().item()
    masked = (sample['labels'] != -100).sum().item()
    
    total_maskable += maskable
    total_masked += masked

observed_ratio = total_masked / total_maskable if total_maskable > 0 else 0
print(f"Expected mask ratio: {config.mask_prob:.2%}")
print(f"Observed mask ratio: {observed_ratio:.2%}")
assert abs(observed_ratio - config.mask_prob) < 0.05, "Masking ratio significantly off!"
print("✓ Masking ratio is within expected range")

## 8. BERT Model Components

### 8.1 Embeddings

BERT uses three types of embeddings that are summed together:
1. **Token embeddings**: Learned embedding for each vocabulary token
2. **Position embeddings**: Learned embedding for each position
3. **Segment embeddings**: Learned embedding for sentence A vs B (for NSP task)

In [None]:
class BERTEmbeddings(nn.Module):
    
    def __init__(self, config: BERTConfig):
        super().__init__()
        
        # Token embeddings
        self.token_embeddings = nn.Embedding(
            config.vocab_size, config.d_model, padding_idx=0
        )
        
        # Position embeddings (learned, not sinusoidal)
        self.position_embeddings = nn.Embedding(
            config.max_seq_len, config.d_model
        )
        
        # Segment embeddings (for sentence A/B)
        self.segment_embeddings = nn.Embedding(
            config.n_segments, config.d_model
        )
        
        # Layer normalization and dropout
        self.layer_norm = nn.LayerNorm(config.d_model, eps=1e-12)
        self.dropout = nn.Dropout(config.dropout)
        
        # Register position IDs buffer
        self.register_buffer(
            'position_ids',
            torch.arange(config.max_seq_len).unsqueeze(0)
        )
        
    def forward(
        self,
        input_ids: torch.Tensor,
        token_type_ids: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        seq_len = input_ids.size(1)
        
        # Get position IDs
        position_ids = self.position_ids[:, :seq_len]
        
        # Default segment IDs to 0
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)
        
        # Get embeddings
        token_embeds = self.token_embeddings(input_ids)
        position_embeds = self.position_embeddings(position_ids)
        segment_embeds = self.segment_embeddings(token_type_ids)
        
        # Sum all embeddings
        embeddings = token_embeds + position_embeds + segment_embeds
        
        # Apply layer norm and dropout
        embeddings = self.layer_norm(embeddings)
        embeddings = self.dropout(embeddings)
        
        return embeddings

# Test embeddings
embed = BERTEmbeddings(config)
test_ids = torch.randint(0, config.vocab_size, (2, 20))
test_seg = torch.zeros_like(test_ids)
out = embed(test_ids, test_seg)
assert out.shape == (2, 20, config.d_model)
print(f"Embeddings test passed. Output shape: {out.shape}")

### 8.2 Multi-Head Self-Attention

In [None]:
class MultiHeadSelfAttention(nn.Module):
    
    def __init__(self, config: BERTConfig):
        super().__init__()
        assert config.d_model % config.n_heads == 0
        
        self.d_model = config.d_model
        self.n_heads = config.n_heads
        self.d_head = config.d_model // config.n_heads
        
        # Linear projections
        self.query = nn.Linear(config.d_model, config.d_model)
        self.key = nn.Linear(config.d_model, config.d_model)
        self.value = nn.Linear(config.d_model, config.d_model)
        self.output = nn.Linear(config.d_model, config.d_model)
        
        self.dropout = nn.Dropout(config.dropout)
        
    def forward(
        self,
        x: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        batch_size, seq_len, _ = x.shape
        
        # Project to Q, K, V
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)
        
        # Reshape to (batch, n_heads, seq_len, d_head)
        Q = Q.view(batch_size, seq_len, self.n_heads, self.d_head).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.n_heads, self.d_head).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.n_heads, self.d_head).transpose(1, 2)
        
        # Compute attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_head)
        
        # Apply attention mask
        if attention_mask is not None:
            # Expand mask to (batch, 1, 1, seq_len)
            mask = attention_mask.unsqueeze(1).unsqueeze(2)
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        # Softmax and dropout
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Apply attention to values
        attn_output = torch.matmul(attn_weights, V)
        
        # Reshape back
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, self.d_model
        )
        
        # Output projection
        output = self.output(attn_output)
        
        return output

# Test attention
attn = MultiHeadSelfAttention(config)
x = torch.randn(2, 20, config.d_model)
mask = torch.ones(2, 20)
out = attn(x, mask)
assert out.shape == x.shape
print(f"Multi-Head Attention test passed. Output shape: {out.shape}")

### 8.3 Feed-Forward Network

In [None]:
class FeedForward(nn.Module):
    
    def __init__(self, config: BERTConfig):
        super().__init__()
        self.linear1 = nn.Linear(config.d_model, config.d_ff)
        self.linear2 = nn.Linear(config.d_ff, config.d_model)
        self.dropout = nn.Dropout(config.dropout)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # BERT uses GELU activation instead of ReLU
        return self.linear2(self.dropout(F.gelu(self.linear1(x))))

# Test FFN
ffn = FeedForward(config)
x = torch.randn(2, 20, config.d_model)
out = ffn(x)
assert out.shape == x.shape
print(f"FFN test passed. Output shape: {out.shape}")

### 8.4 Encoder Layer

In [None]:
class BERTEncoderLayer(nn.Module):
    
    def __init__(self, config: BERTConfig):
        super().__init__()
        self.attention = MultiHeadSelfAttention(config)
        self.feed_forward = FeedForward(config)
        self.norm1 = nn.LayerNorm(config.d_model, eps=1e-12)
        self.norm2 = nn.LayerNorm(config.d_model, eps=1e-12)
        self.dropout = nn.Dropout(config.dropout)
        
    def forward(
        self,
        x: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        # Self-attention with residual + norm
        attn_output = self.attention(x, attention_mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # FFN with residual + norm
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        
        return x

# Test encoder layer
enc_layer = BERTEncoderLayer(config)
x = torch.randn(2, 20, config.d_model)
out = enc_layer(x)
assert out.shape == x.shape
print(f"Encoder layer test passed. Output shape: {out.shape}")

### 8.5 MLM Head

In [None]:
class MLMHead(nn.Module):
    
    def __init__(self, config: BERTConfig):
        super().__init__()
        self.dense = nn.Linear(config.d_model, config.d_model)
        self.layer_norm = nn.LayerNorm(config.d_model, eps=1e-12)
        self.decoder = nn.Linear(config.d_model, config.vocab_size)
        
        # Bias for each vocabulary token
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
        
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        x = self.dense(hidden_states)
        x = F.gelu(x)
        x = self.layer_norm(x)
        logits = self.decoder(x) + self.bias
        return logits

# Test MLM head
mlm_head = MLMHead(config)
x = torch.randn(2, 20, config.d_model)
out = mlm_head(x)
assert out.shape == (2, 20, config.vocab_size)
print(f"MLM head test passed. Output shape: {out.shape}")

## 9. Full BERT Model for MLM

In [None]:
class BERTForMLM(nn.Module):
    
    def __init__(self, config: BERTConfig):
        super().__init__()
        self.config = config
        
        # Embeddings
        self.embeddings = BERTEmbeddings(config)
        
        # Encoder layers
        self.encoder_layers = nn.ModuleList([
            BERTEncoderLayer(config) for _ in range(config.n_layers)
        ])
        
        # MLM head
        self.mlm_head = MLMHead(config)
        
        # Initialize weights
        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
            
    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        # Get embeddings
        hidden_states = self.embeddings(input_ids, token_type_ids)
        
        # Apply encoder layers
        for layer in self.encoder_layers:
            hidden_states = layer(hidden_states, attention_mask)
        
        # Get MLM predictions
        logits = self.mlm_head(hidden_states)
        
        return logits

# Create and test model
model = BERTForMLM(config).to(device)

# Test forward pass
batch = next(iter(mlm_dataloader))
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
token_type_ids = batch['token_type_ids'].to(device)

logits = model(input_ids, attention_mask, token_type_ids)
assert logits.shape == (config.batch_size, config.max_seq_len, config.vocab_size)
print(f"Model test passed. Output shape: {logits.shape}")

# Count parameters
num_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {num_params:,}")

In [None]:
# Sanity check 1: Only masked positions contribute to loss
print("Verifying loss computation...")

batch = next(iter(mlm_dataloader))
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
token_type_ids = batch['token_type_ids'].to(device)
labels = batch['labels'].to(device)

# Forward pass
logits = model(input_ids, attention_mask, token_type_ids)

# Compute loss with ignore_index=-100
loss = F.cross_entropy(
    logits.view(-1, config.vocab_size),
    labels.view(-1),
    ignore_index=-100
)

# Verify masked positions only
num_masked = (labels != -100).sum().item()
print(f"Number of masked positions: {num_masked}")
print(f"Loss: {loss.item():.4f}")

# Manually verify
manual_loss = 0.0
count = 0
for i in range(labels.size(0)):
    for j in range(labels.size(1)):
        if labels[i, j] != -100:
            manual_loss += F.cross_entropy(
                logits[i, j].unsqueeze(0),
                labels[i, j].unsqueeze(0)
            ).item()
            count += 1

manual_loss /= count
print(f"Manual loss: {manual_loss:.4f}")
assert abs(loss.item() - manual_loss) < 0.01, "Loss mismatch!"
print("✓ Loss is correctly computed only for masked positions")

In [None]:
# Sanity check 2: Gradient flow
print("\nVerifying gradient flow...")

model.zero_grad()
loss.backward()

# Check that all parameters have gradients
grad_norms = []
for name, param in model.named_parameters():
    if param.requires_grad:
        assert param.grad is not None, f"No gradient for {name}"
        assert not torch.isnan(param.grad).any(), f"NaN gradient for {name}"
        grad_norms.append((name, param.grad.norm().item()))

print("Sample gradient norms:")
for name, norm in grad_norms[:5]:
    print(f"  {name}: {norm:.6f}")
print("✓ All gradients are valid")

## 11. Training Loop

In [None]:
def compute_mlm_accuracy(
    logits: torch.Tensor,
    labels: torch.Tensor
) -> float:
    predictions = logits.argmax(dim=-1)  # (batch, seq_len)
    mask = labels != -100
    
    if mask.sum() == 0:
        return 0.0
    
    correct = (predictions == labels) & mask
    accuracy = correct.sum().float() / mask.sum().float()
    return accuracy.item()


def train_bert_epoch(
    model: BERTForMLM,
    dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    device: torch.device
) -> Tuple[float, float]:
    model.train()
    total_loss = 0.0
    total_accuracy = 0.0
    num_batches = 0
    
    for batch in tqdm(dataloader, desc="Training", leave=False):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad()
        
        # Forward pass
        logits = model(input_ids, attention_mask, token_type_ids)
        
        # Compute loss
        loss = F.cross_entropy(
            logits.view(-1, model.config.vocab_size),
            labels.view(-1),
            ignore_index=-100
        )
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        # Track metrics
        total_loss += loss.item()
        total_accuracy += compute_mlm_accuracy(logits, labels)
        num_batches += 1
    
    return total_loss / num_batches, total_accuracy / num_batches


def evaluate_bert(
    model: BERTForMLM,
    dataloader: DataLoader,
    device: torch.device
) -> Tuple[float, float]:
    model.eval()
    total_loss = 0.0
    total_accuracy = 0.0
    num_batches = 0
    
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            labels = batch['labels'].to(device)
            
            logits = model(input_ids, attention_mask, token_type_ids)
            
            loss = F.cross_entropy(
                logits.view(-1, model.config.vocab_size),
                labels.view(-1),
                ignore_index=-100
            )
            
            total_loss += loss.item()
            total_accuracy += compute_mlm_accuracy(logits, labels)
            num_batches += 1
    
    return total_loss / num_batches, total_accuracy / num_batches

In [None]:
def train_bert_model(
    config: BERTConfig,
    dataloader: DataLoader,
    device: torch.device,
    model_name: str = "BERT"
) -> Tuple[BERTForMLM, List[float], List[float]]:
    model = BERTForMLM(config).to(device)
    
    # Optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.learning_rate,
        betas=(0.9, 0.999),
        weight_decay=0.01
    )
    
    # Learning rate scheduler with warmup
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=config.n_epochs
    )
    
    losses = []
    accuracies = []
    
    print(f"\n{'='*60}")
    print(f"Training {model_name}")
    print(f"Layers: {config.n_layers}, Mask prob: {config.mask_prob}")
    print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(f"{'='*60}")
    
    start_time = time.time()
    
    for epoch in range(config.n_epochs):
        epoch_start = time.time()
        
        train_loss, train_acc = train_bert_epoch(model, dataloader, optimizer, device)
        
        # Step scheduler
        scheduler.step()
        
        # Record
        losses.append(train_loss)
        accuracies.append(train_acc)
        
        epoch_time = time.time() - epoch_start
        
        print(f"Epoch {epoch+1}/{config.n_epochs} | "
              f"Loss: {train_loss:.4f} | "
              f"MLM Acc: {train_acc:.2%} | "
              f"Time: {epoch_time:.1f}s")
    
    total_time = time.time() - start_time
    print(f"\nTotal training time: {total_time:.1f}s")
    
    return model, losses, accuracies

## 12. Training with Default Configuration

In [None]:
# Train with default config (4 layers, 15% masking)
set_seed(42)
model_default, losses_default, acc_default = train_bert_model(
    config, mlm_dataloader, device, "Default (4 layers, 15% mask)"
)

## 13. Ablation 1: Mask Probability (0.15 vs 0.20)

In [None]:
# Create dataset with 20% masking
config_20mask = BERTConfig(
    d_model=config.d_model,
    n_heads=config.n_heads,
    n_layers=config.n_layers,
    d_ff=config.d_ff,
    max_seq_len=config.max_seq_len,
    vocab_size=config.vocab_size,
    dropout=config.dropout,
    mask_prob=0.20,  # Increased from 0.15
    batch_size=config.batch_size,
    learning_rate=config.learning_rate,
    n_epochs=config.n_epochs
)

# Create new dataset with 20% masking
mlm_dataset_20 = MLMDataset(TEXTS, tokenizer, config.max_seq_len, mask_prob=0.20)
mlm_dataloader_20 = DataLoader(mlm_dataset_20, batch_size=config.batch_size, shuffle=True)

set_seed(42)
model_20mask, losses_20mask, acc_20mask = train_bert_model(
    config_20mask, mlm_dataloader_20, device, "20% Masking"
)

## 14. Ablation 2: Number of Layers (2 vs 4)

In [None]:
# Create config with 2 layers
config_2layers = BERTConfig(
    d_model=config.d_model,
    n_heads=config.n_heads,
    n_layers=2,  # Reduced from 4
    d_ff=config.d_ff,
    max_seq_len=config.max_seq_len,
    vocab_size=config.vocab_size,
    dropout=config.dropout,
    mask_prob=config.mask_prob,
    batch_size=config.batch_size,
    learning_rate=config.learning_rate,
    n_epochs=config.n_epochs
)

set_seed(42)
model_2layers, losses_2layers, acc_2layers = train_bert_model(
    config_2layers, mlm_dataloader, device, "2 Layers"
)

## 15. Results Visualization

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss curves
axes[0].plot(losses_default, label='4 layers, 15% mask', marker='o')
axes[0].plot(losses_20mask, label='4 layers, 20% mask', marker='s')
axes[0].plot(losses_2layers, label='2 layers, 15% mask', marker='^')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Training Loss')
axes[0].set_title('MLM Training Loss Comparison')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy curves
axes[1].plot(acc_default, label='4 layers, 15% mask', marker='o')
axes[1].plot(acc_20mask, label='4 layers, 20% mask', marker='s')
axes[1].plot(acc_2layers, label='2 layers, 15% mask', marker='^')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('MLM Accuracy')
axes[1].set_title('MLM Accuracy Comparison')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
axes[1].yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: '{:.0%}'.format(y)))

plt.tight_layout()
plt.show()

In [None]:
# Results table
print("\n" + "="*80)
print("RESULTS SUMMARY")
print("="*80)
print(f"{'Configuration':<30} {'Final Loss':<15} {'Final MLM Acc':<15} {'Params':<15}")
print("-"*80)

configs_results = [
    ("4 layers, 15% mask (default)", losses_default[-1], acc_default[-1], 
     sum(p.numel() for p in model_default.parameters())),
    ("4 layers, 20% mask", losses_20mask[-1], acc_20mask[-1],
     sum(p.numel() for p in model_20mask.parameters())),
    ("2 layers, 15% mask", losses_2layers[-1], acc_2layers[-1],
     sum(p.numel() for p in model_2layers.parameters())),
]

for name, loss, acc, params in configs_results:
    print(f"{name:<30} {loss:<15.4f} {acc:<15.2%} {params:,}")
print("="*80)

## 16. MLM Inference Demo

In [None]:
def predict_masked_tokens(
    model: BERTForMLM,
    text: str,
    tokenizer: BERTTokenizer,
    device: torch.device,
    top_k: int = 5
) -> None:
    model.eval()
    
    # Tokenize (but keep [MASK] tokens)
    words = text.lower().split()
    token_ids = []
    mask_positions = []
    
    token_ids.append(tokenizer.cls_idx)
    for i, word in enumerate(words):
        if word == '[mask]':
            token_ids.append(tokenizer.mask_idx)
            mask_positions.append(len(token_ids) - 1)
        else:
            idx = tokenizer.word2idx.get(word, tokenizer.unk_idx)
            token_ids.append(idx)
    token_ids.append(tokenizer.sep_idx)
    
    # Create tensors
    input_ids = torch.tensor([token_ids], dtype=torch.long).to(device)
    attention_mask = torch.ones_like(input_ids)
    
    # Get predictions
    with torch.no_grad():
        logits = model(input_ids, attention_mask)
    
    print(f"Input: {text}")
    print(f"\nPredictions:")
    
    for pos in mask_positions:
        probs = F.softmax(logits[0, pos], dim=-1)
        top_probs, top_indices = probs.topk(top_k)
        
        print(f"\nPosition {pos}:")
        for prob, idx in zip(top_probs, top_indices):
            word = tokenizer.idx2word.get(idx.item(), '[UNK]')
            print(f"  {word}: {prob.item():.4f}")

# Test prediction
print("=" * 60)
print("MLM INFERENCE DEMO")
print("=" * 60)

predict_masked_tokens(
    model_default,
    "The quick brown [MASK] jumps over the lazy dog",
    tokenizer,
    device
)

print("\n" + "-" * 60)

predict_masked_tokens(
    model_default,
    "The [MASK] rose slowly over the mountains",
    tokenizer,
    device
)