In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import time
from collections import defaultdict

In [4]:
# Define Permuted MNIST Dataset
class PermutedMNIST(Dataset):
    def __init__(self, root, train=True, transform=None, permutations=None):
        self.mnist_dataset = torchvision.datasets.MNIST(root=root, train=train, transform=transforms.ToTensor(), download=True)
        self.transform = transform
        self.permutations = permutations
        self.train = train

    def __len__(self):
        return len(self.mnist_dataset)

    def __getitem__(self, idx):
        image, label = self.mnist_dataset[idx]
        if self.permutations is not None:
            image = image.view(-1)[self.permutations].view(image.shape)
        if self.transform:
            image = self.transform(image)
        return image, label

# Setup Permuted MNIST Tasks
num_tasks = 5
input_size = 28 * 28  # Flattened MNIST image
permutations = [torch.randperm(input_size) for _ in range(num_tasks)]

# Load Permuted MNIST Datasets for each task
train_tasks = [PermutedMNIST(root="./data", train=True, permutations=permutations[i]) for i in range(num_tasks)]
test_tasks = [PermutedMNIST(root="./data", train=False, permutations=permutations[i]) for i in range(num_tasks)]

# Function to create DataLoaders for each task
def get_task_data(task_idx, batch_size=64):
    train_loader = DataLoader(train_tasks[task_idx], batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_tasks[task_idx], batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

In [10]:
from skopt import gp_minimize
from skopt.space import Real
import numpy as np

def find_optimal_si_lambda_task1_focus(
    model_class, input_size, hidden_size, output_size,
    num_tasks=5, epochs_per_task=1, n_calls=10
):
    """
    Find an optimal lambda value for Synaptic Intelligence using Bayesian optimization,
    focused exclusively on maximizing Task 1 performance after Task 5.
    """
    import torch.nn as nn
    import torch.optim as optim

    # Define the search space (log scale)
    search_space = [Real(0.1, 100.0, "log-uniform", name="lambda")]

    # Objective function: Only cares about Task 1 performance at the very end
    def objective_function(params):
        current_lambda = params[0]
        print(f"\nTrying SI lambda = {current_lambda:.2f}")

        # Initialize a new model
        model = model_class(input_size, hidden_size, output_size)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(model.parameters(), lr=0.01)

        # Storage for SI parameters
        omega_sum = None
        initial_params = {}

        # Save initial model parameters
        for name, param in model.named_parameters():
            initial_params[name] = param.data.clone()

        # Train on each task sequentially
        for task_idx in range(num_tasks):
            # Train on current task using SI
            task_loss, task_acc, omega_curr = train_task(
                model,
                task_idx,
                criterion,
                optimizer,
                omega_sum=omega_sum,
                initial_params=initial_params if task_idx > 0 else None,
                si_lambda=current_lambda,
                epochs=epochs_per_task
            )

            # Update accumulated importance weights (omega_sum)
            if omega_sum is None:
                omega_sum = omega_curr
            else:
                for name in omega_sum:
                    omega_sum[name] += omega_curr[name]

            # Save new initial parameters for the next task
            for name, param in model.named_parameters():
                initial_params[name] = param.data.clone()

        # Evaluate on all tasks
        print("\nEvaluating on all tasks:")
        task_accuracies = evaluate_all_tasks(model, num_tasks)

        # Extract **only Task 1 accuracy** at the end
        task1_performance = task_accuracies[0]
        print(f"Lambda {current_lambda:.2f} | Task 1 Final Accuracy: {task1_performance:.2f}%")

        # Return **negative Task 1 performance** for minimization
        return -task1_performance

    # Run Bayesian optimization
    result = gp_minimize(
        objective_function,
        search_space,
        n_calls=n_calls,
        random_state=42,
        verbose=True
    )

    # Get the best lambda specifically for Task 1 performance
    best_lambda = result.x[0]
    best_performance = -result.fun

    print(f"\nBest SI lambda found: {best_lambda:.2f} with Task 1 final accuracy: {best_performance:.2f}%")
    return best_lambda


# Call the function to find the optimal SI lambda
input_size = 28 * 28
hidden_size = 256
output_size = 10

optimal_si_lambda = find_optimal_si_lambda_task1_focus(
    model_class=SimpleNN,
    input_size=input_size,
    hidden_size=hidden_size,
    output_size=output_size,
    num_tasks=5,
    epochs_per_task=1,  # Keep it efficient
    n_calls=10  # Minimum search effort
)

print(f"Optimal SI lambda value for Task 1 performance: {optimal_si_lambda}")

Iteration No: 1 started. Evaluating function at random point.

Trying SI lambda = 24.53
Task 1, Epoch 1/1, Loss: 1.1940, Accuracy: 75.44%
Task 2, Epoch 1/1, Loss: 0.7927, Accuracy: 81.73%
Task 3, Epoch 1/1, Loss: 0.6902, Accuracy: 82.47%
Task 4, Epoch 1/1, Loss: 0.6393, Accuracy: 82.72%
Task 5, Epoch 1/1, Loss: 0.6498, Accuracy: 81.56%

Evaluating on all tasks:
Task 1 Accuracy: 76.60%
Task 2 Accuracy: 84.84%
Task 3 Accuracy: 83.86%
Task 4 Accuracy: 87.05%
Task 5 Accuracy: 88.91%
Lambda 24.53 | Task 1 Final Accuracy: 76.60%
Iteration No: 1 ended. Evaluation done at random point.
Time taken: 164.0642
Function value obtained: -76.6000
Current minimum: -76.6000
Iteration No: 2 started. Evaluating function at random point.

Trying SI lambda = 0.36
Task 1, Epoch 1/1, Loss: 1.2074, Accuracy: 74.91%
Task 2, Epoch 1/1, Loss: 0.7836, Accuracy: 82.21%
Task 3, Epoch 1/1, Loss: 0.6922, Accuracy: 82.33%
Task 4, Epoch 1/1, Loss: 0.6410, Accuracy: 82.78%
Task 5, Epoch 1/1, Loss: 0.6426, Accuracy: 81.5

In [11]:
# Define a simple neural network
class SimpleNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleNN, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# Function to initialize synaptic intelligence trackers
def initialize_si_trackers(model):
    # Initialize parameter importance (omega)
    omega = {}
    # Initialize parameter change accumulator (path integral)
    path_integral = {}
    # Initialize old parameter values
    old_params = {}
    
    for name, param in model.named_parameters():
        omega[name] = torch.zeros_like(param.data)
        path_integral[name] = torch.zeros_like(param.data)
        old_params[name] = param.data.clone()
    
    return omega, path_integral, old_params

# Function to update synaptic intelligence trackers
def update_si_trackers(model, path_integral, old_params):
    for name, param in model.named_parameters():
        # Compute parameter change
        delta = param.data - old_params[name]
        # Update path integral with parameter change * gradient
        if param.grad is not None:
            path_integral[name] += -param.grad * delta
        # Update old parameter values
        old_params[name] = param.data.clone()
    
    return path_integral, old_params

# Function to compute the synaptic intelligence omega (importance)
def compute_omega(model, path_integral, old_params, epsilon=0.1):
    omega_new = {}
    
    for name, param in model.named_parameters():
        # Compute parameter change
        delta = param.data - old_params[name]
        # Compute importance (omega) based on path integral and parameter change
        # Add epsilon to avoid division by zero
        delta_norm = torch.norm(delta)
        if delta_norm > 0:
            omega_new[name] = path_integral[name] / (delta.pow(2) + epsilon)
        else:
            omega_new[name] = torch.zeros_like(param.data)
    
    return omega_new

# Function to compute the SI penalty
def si_penalty(model, omega_sum, initial_params):
    penalty = 0
    for name, param in model.named_parameters():
        # Compute squared parameter change from old value
        delta = (param.data - initial_params[name]).pow(2)
        # Compute the penalty based on importance and parameter change
        penalty += (omega_sum[name] * delta).sum()
    
    return penalty

# Function to train the model on a specific task with Synaptic Intelligence
def train_task(model, task_idx, criterion, optimizer, omega_sum=None, initial_params=None, si_lambda=1.0, epochs=5):
    train_loader, _ = get_task_data(task_idx)
    
    # Initialize SI trackers for current task
    omega_curr, path_integral, old_params = initialize_si_trackers(model)
    
    # For collecting metrics
    task_train_loss = []
    task_train_acc = []
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for inputs, labels in train_loader:
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            # Add SI penalty if not the first task
            if omega_sum is not None and initial_params is not None:
                si_loss = si_penalty(model, omega_sum, initial_params)
                loss += si_lambda * si_loss
            
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Update SI trackers
            path_integral, old_params = update_si_trackers(model, path_integral, old_params)
            
            running_loss += loss.item()
            
            # Calculate accuracy
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = 100 * correct / total
        
        task_train_loss.append(epoch_loss)
        task_train_acc.append(epoch_acc)
        
        print(f'Task {task_idx+1}, Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%')
    
    # Compute importance (omega) after training on this task
    omega_curr = compute_omega(model, path_integral, old_params)
    
    return task_train_loss, task_train_acc, omega_curr

# Function to evaluate the model on all seen tasks
def evaluate_all_tasks(model, num_tasks):
    accuracies = []
    
    for i in range(num_tasks):
        _, test_loader = get_task_data(i)
        
        model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for inputs, labels in test_loader:
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        accuracy = 100 * correct / total
        accuracies.append(accuracy)
        print(f'Task {i+1} Accuracy: {accuracy:.2f}%')
    
    return accuracies

# Function to calculate forgetting metrics
def calculate_forgetting_metrics(training_history, initial_accuracies):
    forgetting_rate = {}
    
    # For each task (except the last one since we don't have measurements after it)
    for task_idx in range(len(initial_accuracies) - 1):
        forgetting = []
        
        # Calculate forgetting for the task at each subsequent evaluation point
        for eval_idx, accuracies in enumerate(training_history["task_accuracies"]):
            if task_idx <= eval_idx:  # We only have measurements for tasks we've seen
                forgetting.append(initial_accuracies[task_idx] - accuracies[task_idx])
        
        forgetting_rate[f"Task {task_idx+1}"] = forgetting
    
    return forgetting_rate

# Main function to demonstrate Synaptic Intelligence for mitigating catastrophic forgetting
def demonstrate_synaptic_intelligence():
    # Hyperparameters
    input_size = 28 * 28  # Flattened MNIST image
    hidden_size = 256
    output_size = 10  # 10 classes for Permuted MNIST
    learning_rate = 0.01
    epochs_per_task = 5
    si_lambda = 0.2  # Synaptic Intelligence regularization strength
    
    # Initialize model
    model = SimpleNN(input_size, hidden_size, output_size)
    criterion = nn.CrossEntropyLoss()
    
    # Use SGD without weight decay (SI replaces L2 regularization)
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)
    
    # Print model configuration
    print(f"Model Configuration:")
    print(f"- SGD with Synaptic Intelligence (lambda={si_lambda})")
    print(f"- Learning Rate: {learning_rate}")
    print(f"- Hidden Size: {hidden_size}")
    print(f"- Epochs per Task: {epochs_per_task}")
    
    # To store metrics
    training_history = {
        "task_accuracies": [],  # Performance on each task after sequential training
        "training_time": [],    # Time taken to train each task
        "learning_curves": {    # Loss and accuracy during training
            "loss": [],
            "accuracy": []
        }
    }
    
    # To compute forgetting metrics
    initial_accuracies = []  # Accuracy on each task right after learning it
    
    # Store accumulated importance weights (omega) and initial parameters
    omega_sum = None
    initial_params = {}
    
    # Save initial model parameters
    for name, param in model.named_parameters():
        initial_params[name] = param.data.clone()
    
    # Train on each task sequentially
    for task_idx in range(len(train_tasks)): # Change split_mnist_tasks to train_tasks
        print(f"\n{'='*50}")
        print(f"Training on Task {task_idx+1}") #Remove task details from print
        print(f"{'='*50}")
        
        # Measure training time
        start_time = time.time()
        
        # Train on current task using SI if not the first task
        task_loss, task_acc, omega_curr = train_task(
            model, 
            task_idx, 
            criterion, 
            optimizer, 
            omega_sum=omega_sum,
            initial_params=initial_params if task_idx > 0 else None,
            si_lambda=si_lambda,
            epochs=epochs_per_task
        )
        
        # Record training time
        end_time = time.time()
        training_time = end_time - start_time
        training_history["training_time"].append(training_time)
        
        # Save learning curves
        training_history["learning_curves"]["loss"].extend(task_loss)
        training_history["learning_curves"]["accuracy"].extend(task_acc)
        
        # Update accumulated importance weights (omega_sum)
        if omega_sum is None:
            omega_sum = omega_curr
        else:
            for name in omega_sum:
                omega_sum[name] += omega_curr[name]
        
        # Save new initial parameters for the next task
        for name, param in model.named_parameters():
            initial_params[name] = param.data.clone()
        
        # Evaluate on all tasks seen so far
        print("\nEvaluating on all tasks seen so far:")
        task_accuracies = evaluate_all_tasks(model, task_idx + 1)

        # Store the accuracy on the current task after learning it
        if task_idx == 0:
            initial_accuracies.append(task_accuracies[0])
        else:
            training_history["task_accuracies"].append(task_accuracies.copy())
            initial_accuracies.append(task_accuracies[task_idx])

    # Calculate forgetting metrics
    forgetting_rate = calculate_forgetting_metrics(training_history, initial_accuracies)

    return training_history, forgetting_rate, initial_accuracies

