In [6]:

import torch
from torch.optim import SGD
from torch.nn import CrossEntropyLoss
import numpy as np
from avalanche.training import Naive
from avalanche.evaluation.metrics import accuracy_metrics
from avalanche.training.plugins import EvaluationPlugin
from avalanche.logging import InteractiveLogger

from helper import (
    MultiHeadMLP, 
    compute_empirical_fisher_information,
    save_important_weights,
    evaluate_task_with_specific_weights,
    create_split_mnist_benchmark
)


def main():
    # Setup
    print("="*60)
    print("TASK INCREMENTAL CONTINUAL LEARNING WITH FISHER INFORMATION")
    print("="*60)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Create benchmark
    benchmark = create_split_mnist_benchmark()
    
    # Create multi-head model  
    model = MultiHeadMLP(
        input_size=784,
        hidden_size=512, 
        #num_classes_per_task=2,
        num_tasks=5
    ).to(device)
    
    print(f"\nModel Architecture:")
    print(model)
    print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Storage for important weights paths
    important_weights_paths = {}
    
    # Setup basic logger for training
    interactive_logger = InteractiveLogger()
    
    print(f"\n{'='*60}")
    print("PHASE 1: SEQUENTIAL TRAINING ON ALL TASKS")  
    print(f"{'='*60}")
    
    # Training phase - train on each task sequentially
    for task_id, experience in enumerate(benchmark.train_stream):
        print(f"\n{'─'*40}")
        print(f"Training on Task {task_id}")
        print(f"Classes: {experience.classes_in_this_experience}")
        print(f"{'─'*40}")
        
        # Setup evaluation plugin for this task
        eval_plugin = EvaluationPlugin(
            accuracy_metrics(epoch=True, experience=True),
            loggers=[interactive_logger]
        )
        
        # Create strategy for this specific task
        cl_strategy = Naive(
            model,
            SGD(model.parameters(), lr=0.001, momentum=0.9),
            CrossEntropyLoss(),
            train_mb_size=500,
            train_epochs=10,  # Increased epochs for better learning
            eval_mb_size=100,
            device=device,
            evaluator=eval_plugin
        )
        
        # Override forward method to use specific task head
        def task_specific_forward(self, x):
            return self.model(x, task_id=task_id)
        
        # Monkey patch the forward method (hack for avalanche)
        original_forward = cl_strategy.model.forward
        cl_strategy.model.forward = lambda x: original_forward(x, task_id=task_id)
        
        # Train on current task
        print(f"Training task {task_id}...")
        cl_strategy.train(experience)
        
        # Restore original forward method
        cl_strategy.model.forward = original_forward
        
        print(f"Task {task_id} training completed!")
        
        # Compute Fisher Information for this task
        print(f"Computing Fisher Information for task {task_id}...")
        fisher_dict = compute_empirical_fisher_information(
            model=model,
            dataset=experience.dataset, 
            device=device,
            num_samples=200,
            task_id=task_id
        )
        
        # Save important weights for this task
        save_path = f"important_weights_task_{task_id}.pt"
        save_important_weights(
            model=model,
            fisher_dict=fisher_dict,
            top_n_percent=0,
            save_path=save_path
        )
        important_weights_paths[task_id] = save_path
        
        print(f"Task {task_id} completed and weights saved!")
    
    print(f"\n{'='*60}")
    print("PHASE 2: TASK-SPECIFIC INFERENCE WITH IMPORTANT WEIGHTS")
    print(f"{'='*60}")
    
    # Evaluation phase - test each task with its specific important weights
    final_accuracies = {}
    
    print(f"\nEvaluating all tasks with their respective important weights...")
    
    for task_id in range(5):
        print(f"\n{'─'*30}")
        print(f"Evaluating Task {task_id}")
        print(f"{'─'*30}")
        
        # Get test experience for this task
        test_experience = benchmark.test_stream[task_id]
        
        # Evaluate with task-specific important weights
        accuracy = evaluate_task_with_specific_weights(
            model=model,
            test_experience=test_experience,
            task_id=task_id,
            important_weights_path=important_weights_paths[task_id],
            device=device
        )
        
        final_accuracies[f"Task_{task_id}"] = accuracy
        print(f"Task {task_id} Accuracy: {accuracy:.2f}%")
    
    print(f"\n{'='*60}")
    print("FINAL RESULTS")
    print(f"{'='*60}")
    
    # Print final results summary
    print(f"\nFinal Test Accuracies:")
    print("-" * 30)
    
    total_accuracy = 0
    for task_id in range(5):
        acc = final_accuracies[f"Task_{task_id}"]
        print(f"Task {task_id}: {acc:.2f}%")
        total_accuracy += acc
    
    average_accuracy = total_accuracy / 5
    print("-" * 30)
    print(f"Average Accuracy: {average_accuracy:.2f}%")
    
    # Save final results
    results_dict = {
        'final_accuracies': final_accuracies,
        'average_accuracy': average_accuracy,
        'method': 'Task-Specific Important Weights'
    }
    
    torch.save(results_dict, "final_results.pt")
    print(f"\nResults saved to 'final_results.pt'")
    
    # Save final model
    torch.save(model.state_dict(), "final_multihead_model.pth") 
    print(f"Final model saved to 'final_multihead_model.pth'")
    
    print(f"\n{'='*60}")
    print("EXPERIMENT COMPLETED!")
    print(f"{'='*60}")
    
    return results_dict


