# Preventing Forgetting in Continual Learning

This notebook demonstrates how to use custom architectures and loss functions to prevent catastrophic forgetting in continual learning scenarios.

## Overview

Continual learning (also known as lifelong learning or incremental learning) aims to enable machine learning models to learn from a continuous stream of data, acquiring new knowledge while retaining previously learned information.

**Key Challenges:**
- **Catastrophic Forgetting**: Neural networks tend to forget previously learned tasks when trained on new tasks
- **Stability-Plasticity Dilemma**: Balancing the ability to learn new information while preserving old knowledge

**Approaches Implemented:**
1. Custom CNN and ResNet architectures
2. Knowledge Distillation Loss
3. Elastic Weight Consolidation (EWC)
4. Learning without Forgetting (LwF)
5. iCaRL (Incremental Classifier and Representation Learning)

## 1. Setup and Imports

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Import custom modules
from models import CustomCNN, CustomResNet, create_resnet18
from losses import KnowledgeDistillationLoss, ElasticWeightConsolidationLoss, LwFLoss, compute_fisher_information
from experiments import ContinualLearningBenchmark, create_task_split

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Load and Prepare Dataset

We'll use CIFAR-10 as our benchmark dataset, split into multiple tasks.

In [None]:
# Data transformations
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Load CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train
)

test_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test
)

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Classes: {train_dataset.classes}")

## 3. Create Task Splits

Split the dataset into sequential tasks (e.g., 5 tasks with 2 classes each).

In [None]:
# Configuration
num_tasks = 5
batch_size = 128

# Create task splits
train_task_datasets = create_task_split(train_dataset, num_tasks, task_type='class_incremental')
test_task_datasets = create_task_split(test_dataset, num_tasks, task_type='class_incremental')

# Create data loaders for each task
train_loaders = [
    DataLoader(task_data, batch_size=batch_size, shuffle=True, num_workers=2)
    for task_data in train_task_datasets
]

test_loaders = [
    DataLoader(task_data, batch_size=batch_size, shuffle=False, num_workers=2)
    for task_data in test_task_datasets
]

print(f"Number of tasks: {num_tasks}")
for i, loader in enumerate(train_loaders):
    print(f"Task {i}: {len(loader.dataset)} training samples")

## 4. Baseline: Training without Continual Learning

First, let's establish a baseline by training sequentially without any anti-forgetting techniques.

In [None]:
# Create model
baseline_model = CustomCNN(num_classes=10, input_channels=3)

# Create benchmark
baseline_benchmark = ContinualLearningBenchmark(
    model=baseline_model,
    device=device,
    num_tasks=num_tasks
)

# Run benchmark
print("\n" + "="*60)
print("BASELINE: Training without continual learning techniques")
print("="*60)

baseline_results = baseline_benchmark.run_benchmark(
    train_loaders=train_loaders,
    test_loaders=test_loaders,
    epochs_per_task=10,
    learning_rate=0.001
)

# Print performance matrix
baseline_benchmark.print_performance_matrix()

## 5. Experiment 1: Custom ResNet with Knowledge Distillation

Use a custom ResNet architecture with knowledge distillation to preserve old knowledge.

In [None]:
# TODO: Implement knowledge distillation experiment
# This is a template - customize based on your needs

print("\n" + "="*60)
print("EXPERIMENT 1: ResNet with Knowledge Distillation")
print("="*60)

# Create model
resnet_model = create_resnet18(num_classes=10, input_channels=3)

# TODO: Implement training loop with knowledge distillation
# Hint: Store old model, use KnowledgeDistillationLoss

print("Implementation pending - add your custom code here")

## 6. Experiment 2: Elastic Weight Consolidation (EWC)

Apply EWC to constrain important parameters.

In [None]:
# TODO: Implement EWC experiment

print("\n" + "="*60)
print("EXPERIMENT 2: Elastic Weight Consolidation")
print("="*60)

# Create model
ewc_model = CustomCNN(num_classes=10, input_channels=3)

# TODO: Implement EWC training
# Hint: Compute Fisher information after each task
# Use ElasticWeightConsolidationLoss

print("Implementation pending - add your custom code here")

## 7. Experiment 3: Learning without Forgetting (LwF)

Use LwF to maintain performance on old tasks.

In [None]:
# TODO: Implement LwF experiment

print("\n" + "="*60)
print("EXPERIMENT 3: Learning without Forgetting")
print("="*60)

# Create model
lwf_model = create_resnet18(num_classes=10, input_channels=3)

# TODO: Implement LwF training
# Hint: Use LwFLoss with old task indices

print("Implementation pending - add your custom code here")

## 8. Visualization and Comparison

Compare the performance of different methods.

In [None]:
# TODO: Add visualization code

# Example: Plot performance matrix as heatmap
def plot_performance_matrix(performance_matrix, title):
    plt.figure(figsize=(10, 8))
    sns.heatmap(performance_matrix, annot=True, fmt='.1f', cmap='YlOrRd',
                xticklabels=[f'Task {i}' for i in range(num_tasks)],
                yticklabels=[f'After Task {i}' for i in range(num_tasks)])
    plt.title(title)
    plt.xlabel('Task Evaluated')
    plt.ylabel('Training Stage')
    plt.tight_layout()
    plt.show()

# Plot baseline results
plot_performance_matrix(baseline_results['performance_matrix'], 
                       'Baseline Performance Matrix')

## 9. Metrics Summary

Compare metrics across different methods.

In [None]:
# TODO: Create comparison table/plot

import pandas as pd

# Example metrics comparison
metrics_data = {
    'Method': ['Baseline'],
    'Average Accuracy': [baseline_results['final_metrics']['average_accuracy']],
    'Forgetting': [baseline_results['final_metrics']['forgetting']],
    'Backward Transfer': [baseline_results['final_metrics']['backward_transfer']]
}

# TODO: Add other methods' metrics

metrics_df = pd.DataFrame(metrics_data)
print("\nMetrics Comparison:")
print(metrics_df.to_string(index=False))

## 10. Conclusions and Next Steps

### Key Findings
- TODO: Add your observations
- Compare baseline vs. continual learning methods
- Analyze forgetting patterns

### Future Work
1. Experiment with different architectures
2. Try other continual learning methods (e.g., PackNet, Progressive Neural Networks)
3. Test on different datasets (ImageNet, MNIST, etc.)
4. Optimize hyperparameters
5. Implement memory replay strategies

### References
- [Learning without Forgetting](https://arxiv.org/abs/1606.09282)
- [Elastic Weight Consolidation](https://arxiv.org/abs/1612.00796)
- [iCaRL: Incremental Classifier and Representation Learning](https://arxiv.org/abs/1611.07725)
- [Continual Learning Survey](https://arxiv.org/abs/1909.08383)

## Appendix: Model Architecture Details

In [None]:
# Print model architectures
print("Custom CNN Architecture:")
print("="*60)
cnn_model = CustomCNN(num_classes=10)
print(cnn_model)

print("\n" + "="*60)
print("Custom ResNet-18 Architecture:")
print("="*60)
resnet_model = create_resnet18(num_classes=10)
print(resnet_model)

# Count parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nCustom CNN parameters: {count_parameters(cnn_model):,}")
print(f"ResNet-18 parameters: {count_parameters(resnet_model):,}")