# Session 9: BERT Fine-tuning and Classification Tasks

## 📚 Learning Objectives
By the end of this session, you will be able to:
- Understand BERT architecture and pre-training objectives
- Fine-tune BERT for text classification tasks
- Implement custom datasets and data loaders
- Apply transfer learning principles to NLP
- Evaluate fine-tuned models and interpret results
- Optimize training parameters and avoid overfitting

## 🎯 Session Overview
1. **BERT Deep Dive** - Architecture, pre-training, and variants
2. **Fine-tuning Setup** - Data preparation and model configuration
3. **Classification Tasks** - Sentiment analysis and text classification
4. **Training Pipeline** - Loss functions, optimizers, and training loops
5. **Model Evaluation** - Metrics, validation, and error analysis
6. **Advanced Techniques** - Learning rate scheduling and regularization

---

In [None]:
# Required imports for BERT fine-tuning
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

# Transformers and datasets
try:
    from transformers import (
        BertTokenizer, BertForSequenceClassification,
        AdamW, get_linear_schedule_with_warmup,
        TrainingArguments, Trainer
    )
    transformers_available = True
    print("✅ Transformers library loaded successfully")
except ImportError:
    transformers_available = False
    print("❌ Transformers library not available")

try:
    from datasets import Dataset as HFDataset
    datasets_available = True
    print("✅ Datasets library available")
except ImportError:
    datasets_available = False
    print("❌ Datasets library not available")

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🔧 Using device: {device}")

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)

print("📦 Setup completed!")

## Section 1: Understanding BERT Architecture

### 1.1 BERT Key Concepts

**Pre-training Objectives:**
1. **Masked Language Modeling (MLM)**: Predict masked tokens
2. **Next Sentence Prediction (NSP)**: Determine if two sentences are consecutive

**Architecture Features:**
- **Bidirectional**: Processes text in both directions
- **Transformer Encoder**: Stack of 12 (base) or 24 (large) layers
- **Multi-Head Attention**: 12 (base) or 16 (large) attention heads
- **Hidden Size**: 768 (base) or 1024 (large) dimensions

### 1.2 Fine-tuning Approach
- Add task-specific layers on top of BERT
- Fine-tune all parameters end-to-end
- Use lower learning rates than pre-training

In [None]:
# Custom Dataset Class for Text Classification

class TextClassificationDataset(Dataset):
    """Custom dataset for text classification with BERT"""
    
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        
        # Tokenize text
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

def create_sample_dataset():
    """Create a sample sentiment analysis dataset"""
    
    # Sample data for sentiment classification
    sample_data = [
        ("I love this product! It's amazing.", 1),  # Positive
        ("This is the worst experience ever.", 0),  # Negative
        ("Great quality and fast delivery.", 1),    # Positive
        ("Terrible customer service.", 0),          # Negative
        ("Absolutely fantastic! Highly recommend.", 1),  # Positive
        ("Poor quality, not worth the money.", 0),  # Negative
        ("Excellent features and easy to use.", 1), # Positive
        ("Disappointing and overpriced.", 0),       # Negative
        ("Outstanding performance and design.", 1), # Positive
        ("Completely broken upon arrival.", 0),     # Negative
        ("This product exceeded my expectations.", 1),  # Positive
        ("Not what I ordered, very frustrated.", 0),    # Negative
        ("Perfect for my needs, very satisfied.", 1),   # Positive
        ("Waste of money, would not recommend.", 0),    # Negative
        ("Impressive build quality and features.", 1),  # Positive
        ("Cheaply made and unreliable.", 0),            # Negative
        ("Best purchase I've made this year.", 1),      # Positive
        ("Returned it immediately, poor quality.", 0),  # Negative
        ("Works perfectly, exactly as described.", 1),  # Positive
        ("Defective product, very disappointing.", 0),  # Negative
    ]
    
    texts, labels = zip(*sample_data)
    return list(texts), list(labels)

def prepare_data_loaders(texts, labels, tokenizer, test_size=0.2, batch_size=8):
    """Prepare train and test data loaders"""
    
    # Split data
    train_texts, test_texts, train_labels, test_labels = train_test_split(
        texts, labels, test_size=test_size, random_state=42, stratify=labels
    )
    
    print(f"Training samples: {len(train_texts)}")
    print(f"Testing samples: {len(test_texts)}")
    
    # Create datasets
    train_dataset = TextClassificationDataset(train_texts, train_labels, tokenizer)
    test_dataset = TextClassificationDataset(test_texts, test_labels, tokenizer)
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, test_loader, train_dataset, test_dataset

