# BERT: Part 4

## Fine-tuning - Adapting BERT to Your Task

---

**Paper:** [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805)

---

Pre-training gives us a model that "understands" language. Fine-tuning teaches it to do something useful with that understanding.

The beauty of BERT: fine-tuning is simple. Same architecture, just add a small output layer.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import FancyBboxPatch
np.random.seed(42)

---

## The Fine-tuning Recipe

1. **Start** with pre-trained BERT weights
2. **Add** a task-specific output layer (usually just one linear layer)
3. **Train** on your labeled data for a few epochs
4. **Done**

That's it. The whole BERT model gets updated during fine-tuning, but because it starts from good weights, you don't need much data or time.

In [None]:
# Fine-tuning overview
fig, ax = plt.subplots(figsize=(14, 8))
ax.set_xlim(0, 14)
ax.set_ylim(0, 8)
ax.axis('off')

ax.text(7, 7.5, 'Fine-tuning: Add Output Layer + Train', fontsize=14, ha='center', fontweight='bold')

# Pre-trained BERT
bert_box = FancyBboxPatch((3, 2.5), 8, 3.5, boxstyle="round,pad=0.1",
                           facecolor='#3498db', edgecolor='#2980b9', linewidth=3, alpha=0.8)
ax.add_patch(bert_box)
ax.text(7, 4.25, 'Pre-trained BERT\n(110M or 340M parameters)\n\nAlready knows language!', 
        fontsize=11, ha='center', va='center', color='white', fontweight='bold')

# Task-specific head
head_box = FancyBboxPatch((5, 6.2), 4, 0.8, boxstyle="round,pad=0.05",
                           facecolor='#e74c3c', edgecolor='none', alpha=0.9)
ax.add_patch(head_box)
ax.text(7, 6.6, 'Task Head (new)', fontsize=10, ha='center', va='center', 
        color='white', fontweight='bold')

# Arrow
ax.annotate('', xy=(7, 6.2), xytext=(7, 6),
            arrowprops=dict(arrowstyle='->', color='#333', lw=2))

# Input
ax.text(7, 1.8, 'Your task data', fontsize=10, ha='center')
ax.annotate('', xy=(7, 2.5), xytext=(7, 2.1),
            arrowprops=dict(arrowstyle='->', color='#333', lw=2))

# Details on the side
details = [
    'Fine-tuning:',
    '• Learning rate: 2e-5 to 5e-5',
    '• Epochs: 2-4',
    '• Batch size: 16-32',
    '• Time: minutes to hours',
]
for i, line in enumerate(details):
    weight = 'bold' if i == 0 else 'normal'
    ax.text(0.5, 5 - i*0.5, line, fontsize=9, fontweight=weight)

ax.text(7, 0.8, 'Everything gets updated - both BERT and the new head', 
        fontsize=10, ha='center', style='italic', color='#666')

plt.tight_layout()
plt.show()

---

## Task Type 1: Sentence Classification

**Examples:** Sentiment analysis, spam detection, topic classification

**How it works:**
1. Feed sentence through BERT
2. Take the [CLS] hidden state
3. Pass through a linear layer → class probabilities

In [None]:
# Sentence classification
fig, ax = plt.subplots(figsize=(14, 8))
ax.set_xlim(0, 14)
ax.set_ylim(0, 8)
ax.axis('off')

ax.text(7, 7.5, 'Sentence Classification (e.g., Sentiment)', fontsize=13, ha='center', fontweight='bold')

# Input
tokens = ['[CLS]', 'This', 'movie', 'was', 'great', '!', '[SEP]']
for i, token in enumerate(tokens):
    color = '#9b59b6' if token in ['[CLS]', '[SEP]'] else '#3498db'
    ax.text(2 + i*1.5, 6, token, fontsize=10, ha='center',
            bbox=dict(boxstyle='round,pad=0.2', facecolor=color, edgecolor='none', alpha=0.7),
            color='white')

# BERT
bert_box = FancyBboxPatch((1, 3.5), 11, 1.5, boxstyle="round,pad=0.05",
                           facecolor='#2c3e50', edgecolor='none', alpha=0.9)
