# BERT: Part 3

## Pre-training - Masked Language Modeling and Next Sentence Prediction

---

**Paper:** [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805)

---

This is where BERT gets its power. The pre-training tasks are what teach BERT to understand language before it ever sees your specific task.

Let's look at exactly how this works.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import FancyBboxPatch
np.random.seed(42)

---

## Task 1: Masked Language Modeling (MLM)

The core idea: randomly hide some words, ask the model to predict them.

```
Original:  "The cat sat on the mat"
Masked:    "The [MASK] sat on the mat"
Target:    Predict "cat"
```

This forces the model to understand context from both directions.

### The Masking Strategy: 15% with a Twist

BERT masks **15%** of tokens in each sequence. But here's the clever part - it doesn't always replace with [MASK].

For the selected 15%:
- **80%** are replaced with [MASK]
- **10%** are replaced with a random word
- **10%** are kept unchanged

**Why this weird split?**

In [None]:
# Masking strategy visualization
fig, ax = plt.subplots(figsize=(12, 8))
ax.set_xlim(0, 12)
ax.set_ylim(0, 8)
ax.axis('off')

ax.text(6, 7.5, 'BERT Masking Strategy', fontsize=14, ha='center', fontweight='bold')

# Original sentence
ax.text(1, 6.5, 'Original:', fontsize=11, fontweight='bold')
ax.text(3.5, 6.5, '"The cat sat on the fluffy mat"', fontsize=11)

# Selected for masking
ax.text(1, 5.5, 'Selected (15%):', fontsize=11, fontweight='bold')
ax.text(3.5, 5.5, '"cat" and "fluffy" chosen randomly', fontsize=11, color='#e74c3c')

# Three branches
branches = [
    (2, 4, '80% → [MASK]', '"The [MASK] sat on the [MASK] mat"', '#e74c3c'),
    (6, 4, '10% → Random', '"The dog sat on the purple mat"', '#f39c12'),
    (10, 4, '10% → Keep', '"The cat sat on the fluffy mat"', '#27ae60'),
]

for x, y, label, example, color in branches:
    box = FancyBboxPatch((x-1.5, y-0.8), 3, 1.2, boxstyle="round,pad=0.05",
                          facecolor=color, edgecolor='none', alpha=0.8)
    ax.add_patch(box)
    ax.text(x, y, label, fontsize=10, ha='center', va='center', 
            color='white', fontweight='bold')
    ax.text(x, y-1.5, example, fontsize=8, ha='center', color=color)

# Arrows
ax.annotate('', xy=(2, 4.7), xytext=(4, 5.3), arrowprops=dict(arrowstyle='->', color='#333'))
ax.annotate('', xy=(6, 4.7), xytext=(6, 5.3), arrowprops=dict(arrowstyle='->', color='#333'))
ax.annotate('', xy=(10, 4.7), xytext=(8, 5.3), arrowprops=dict(arrowstyle='->', color='#333'))

# Explanation
ax.text(6, 1.5, 'Why this split?', fontsize=11, ha='center', fontweight='bold')
explanations = [
    '• 80% [MASK]: Main signal - learn to predict from context',
    '• 10% Random: Teaches model that ANY position might need correction',
    '• 10% Keep: Prevents model from only learning when it sees [MASK]',
]
for i, exp in enumerate(explanations):
    ax.text(2, 1 - i*0.4, exp, fontsize=9, ha='left')

plt.tight_layout()
plt.show()

### Why Not Just Use [MASK]?

The problem: during fine-tuning and inference, there's no [MASK] token. The model never sees [MASK] in real data.

If we only trained with [MASK], the model might:
- Only pay attention when it sees [MASK]
- Not learn to produce good representations for normal tokens

The 10% random and 10% unchanged cases force the model to:
- Maintain good representations for ALL positions
- Not rely on the presence of [MASK] as a signal

This is a small but important detail that makes BERT work well at inference time.

