# Solutions: Lab 2.3.5 - BERT Fine-tuning

This notebook contains solutions to the exercises from notebook 05.

---

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW  # Use PyTorch's AdamW (not deprecated transformers.AdamW)

try:
    from transformers import BertModel, BertTokenizer, get_linear_schedule_with_warmup
    HAS_TRANSFORMERS = True
except ImportError:
    HAS_TRANSFORMERS = False
    print("Please install transformers: pip install transformers")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

## Exercise 1: Multi-class Classification

**Task:** Modify the binary classifier to handle multiple classes (e.g., emotion detection).

In [None]:
class BertMultiClassifier(nn.Module):
    """
    BERT for multi-class classification.
    
    Modifications from binary classifier:
    - Output layer has num_classes instead of 1
    - Uses CrossEntropyLoss instead of BCEWithLogitsLoss
    - No sigmoid activation (handled by CrossEntropyLoss)
    """
    
    def __init__(self, num_classes, dropout=0.3, freeze_bert=False):
        super().__init__()
        
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        
        if freeze_bert:
            for param in self.bert.parameters():
                param.requires_grad = False
        
        self.classifier = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, num_classes)  # num_classes outputs
        )
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled = outputs.pooler_output
        return self.classifier(pooled)  # Returns logits for each class

# Example usage for emotion classification (6 classes)
if HAS_TRANSFORMERS:
    EMOTION_LABELS = ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']
    num_classes = len(EMOTION_LABELS)
    
    model = BertMultiClassifier(num_classes=num_classes, freeze_bert=True)
    print(f"Model for {num_classes}-class classification created")
    print(f"Labels: {EMOTION_LABELS}")
    
    # Loss function for multi-class
    criterion = nn.CrossEntropyLoss()
    
    # Example forward pass
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    text = "I am so happy today!"
    encoding = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
    
    with torch.no_grad():
        logits = model(encoding['input_ids'], encoding['attention_mask'])
        probs = torch.softmax(logits, dim=-1)
        pred = probs.argmax(dim=-1).item()
    
    print(f"\nExample prediction for '{text}':")
    print(f"Predicted: {EMOTION_LABELS[pred]} ({probs[0, pred]:.2%} confidence)")
    print(f"\nAll probabilities:")
    for i, label in enumerate(EMOTION_LABELS):
        print(f"  {label}: {probs[0, i]:.2%}")

## Exercise 2: Learning Rate Scheduling

**Task:** Implement different learning rate schedules and compare their effects.

In [None]:
import matplotlib.pyplot as plt

def visualize_lr_schedules(num_epochs=3, steps_per_epoch=100):
    """
    Compare different learning rate schedules.
    """
    total_steps = num_epochs * steps_per_epoch
    warmup_steps = int(0.1 * total_steps)  # 10% warmup
    
    # Create dummy optimizer
    model = nn.Linear(10, 2)
    base_lr = 2e-5
    
    schedules = {}
    
    # 1. Linear warmup + decay
    optimizer1 = torch.optim.AdamW(model.parameters(), lr=base_lr)
    scheduler1 = get_linear_schedule_with_warmup(
        optimizer1, 
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps
    )
    
    lrs1 = []
    for _ in range(total_steps):
        lrs1.append(optimizer1.param_groups[0]['lr'])
        scheduler1.step()
    schedules['Linear warmup + decay'] = lrs1
    
    # 2. Cosine annealing
    optimizer2 = torch.optim.AdamW(model.parameters(), lr=base_lr)
    scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer2, 
        T_max=total_steps
    )
    
    lrs2 = []
    for _ in range(total_steps):
        lrs2.append(optimizer2.param_groups[0]['lr'])
        scheduler2.step()
    schedules['Cosine annealing'] = lrs2
    
    # 3. Step decay
    optimizer3 = torch.optim.AdamW(model.parameters(), lr=base_lr)
    scheduler3 = torch.optim.lr_scheduler.StepLR(
        optimizer3,
        step_size=steps_per_epoch,
        gamma=0.5
    )
    
    lrs3 = []
    for _ in range(total_steps):
        lrs3.append(optimizer3.param_groups[0]['lr'])
        scheduler3.step()
    schedules['Step decay'] = lrs3
    
    # 4. Constant LR (baseline)
    schedules['Constant'] = [base_lr] * total_steps
    
    # Plot
    plt.figure(figsize=(12, 5))
    
    for name, lrs in schedules.items():
        plt.plot(lrs, label=name, linewidth=2)
    
    # Add epoch markers
    for i in range(1, num_epochs):
        plt.axvline(x=i * steps_per_epoch, color='gray', linestyle='--', alpha=0.5)
    
    plt.xlabel('Training Step')
    plt.ylabel('Learning Rate')
    plt.title('Learning Rate Schedules Comparison')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    print("Recommendations:")
    print("- Linear warmup + decay: Standard for BERT fine-tuning")
    print("- Cosine annealing: Good for longer training")
    print("- Step decay: Simple but less smooth")
    print("- Constant: Usually not recommended for fine-tuning")

