# BERT Fine-Tuning for Downstream Tasks

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adiel2012/deep-learning-abc/blob/main/bert_fine_tuning.ipynb)

In this notebook, we implement **BERT fine-tuning** from scratch for three common downstream tasks:

1. **Sequence Classification** (e.g., Sentiment Analysis)
2. **Token Classification** (e.g., Named Entity Recognition)
3. **Question Answering** (e.g., SQuAD)

We use a pre-trained `BertModel` (from Hugging Face for convenience, but the architecture matches our scratch implementation) and add custom heads.

In [None]:
!pip install torch transformers datasets scikit-learn matplotlib

In [None]:
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Load pre-trained BERT base (uncased)
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
bert = BertModel.from_pretrained(model_name).to(device)

## 1. Sequence Classification (e.g., Sentiment Analysis)

**Goal:** Classify an entire sentence (e.g., Positive vs. Negative).
**Method:** Use the `[CLS]` token embedding (first token) as the sentence representation.

Head: `Dropout` → `Linear(d_model, num_classes)`

In [None]:
class BertForSequenceClassification(nn.Module):
    def __init__(self, bert_model, num_classes):
        super().__init__()
        self.bert = bert_model
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(768, num_classes)
        
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        # Pooler output is usually the [CLS] token processed by a linear layer + Tanh
        # We'll use the raw [CLS] hidden state for transparency:
        cls_token = outputs.last_hidden_state[:, 0, :]  # [Batch, 768]
        
        x = self.dropout(cls_token)
        logits = self.classifier(x)
        return logits

# Demo
model_seq = BertForSequenceClassification(bert, num_classes=2).to(device)
inputs = tokenizer(["I love this movie!", "This film was terrible."], 
                   return_tensors="pt", padding=True, truncation=True).to(device)

with torch.no_grad():
    logits = model_seq(inputs.input_ids, inputs.attention_mask)
    probs = torch.softmax(logits, dim=-1)

print("Sequence Classification (Sentiment):")
print(f"  'I love this movie!'    -> {probs[0].cpu().numpy()}")
print(f"  'This film was terrible.' -> {probs[1].cpu().numpy()}")

## 2. Token Classification (e.g., NER)

**Goal:** Classify each token in the sequence (e.g., Person, Org, Loc, O).
**Method:** Apply a classifier to **every** token's embedding.

Head: `Dropout` → `Linear(d_model, num_classes)` (applied per token)

In [None]:
class BertForTokenClassification(nn.Module):
    def __init__(self, bert_model, num_classes):
        super().__init__()
        self.bert = bert_model
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(768, num_classes)
        
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state  # [Batch, SeqLen, 768]
        
        x = self.dropout(sequence_output)
        logits = self.classifier(x)  # [Batch, SeqLen, NumClasses]
        return logits

# Demo: Named Entity Recognition (3 classes: O, B-PER, I-PER)
model_token = BertForTokenClassification(bert, num_classes=3).to(device)
text = "Hugging Face is based in New York."
inputs = tokenizer(text, return_tensors="pt").to(device)

with torch.no_grad():
    logits = model_token(inputs.input_ids, inputs.attention_mask)
    preds = torch.argmax(logits, dim=-1)[0]

tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
print("\nToken Classification (NER):")
for token, pred in zip(tokens, preds.cpu().numpy()):
    print(f"  {token:<12} -> Class {pred}")

## 3. Question Answering (e.g., SQuAD)

**Goal:** Find the *answer span* (start and end indices) in the text.
**Method:** Predict `start` and `end` scores for every token.

Head: `Linear(d_model, 2)` → Splits into `start_logits` and `end_logits`.

In [None]:
class BertForQuestionAnswering(nn.Module):
    def __init__(self, bert_model):
        super().__init__()
        self.bert = bert_model
        # Output 2 logits per token: start_score, end_score
        self.qa_outputs = nn.Linear(768, 2)
        
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state
        
        logits = self.qa_outputs(sequence_output)  # [Batch, SeqLen, 2]
        start_logits, end_logits = logits.split(1, dim=-1)
        
        return start_logits.squeeze(-1), end_logits.squeeze(-1)

# Demo: Question Answering
model_qa = BertForQuestionAnswering(bert).to(device)

question = "Where do I live?"
context = "My name is Sarah and I live in London."
inputs = tokenizer(question, context, return_tensors="pt").to(device)

with torch.no_grad():
    start_logits, end_logits = model_qa(inputs.input_ids, inputs.attention_mask)
    start_idx = torch.argmax(start_logits)
    end_idx = torch.argmax(end_logits)

all_tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
answer = tokenizer.decode(inputs.input_ids[0][start_idx : end_idx + 1])

print("\nQuestion Answering:")
print(f"  Question: {question}")
print(f"  Context:  {context}")
print(f"  Answer Span: tokens[{start_idx}:{end_idx+1}]")
print(f"  Predicted Answer: '{answer}'")

## Visualizing Fine-Tuning Architectures

Here is a comparison of how the same pre-trained BERT body is adapted for different tasks.

In [None]:
print("Architecture Summary Table")
print("="*60)
print(f"{'Task':<25} {'Input':<20} {'Output Head'}")
print("-"*60)
print(f"{'Sequence Classification':<25} {'[CLS] token':<20} {'Linear(768, K)'}")
print(f"{'Token Classification':<25} {'All tokens':<20} {'Linear(768, K) per token'}")
print(f"{'Question Answering':<25} {'All tokens':<20} {'Linear(768, 2) per token'}")
print(f"{'Multiple Choice':<25} {'[CLS] per choice':<20} {'Linear(768, 1)'}")
print("="*60)