In [None]:
# Implement the masking procedure
def create_mlm_examples(tokens, vocab_size=30000, mask_token_id=103, 
                         mask_prob=0.15, mask_token_prob=0.8, random_prob=0.1):
    """
    Create masked language modeling examples.
    
    Args:
        tokens: List of token IDs
        vocab_size: Size of vocabulary (for random replacement)
        mask_token_id: ID of [MASK] token
        mask_prob: Probability of selecting a token for masking
        mask_token_prob: Of selected, probability of replacing with [MASK]
        random_prob: Of selected, probability of replacing with random token
    
    Returns:
        masked_tokens: Tokens with masking applied
        labels: Original tokens for masked positions, -100 for others
        mask_positions: Which positions were selected
    """
    tokens = np.array(tokens)
    masked_tokens = tokens.copy()
    labels = np.full_like(tokens, -100)  # -100 = ignore in loss
    
    # Randomly select 15% of positions
    mask_positions = np.random.random(len(tokens)) < mask_prob
    
    # Don't mask special tokens (positions 0 and last are usually [CLS] and [SEP])
    mask_positions[0] = False
    mask_positions[-1] = False
    
    for i in np.where(mask_positions)[0]:
        labels[i] = tokens[i]  # Save original for prediction target
        
        rand = np.random.random()
        if rand < mask_token_prob:  # 80% -> [MASK]
            masked_tokens[i] = mask_token_id
        elif rand < mask_token_prob + random_prob:  # 10% -> random
            masked_tokens[i] = np.random.randint(1000, vocab_size)  # Avoid special tokens
        # else: 10% -> keep original (do nothing)
    
    return masked_tokens, labels, mask_positions

# Demo
# Simulated token IDs for "[CLS] The cat sat on the mat [SEP]"
# (In reality, these would come from the tokenizer)
token_ids = [101, 1996, 4937, 2068, 2006, 1996, 13523, 102]  # Example IDs
token_names = ['[CLS]', 'The', 'cat', 'sat', 'on', 'the', 'mat', '[SEP]']

print("Original tokens:", token_names)
print("Token IDs:", token_ids)
print()

# Run masking multiple times to show variation
print("Masking examples (15% chance per token):")
print("-" * 50)
for i in range(5):
    masked, labels, positions = create_mlm_examples(token_ids)
    
    result = []
    for j, (orig, mask, label) in enumerate(zip(token_names, masked, labels)):
        if label != -100:  # This position was selected
            if mask == 103:
                result.append('[MASK]')
            elif mask != token_ids[j]:
                result.append(f'[RAND:{mask}]')
            else:
                result.append(f'{orig}*')  # Kept same
        else:
            result.append(orig)
    
    print(f"Example {i+1}: {' '.join(result)}")

### How MLM Prediction Works

For each masked position, BERT predicts the original token:

1. Get the hidden state at that position (768 dimensions for BERT-Base)
2. Pass through a prediction head (linear layer + softmax)
3. Output: probability distribution over vocabulary
4. Loss: cross-entropy between prediction and actual token

In [None]:
# MLM prediction visualization
fig, ax = plt.subplots(figsize=(14, 8))
ax.set_xlim(0, 14)
ax.set_ylim(0, 8)
ax.axis('off')

ax.text(7, 7.5, 'MLM Prediction at Masked Position', fontsize=13, ha='center', fontweight='bold')

# Input with mask
tokens = ['[CLS]', 'The', '[MASK]', 'sat', 'on', 'mat', '[SEP]']
for i, token in enumerate(tokens):
    color = '#e74c3c' if token == '[MASK]' else '#3498db'
    ax.text(2 + i*1.5, 6, token, fontsize=10, ha='center',
            bbox=dict(boxstyle='round,pad=0.2', facecolor=color, edgecolor='none', alpha=0.7),
            color='white')

