In [5]:
from torch.utils.data import DataLoader, Subset
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from avalanche.benchmarks.classic import SplitMNIST
import copy


def network_mnist(size_first_layer, size_second_layer):
    class MLP(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc1 = nn.Linear(784, size_first_layer)
            self.fc2 = nn.Linear(size_first_layer, size_second_layer)
            self.fc3 = nn.Linear(size_second_layer, 10)

        def forward(self, x):
            x = x.view(x.size(0), -1)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x  

    return MLP()


def ewc_train(model, task_number, epochs, criterion, optimizer, fisher_dict_prev, 
              parameter_dict_prev, ewc_lambda, device, train_stream):
    experience = train_stream[task_number]
    train_loader = DataLoader(experience.dataset, batch_size=64, shuffle=True)
    model.train()
    
    for epoch in range(epochs):
        total_loss = 0
        total_ce_loss = 0
        total_ewc_loss = 0
        
        for images, labels, *_ in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            ce_loss = criterion(outputs, labels)
            loss = ce_loss

            # EWC regularization
            ewc_loss = 0
            if len(fisher_dict_prev) > 0:
                for i in range(task_number):
                    fisher_dict = fisher_dict_prev[i]
                    optpar_dict = parameter_dict_prev[i]
                    for name, param in model.named_parameters():
                        if name in fisher_dict:
                            fisher = fisher_dict[name]
                            optpar = optpar_dict[name]
                            ewc_loss += (fisher * (param - optpar).pow(2)).sum()
                
                loss = ce_loss + (ewc_lambda / 2) * ewc_loss

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_ce_loss += ce_loss.item()
            total_ewc_loss += ewc_loss.item() if isinstance(ewc_loss, torch.Tensor) else ewc_loss

        avg_loss = total_loss / len(train_loader)
        avg_ce = total_ce_loss / len(train_loader)
        avg_ewc = total_ewc_loss / len(train_loader)
        print(f"  Epoch {epoch+1}/{epochs} - Total Loss: {avg_loss:.4f}, "
              f"CE Loss: {avg_ce:.4f}, EWC Loss: {avg_ewc:.4f}")


def test_taskwise(model, task_number, device, test_stream):
    experience = test_stream[task_number]
    test_loader = DataLoader(experience.dataset, batch_size=64, shuffle=False)
    model.eval()

    correct, total = 0, 0
    with torch.no_grad():
        for images, labels, *_ in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    acc = 100 * correct / total
    return acc


def test(model, device, test_stream, num_tasks):
    print("\n" + "="*50)
    print("Testing on all tasks:")
    print("="*50)
    sum_acc = 0
    acc_list = []
    for i in range(num_tasks):
        acc = test_taskwise(model, i, device, test_stream)
        print(f"  Task {i}: {acc:.2f}%")
        sum_acc += acc
        acc_list.append(acc)
    avg_acc = sum_acc / num_tasks
    print(f"\n  Average Accuracy: {avg_acc:.2f}%")
    print("="*50)
    return avg_acc, acc_list


def compute_fisher_information(model, task_number, num_samples, device, train_stream):
    """
    Compute diagonal Fisher Information Matrix using empirical Fisher.
    Uses the squared gradients of the log-likelihood.
    """
    model.eval()
    experience = train_stream[task_number]
    train_loader = DataLoader(experience.dataset, batch_size=1, shuffle=True)

    # Initialize Fisher information dict
    fisher_dict = {name: torch.zeros_like(param, device=device) 
                   for name, param in model.named_parameters() if param.requires_grad}

    count = 0
    for images, labels, *_ in train_loader:
        if count >= num_samples:
            break
            
        images, labels = images.to(device), labels.to(device)

        model.zero_grad()
        outputs = model(images)
        
        # Use log probabilities for proper Fisher computation
        log_probs = F.log_softmax(outputs, dim=1)
        
        # Select the log probability of the true class
        loss = -log_probs[0, labels[0]]
        
        loss.backward()

        # Accumulate squared gradients
        for name, param in model.named_parameters():
            if param.requires_grad and param.grad is not None:
                fisher_dict[name] += param.grad.detach().pow(2)

        count += 1

    # Average over samples
    for name in fisher_dict:
        fisher_dict[name] /= count

    return fisher_dict


def main():
    # Initialize model
    model = network_mnist(256, 128)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    epochs = 5
    ewc_lambda = 500000 # Reduced from 10000, tune this based on results
    
    # Prepare benchmark
    benchmark = SplitMNIST(n_experiences=5, seed=1, shuffle=False)
    train_stream = benchmark.train_stream
    test_stream = benchmark.test_stream
    
    # Storage for EWC
    fisher_dict_prev = []
    parameter_dict_prev = []
    
    # Track accuracies
    all_accuracies = []
    
    for task in range(5):
        print(f"\n{'='*70}")
        print(f"Training on Task {task}")
        print(f"{'='*70}")
        
        # Create fresh optimizer for each task
        optimizer = optim.Adam(model.parameters(), lr=0.001)
        
        # Train with EWC
        ewc_train(model, task, epochs, criterion, optimizer, 
                  fisher_dict_prev, parameter_dict_prev, ewc_lambda, device, train_stream)
        
        # Test on current task
        acc = test_taskwise(model, task, device, test_stream)
        print(f"\nPost-training accuracy on Task {task}: {acc:.2f}%")
        
        # Compute Fisher Information
        print(f"Computing Fisher Information for Task {task}...")
        fisher_dict = compute_fisher_information(model, task_number=task, 
                                                num_samples=500, device=device, 
                                                train_stream=train_stream)
        fisher_dict_prev.append(fisher_dict)
        
        # Store optimal parameters (only trainable parameters)
        original_weights = {name: param.clone().detach() 
                          for name, param in model.named_parameters() 
                          if param.requires_grad}
        parameter_dict_prev.append(original_weights)
        
        # Test on all tasks seen so far
        avg_acc, acc_list = test(model, device, test_stream, task + 1)
        all_accuracies.append(acc_list[:task+1])
    
    # Final results
    print("\n" + "="*70)
    print("FINAL RESULTS")
    print("="*70)
    print("\nAccuracy matrix (rows=after task, cols=task performance):")
    for i, acc_list in enumerate(all_accuracies):
        print(f"After Task {i}: {[f'{a:.2f}' for a in acc_list]}")


if __name__ == "__main__":
    main()



Training on Task 0
  Epoch 1/5 - Total Loss: 0.0488, CE Loss: 0.0488, EWC Loss: 0.0000
  Epoch 2/5 - Total Loss: 0.0047, CE Loss: 0.0047, EWC Loss: 0.0000
  Epoch 3/5 - Total Loss: 0.0025, CE Loss: 0.0025, EWC Loss: 0.0000
  Epoch 4/5 - Total Loss: 0.0011, CE Loss: 0.0011, EWC Loss: 0.0000
  Epoch 5/5 - Total Loss: 0.0024, CE Loss: 0.0024, EWC Loss: 0.0000

Post-training accuracy on Task 0: 99.95%
Computing Fisher Information for Task 0...

Testing on all tasks:
  Task 0: 99.95%

  Average Accuracy: 99.95%

Training on Task 1
  Epoch 1/5 - Total Loss: 0.6381, CE Loss: 0.4644, EWC Loss: 0.0000
  Epoch 2/5 - Total Loss: 0.0609, CE Loss: 0.0484, EWC Loss: 0.0000
  Epoch 3/5 - Total Loss: 0.0403, CE Loss: 0.0316, EWC Loss: 0.0000
  Epoch 4/5 - Total Loss: 0.0335, CE Loss: 0.0254, EWC Loss: 0.0000
  Epoch 5/5 - Total Loss: 0.0217, CE Loss: 0.0161, EWC Loss: 0.0000

Post-training accuracy on Task 1: 99.02%
Computing Fisher Information for Task 1...

Testing on all tasks:
  Task 0: 0.00%
  T