ax.add_patch(bert_box)
ax.text(6.5, 4.25, 'BERT', fontsize=12, ha='center', va='center', color='white', fontweight='bold')

# Arrows down
for i in range(7):
    ax.annotate('', xy=(2 + i*1.5, 5), xytext=(2 + i*1.5, 5.6),
                arrowprops=dict(arrowstyle='->', color='#333', lw=1))

# Only [CLS] output used
ax.annotate('', xy=(2, 2.8), xytext=(2, 3.5),
            arrowprops=dict(arrowstyle='->', color='#9b59b6', lw=3))

# Classifier
clf_box = FancyBboxPatch((0.5, 1.8), 3, 0.8, boxstyle="round,pad=0.05",
                          facecolor='#e74c3c', edgecolor='none', alpha=0.9)
ax.add_patch(clf_box)
ax.text(2, 2.2, 'Linear (768→2)', fontsize=10, ha='center', va='center',
        color='white', fontweight='bold')

# Output
ax.annotate('', xy=(4.5, 2.2), xytext=(3.5, 2.2),
            arrowprops=dict(arrowstyle='->', color='#333', lw=2))

ax.text(6, 2.5, 'Positive: 0.94', fontsize=11, ha='center', color='#27ae60', fontweight='bold')
ax.text(6, 1.8, 'Negative: 0.06', fontsize=11, ha='center', color='#e74c3c')

# Note
ax.text(9, 3, 'Other token outputs\nare ignored', fontsize=9, ha='center', color='#999')

ax.text(7, 0.8, '[CLS] token aggregates sentence meaning through self-attention', 
        fontsize=10, ha='center', style='italic', color='#666')

plt.tight_layout()
plt.show()

---

## Task Type 2: Sentence Pair Classification

**Examples:** Natural language inference (NLI), paraphrase detection, textual similarity

**How it works:**
1. Concatenate both sentences with [SEP] between them
2. Feed through BERT
3. Take [CLS] → linear layer → class

In [None]:
# Sentence pair classification (NLI example)
fig, ax = plt.subplots(figsize=(14, 8))
ax.set_xlim(0, 14)
ax.set_ylim(0, 8)
ax.axis('off')

ax.text(7, 7.5, 'Sentence Pair Classification (NLI)', fontsize=13, ha='center', fontweight='bold')

# Premise and hypothesis
ax.text(3.5, 6.5, 'Premise: "A man is playing guitar"', fontsize=10, ha='center',
        bbox=dict(boxstyle='round,pad=0.3', facecolor='#3498db', edgecolor='none', alpha=0.6))
ax.text(10.5, 6.5, 'Hypothesis: "Music is being played"', fontsize=10, ha='center',
        bbox=dict(boxstyle='round,pad=0.3', facecolor='#27ae60', edgecolor='none', alpha=0.6))

# Combined input
ax.text(7, 5.5, '[CLS] A man is playing guitar [SEP] Music is being played [SEP]', 
        fontsize=9, ha='center',
        bbox=dict(boxstyle='round,pad=0.3', facecolor='#ecf0f1', edgecolor='#bdc3c7'))

# BERT
bert_box = FancyBboxPatch((2, 3), 10, 1.5, boxstyle="round,pad=0.05",
                           facecolor='#2c3e50', edgecolor='none', alpha=0.9)
ax.add_patch(bert_box)
ax.text(7, 3.75, 'BERT', fontsize=12, ha='center', va='center', color='white', fontweight='bold')

ax.annotate('', xy=(7, 4.5), xytext=(7, 5.1),
            arrowprops=dict(arrowstyle='->', color='#333', lw=2))

# [CLS] output
ax.annotate('', xy=(3, 2.2), xytext=(3, 3),
            arrowprops=dict(arrowstyle='->', color='#9b59b6', lw=2))

# Classifier
clf_box = FancyBboxPatch((1.5, 1.3), 3, 0.8, boxstyle="round,pad=0.05",
                          facecolor='#e74c3c', edgecolor='none', alpha=0.9)
ax.add_patch(clf_box)
ax.text(3, 1.7, 'Linear (768→3)', fontsize=10, ha='center', va='center',
        color='white', fontweight='bold')

