# Day 29: Mitigating Catastrophic Forgetting

In this notebook, we'll explore techniques to mitigate catastrophic forgetting when fine-tuning large language models. Catastrophic forgetting occurs when a model loses previously learned knowledge during fine-tuning on a new task.

## Overview

1. Setup and dependencies
2. Demonstrating catastrophic forgetting
3. Implementing regularization-based methods
4. Comparing model performance

## 1. Setup and Dependencies

In [None]:
!pip install -q transformers datasets peft evaluate accelerate torch

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datasets import load_dataset
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding
)
from peft import (
    get_peft_model,
    LoraConfig,
    TaskType
)
import evaluate
import copy

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 2. Loading and Preparing Datasets

We'll use two different tasks to demonstrate catastrophic forgetting: sentiment analysis (SST-2) and natural language inference (MNLI).

In [None]:
# Define the base model
base_model_name = "roberta-base"

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model_name)

# Load datasets for two different tasks
# Task 1: Sentiment Analysis (SST-2)
sst2_dataset = load_dataset("glue", "sst2")
# Use smaller subsets for demonstration
sst2_train = sst2_dataset["train"].select(range(1000))
sst2_eval = sst2_dataset["validation"].select(range(200))

# Task 2: Natural Language Inference (MNLI)
mnli_dataset = load_dataset("glue", "mnli")
mnli_train = mnli_dataset["train"].select(range(1000))
mnli_eval = mnli_dataset["validation_matched"].select(range(200))

print("SST-2 dataset (Sentiment Analysis):")
print(f"  Train: {len(sst2_train)} examples")
print(f"  Validation: {len(sst2_eval)} examples")
print(f"  Labels: {sst2_dataset['train'].features['label'].names}")

print("\nMNLI dataset (Natural Language Inference):")
print(f"  Train: {len(mnli_train)} examples")
print(f"  Validation: {len(mnli_eval)} examples")
print(f"  Labels: {mnli_dataset['train'].features['label'].names}")

In [None]:
# Tokenize SST-2 dataset
def tokenize_sst2(examples):
    return tokenizer(examples["sentence"], padding="max_length", truncation=True, max_length=128)

tokenized_sst2_train = sst2_train.map(tokenize_sst2, batched=True)
tokenized_sst2_eval = sst2_eval.map(tokenize_sst2, batched=True)

# Tokenize MNLI dataset
def tokenize_mnli(examples):
    return tokenizer(
        examples["premise"],
        examples["hypothesis"],
        padding="max_length",
        truncation=True,
        max_length=128
    )

tokenized_mnli_train = mnli_train.map(tokenize_mnli, batched=True)
tokenized_mnli_eval = mnli_eval.map(tokenize_mnli, batched=True)

# Create data collators
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

## 3. Demonstrating Catastrophic Forgetting

We'll first train a model on Task 1 (SST-2), then fine-tune it on Task 2 (MNLI), and observe how performance on Task 1 degrades.

In [None]:
# Define metrics for evaluation
accuracy_metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy_metric.compute(predictions=predictions, references=labels)

In [None]:
# Function to train on a task
def train_on_task(model, train_dataset, eval_dataset, num_labels, output_dir, num_epochs=3):
    # Define training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        learning_rate=3e-5,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=32,
        num_train_epochs=num_epochs,
        weight_decay=0.01,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        push_to_hub=False,
        report_to="none"
    )
    
    # Create the trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics
    )
    
    # Train the model
    trainer.train()
    
    # Evaluate the model
    eval_results = trainer.evaluate()
    
    return model, eval_results

In [None]:
# Train on Task 1 (SST-2)
print("Training on Task 1 (SST-2)...")
model_task1 = AutoModelForSequenceClassification.from_pretrained(
    base_model_name,
    num_labels=2,  # Binary classification for SST-2
    return_dict=True
).to(device)

model_task1, task1_results = train_on_task(
    model_task1,
    tokenized_sst2_train,
    tokenized_sst2_eval,
    num_labels=2,
    output_dir="./results/task1",
    num_epochs=3
)

print(f"Task 1 (SST-2) accuracy: {task1_results['eval_accuracy']:.4f}")

# Save a copy of the Task 1 model for later comparison
task1_model_copy = copy.deepcopy(model_task1)

In [None]:
# Fine-tune on Task 2 (MNLI) using the Task 1 model
print("\nFine-tuning on Task 2 (MNLI)...")

# We need to resize the classification head for the new task
model_task2 = AutoModelForSequenceClassification.from_pretrained(
    base_model_name,
    num_labels=3,  # MNLI has 3 classes
    return_dict=True
).to(device)