if __name__ == "__main__":
    results = main()
    print(f"\nExperiment finished with average accuracy: {results['average_accuracy']:.2f}%")


TASK INCREMENTAL CONTINUAL LEARNING WITH FISHER INFORMATION
Using device: cuda

Model Architecture:
MultiHeadMLP(
  (backbone): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=512, bias=True)
    (5): ReLU()
  )
  (heads): ModuleList(
    (0-4): 5 x Linear(in_features=512, out_features=10, bias=True)
  )
)
Total parameters: 952,882

PHASE 1: SEQUENTIAL TRAINING ON ALL TASKS

────────────────────────────────────────
Training on Task 0
Classes: [0, 1]
────────────────────────────────────────
Training task 0...
-- >> Start of training phase << --
100%|██████████| 26/26 [00:01<00:00, 22.82it/s]
Epoch 0 ended.
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.7000
100%|██████████| 26/26 [00:00<00:00, 27.09it/s]
Epoch 1 ended.
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9919
100%|██████████| 26/26 [00:00<00:00, 26.52

In [9]:

import torch
from torch.optim import SGD
from torch.nn import CrossEntropyLoss
import numpy as np
from avalanche.training import Naive
from avalanche.evaluation.metrics import accuracy_metrics
from avalanche.training.plugins import EvaluationPlugin
from avalanche.logging import InteractiveLogger

from helper_simple import (
    SimpleMLP, 
    compute_empirical_fisher_information,
    save_important_weights,
    evaluate_task_with_specific_weights,
    create_split_mnist_benchmark
)


def main():
    # Setup
    print("="*60)
    print("TASK INCREMENTAL CONTINUAL LEARNING WITH FISHER INFORMATION")
    print("Using Simple MLP (No Multi-Head)")
    print("="*60)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Create benchmark
    benchmark = create_split_mnist_benchmark()
    
    # Create simple MLP model  
    model = SimpleMLP(
        input_size=784,
        hidden_size=512, 
        num_classes=10
    ).to(device)
    
    print(f"\nModel Architecture:")
    print(model)
    print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Storage for important weights paths
    important_weights_paths = {}
    
    # Setup basic logger for training
    interactive_logger = InteractiveLogger()
    
    print(f"\n{'='*60}")
    print("PHASE 1: SEQUENTIAL TRAINING ON ALL TASKS")  
    print(f"{'='*60}")
    
    # Training phase - train on each task sequentially
    for task_id, experience in enumerate(benchmark.train_stream):
        print(f"\n{'─'*40}")
        print(f"Training on Task {task_id}")
        print(f"Classes: {experience.classes_in_this_experience}")
        print(f"Dataset size: {len(experience.dataset)}")
        print(f"{'─'*40}")
        
        # Setup evaluation plugin for this task
        eval_plugin = EvaluationPlugin(
            accuracy_metrics(epoch=True, experience=True),
            loggers=[interactive_logger]
        )
        
        # Create strategy for this specific task
        cl_strategy = Naive(
            model,
            SGD(model.parameters(), lr=0.001, momentum=0.9),
            CrossEntropyLoss(),
            train_mb_size=500,
            train_epochs=3,  # Reduced epochs to prevent overfitting
            eval_mb_size=100,
            device=device,
            evaluator=eval_plugin
        )
        
        # Train on current task (no need to modify forward method for simple MLP)
        print(f"Training task {task_id}...")
        cl_strategy.train(experience)
        
        print(f"Task {task_id} training completed!")
        
        # Compute Fisher Information for this task
        print(f"Computing Fisher Information for task {task_id}...")
        fisher_dict = compute_empirical_fisher_information(
            model=model,
            dataset=experience.dataset, 
            device=device,
            num_samples=200
        )
        
        # Save important weights for this task
        save_path = f"important_weights_task_{task_id}.pt"
        save_important_weights(
            model=model,
            fisher_dict=fisher_dict,
            top_n_percent=10.0,  # Save top 10% important weights
            save_path=save_path
        )
        important_weights_paths[task_id] = save_path
        
        print(f"Task {task_id} completed and weights saved!")
    
    print(f"\n{'='*60}")
    print("PHASE 2: TASK-SPECIFIC INFERENCE WITH IMPORTANT WEIGHTS")
    print(f"{'='*60}")
    
    # Evaluation phase - test each task with its specific important weights
    final_accuracies = {}
    
    print(f"\nEvaluating all tasks with their respective important weights...")
    
    for task_id in range(5):
        print(f"\n{'─'*30}")
        print(f"Evaluating Task {task_id}")
        print(f"{'─'*30}")
        
        # Get test experience for this task
        test_experience = benchmark.test_stream[task_id]
        
        # Evaluate with task-specific important weights
        accuracy = evaluate_task_with_specific_weights(
            model=model,
            test_experience=test_experience,
            important_weights_path=important_weights_paths[task_id],
            device=device
        )
        
        final_accuracies[f"Task_{task_id}"] = accuracy
        print(f"Task {task_id} Accuracy: {accuracy:.2f}%")
    
    print(f"\n{'='*60}")
    print("FINAL RESULTS")
    print(f"{'='*60}")
    
    # Print final results summary
    print(f"\nFinal Test Accuracies:")
    print("-" * 30)
    
    total_accuracy = 0
    for task_id in range(5):
        acc = final_accuracies[f"Task_{task_id}"]
        print(f"Task {task_id}: {acc:.2f}%")
        total_accuracy += acc
    
    average_accuracy = total_accuracy / 5
    print("-" * 30)
    print(f"Average Accuracy: {average_accuracy:.2f}%")
    
    # Additional analysis: Evaluate without weight restoration (baseline)
    print(f"\n{'='*40}")
    print("BASELINE: EVALUATION WITHOUT WEIGHT RESTORATION")
    print(f"{'='*40}")
    
    baseline_accuracies = {}
    for task_id in range(5):
        test_experience = benchmark.test_stream[task_id]
        
        model.eval()
        correct = 0
        total = 0
        
        from torch.utils.data import DataLoader
        test_loader = DataLoader(test_experience.dataset, batch_size=100, shuffle=False, num_workers=0)
        
        with torch.no_grad():
            for batch in test_loader:
                if isinstance(batch, (list, tuple)):
                    inputs, targets = batch[0], batch[1]
                else:
                    continue
                    
                inputs = inputs.to(device)
                targets = targets.to(device)
                
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
        
        baseline_accuracy = 100.0 * correct / total
        baseline_accuracies[f"Task_{task_id}"] = baseline_accuracy
        print(f"Task {task_id} Baseline Accuracy: {baseline_accuracy:.2f}%")
    
    baseline_avg = sum(baseline_accuracies.values()) / len(baseline_accuracies)
    print(f"Baseline Average Accuracy: {baseline_avg:.2f}%")
    
    # Compare methods
    print(f"\n{'='*40}")
    print("COMPARISON")
    print(f"{'='*40}")
    print(f"With Fisher Weight Restoration: {average_accuracy:.2f}%")
    print(f"Without Weight Restoration:     {baseline_avg:.2f}%")
    print(f"Improvement: {average_accuracy - baseline_avg:.2f}%")
    
    # Save final results
    results_dict = {
        'final_accuracies': final_accuracies,
        'baseline_accuracies': baseline_accuracies,
        'average_accuracy': average_accuracy,
        'baseline_average': baseline_avg,
        'improvement': average_accuracy - baseline_avg,
        'method': 'Task-Specific Important Weights (Simple MLP)'
    }
    
    torch.save(results_dict, "final_results_simple_mlp.pt")
    print(f"\nResults saved to 'final_results_simple_mlp.pt'")
    
    # Save final model
    torch.save(model.state_dict(), "final_simple_mlp_model.pth") 
    print(f"Final model saved to 'final_simple_mlp_model.pth'")
    
    print(f"\n{'='*60}")
    print("EXPERIMENT COMPLETED!")
    print(f"{'='*60}")
    
    return results_dict


if __name__ == "__main__":
    results = main()
    print(f"\nExperiment finished:")
    print(f"  Fisher Method: {results['average_accuracy']:.2f}%")
    print(f"  Baseline:      {results['baseline_average']:.2f}%")
    print(f"  Improvement:   {results['improvement']:.2f}%")


TASK INCREMENTAL CONTINUAL LEARNING WITH FISHER INFORMATION
Using Simple MLP (No Multi-Head)
Using device: cuda

Model Architecture:
SimpleMLP(
  (model): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=512, bias=True)
    (5): ReLU()
    (6): Linear(in_features=512, out_features=10, bias=True)
  )
)
Total parameters: 932,362