if HAS_TRANSFORMERS:
    visualize_lr_schedules()

## Exercise 3: Layer-wise Learning Rate Decay

**Task:** Implement LLRD (Layer-wise Learning Rate Decay) where earlier layers get lower learning rates.

In [None]:
def get_llrd_params(model, base_lr=2e-5, decay_factor=0.9):
    """
    Get parameter groups with layer-wise learning rate decay.
    
    Earlier layers get lower learning rates:
    - Layer 0: base_lr * decay_factor^11
    - Layer 11: base_lr * decay_factor^0 = base_lr
    - Classifier: base_lr
    
    Args:
        model: BERT-based model
        base_lr: Learning rate for top layers
        decay_factor: How much to reduce LR for each layer down
    """
    param_groups = []
    
    # Group parameters by layer
    # Embeddings get lowest LR
    embeddings_params = []
    for name, param in model.bert.embeddings.named_parameters():
        embeddings_params.append(param)
    
    num_layers = 12  # BERT-base has 12 layers
    embedding_lr = base_lr * (decay_factor ** num_layers)
    param_groups.append({
        'params': embeddings_params,
        'lr': embedding_lr,
        'name': 'embeddings'
    })
    
    # Encoder layers
    for layer_idx in range(num_layers):
        layer_params = []
        for name, param in model.bert.encoder.layer[layer_idx].named_parameters():
            layer_params.append(param)
        
        layer_lr = base_lr * (decay_factor ** (num_layers - 1 - layer_idx))
        param_groups.append({
            'params': layer_params,
            'lr': layer_lr,
            'name': f'layer_{layer_idx}'
        })
    
    # Pooler and classifier get full LR
    classifier_params = list(model.bert.pooler.parameters()) + list(model.classifier.parameters())
    param_groups.append({
        'params': classifier_params,
        'lr': base_lr,
        'name': 'classifier'
    })
    
    return param_groups

if HAS_TRANSFORMERS:
    # Create model and get LLRD params
    model = BertMultiClassifier(num_classes=6, freeze_bert=False)
    param_groups = get_llrd_params(model, base_lr=2e-5, decay_factor=0.9)
    
    print("Layer-wise Learning Rate Decay:")
    print("=" * 50)
    for group in param_groups:
        print(f"  {group['name']}: lr = {group['lr']:.2e}")
    
    # Create optimizer with LLRD
    optimizer = torch.optim.AdamW(param_groups, weight_decay=0.01)
    print(f"\nOptimizer created with {len(param_groups)} parameter groups")
    
    # Visualize
    lrs = [g['lr'] for g in param_groups]
    names = [g['name'] for g in param_groups]
    
    plt.figure(figsize=(12, 5))
    plt.bar(range(len(lrs)), lrs)
    plt.xticks(range(len(names)), names, rotation=45, ha='right')
    plt.ylabel('Learning Rate')
    plt.title('Layer-wise Learning Rate Decay (LLRD)')
    plt.tight_layout()
    plt.show()

---

## Key Takeaways

1. **Multi-class classification** uses CrossEntropyLoss and outputs logits for each class
2. **Learning rate scheduling** is crucial - linear warmup + decay is standard for BERT
3. **LLRD** gives lower learning rates to earlier layers, preserving pre-trained features
4. **Freezing** earlier layers can speed up training while maintaining most performance

---