In [None]:
!pip install avalanche-lib medmnist

Collecting avalanche-lib
  Downloading avalanche_lib-0.6.0-py3-none-any.whl.metadata (12 kB)
Collecting medmnist
  Downloading medmnist-3.0.2-py3-none-any.whl.metadata (14 kB)
Collecting gputil (from avalanche-lib)
  Downloading GPUtil-1.4.0.tar.gz (5.5 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pytorchcv (from avalanche-lib)
  Downloading pytorchcv-0.0.74-py3-none-any.whl.metadata (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.2/134.2 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
Collecting torchmetrics (from avalanche-lib)
  Downloading torchmetrics-1.8.2-py3-none-any.whl.metadata (22 kB)
Collecting qpsolvers[open_source_solvers] (from avalanche-lib)
  Downloading qpsolvers-4.8.2-py3-none-any.whl.metadata (12 kB)
Collecting fire (from medmnist)
  Downloading fire-0.7.1-py3-none-any.whl.metadata (5.8 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics->avalanche-lib)
  Downloading lightning_utilities-0.15.2-py3-none-any.wh

In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
from medmnist import PathMNIST, INFO
from avalanche.benchmarks import nc_benchmark
from avalanche.models import MultiHeadClassifier, SimpleCNN
from avalanche.training.supervised import Naive, Cumulative
from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics, forgetting_metrics
from avalanche.logging import InteractiveLogger
from avalanche.training.plugins import EvaluationPlugin


In [None]:
# --- 1. DATA PREPARATION ---
data_flag = 'pathmnist'
info = INFO[data_flag]
n_classes = len(info['label']) # PathMNIST has 9 classes
n_channels = info['n_channels'] # 3 (RGB)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5], std=[.5])
])

# Load MedMNIST datasets
train_dataset = PathMNIST(split='train', transform=transform, download=True)
test_dataset = PathMNIST(split='test', transform=transform, download=True)

import numpy as np

# Load datasets
train_dataset = PathMNIST(split='train', transform=transform, download=True)
test_dataset = PathMNIST(split='test', transform=transform, download=True)

# THE FIX: Flatten and cast to int
# MedMNIST labels are usually shape (N, 1). We need (N,)
train_dataset.labels = train_dataset.labels.flatten().astype(int)
test_dataset.labels = test_dataset.labels.flatten().astype(int)

# Avalanche specific target assignment
train_dataset.targets = train_dataset.labels.tolist()
test_dataset.targets = test_dataset.labels.tolist()


100%|██████████| 206M/206M [04:01<00:00, 852kB/s]


In [None]:
# --- 2. CREATE THE BENCHMARK ---
# We split the 9 classes into 3 tasks (3 classes each).
# 'return_task_id=True' is mandatory for Multi-head models.
benchmark = nc_benchmark(
    train_dataset, test_dataset,
    n_experiences=3,
    task_labels=True,
    seed=1234,
    shuffle=True
)


In [None]:
# --- 3. DEFINE THE DYNAMIC MODEL ---
from avalanche.models import MultiTaskModule

class MedMultiHeadCNN(MultiTaskModule): # Inherit from MultiTaskModule
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(n_channels, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten()
        )

        with torch.no_grad():
            dummy_input = torch.randn(1, n_channels, 28, 28)
            self.feature_dim = self.features(dummy_input).shape[1]

        self.classifier = MultiHeadClassifier(in_features=self.feature_dim)

    def forward(self, x, task_labels):
        x = self.features(x)
        # MultiTaskModule expects the output to be routed correctly
        return self.classifier(x, task_labels)

model = MedMultiHeadCNN()

In [None]:
# --- 4. SET UP THE STRATEGY ---
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Setup Evaluation
eval_plugin = EvaluationPlugin(
    accuracy_metrics(experience=True, stream=True),
    forgetting_metrics(experience=True, stream=True),
    loggers=[InteractiveLogger()]
)

# You can choose between 'Naive' (train only on current task)
# or 'Cumulative' (train on current + all previous data).
strategy = Naive(
    model, optimizer, criterion,
    train_mb_size=128, train_epochs=8, eval_mb_size=128,
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    evaluator=eval_plugin
)




In [None]:
# --- 5. TRAINING LOOP ---
print("Starting Training Loop...")
for experience in benchmark.train_stream:
    print(f"Current Experience: {experience.current_experience}")
    print(f"Classes in this task: {experience.classes_in_this_experience}")

    # Avalanche handles the model.adaptation() and task_labels internally
    strategy.train(experience)
    print("Training completed. Evaluating on all tasks...")
    strategy.eval(benchmark.test_stream)

Starting Training Loop...
Current Experience: 0
Classes in this task: [2, 4, 6]
-- >> Start of training phase << --
100%|██████████| 206/206 [00:09<00:00, 21.69it/s]
Epoch 0 ended.
100%|██████████| 206/206 [00:07<00:00, 26.10it/s]
Epoch 1 ended.
100%|██████████| 206/206 [00:08<00:00, 25.52it/s]
Epoch 2 ended.
-- >> End of training phase << --
Training completed. Evaluating on all tasks...
-- >> Start of eval phase << --
-- Starting eval on experience 0 (Task 0) from test stream --
100%|██████████| 17/17 [00:00<00:00, 20.18it/s]
> Eval on experience 0 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000 = 0.9669
-- Starting eval on experience 1 (Task 1) from test stream --
100%|██████████| 21/21 [00:00<00:00, 26.49it/s]
> Eval on experience 1 (Task 1) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task001/Exp001 = 0.0000
-- Starting eval on experience 2 (Task 2) from test stream --
100%|██████████| 19/19 [00:00<00:00, 29.42it/s]
> Eval on ex

In [None]:
import json

metrics = strategy.evaluator.get_all_metrics()

with open("dynamic_metrics.json", "w") as f:
    json.dump(metrics, f)