In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import time
from collections import defaultdict

In [2]:
# Define Permuted MNIST Dataset
class PermutedMNIST(Dataset):
    def __init__(self, root, train=True, transform=None, permutations=None):
        self.mnist_dataset = torchvision.datasets.MNIST(root=root, train=train, transform=transforms.ToTensor(), download=True)
        self.transform = transform
        self.permutations = permutations
        self.train = train

    def __len__(self):
        return len(self.mnist_dataset)

    def __getitem__(self, idx):
        image, label = self.mnist_dataset[idx]
        if self.permutations is not None:
            image = image.view(-1)[self.permutations].view(image.shape)
        if self.transform:
            image = self.transform(image)
        return image, label

# Setup Permuted MNIST Tasks
num_tasks = 5
input_size = 28 * 28  # Flattened MNIST image
permutations = [torch.randperm(input_size) for _ in range(num_tasks)]

# Load Permuted MNIST Datasets for each task
train_tasks = [PermutedMNIST(root="./data", train=True, permutations=permutations[i]) for i in range(num_tasks)]
test_tasks = [PermutedMNIST(root="./data", train=False, permutations=permutations[i]) for i in range(num_tasks)]

# Function to create DataLoaders for each task
def get_task_data(task_idx, batch_size=64):
    train_loader = DataLoader(train_tasks[task_idx], batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_tasks[task_idx], batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

In [6]:
from skopt import gp_minimize
from skopt.space import Real
import numpy as np

def find_best_l2_lambda_task1_focus(
    model_class, input_size, hidden_size, output_size,
    num_tasks=5, epochs_per_task=1, n_calls=15, initial_trials=[0.001, 0.01, 0.1, 1.0]
):
    """
    Optimized L2 lambda tuner focused solely on Task 1 final performance after Task 5.
    Includes exploration safeguards and defaults to a reliable manual value if needed.
    """
    import torch.nn as nn
    import torch.optim as optim

    # Define the search space (log scale - note different range from EWC)
    search_space = [Real(0.0001, 10.0, "log-uniform", name="lambda")]

    # Track best lambda and performance so far
    best_lambda = None
    best_task1_performance = 0.0

    # Objective function: Tracks only Task 1 performance at the end
    def objective_function(params):
        nonlocal best_lambda, best_task1_performance
        current_lambda = params[0]
        print(f"\nTrying L2 lambda = {current_lambda:.6f}")

        # Initialize a new model
        model = model_class(input_size, hidden_size, output_size)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(model.parameters(), lr=0.01)

        # Storage for initial parameters
        initial_params_list = []

        # Train on each task sequentially
        for task_idx in range(num_tasks):
            # Train on current task using L2 regularization
            task_loss, task_acc = train_task(
                model,
                task_idx,
                criterion,
                optimizer,
                initial_params_list=initial_params_list if task_idx > 0 else None,
                l2_lambda=current_lambda,
                epochs=epochs_per_task
            )

            # Store initial parameters after training
            initial_params = store_initial_params(model)
            initial_params_list.append(initial_params)

        # Evaluate on all tasks
        print("\nEvaluating on all tasks:")
        task_accuracies = evaluate_all_tasks(model, num_tasks)

        # Extract only Task 1 performance at the end
        task1_performance = task_accuracies[0]
        print(f"Lambda {current_lambda:.6f} | Task 1 Final Accuracy: {task1_performance:.2f}%")

        # Track the best lambda value specifically for Task 1
        if task1_performance > best_task1_performance:
            best_task1_performance = task1_performance
            best_lambda = current_lambda

        # Return **negative Task 1 performance** for minimization
        return -task1_performance

    # Pre-run manual "anchor" lambdas to avoid small-value traps
    for initial_lambda in initial_trials:
        print(f"\n🔍 Testing anchor lambda = {initial_lambda}")
        objective_function([initial_lambda])

    # Run Bayesian optimization (now with good anchors)
    result = gp_minimize(
        objective_function,
        search_space,
        n_calls=n_calls,
        random_state=42,
        verbose=True
    )

    # If Bayesian result fails to outperform manual values, default to the manual best
    if -result.fun < best_task1_performance:
        print("\n⚠️ Bayesian optimization underperformed manual anchors. Defaulting to best manual result.")
    else:
        best_lambda = result.x[0]
        best_task1_performance = -result.fun

    print(f"\n✅ Best L2 lambda found: {best_lambda:.6f} with Task 1 final accuracy: {best_task1_performance:.2f}%")
    return best_lambda


# Call the function to find the optimal L2 lambda
input_size = 28 * 28
hidden_size = 256
output_size = 10

optimal_l2_lambda = find_best_l2_lambda_task1_focus(
    model_class=SimpleNN,
    input_size=input_size,
    hidden_size=hidden_size,
    output_size=output_size,
    num_tasks=5,
    epochs_per_task=1,
    n_calls=10,
    initial_trials=[0.001, 0.01, 0.05, 0.1, 1.0]
)

print(f"Optimal L2 lambda value for Task 1 performance: {optimal_l2_lambda}")


🔍 Testing anchor lambda = 0.001

Trying L2 lambda = 0.001000
Task 1, Epoch 1/1, Loss: 1.2138, Accuracy: 76.25%
Task 2, Epoch 1/1, Loss: 0.8167, Accuracy: 80.67%
Task 3, Epoch 1/1, Loss: 0.7203, Accuracy: 81.53%
Task 4, Epoch 1/1, Loss: 0.6726, Accuracy: 82.43%
Task 5, Epoch 1/1, Loss: 0.6704, Accuracy: 82.16%

Evaluating on all tasks:
Task 1 Accuracy: 78.15%
Task 2 Accuracy: 81.90%
Task 3 Accuracy: 82.32%
Task 4 Accuracy: 86.70%
Task 5 Accuracy: 89.14%
Lambda 0.001000 | Task 1 Final Accuracy: 78.15%

🔍 Testing anchor lambda = 0.01

Trying L2 lambda = 0.010000
Task 1, Epoch 1/1, Loss: 1.2275, Accuracy: 74.39%
Task 2, Epoch 1/1, Loss: 0.8334, Accuracy: 81.39%
Task 3, Epoch 1/1, Loss: 0.8145, Accuracy: 81.33%
Task 4, Epoch 1/1, Loss: 0.8515, Accuracy: 82.31%
Task 5, Epoch 1/1, Loss: 0.9242, Accuracy: 81.95%

Evaluating on all tasks:
Task 1 Accuracy: 83.78%
Task 2 Accuracy: 82.38%
Task 3 Accuracy: 81.99%
Task 4 Accuracy: 83.26%
Task 5 Accuracy: 88.41%
Lambda 0.010000 | Task 1 Final Accura

In [11]:
# Define a simple neural network
class SimpleNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleNN, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# Function to store initial parameters for L2 regularization
def store_initial_params(model):
    initial_params = {}
    for name, param in model.named_parameters():
        initial_params[name] = param.data.clone()
    return initial_params

# Function to train the model on a specific task with L2 regularization
def train_task(model, task_idx, criterion, optimizer, initial_params_list=None, l2_lambda=0.01, epochs=5):
    train_loader, _ = get_task_data(task_idx)
    
    # For collecting metrics
    task_train_loss = []
    task_train_acc = []
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for inputs, labels in train_loader:
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            # Add L2 penalty if not the first task
            if initial_params_list and task_idx > 0:
                l2_loss = 0
                for name, param in model.named_parameters():
                    # Apply L2 regularization to keep parameters close to their initial values
                    for init_params in initial_params_list:
                        l2_loss += ((param - init_params[name]).pow(2)).sum()
                
                # Scale the L2 loss by lambda and add to the task loss
                loss += (l2_lambda / 2) * l2_loss
            
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
            # Calculate accuracy
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = 100 * correct / total
        
        task_train_loss.append(epoch_loss)
        task_train_acc.append(epoch_acc)
        
        print(f'Task {task_idx+1}, Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%')
    
    return task_train_loss, task_train_acc

# Function to evaluate the model on all seen tasks
def evaluate_all_tasks(model, num_tasks):
    accuracies = []
    
    for i in range(num_tasks):
        _, test_loader = get_task_data(i)
        
        model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for inputs, labels in test_loader:
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        accuracy = 100 * correct / total
        accuracies.append(accuracy)
        print(f'Task {i+1} Accuracy: {accuracy:.2f}%')
    
    return accuracies

# Function to calculate forgetting metrics
def calculate_forgetting_metrics(training_history, initial_accuracies):
    forgetting_rate = {}
    
    # For each task (except the last one since we don't have measurements after it)
    for task_idx in range(len(initial_accuracies) - 1):
        forgetting = []
        
        # Calculate forgetting for the task at each subsequent evaluation point
        for eval_idx, accuracies in enumerate(training_history["task_accuracies"]):
            if task_idx <= eval_idx:  # We only have measurements for tasks we've seen
                forgetting.append(initial_accuracies[task_idx] - accuracies[task_idx])
        
        forgetting_rate[f"Task {task_idx+1}"] = forgetting
    
    return forgetting_rate

# Main function to demonstrate L2 regularization for mitigating catastrophic forgetting
def demonstrate_l2_regularization():
    # Hyperparameters
    input_size = 28 * 28  # Flattened MNIST image
    hidden_size = 256
    output_size = 10  # 10 classes for Permuted MNIST
    learning_rate = 0.01
    epochs_per_task = 5
    l2_lambda = 0.01  # L2 regularization strength
    
    # Initialize model
    model = SimpleNN(input_size, hidden_size, output_size)
    criterion = nn.CrossEntropyLoss()
    
    # Use SGD without weight decay (we'll implement L2 regularization manually)
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)
    
    # Print model configuration
    print(f"Model Configuration:")
    print(f"- SGD with L2 Regularization (lambda={l2_lambda})")
    print(f"- Learning Rate: {learning_rate}")
    print(f"- Hidden Size: {hidden_size}")
    print(f"- Epochs per Task: {epochs_per_task}")
    
    # To store metrics
    training_history = {
        "task_accuracies": [],  # Performance on each task after sequential training
        "training_time": [],    # Time taken to train each task
        "learning_curves": {    # Loss and accuracy during training
            "loss": [],
            "accuracy": []
        }
    }
    
    # To compute forgetting metrics
    initial_accuracies = []  # Accuracy on each task right after learning it
    
    # Store initial parameters for each task
    initial_params_list = []
    
    # Train on each task sequentially
    for task_idx in range(len(train_tasks)):
        print(f"\n{'='*50}")
        print(f"Training on Task {task_idx+1}")
        print(f"{'='*50}")
        
        # Measure training time
        start_time = time.time()
        
        # Train on current task using L2 regularization if not the first task
        task_loss, task_acc = train_task(
            model, 
            task_idx, 
            criterion, 
            optimizer, 
            initial_params_list=initial_params_list if task_idx > 0 else None,
            l2_lambda=l2_lambda,
            epochs=epochs_per_task
        )
        
        # Record training time
        end_time = time.time()
        training_time = end_time - start_time
        training_history["training_time"].append(training_time)
        
        # Save learning curves
        training_history["learning_curves"]["loss"].extend(task_loss)
        training_history["learning_curves"]["accuracy"].extend(task_acc)
        
        # After training on this task, store the current parameters
        initial_params = store_initial_params(model)
        initial_params_list.append(initial_params)
        
        # Evaluate on all tasks seen so far
        print("\nEvaluating on all tasks seen so far:")
        task_accuracies = evaluate_all_tasks(model, task_idx + 1)
        
        # Store the accuracy on the current task after learning it
        if task_idx == 0:
            initial_accuracies.append(task_accuracies[0])
        else:
            training_history["task_accuracies"].append(task_accuracies.copy())
            initial_accuracies.append(task_accuracies[task_idx])
    
    # Calculate forgetting metrics
    forgetting_rate = calculate_forgetting_metrics(training_history, initial_accuracies)
    
    return training_history, forgetting_rate, initial_accuracies