# Copy weights from task1 model to task2 model (except for the classification head)
for name, param in model_task1.named_parameters():
    if "classifier" not in name:  # Skip the classification head
        model_task2.state_dict()[name].copy_(param)

model_task2, task2_results = train_on_task(
    model_task2,
    tokenized_mnli_train,
    tokenized_mnli_eval,
    num_labels=3,
    output_dir="./results/task2",
    num_epochs=3
)

print(f"Task 2 (MNLI) accuracy: {task2_results['eval_accuracy']:.4f}")

In [None]:
# Evaluate Task 1 performance after fine-tuning on Task 2
# We need to create a new model with the Task 1 classification head
model_task1_after = AutoModelForSequenceClassification.from_pretrained(
    base_model_name,
    num_labels=2,  # Binary classification for SST-2
    return_dict=True
).to(device)

# Copy weights from task2 model to task1_after model (except for the classification head)
for name, param in model_task2.named_parameters():
    if "classifier" not in name:  # Skip the classification head
        model_task1_after.state_dict()[name].copy_(param)

# Copy the original task1 classification head
for name, param in task1_model_copy.named_parameters():
    if "classifier" in name:
        model_task1_after.state_dict()[name].copy_(param)

# Evaluate on Task 1
trainer = Trainer(
    model=model_task1_after,
    eval_dataset=tokenized_sst2_eval,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

task1_after_results = trainer.evaluate()
print(f"\nTask 1 (SST-2) accuracy after fine-tuning on Task 2: {task1_after_results['eval_accuracy']:.4f}")
print(f"Performance drop: {task1_results['eval_accuracy'] - task1_after_results['eval_accuracy']:.4f}")

## 4. Mitigating Catastrophic Forgetting with LoRA

Now, let's use LoRA to mitigate catastrophic forgetting by only updating a small subset of parameters.

In [None]:
# Train on Task 1 (SST-2) with LoRA
print("Training on Task 1 (SST-2) with LoRA...")

# Load the base model
base_model_task1 = AutoModelForSequenceClassification.from_pretrained(
    base_model_name,
    num_labels=2,  # Binary classification for SST-2
    return_dict=True
)

# Define LoRA configuration
lora_config_task1 = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=["query", "key", "value"],
    bias="none"
)

# Create the PEFT model
lora_model_task1 = get_peft_model(base_model_task1, lora_config_task1).to(device)

# Train on Task 1
lora_model_task1, lora_task1_results = train_on_task(
    lora_model_task1,
    tokenized_sst2_train,
    tokenized_sst2_eval,
    num_labels=2,
    output_dir="./results/lora_task1",
    num_epochs=3
)

print(f"Task 1 (SST-2) accuracy with LoRA: {lora_task1_results['eval_accuracy']:.4f}")

In [None]:
# Fine-tune on Task 2 (MNLI) with a new LoRA adapter
print("\nFine-tuning on Task 2 (MNLI) with a new LoRA adapter...")

# Load the base model for Task 2
base_model_task2 = AutoModelForSequenceClassification.from_pretrained(
    base_model_name,
    num_labels=3,  # MNLI has 3 classes
    return_dict=True
)

# Define LoRA configuration for Task 2
lora_config_task2 = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=["query", "key", "value"],
    bias="none"
)

# Create the PEFT model
lora_model_task2 = get_peft_model(base_model_task2, lora_config_task2).to(device)

# Train on Task 2
lora_model_task2, lora_task2_results = train_on_task(
    lora_model_task2,
    tokenized_mnli_train,
    tokenized_mnli_eval,
    num_labels=3,
    output_dir="./results/lora_task2",
    num_epochs=3
)

print(f"Task 2 (MNLI) accuracy with LoRA: {lora_task2_results['eval_accuracy']:.4f}")