# Training function for BERT fine-tuning
def train_bert_model(model, train_loader, test_loader, num_epochs=3, learning_rate=2e-5):
    """Train BERT model for classification"""
    
    # Setup optimizer and scheduler
    optimizer = AdamW(model.parameters(), lr=learning_rate)
    
    total_steps = len(train_loader) * num_epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,
        num_training_steps=total_steps
    )
    
    # Training history
    train_losses = []
    train_accuracies = []
    
    model.train()
    
    for epoch in range(num_epochs):
        total_loss = 0
        correct_predictions = 0
        total_predictions = 0
        
        print(f"\\nEpoch {epoch + 1}/{num_epochs}")
        print("-" * 30)
        
        for batch_idx, batch in enumerate(train_loader):
            # Move batch to device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            # Forward pass
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            
            loss = outputs.loss
            logits = outputs.logits
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            
            # Calculate accuracy
            predictions = torch.argmax(logits, dim=1)
            correct_predictions += (predictions == labels).sum().item()
            total_predictions += labels.size(0)
            total_loss += loss.item()
            
            if batch_idx % 2 == 0:  # Print every 2 batches
                print(f"  Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}")
        
        # Calculate epoch metrics
        avg_loss = total_loss / len(train_loader)
        accuracy = correct_predictions / total_predictions
        
        train_losses.append(avg_loss)
        train_accuracies.append(accuracy)
        
        print(f"  Avg Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")
        
        # Evaluate on test set
        test_accuracy = evaluate_model(model, test_loader)
        print(f"  Test Accuracy: {test_accuracy:.4f}")
    
    return train_losses, train_accuracies

def evaluate_model(model, test_loader):
    """Evaluate model performance"""
    model.eval()
    
    correct_predictions = 0
    total_predictions = 0
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            
            predictions = torch.argmax(logits, dim=1)
            correct_predictions += (predictions == labels).sum().item()
            total_predictions += labels.size(0)
            
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    accuracy = correct_predictions / total_predictions
    
    # Detailed evaluation
    if len(set(all_labels)) > 1:  # Multi-class classification
        print("\\n📊 Detailed Evaluation:")
        print(classification_report(all_labels, all_predictions, 
                                  target_names=['Negative', 'Positive']))
        
        # Confusion matrix
        cm = confusion_matrix(all_labels, all_predictions)
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                   xticklabels=['Negative', 'Positive'],
                   yticklabels=['Negative', 'Positive'])
        plt.title('Confusion Matrix')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.show()
    
    model.train()  # Set back to training mode
    return accuracy

# Main fine-tuning demonstration
if transformers_available:
    print("🚀 BERT FINE-TUNING DEMONSTRATION")
    print("=" * 60)
    
    # Load pre-trained BERT model and tokenizer
    model_name = "bert-base-uncased"
    tokenizer = BertTokenizer.from_pretrained(model_name)
    model = BertForSequenceClassification.from_pretrained(
        model_name,
        num_labels=2  # Binary classification (positive/negative)
    )
    model.to(device)
    
    print(f"✅ Loaded BERT model: {model_name}")
    print(f"   Number of parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Create sample dataset
    texts, labels = create_sample_dataset()
    print(f"\\n📊 Dataset Statistics:")
    print(f"   Total samples: {len(texts)}")
    print(f"   Positive samples: {sum(labels)}")
    print(f"   Negative samples: {len(labels) - sum(labels)}")
    
    # Prepare data loaders
    train_loader, test_loader, train_dataset, test_dataset = prepare_data_loaders(
        texts, labels, tokenizer, batch_size=4  # Small batch size for demo
    )
    
    # Show sample tokenization
    print(f"\\n🔤 Sample Tokenization:")
    sample_text = texts[0]
    tokens = tokenizer.tokenize(sample_text)
    token_ids = tokenizer.convert_tokens_to_ids(tokens)
    
    print(f"Original text: {sample_text}")
    print(f"Tokens: {tokens}")
    print(f"Token IDs: {token_ids}")
    
    # Train the model
    print(f"\\n🎯 Starting Fine-tuning:")
    train_losses, train_accuracies = train_bert_model(
        model, train_loader, test_loader, 
        num_epochs=2,  # Small number for demo
        learning_rate=2e-5
    )
    
    # Plot training progress
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    ax1.plot(train_losses, 'b-', label='Training Loss')
    ax1.set_title('Training Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    
    ax2.plot(train_accuracies, 'r-', label='Training Accuracy')
    ax2.set_title('Training Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.legend()
    
    plt.tight_layout()
    plt.show()
    
    # Final evaluation
    print(f"\\n🏁 Final Model Evaluation:")
    final_accuracy = evaluate_model(model, test_loader)
    print(f"Final Test Accuracy: {final_accuracy:.4f}")

else:
    print("❌ Transformers library not available for BERT fine-tuning demo")