In [12]:
# Example usage
if __name__ == "__main__":
    training_history, forgetting_rate, initial_accuracies = demonstrate_l2_regularization()

Model Configuration:
- SGD with L2 Regularization (lambda=0.01)
- Learning Rate: 0.01
- Hidden Size: 256
- Epochs per Task: 5

Training on Task 1
Task 1, Epoch 1/5, Loss: 1.2179, Accuracy: 75.57%
Task 1, Epoch 2/5, Loss: 0.4913, Accuracy: 87.52%
Task 1, Epoch 3/5, Loss: 0.3914, Accuracy: 89.28%
Task 1, Epoch 4/5, Loss: 0.3502, Accuracy: 90.19%
Task 1, Epoch 5/5, Loss: 0.3253, Accuracy: 90.81%

Evaluating on all tasks seen so far:
Task 1 Accuracy: 91.59%

Training on Task 2
Task 2, Epoch 1/5, Loss: 0.7003, Accuracy: 82.60%
Task 2, Epoch 2/5, Loss: 0.4436, Accuracy: 89.43%
Task 2, Epoch 3/5, Loss: 0.4124, Accuracy: 90.39%
Task 2, Epoch 4/5, Loss: 0.3981, Accuracy: 91.02%
Task 2, Epoch 5/5, Loss: 0.3888, Accuracy: 91.41%

Evaluating on all tasks seen so far:
Task 1 Accuracy: 89.95%
Task 2 Accuracy: 92.12%

Training on Task 3
Task 3, Epoch 1/5, Loss: 0.7540, Accuracy: 81.88%
Task 3, Epoch 2/5, Loss: 0.5327, Accuracy: 89.46%
Task 3, Epoch 3/5, Loss: 0.5084, Accuracy: 90.50%
Task 3, Epoch 4/