PHASE 1: SEQUENTIAL TRAINING ON ALL TASKS

────────────────────────────────────────
Training on Task 0
Classes: [0, 1]
Dataset size: 12665
────────────────────────────────────────
Training task 0...
-- >> Start of training phase << --
100%|██████████| 26/26 [00:01<00:00, 23.83it/s]
Epoch 0 ended.
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.5823
100%|██████████| 26/26 [00:00<00:00, 27.44it/s]
Epoch 1 ended.
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9911
100%|██████████| 26/26 [00:0

In [4]:

from avalanche.benchmarks.classic import SplitMNIST
from avalanche.evaluation.metrics import forgetting_metrics, accuracy_metrics,\
    loss_metrics, timing_metrics, cpu_usage_metrics, StreamConfusionMatrix,\
    disk_usage_metrics, gpu_usage_metrics
from avalanche.models import SimpleMLP
from avalanche.logging import InteractiveLogger, TextLogger, TensorboardLogger
from avalanche.training.plugins import EvaluationPlugin
from avalanche.training import Naive
from wandb_logger import WandBLogger

from torch.optim import SGD
from torch.nn import CrossEntropyLoss
import torch 

from fisher_info import compute_empirical_fisher_information,plot_fisher_information_detailed,plot_fisher_information_heatmap


benchmark = SplitMNIST(n_experiences=5,return_task_id=False,seed=42,shuffle=False)   #creates the benchmark
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

model = SimpleMLP(num_classes=benchmark.n_classes).to(device)   #creates the model
print(model)

tb_logger = WandBLogger(project_name="avalanche_tut_!", run_name="run_1") #WandB logger
text_logger = TextLogger(open('log.txt', 'a'))
interactive_logger = InteractiveLogger()


eval_plugin = EvaluationPlugin(
    accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True),
    #loss_metrics(minibatch=True, epoch=True, experience=True, stream=True),
    #timing_metrics(epoch=True),
    #cpu_usage_metrics(experience=True),
    #forgetting_metrics(experience=True, stream=True),
    #StreamConfusionMatrix(num_classes=benchmark.n_classes, save_image=False),
    #disk_usage_metrics(minibatch=True, epoch=True, experience=True, stream=True),
    loggers=[interactive_logger, text_logger, tb_logger]
)