# BERT
bert_box = FancyBboxPatch((1, 4), 11, 1.2, boxstyle="round,pad=0.05",
                           facecolor='#2c3e50', edgecolor='none', alpha=0.9)
ax.add_patch(bert_box)
ax.text(6.5, 4.6, 'BERT (12 layers)', fontsize=11, ha='center', va='center',
        color='white', fontweight='bold')

# Arrows to BERT
for i in range(7):
    ax.annotate('', xy=(2 + i*1.5, 5.2), xytext=(2 + i*1.5, 5.6),
                arrowprops=dict(arrowstyle='->', color='#333', lw=1))

# Hidden state for masked position
ax.annotate('', xy=(5, 3.2), xytext=(5, 4),
            arrowprops=dict(arrowstyle='->', color='#e74c3c', lw=2))

hidden_box = FancyBboxPatch((3.5, 2.2), 3, 0.8, boxstyle="round,pad=0.05",
                             facecolor='#e74c3c', edgecolor='none', alpha=0.8)
ax.add_patch(hidden_box)
ax.text(5, 2.6, 'Hidden state (768 dims)', fontsize=9, ha='center', va='center',
        color='white', fontweight='bold')

# Prediction head
ax.annotate('', xy=(5, 1.4), xytext=(5, 2.2),
            arrowprops=dict(arrowstyle='->', color='#333', lw=1.5))

pred_box = FancyBboxPatch((3, 0.4), 4, 0.8, boxstyle="round,pad=0.05",
                           facecolor='#27ae60', edgecolor='none', alpha=0.8)
ax.add_patch(pred_box)
ax.text(5, 0.8, 'Linear + Softmax → P(cat) = 0.87', fontsize=9, ha='center', va='center',
        color='white', fontweight='bold')

# Other positions (no prediction needed)
ax.text(10, 2.6, 'Other positions:\nno prediction\n(labels = -100)', fontsize=9, 
        ha='center', color='#666')

plt.tight_layout()
plt.show()

---

## Task 2: Next Sentence Prediction (NSP)

MLM teaches word-level understanding. NSP teaches sentence-level understanding.

### The Task

Given two sentences A and B, predict: does B actually follow A in the original text?

```
Positive example (IsNext):
  A: "The man went to the store."
  B: "He bought some milk."
  Label: IsNext

Negative example (NotNext):
  A: "The man went to the store."
  B: "Penguins live in Antarctica."
  Label: NotNext
```

Training data: 50% real consecutive sentences, 50% random pairs.

In [None]:
# NSP visualization
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Positive example
ax = axes[0]
ax.set_xlim(0, 10)
ax.set_ylim(0, 6)
ax.axis('off')
ax.set_title('Positive: IsNext (50%)', fontsize=12, fontweight='bold', color='#27ae60')

ax.text(5, 5, 'Sentence A:', fontsize=10, ha='center', fontweight='bold')
ax.text(5, 4.3, '"The man went to the store."', fontsize=10, ha='center',
        bbox=dict(boxstyle='round,pad=0.3', facecolor='#3498db', edgecolor='none', alpha=0.7),
        color='white')

ax.text(5, 3.2, 'Sentence B:', fontsize=10, ha='center', fontweight='bold')
ax.text(5, 2.5, '"He bought some milk."', fontsize=10, ha='center',
        bbox=dict(boxstyle='round,pad=0.3', facecolor='#27ae60', edgecolor='none', alpha=0.7),
        color='white')

ax.text(5, 1.2, 'B actually follows A in original text', fontsize=9, ha='center', style='italic')
ax.text(5, 0.6, 'Label: IsNext', fontsize=11, ha='center', fontweight='bold', color='#27ae60')

# Negative example
ax = axes[1]
ax.set_xlim(0, 10)
ax.set_ylim(0, 6)
ax.axis('off')
ax.set_title('Negative: NotNext (50%)', fontsize=12, fontweight='bold', color='#e74c3c')