# Output classes
ax.annotate('', xy=(5.5, 1.7), xytext=(4.5, 1.7),
            arrowprops=dict(arrowstyle='->', color='#333', lw=2))

ax.text(8, 2.1, 'Entailment: 0.89', fontsize=10, ha='left', color='#27ae60', fontweight='bold')
ax.text(8, 1.6, 'Neutral: 0.08', fontsize=10, ha='left', color='#f39c12')
ax.text(8, 1.1, 'Contradiction: 0.03', fontsize=10, ha='left', color='#e74c3c')

ax.text(7, 0.4, 'NLI: Does hypothesis follow from premise?', fontsize=10, ha='center', style='italic', color='#666')

plt.tight_layout()
plt.show()

---

## Task Type 3: Token Classification (NER, POS Tagging)

**Examples:** Named Entity Recognition, Part-of-Speech tagging

**How it works:**
1. Feed sentence through BERT
2. Take hidden state for **each** token
3. Each token → linear layer → tag

Unlike classification, we use ALL token outputs, not just [CLS].

In [None]:
# Token classification (NER)
fig, ax = plt.subplots(figsize=(14, 9))
ax.set_xlim(0, 14)
ax.set_ylim(0, 9)
ax.axis('off')

ax.text(7, 8.5, 'Token Classification (Named Entity Recognition)', fontsize=13, ha='center', fontweight='bold')

# Input tokens
tokens = ['[CLS]', 'John', 'works', 'at', 'Google', 'in', 'NYC', '[SEP]']
labels = ['', 'B-PER', 'O', 'O', 'B-ORG', 'O', 'B-LOC', '']
x_positions = [1, 2.5, 4, 5.5, 7, 8.5, 10, 11.5]

for x, token in zip(x_positions, tokens):
    color = '#9b59b6' if token in ['[CLS]', '[SEP]'] else '#3498db'
    ax.text(x, 7, token, fontsize=10, ha='center',
            bbox=dict(boxstyle='round,pad=0.2', facecolor=color, edgecolor='none', alpha=0.7),
            color='white')

# BERT
bert_box = FancyBboxPatch((0.5, 4.5), 12, 1.5, boxstyle="round,pad=0.05",
                           facecolor='#2c3e50', edgecolor='none', alpha=0.9)
ax.add_patch(bert_box)
ax.text(6.5, 5.25, 'BERT', fontsize=12, ha='center', va='center', color='white', fontweight='bold')

# Arrows down to BERT
for x in x_positions:
    ax.annotate('', xy=(x, 6), xytext=(x, 6.5),
                arrowprops=dict(arrowstyle='->', color='#333', lw=1))

# Hidden states
for x in x_positions:
    ax.annotate('', xy=(x, 3.7), xytext=(x, 4.5),
                arrowprops=dict(arrowstyle='->', color='#333', lw=1))

# Linear layers (one per token, but same weights)
for x in x_positions[1:-1]:  # Skip [CLS] and [SEP]
    clf_box = FancyBboxPatch((x-0.5, 3), 1, 0.6, boxstyle="round,pad=0.02",
                              facecolor='#e74c3c', edgecolor='none', alpha=0.8)
    ax.add_patch(clf_box)

ax.text(6.5, 3.3, 'Linear (768 → num_tags)', fontsize=9, ha='center', color='white', fontweight='bold')

# Output labels
colors_map = {'B-PER': '#e74c3c', 'B-ORG': '#3498db', 'B-LOC': '#27ae60', 'O': '#95a5a6'}
for x, label in zip(x_positions[1:-1], labels[1:-1]):
    color = colors_map.get(label, '#95a5a6')
    ax.text(x, 2, label, fontsize=10, ha='center', fontweight='bold',
            bbox=dict(boxstyle='round,pad=0.2', facecolor=color, edgecolor='none', alpha=0.8),
            color='white')
    ax.annotate('', xy=(x, 2.4), xytext=(x, 3),
                arrowprops=dict(arrowstyle='->', color='#333', lw=1))

# Legend
ax.text(7, 1, 'B-PER = Person, B-ORG = Organization, B-LOC = Location, O = Other', 
        fontsize=9, ha='center', color='#666')

