In [12]:
import torch
import torch.nn as nn
from os.path import expanduser
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from avalanche.benchmarks import nc_benchmark
from avalanche.training.supervised import EWC
from avalanche.models import as_multitask
from avalanche.evaluation.metrics import accuracy_metrics
from avalanche.logging import InteractiveLogger
from avalanche.training.plugins import EvaluationPlugin


class CustomMLP(nn.Module):
    """Custom MLP with 2 hidden layers (256 and 128)"""
    def __init__(self, num_classes=2):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.classifier(x)


def main():
    
    scenario_name = "multitask_smnist"
    
    # EWC settings
    ewc_mode = "separate"  # Options: "separate", "online"
    ewc_lambda = 0.5  # Penalty hyperparameter for EWC
    decay_factor = 0.1  # Decay factor (only used when ewc_mode is "online")
    
    # Training settings
    learning_rate = 1e-3
    epochs = 5
    minibatch_size = 128
   
    
    # Device settings
    cuda_id = 0  # GPU id to use, set to -1 for CPU
    
    device = torch.device(cuda_id if torch.cuda.is_available() and cuda_id >= 0 else "cpu")
    print(f"Using device: {device}")
    
    # Create model with 2 hidden layers (256 and 128)
    model = CustomMLP(num_classes=2)
    model = as_multitask(model, "classifier")
    model.to(device)
    
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
    criterion = torch.nn.CrossEntropyLoss()
    
    # Create benchmark
    mnist_train = MNIST(
        root=expanduser("~") + "/.avalanche/data/mnist/",
        train=True,
        download=True,
        transform=ToTensor(),
    )
    mnist_test = MNIST(
        root=expanduser("~") + "/.avalanche/data/mnist/",
        train=False,
        download=True,
        transform=ToTensor(),
    )
    scenario = nc_benchmark(
        mnist_train,
        mnist_test,
        5,
        task_labels=True,
        seed=1234,
        class_ids_from_zero_in_each_exp=True,
    )
    
    # Setup evaluation plugin
    interactive_logger = InteractiveLogger()
    eval_plugin = EvaluationPlugin(
        accuracy_metrics(
            minibatch=True, epoch=True, experience=True, stream=True
        ),
        loggers=[interactive_logger],
    )
    
    if ewc_mode == 'separate':
        decay_factor = None
    
    strategy = EWC(
        model,
        optimizer,
        criterion,
        ewc_lambda,
        ewc_mode,
        decay_factor=decay_factor,
        train_epochs=epochs,
        device=device,
        train_mb_size=minibatch_size,
        evaluator=eval_plugin,
    )
  
    
    print("Starting experiment...")
    print(f"Scenario: {scenario_name}")
    print(f"EWC Mode: {ewc_mode}")
    print(f"EWC Lambda: {ewc_lambda}")
    print(f"Epochs per task: {epochs}")
    print("=" * 70)
    
    results = []
    for experience in scenario.train_stream:
        print(f"\nStart training on experience {experience.current_experience}")
        
        strategy.train(experience)
        
        print(f"End training on experience {experience.current_experience}")
        print("Computing accuracy on the test set")
        
        results.append(strategy.eval(scenario.test_stream[:]))
    
    print("\n" + "=" * 70)
    print("Experiment completed!")
    print("=" * 70)
    
    
    return results


if __name__ == "__main__":
    results = main()


Using device: cuda:0
Starting experiment...
Scenario: multitask_smnist
EWC Mode: separate
EWC Lambda: 0.5
Epochs per task: 5

Start training on experience 0
-- >> Start of training phase << --
100%|██████████| 88/88 [00:03<00:00, 25.69it/s]
Epoch 0 ended.
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.4813
	Top1_Acc_MB/train_phase/train_stream/Task000 = 0.5276
100%|██████████| 88/88 [00:03<00:00, 27.06it/s]
Epoch 1 ended.
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.4813
	Top1_Acc_MB/train_phase/train_stream/Task000 = 0.4173
100%|██████████| 88/88 [00:03<00:00, 26.58it/s]
Epoch 2 ended.
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.4814
	Top1_Acc_MB/train_phase/train_stream/Task000 = 0.3701
100%|██████████| 88/88 [00:02<00:00, 29.96it/s]
Epoch 3 ended.
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.4899
	Top1_Acc_MB/train_phase/train_stream/Task000 = 0.4331
100%|██████████| 88/88 [00:03<00:00, 26.95it/s]
Epoch 4 ended.
	Top1_Acc_Epoch/train_phase/train_stream