ax.text(5, 5, 'Sentence A:', fontsize=10, ha='center', fontweight='bold')
ax.text(5, 4.3, '"The man went to the store."', fontsize=10, ha='center',
        bbox=dict(boxstyle='round,pad=0.3', facecolor='#3498db', edgecolor='none', alpha=0.7),
        color='white')

ax.text(5, 3.2, 'Sentence B:', fontsize=10, ha='center', fontweight='bold')
ax.text(5, 2.5, '"Penguins live in Antarctica."', fontsize=10, ha='center',
        bbox=dict(boxstyle='round,pad=0.3', facecolor='#e74c3c', edgecolor='none', alpha=0.7),
        color='white')

ax.text(5, 1.2, 'B is a random sentence (not related)', fontsize=9, ha='center', style='italic')
ax.text(5, 0.6, 'Label: NotNext', fontsize=11, ha='center', fontweight='bold', color='#e74c3c')

plt.tight_layout()
plt.show()

### How NSP Prediction Works

NSP uses the [CLS] token representation:

1. Take the hidden state of [CLS] (which has "seen" both sentences)
2. Pass through a linear layer → 2 classes
3. Softmax → probability of IsNext vs NotNext
4. Binary cross-entropy loss

In [None]:
# NSP prediction visualization
fig, ax = plt.subplots(figsize=(12, 7))
ax.set_xlim(0, 12)
ax.set_ylim(0, 7)
ax.axis('off')

ax.text(6, 6.5, 'NSP Uses [CLS] Token', fontsize=13, ha='center', fontweight='bold')

# Input
tokens = ['[CLS]', 'Sent', 'A', '...', '[SEP]', 'Sent', 'B', '...', '[SEP]']
colors = ['#9b59b6', '#3498db', '#3498db', '#3498db', '#9b59b6', 
          '#27ae60', '#27ae60', '#27ae60', '#9b59b6']
for i, (token, color) in enumerate(zip(tokens, colors)):
    ax.text(1.5 + i*1.1, 5.5, token, fontsize=9, ha='center',
            bbox=dict(boxstyle='round,pad=0.15', facecolor=color, edgecolor='none', alpha=0.7),
            color='white')

# BERT
bert_box = FancyBboxPatch((1, 3.5), 10, 1.2, boxstyle="round,pad=0.05",
                           facecolor='#2c3e50', edgecolor='none', alpha=0.9)
ax.add_patch(bert_box)
ax.text(6, 4.1, 'BERT', fontsize=11, ha='center', va='center', color='white', fontweight='bold')

# Arrows down
for i in range(9):
    ax.annotate('', xy=(1.5 + i*1.1, 4.7), xytext=(1.5 + i*1.1, 5.2),
                arrowprops=dict(arrowstyle='->', color='#333', lw=0.8))

# Only [CLS] output is used for NSP
ax.annotate('', xy=(1.5, 2.7), xytext=(1.5, 3.5),
            arrowprops=dict(arrowstyle='->', color='#9b59b6', lw=2))

cls_hidden = FancyBboxPatch((0.5, 1.8), 2, 0.7, boxstyle="round,pad=0.05",
                             facecolor='#9b59b6', edgecolor='none', alpha=0.8)
ax.add_patch(cls_hidden)
ax.text(1.5, 2.15, '[CLS] hidden', fontsize=9, ha='center', va='center', color='white', fontweight='bold')

# Classifier
ax.annotate('', xy=(3.5, 2.15), xytext=(2.5, 2.15),
            arrowprops=dict(arrowstyle='->', color='#333', lw=1.5))

clf_box = FancyBboxPatch((3.5, 1.8), 2.5, 0.7, boxstyle="round,pad=0.05",
                          facecolor='#f39c12', edgecolor='none', alpha=0.8)