cl_strategy = Naive(
    model, SGD(model.parameters(), lr=0.001, momentum=0.9),
    CrossEntropyLoss(), train_mb_size=500, train_epochs=1, eval_mb_size=100,device=device,
    evaluator=eval_plugin)


print('Starting experiment...')
results = []
fisher_dicts=[]
weights_per_experience = []
for experience in benchmark.train_stream:
    print("Start of experience: ", experience.current_experience)
    print("Current Classes: ", experience.classes_in_this_experience)
    res = cl_strategy.train(experience, num_workers=4)
    print('Training completed')


    
    print('Computing Fisher Information')
    fisher_dict = compute_empirical_fisher_information(
        model=model,
        dataset=experience.dataset,
        device=device,
        num_samples=200  # Limit to 200 samples for testing
    ) 
    fisher_dicts.append(fisher_dict)
    #plot_fisher_information_detailed(fisher_dict,model=model,save_path=f'fisher_detailed_exp{experience.current_experience}.png')   
    #print(fisher_dict)
    weights_per_experience.append(save_top_n_percent_weights(model, fisher_dict, top_n_percent=10.0, save_path=f"important_weights_exp{experience.current_experience}.pt"))
    
    print('Computing accuracy on the whole test set')
    results.append(cl_strategy.eval(benchmark.test_stream, num_workers=4))
    for name, param in model.named_parameters():
        if name in fisher_dict:
            print(f"Param: {name}, Fisher Info Mean: {fisher_dict[name].mean().item():.6f}")
        else :
            print(f"Param: {name} not found in Fisher Info dictionary.")