ax.text(7, 0.4, 'Same linear layer applied to each token (shared weights)', 
        fontsize=9, ha='center', style='italic', color='#666')

plt.tight_layout()
plt.show()

---

## Task Type 4: Question Answering (Extractive)

**Examples:** SQuAD, reading comprehension

**How it works:**
1. Input: [CLS] question [SEP] context [SEP]
2. Predict **start** and **end** positions of answer in context
3. Two linear layers: one for start, one for end

This is extractive QA - the answer is a span from the context.

In [None]:
# Question answering
fig, ax = plt.subplots(figsize=(14, 10))
ax.set_xlim(0, 14)
ax.set_ylim(0, 10)
ax.axis('off')

ax.text(7, 9.5, 'Extractive Question Answering (SQuAD)', fontsize=13, ha='center', fontweight='bold')

# Question and context
ax.text(7, 8.5, 'Question: "Where is the Eiffel Tower?"', fontsize=10, ha='center',
        bbox=dict(boxstyle='round,pad=0.3', facecolor='#3498db', edgecolor='none', alpha=0.6),
        color='white')
ax.text(7, 7.7, 'Context: "The Eiffel Tower is located in Paris, France."', fontsize=10, ha='center',
        bbox=dict(boxstyle='round,pad=0.3', facecolor='#27ae60', edgecolor='none', alpha=0.6),
        color='white')

# Input format
ax.text(7, 6.7, '[CLS] Where is the Eiffel Tower ? [SEP] The Eiffel Tower is located in Paris , France . [SEP]', 
        fontsize=8, ha='center',
        bbox=dict(boxstyle='round,pad=0.2', facecolor='#ecf0f1', edgecolor='#bdc3c7'))

# BERT
bert_box = FancyBboxPatch((1, 4), 12, 1.8, boxstyle="round,pad=0.05",
                           facecolor='#2c3e50', edgecolor='none', alpha=0.9)
ax.add_patch(bert_box)
ax.text(7, 4.9, 'BERT', fontsize=12, ha='center', va='center', color='white', fontweight='bold')

ax.annotate('', xy=(7, 5.8), xytext=(7, 6.3),
            arrowprops=dict(arrowstyle='->', color='#333', lw=2))

# Token outputs
context_tokens = ['The', 'Eiffel', 'Tower', 'is', 'located', 'in', 'Paris', ',', 'France', '.']
x_positions = np.linspace(2, 12, len(context_tokens))

for i, (x, token) in enumerate(zip(x_positions, context_tokens)):
    # Highlight answer span
    if token in ['Paris', ',', 'France']:
        color = '#f39c12'
    else:
        color = '#95a5a6'
    ax.text(x, 3.3, token, fontsize=9, ha='center',
            bbox=dict(boxstyle='round,pad=0.15', facecolor=color, edgecolor='none', alpha=0.7))

# Start and end predictions
start_box = FancyBboxPatch((1.5, 1.5), 4, 1, boxstyle="round,pad=0.05",
                            facecolor='#e74c3c', edgecolor='none', alpha=0.8)
ax.add_patch(start_box)
ax.text(3.5, 2, 'Start prediction\nPosition 6 ("Paris")', fontsize=9, ha='center', va='center',
        color='white', fontweight='bold')

end_box = FancyBboxPatch((8.5, 1.5), 4, 1, boxstyle="round,pad=0.05",
                          facecolor='#9b59b6', edgecolor='none', alpha=0.8)
ax.add_patch(end_box)
ax.text(10.5, 2, 'End prediction\nPosition 8 ("France")', fontsize=9, ha='center', va='center',
        color='white', fontweight='bold')

# Arrows
ax.annotate('', xy=(3.5, 2.5), xytext=(x_positions[6], 3),
            arrowprops=dict(arrowstyle='->', color='#e74c3c', lw=2))
ax.annotate('', xy=(10.5, 2.5), xytext=(x_positions[8], 3),
            arrowprops=dict(arrowstyle='->', color='#9b59b6', lw=2))