ax.add_patch(clf_box)
ax.text(4.75, 2.15, 'Linear (768→2)', fontsize=9, ha='center', va='center', color='white', fontweight='bold')

# Output
ax.annotate('', xy=(7, 2.15), xytext=(6, 2.15),
            arrowprops=dict(arrowstyle='->', color='#333', lw=1.5))

ax.text(8.5, 2.5, 'IsNext: 0.92', fontsize=10, ha='center', color='#27ae60', fontweight='bold')
ax.text(8.5, 1.8, 'NotNext: 0.08', fontsize=10, ha='center', color='#e74c3c')

# Note
ax.text(6, 0.7, '[CLS] aggregates information from entire sequence via self-attention', 
        fontsize=9, ha='center', style='italic', color='#666')

plt.tight_layout()
plt.show()

### Is NSP Actually Useful?

This is interesting: later work (RoBERTa, 2019) found that NSP doesn't help much. In some cases, removing it actually improves results.

Why did BERT include it?
- The authors believed sentence relationships mattered
- Tasks like QA and NLI involve sentence pairs

Why might it not help?
- The task is too easy (topic difference is obvious)
- MLM already captures enough

Modern BERT variants often drop NSP. But for understanding the original paper, it's important to know it was there.

---

## Training Data

BERT was pre-trained on:

| Dataset | Size |
|---------|------|
| BooksCorpus | 800M words |
| English Wikipedia | 2,500M words |
| **Total** | **3,300M words** |

### Why These Datasets?

**BooksCorpus:** Long, coherent text. Good for learning sentence relationships.

**Wikipedia:** Factual, diverse topics. Good for learning world knowledge.

Both are relatively clean text - no need for heavy preprocessing.

In [None]:
# Training data visualization
fig, ax = plt.subplots(figsize=(10, 5))

datasets = ['BooksCorpus', 'Wikipedia']
sizes = [800, 2500]
colors = ['#3498db', '#27ae60']

bars = ax.barh(datasets, sizes, color=colors, edgecolor='white', linewidth=2)
ax.set_xlabel('Words (millions)', fontsize=11)
ax.set_title('BERT Pre-training Data', fontsize=12, fontweight='bold')
ax.set_xlim(0, 3500)

for bar, size in zip(bars, sizes):
    ax.text(size + 50, bar.get_y() + bar.get_height()/2, 
            f'{size}M', va='center', fontsize=11, fontweight='bold')

ax.axvline(x=3300, color='#e74c3c', linestyle='--', lw=2)
ax.text(3350, 0.5, 'Total: 3.3B words', fontsize=10, color='#e74c3c', fontweight='bold')

ax.grid(True, alpha=0.3, axis='x')
plt.tight_layout()
plt.show()

---

## Training Details

### Compute Requirements

| Model | Hardware | Training Time |
|-------|----------|---------------|
| BERT-Base | 16 TPU chips | 4 days |
| BERT-Large | 64 TPU chips | 4 days |

This was expensive in 2018. Today you'd use pre-trained weights.

### Hyperparameters

```
Batch size: 256 sequences
Max sequence length: 512 tokens
Learning rate: 1e-4 (with warmup)
Optimizer: Adam (β1=0.9, β2=0.999)
Training steps: 1,000,000
Warmup steps: 10,000
```

### Sequence Length Trick

An interesting detail: BERT trains with shorter sequences first.

- First 90% of training: max_length = 128
- Last 10% of training: max_length = 512

Why? Shorter sequences are faster to process. This reduces training time significantly while still learning long-range dependencies at the end.

---

## The Complete Pre-training Loss

BERT's total loss is the sum of both tasks:

$$\mathcal{L}_{total} = \mathcal{L}_{MLM} + \mathcal{L}_{NSP}$$

Both are cross-entropy losses:
- MLM: averaged over masked positions
- NSP: single binary classification