torch.save(model.state_dict(), "final_splitmnist_model.pth")
print("Model saved as final_splitmnist_model.pth")


Using device: cuda
SimpleMLP(
  (features): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
  )
  (classifier): Linear(in_features=512, out_features=10, bias=True)
)


0,1
Top1_Acc_Epoch/train_phase/train_stream/Task000,█▃▁▂▁
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000,█▃▂▂▁
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001,█▃▁▁
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp002,█▃▁
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp003,█▁
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp004,▁
Top1_Acc_MB/train_phase/train_stream/Task000,▄▇████████▁▁▁▃▇▇▇▁▁▁▁▆▇▁▁▁▃▅▆▇▁▁▁▁▁▅▇▇▇▇
Top1_Acc_Stream/eval_phase/test_stream/Task000,█▆▃▁▁
TrainingExperience,▁▃▅▆█

0,1
Top1_Acc_Epoch/train_phase/train_stream/Task000,0.32466
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000,0.52151
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001,0.06464
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp002,0.00854
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp003,0.6279
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp004,0.941
Top1_Acc_MB/train_phase/train_stream/Task000,0.91333
Top1_Acc_Stream/eval_phase/test_stream/Task000,0.4364
TrainingExperience,4.0




Starting experiment...
Start of experience:  0
Current Classes:  [0, 1]
-- >> Start of training phase << --
100%|██████████| 26/26 [00:00<00:00, 47.65it/s]
Epoch 0 ended.
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.8057
	Top1_Acc_MB/train_phase/train_stream/Task000 = 0.9939
-- >> End of training phase << --
Training completed
Computing Fisher Information
Computing Empirical Fisher Information for 200 samples...
Processed 100/200 samples...
Processed 200/200 samples...
Empirical Fisher Information computed successfully for 200 samples!
--- Identifying and saving top 10.0% of weights ---
Threshold for top 10.0% weights: 0.000017
Found 40,705 important weights.
Saving masks and values to 'important_weights_exp0.pt'
Computing accuracy on the whole test set
-- >> Start of eval phase << --
-- Starting eval on experience 0 (Task 0) from test stream --
100%|██████████| 22/22 [00:00<00:00, 62.17it/s]
> Eval on experience 0 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_

In [2]:
import torch
import torch.nn as nn
import copy

# ----------------------------------
# FUNCTION 1: SAVE IMPORTANT WEIGHTS
# ----------------------------------
def save_top_n_percent_weights(model, fisher_dict, top_n_percent, save_path="important_weights.pt"):
    """
    Identifies and saves the top N% of weights based on Fisher scores.

    Args:
        model (nn.Module): The PyTorch model.
        fisher_dict (dict): A dictionary with layer names as keys and Fisher score tensors as values.
        top_n_percent (float): The percentage of top weights to save (e.g., 20.0 for top 20%).
        save_path (str): Path to save the resulting data.
    """
    print(f"--- Identifying and saving top {top_n_percent}% of weights ---")
    
    # 1. Flatten all Fisher scores into a single tensor to find the global threshold
    all_scores = torch.cat([f.view(-1) for f in fisher_dict.values()])
    #print(all_scores)
    # 2. Calculate the threshold value. For the top 20%, we need the 80th percentile.
    threshold_quantile = 1.0 - (top_n_percent / 100.0)
    threshold = torch.quantile(all_scores, threshold_quantile)
    print(f"Threshold for top {top_n_percent}% weights: {threshold:.6f}")
    important_weights_data = {}
    total_important_weights = 0

    # 3. Create masks and store the original values of the important weights
    with torch.no_grad():
        for name, param in model.named_parameters():
            #print(name)
            if name in fisher_dict:
                #print(name)
                # Create a binary mask: 1 where score > threshold, 0 otherwise
                mask = fisher_dict[name] >= threshold
                
                # Use the mask to get the actual weight values from the model
                important_values = param.data[mask]
                
                # Store both the mask and the values
                important_weights_data[name] = {
                    'mask': mask.cpu(),  # Move to CPU for saving
                    'values': important_values.cpu()
                }
                total_important_weights += important_values.numel()

    print(f"Found {total_important_weights:,} important weights.")
    print(f"Saving masks and values to '{save_path}'")
    torch.save(important_weights_data, save_path)
    return important_weights_data