# Answer
ax.text(7, 0.7, 'Answer: "Paris, France" (span from start to end)', fontsize=11, ha='center',
        fontweight='bold', bbox=dict(boxstyle='round', facecolor='#f39c12', edgecolor='none'),
        color='white')

plt.tight_layout()
plt.show()

---

## Fine-tuning Hyperparameters

The paper recommends:

| Hyperparameter | Recommended Values |
|----------------|--------------------|
| Batch size | 16, 32 |
| Learning rate | 2e-5, 3e-5, 5e-5 |
| Epochs | 2, 3, 4 |

### Why Such Small Learning Rates?

BERT is already well-trained. We want to **adapt** it, not **destroy** what it learned.

Large learning rates would overwrite the pre-trained knowledge. Small learning rates make small adjustments.

### Why So Few Epochs?

With pre-trained weights, the model converges quickly. Too many epochs leads to overfitting on small datasets.

In [None]:
# Fine-tuning code example
print("Fine-tuning Example (PyTorch-style pseudocode)")
print("=" * 50)

code = '''
from transformers import BertForSequenceClassification, AdamW

# Load pre-trained BERT with classification head
model = BertForSequenceClassification.from_pretrained(
    'bert-base-uncased',
    num_labels=2  # Binary classification
)

# Optimizer with small learning rate
optimizer = AdamW(model.parameters(), lr=2e-5)

# Fine-tuning loop
for epoch in range(3):  # Just 3 epochs!
    for batch in train_dataloader:
        # Forward pass
        outputs = model(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            labels=batch['labels']
        )
        
        loss = outputs.loss
        
        # Backward pass
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
'''

print(code)

---

## Results: Fine-tuning Performance

The paper showed fine-tuning works incredibly well:

In [None]:
# Results comparison
fig, ax = plt.subplots(figsize=(12, 6))

tasks = ['MNLI\n(NLI)', 'SST-2\n(Sentiment)', 'CoLA\n(Grammar)', 'SQuAD\n(QA)']
bert_scores = [86.7, 94.9, 60.5, 90.9]
previous_sota = [80.6, 93.2, 35.0, 84.1]

x = np.arange(len(tasks))
width = 0.35

bars1 = ax.bar(x - width/2, previous_sota, width, label='Previous SOTA', color='#bdc3c7')
bars2 = ax.bar(x + width/2, bert_scores, width, label='BERT-Large', color='#e74c3c')

ax.set_ylabel('Score', fontsize=11)
ax.set_title('Fine-tuning Results: BERT vs Previous State-of-the-Art', fontsize=12, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(tasks, fontsize=10)
ax.legend(fontsize=10)
ax.set_ylim(0, 100)
ax.grid(True, alpha=0.3, axis='y')

# Add improvement annotations
improvements = ['+6.1', '+1.7', '+25.5', '+6.8']
for i, (imp, b1, b2) in enumerate(zip(improvements, previous_sota, bert_scores)):
    ax.annotate(imp, xy=(i + width/2, b2 + 2), ha='center', fontsize=9, 
                fontweight='bold', color='#27ae60')

plt.tight_layout()
plt.show()

The improvements are massive:
- **CoLA**: +25.5 points (grammaticality judgment)
- **SQuAD**: +6.8 points (question answering)
- **MNLI**: +6.1 points (natural language inference)

All from the same pre-trained model, just different output heads.

---

## Summary: Fine-tuning Patterns

| Task Type | Output Used | Output Layer |
|-----------|-------------|-------------|
| Sentence classification | [CLS] | Linear → num_classes |
| Sentence pair classification | [CLS] | Linear → num_classes |
| Token classification | All tokens | Linear → num_tags |
| Question answering | Context tokens | Linear → 2 (start/end) |

The pattern is always the same:
1. Feed input through BERT
2. Take relevant hidden states
3. Simple linear layer(s)
4. Train for a few epochs

---

## What's Next: Part 5

Now we understand the theory. In Part 5, we'll:

- **Implement BERT from scratch**
- Load pre-trained weights
- Fine-tune on a real task
- See it work

---

*Paper:* [BERT: Pre-training of Deep Bidirectional Transformers](https://arxiv.org/abs/1810.04805)