In [None]:
# Combined loss visualization
fig, ax = plt.subplots(figsize=(12, 8))
ax.set_xlim(0, 12)
ax.set_ylim(0, 8)
ax.axis('off')

ax.text(6, 7.5, 'BERT Pre-training: Combined Loss', fontsize=14, ha='center', fontweight='bold')

# Input
input_box = FancyBboxPatch((2, 5.5), 8, 1, boxstyle="round,pad=0.05",
                            facecolor='#ecf0f1', edgecolor='#bdc3c7', linewidth=2)
ax.add_patch(input_box)
ax.text(6, 6, '[CLS] The [MASK] sat ... [SEP] He was ... [SEP]', fontsize=10, ha='center')

# BERT
bert_box = FancyBboxPatch((2, 3.5), 8, 1.2, boxstyle="round,pad=0.05",
                           facecolor='#2c3e50', edgecolor='none', alpha=0.9)
ax.add_patch(bert_box)
ax.text(6, 4.1, 'BERT', fontsize=12, ha='center', va='center', color='white', fontweight='bold')

ax.annotate('', xy=(6, 4.7), xytext=(6, 5.5),
            arrowprops=dict(arrowstyle='->', color='#333', lw=2))

# Two loss branches
# MLM
mlm_box = FancyBboxPatch((0.5, 1.5), 4.5, 1.2, boxstyle="round,pad=0.05",
                          facecolor='#e74c3c', edgecolor='none', alpha=0.8)
ax.add_patch(mlm_box)
ax.text(2.75, 2.1, 'MLM Loss\nPredict masked tokens', fontsize=9, ha='center', va='center',
        color='white', fontweight='bold')

ax.annotate('', xy=(2.75, 2.7), xytext=(4.5, 3.5),
            arrowprops=dict(arrowstyle='->', color='#333', lw=1.5))

# NSP
nsp_box = FancyBboxPatch((7, 1.5), 4.5, 1.2, boxstyle="round,pad=0.05",
                          facecolor='#3498db', edgecolor='none', alpha=0.8)
ax.add_patch(nsp_box)
ax.text(9.25, 2.1, 'NSP Loss\nPredict if B follows A', fontsize=9, ha='center', va='center',
        color='white', fontweight='bold')

ax.annotate('', xy=(9.25, 2.7), xytext=(7.5, 3.5),
            arrowprops=dict(arrowstyle='->', color='#333', lw=1.5))

# Total loss
ax.text(6, 0.5, 'Total Loss = L_MLM + L_NSP', fontsize=12, ha='center', fontweight='bold',
        bbox=dict(boxstyle='round', facecolor='#f39c12', edgecolor='none', alpha=0.8),
        color='white')

ax.annotate('', xy=(4, 0.7), xytext=(2.75, 1.5),
            arrowprops=dict(arrowstyle='->', color='#333', lw=1))
ax.annotate('', xy=(8, 0.7), xytext=(9.25, 1.5),
            arrowprops=dict(arrowstyle='->', color='#333', lw=1))

plt.tight_layout()
plt.show()

---

## Summary: Pre-training

| Aspect | Details |
|--------|--------|
| **Task 1: MLM** | Predict 15% masked tokens |
| Masking strategy | 80% [MASK], 10% random, 10% unchanged |
| **Task 2: NSP** | Binary: does B follow A? |
| Training data | 50% IsNext, 50% NotNext |
| **Data** | Wikipedia + BooksCorpus (3.3B words) |
| **Compute** | 4 days on 16-64 TPU chips |

The key insight: these tasks are **self-supervised**. No human labels needed - the supervision comes from the text itself.

---

## What's Next: Part 4

Pre-training gives us a general language understanding model. In Part 4, we'll see how to use it:

- **Fine-tuning** on specific tasks
- Different task types (classification, NER, QA)
- How to add task-specific heads

---

*Paper:* [BERT: Pre-training of Deep Bidirectional Transformers](https://arxiv.org/abs/1810.04805)