In [12]:
# Example usage (you can call this function in your main script)
if __name__ == "__main__":
    training_history, forgetting_rate, initial_accuracies = demonstrate_synaptic_intelligence()

Model Configuration:
- SGD with Synaptic Intelligence (lambda=0.2)
- Learning Rate: 0.01
- Hidden Size: 256
- Epochs per Task: 5

Training on Task 1
Task 1, Epoch 1/5, Loss: 1.2055, Accuracy: 75.65%
Task 1, Epoch 2/5, Loss: 0.4859, Accuracy: 87.79%
Task 1, Epoch 3/5, Loss: 0.3880, Accuracy: 89.40%
Task 1, Epoch 4/5, Loss: 0.3469, Accuracy: 90.26%
Task 1, Epoch 5/5, Loss: 0.3215, Accuracy: 90.90%

Evaluating on all tasks seen so far:
Task 1 Accuracy: 91.55%

Training on Task 2
Task 2, Epoch 1/5, Loss: 0.6660, Accuracy: 82.58%
Task 2, Epoch 2/5, Loss: 0.3784, Accuracy: 89.66%
Task 2, Epoch 3/5, Loss: 0.3303, Accuracy: 90.78%
Task 2, Epoch 4/5, Loss: 0.3030, Accuracy: 91.44%
Task 2, Epoch 5/5, Loss: 0.2835, Accuracy: 91.94%

Evaluating on all tasks seen so far:
Task 1 Accuracy: 87.99%
Task 2 Accuracy: 92.50%

Training on Task 3
Task 3, Epoch 1/5, Loss: 0.5987, Accuracy: 83.19%
Task 3, Epoch 2/5, Loss: 0.3557, Accuracy: 89.99%
Task 3, Epoch 3/5, Loss: 0.3125, Accuracy: 91.18%
Task 3, Epoch