In [None]:
# Re-evaluate Task 1 with its LoRA adapter
trainer = Trainer(
    model=lora_model_task1,
    eval_dataset=tokenized_sst2_eval,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

lora_task1_after_results = trainer.evaluate()
print(f"\nTask 1 (SST-2) accuracy with LoRA after training Task 2: {lora_task1_after_results['eval_accuracy']:.4f}")
print(f"Performance change: {lora_task1_after_results['eval_accuracy'] - lora_task1_results['eval_accuracy']:.4f}")

## 5. Comparing Results

Let's compare the results of the standard fine-tuning approach and the LoRA approach.

In [None]:
# Compile results
results = {
    "Standard Fine-tuning": {
        "Task 1 (Initial)": task1_results["eval_accuracy"],
        "Task 2": task2_results["eval_accuracy"],
        "Task 1 (After Task 2)": task1_after_results["eval_accuracy"],
        "Task 1 Performance Drop": task1_results["eval_accuracy"] - task1_after_results["eval_accuracy"]
    },
    "LoRA": {
        "Task 1 (Initial)": lora_task1_results["eval_accuracy"],
        "Task 2": lora_task2_results["eval_accuracy"],
        "Task 1 (After Task 2)": lora_task1_after_results["eval_accuracy"],
        "Task 1 Performance Drop": lora_task1_results["eval_accuracy"] - lora_task1_after_results["eval_accuracy"]
    }
}

# Convert to DataFrame for better visualization
results_df = pd.DataFrame(results)
print("Comparison of Standard Fine-tuning vs. LoRA:")
print(results_df)

In [None]:
# Visualize the results
plt.figure(figsize=(10, 6))

# Plot Task 1 performance before and after Task 2
x = np.arange(2)
width = 0.35

plt.bar(x - width/2, 
        [results["Standard Fine-tuning"]["Task 1 (Initial)"], results["Standard Fine-tuning"]["Task 1 (After Task 2)"]], 
        width, label="Standard Fine-tuning")
plt.bar(x + width/2, 
        [results["LoRA"]["Task 1 (Initial)"], results["LoRA"]["Task 1 (After Task 2)"]], 
        width, label="LoRA")

plt.xlabel("Training Stage")
plt.ylabel("Task 1 Accuracy")
plt.title("Impact of Task 2 Training on Task 1 Performance")
plt.xticks(x, ["Before Task 2", "After Task 2"])
plt.legend()
plt.grid(axis="y", linestyle="--", alpha=0.7)

plt.tight_layout()
plt.show()

## 6. Other Techniques to Mitigate Catastrophic Forgetting

Let's briefly discuss other techniques that can be used to mitigate catastrophic forgetting.

### 6.1 Elastic Weight Consolidation (EWC)

EWC adds a regularization term that penalizes changes to parameters that are important for the previous task.

In [None]:
# Pseudocode for EWC implementation
'''
# Calculate Fisher Information Matrix for Task 1
fisher_information = calculate_fisher_information(model, task1_data)

# Define EWC loss
def ewc_loss(model, old_model, fisher, lambda_ewc):
    loss = 0
    for name, param in model.named_parameters():
        if name in fisher:
            loss += lambda_ewc * torch.sum(fisher[name] * (param - old_model[name])**2)
    return loss

# Add EWC loss to the standard loss during Task 2 training
total_loss = task2_loss + ewc_loss(model, task1_model, fisher_information, lambda_ewc=1000)
'''

### 6.2 Replay Methods

Replay methods involve mixing in samples from previous tasks during training on new tasks.

In [None]:
# Pseudocode for replay implementation
'''
# Create a combined dataset with samples from both tasks
combined_dataset = concatenate_datasets([task1_dataset.select(indices), task2_dataset])

# Train on the combined dataset
model = train_model(model, combined_dataset)
'''

### 6.3 Knowledge Distillation

Knowledge distillation involves using the outputs of the original model as soft targets when training on the new task.

In [None]:
# Pseudocode for knowledge distillation implementation
'''
# Get predictions from the original model on Task 2 data
with torch.no_grad():
    task1_model_outputs = task1_model(task2_inputs)

# Define distillation loss
def distillation_loss(student_logits, teacher_logits, temperature=2.0):
    soft_targets = F.softmax(teacher_logits / temperature, dim=-1)
    soft_predictions = F.log_softmax(student_logits / temperature, dim=-1)
    return F.kl_div(soft_predictions, soft_targets, reduction="batchmean") * (temperature ** 2)

# Combine task loss and distillation loss
total_loss = task2_loss + alpha * distillation_loss(model_outputs, task1_model_outputs)
'''

## Conclusion

In this notebook, we've explored catastrophic forgetting in large language models and techniques to mitigate it:

1. We demonstrated catastrophic forgetting by fine-tuning a model on two sequential tasks and observing the performance drop on the first task.

2. We showed how Parameter-Efficient Fine-Tuning (PEFT) with LoRA can mitigate catastrophic forgetting by only updating a small subset of parameters for each task.

3. We discussed other techniques like Elastic Weight Consolidation (EWC), replay methods, and knowledge distillation.

These techniques are essential for developing models that can learn multiple tasks sequentially without forgetting previously learned knowledge. By carefully selecting the right approach for your specific use case, you can create more versatile and robust language models.