# ------------------------------------
# FUNCTION 2: RESTORE IMPORTANT WEIGHTS
# ------------------------------------
def restore_important_weights(model, load_path="important_weights.pt"):
    """
    Restores the saved important weights into a new model instance.

    Args:
        model (nn.Module): The model to restore weights into (e.g., after retraining).
        load_path (str): Path to the saved important weights data.
    
    Returns:
        nn.Module: The model with important weights restored.
    """
    print(f"\n--- Restoring important weights from '{load_path}' ---")
    
    # Load the saved masks and values
    important_weights_data = torch.load(load_path)
    device = next(model.parameters()).device
    
    restored_count = 0
    
    with torch.no_grad():
        for name, param in model.named_parameters():
            if name in important_weights_data:
                data = important_weights_data[name]
                mask = data['mask'].to(device)
                saved_values = data['values'].to(device)
                
                # Use the mask to overwrite values in the current model
                param.data[mask] = saved_values
                restored_count += saved_values.numel()
                
    print(f"Restored {restored_count:,} weights into the model.")
    return model


# ----------------------------------
# EXAMPLE USAGE
# ----------------------------------
if __name__ == '__main__sjhfb':
    
    # 1. Create a dummy model and dummy Fisher scores
    model = nn.Sequential(
        nn.Linear(10, 50),
        nn.ReLU(),
        nn.Linear(50, 2)
    )
    # The Fisher dict has the same keys and tensor shapes as the model's state_dict
    fisher_dict = {name: torch.rand_like(p) for name, p in model.named_parameters() if p.requires_grad}

    # 2. Save the top 20% of weights from the original model
    save_top_n_percent_weights(model, fisher_dict, top_n_percent=20.0)

    # 3. Simulate retraining: create a new model instance with different weights
    retrained_model = copy.deepcopy(model)
    # Let's change all weights to prove the restoration works
    with torch.no_grad():
        for param in retrained_model.parameters():
            param.data.fill_(999.0) # Fill with a placeholder value

    # 4. Verify that a specific important weight has been changed
    original_weight_value = model[0].weight[0, 0].item()
    print(f"\nAn original important weight value (example): {original_weight_value:.4f}")
    print(f"Same weight in 'retrained' model before restoration: {retrained_model[0].weight[0, 0].item():.4f}")

    # 5. Restore the saved important weights into the "retrained" model
    restored_model = restore_important_weights(retrained_model)

    # 6. Final verification
    print(f"Same weight in model AFTER restoration: {restored_model[0].weight[0, 0].item():.4f}")

    # Check that only the important weights were changed and others remain 999.0
    important_data = torch.load("important_weights.pt")
    mask_for_first_layer = important_data['0.weight']['mask']
    
    # Find a location that was NOT important (where mask is False)
    unimportant_indices = (mask_for_first_layer == False).nonzero()[0]
    unimportant_val = restored_model[0].weight[unimportant_indices[0], unimportant_indices[1]].item()
    print(f"Value of an UNIMPORTANT weight after restoration: {unimportant_val}")

In [34]:
save_top_n_percent_weights(model, fisher_dicts[0], top_n_percent=10.0, save_path="important_weights_exp0.pt")

--- Identifying and saving top 10.0% of weights ---
Threshold for top 10.0% weights: 0.000029
Found 40,705 important weights.
Saving masks and values to 'important_weights_exp0.pt'


{'features.0.weight': {'mask': tensor([[False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          ...,
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False]]),
  'values': tensor([ 0.0137,  0.0003,  0.0323,  ..., -0.0181,  0.0329,  0.0316])},
 'features.0.bias': {'mask': tensor([False, False, False, False, False,  True, False, False, False, False,
          False,  True, False, False, False, False, False, False, False, False,
          False,  True, False, False, False, False, False,  True, False,  True,
          False, False, False, False, False, False, False,  True, False, False,
           True,  True, False, False, False,  True, False, False, False, False,
           True, False,  True, False, False, False, False, False, False, False,
    

In [6]:
from avalanche.benchmarks.classic import SplitMNIST
from avalanche.evaluation.metrics import forgetting_metrics, accuracy_metrics,\
    loss_metrics, timing_metrics, cpu_usage_metrics, StreamConfusionMatrix,\
    disk_usage_metrics, gpu_usage_metrics
from avalanche.models import SimpleMLP
from avalanche.logging import InteractiveLogger, TextLogger, TensorboardLogger
from avalanche.training.plugins import EvaluationPlugin
from avalanche.training import Naive
from wandb_logger import WandBLogger

from torch.optim import SGD
from torch.nn import CrossEntropyLoss
import torch 
import copy

from fisher_info import compute_empirical_fisher_information, plot_fisher_information_detailed, plot_fisher_information_heatmap



# Create benchmark
benchmark = SplitMNIST(n_experiences=5, return_task_id=False, seed=42, shuffle=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Create model
model = SimpleMLP(num_classes=benchmark.n_classes).to(device)
print(model)

# Setup loggers
tb_logger = WandBLogger(project_name="avalanche_tut_!", run_name="run_1")
text_logger = TextLogger(open('log.txt', 'a'))
interactive_logger = InteractiveLogger()

# Setup evaluation plugin
eval_plugin = EvaluationPlugin(
    accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True),
    loggers=[interactive_logger, text_logger, tb_logger]
)

# Setup continual learning strategy
cl_strategy = Naive(
    model, SGD(model.parameters(), lr=0.001, momentum=0.9),
    CrossEntropyLoss(), train_mb_size=500, train_epochs=10, eval_mb_size=100, device=device,
    evaluator=eval_plugin
)

# Storage for important weights data
important_weights_paths = {}  # Maps experience_id -> path to saved weights
fisher_dicts = []

print('Starting experiment...')
results = []

# Training loop
for experience in benchmark.train_stream:
    exp_id = experience.current_experience
    print(f"\n{'='*60}")
    print(f"Start of experience: {exp_id}")
    print(f"Current Classes: {experience.classes_in_this_experience}")
    print(f"{'='*60}\n")
    
    # Train on current experience
    res = cl_strategy.train(experience, num_workers=4)
    print('Training completed')
    
    # Compute Fisher Information after training
    print(f'\nComputing Fisher Information for experience {exp_id}...')
    fisher_dict = compute_empirical_fisher_information(
        model=model,
        dataset=experience.dataset,
        device=device,
        num_samples=200
    )
    fisher_dicts.append(fisher_dict)
    
    # Plot Fisher Information
    '''plot_fisher_information_detailed(
        fisher_dict, 
        model=model, 
        save_path=f'fisher_detailed_exp{exp_id}.png'
    )'''
    
    # Print Fisher statistics
    print(f"\nFisher Information Statistics for Experience {exp_id}:")
    for name, param in model.named_parameters():
        if name in fisher_dict:
            print(f"  {name}: Mean={fisher_dict[name].mean().item():.6f}, "
                  f"Max={fisher_dict[name].max().item():.6f}")
    
    # Save top 10% important weights for this experience
    save_path = f"important_weights_exp{exp_id}.pt"
    print(f"\nSaving top 10% important weights for experience {exp_id}...")
    save_top_n_percent_weights(
        model=model,
        fisher_dict=fisher_dict,
        top_n_percent=10.0,
        save_path=save_path
    )
    important_weights_paths[exp_id] = save_path
    
    # Evaluate on all test experiences seen so far
    print(f'\n{"="*60}')
    print(f"Evaluating on all test experiences (0 to {exp_id})...")
    print(f"{"="*60}\n")
    
    for test_exp_id in range(exp_id + 1):
        print(f"\n--- Evaluating on Test Experience {test_exp_id} ---")
        
        # Restore important weights for this test experience
        if test_exp_id in important_weights_paths:
            print(f"Restoring important weights from experience {test_exp_id}...")
            restore_important_weights(
                model=model,
                load_path=important_weights_paths[test_exp_id]
            )
        
        # Evaluate on the specific test experience
        test_stream_subset = [benchmark.test_stream[test_exp_id]]
        result = cl_strategy.eval(test_stream_subset, num_workers=4)
        print(f"Results on test experience {test_exp_id}: {result}")
        
        print(f"Completed evaluation on test experience {test_exp_id}")
    
    # Store overall results
    print(f'\nComputing accuracy on the whole test set (experiences 0-{exp_id})...')
    results.append(cl_strategy.eval(benchmark.test_stream[:exp_id+1], num_workers=4))

# Save final model
torch.save(model.state_dict(), "final_splitmnist_model.pth")
print("\n" + "="*60)
print("Training and evaluation completed!")
print("Model saved as final_splitmnist_model.pth")
print("="*60)

Using device: cuda
SimpleMLP(
  (features): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
  )
  (classifier): Linear(in_features=512, out_features=10, bias=True)
)


0,1
Top1_Acc_Epoch/train_phase/train_stream/Task000,█▃▁▂▁
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000,█▄▄▁▁
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001,█▄▁▁
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp002,█▂▁
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp003,█▁
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp004,▁
Top1_Acc_MB/train_phase/train_stream/Task000,▂▅▆▇██████▁▁▃▅▆▇▇▁▁▁▁▁▁▁▁▇▇▁▁▁▁▇▇█▁▃▅▆▇▇
Top1_Acc_Stream/eval_phase/test_stream/Task000,█▆▄▁▁
TrainingExperience,▁▃▅▆█

0,1
Top1_Acc_Epoch/train_phase/train_stream/Task000,0.34822
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000,0.57825
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001,0.0573
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp002,0.01067
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp003,0.429
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp004,0.94453
Top1_Acc_MB/train_phase/train_stream/Task000,0.92667
Top1_Acc_Stream/eval_phase/test_stream/Task000,0.4085
TrainingExperience,4.0




Starting experiment...

Start of experience: 0
Current Classes: [0, 1]

-- >> Start of training phase << --
100%|██████████| 26/26 [00:00<00:00, 48.76it/s]
Epoch 0 ended.
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.8707
	Top1_Acc_MB/train_phase/train_stream/Task000 = 0.9939
100%|██████████| 26/26 [00:00<00:00, 43.63it/s]
Epoch 1 ended.
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9939
	Top1_Acc_MB/train_phase/train_stream/Task000 = 0.9939
100%|██████████| 26/26 [00:00<00:00, 46.21it/s]
Epoch 2 ended.
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9957
	Top1_Acc_MB/train_phase/train_stream/Task000 = 1.0000
100%|██████████| 26/26 [00:00<00:00, 46.21it/s]
Epoch 3 ended.
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9958
	Top1_Acc_MB/train_phase/train_stream/Task000 = 0.9939
100%|██████████| 26/26 [00:00<00:00, 50.53it/s]
Epoch 4 ended.
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9965
	Top1_Acc_MB/train_phase/train_stream/Task000 = 0.9939
100%|███████

In [37]:
print('All results:\n', results)

All results:
 [{'Top1_Acc_MB/train_phase/train_stream/Task000': 1.0, 'Top1_Acc_Epoch/train_phase/train_stream/Task000': 0.7966048164232136, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000': 0.9966903073286052, 'Top1_Acc_Stream/eval_phase/test_stream/Task000': 0.9966903073286052}, {'Top1_Acc_MB/train_phase/train_stream/Task000': 0.9213483146067416, 'Top1_Acc_Epoch/train_phase/train_stream/Task000': 0.45181570022334355, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000': 0.7546099290780142, 'Top1_Acc_Stream/eval_phase/test_stream/Task000': 0.8429155641087323, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001': 0.9343780607247796}, {'Top1_Acc_MB/train_phase/train_stream/Task000': 0.935361216730038, 'Top1_Acc_Epoch/train_phase/train_stream/Task000': 0.34555624611559976, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000': 0.7380614657210401, 'Top1_Acc_Stream/eval_phase/test_stream/Task000': 0.6221190515669043, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001': 0.210577864838393