In [None]:
!pip install avalanche-lib medmnist torch torchvision --quiet

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

from torchvision import transforms
from torch.utils.data import Dataset

import medmnist
from medmnist import PathMNIST

from avalanche.benchmarks import nc_benchmark
from avalanche.training import EWC
from avalanche.training.plugins import EvaluationPlugin
from avalanche.evaluation.metrics import (
    accuracy_metrics,
    forgetting_metrics
)
from avalanche.logging import InteractiveLogger


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


In [None]:
#Dataset wrapper
class MedMNISTWrapper(Dataset):
    def __init__(self, medmnist_dataset):
        self.dataset = medmnist_dataset
        self.targets = [
            int(label) for label in medmnist_dataset.labels.squeeze()
        ]

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        x, y = self.dataset[index]
        return x, int(y)

In [None]:
#Load PathMNIST
transform = transforms.Compose([
    transforms.ToTensor(),
])

train_raw = PathMNIST(
    split="train",
    transform=transform,
    download=True
)

test_raw = PathMNIST(
    split="test",
    transform=transform,
    download=True
)

train_dataset = MedMNISTWrapper(train_raw)
test_dataset  = MedMNISTWrapper(test_raw)


In [None]:
benchmark = nc_benchmark(
    train_dataset,
    test_dataset,
    n_experiences=3,
    task_labels=False,
    shuffle=True,
    seed=123
)


In [None]:
print("Total classes:", benchmark.n_classes)

for exp in benchmark.train_stream:
    print(
        f"Experience {exp.current_experience}:",
        exp.classes_in_this_experience
    )


Total classes: 9
Experience 0: [2, 4, 6]
Experience 1: [0, 3, 7]
Experience 2: [8, 1, 5]


In [None]:
#CNN model
class CNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.classifier = nn.Linear(64 * 7 * 7, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)

In [None]:
model = CNN(num_classes=benchmark.n_classes).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
#optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
criterion = nn.CrossEntropyLoss()


In [None]:
eval_plugin = EvaluationPlugin(
    accuracy_metrics(epoch=True, experience=True, stream=True),
    forgetting_metrics(experience=True, stream=True),
    loggers=[InteractiveLogger()]
)

In [None]:
cl_strategy = EWC(
    model=model,
    optimizer=optimizer,
    criterion=criterion,
    ewc_lambda=5000,          # strength of regularization
    train_mb_size=64,
    train_epochs=8,
    eval_mb_size=128,
    device=device,
    evaluator=eval_plugin
)


In [None]:
for experience in benchmark.train_stream:
    print("Training on experience", experience.current_experience)
    cl_strategy.train(experience)
    print("Evaluating...")
    cl_strategy.eval(benchmark.test_stream)

Training on experience 0
-- >> Start of training phase << --
0it [00:00, ?it/s]

  return x, int(y)


100%|██████████| 411/411 [00:08<00:00, 50.00it/s]
Epoch 0 ended.
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.8136
100%|██████████| 411/411 [00:06<00:00, 66.54it/s]
Epoch 1 ended.
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.8834
100%|██████████| 411/411 [00:06<00:00, 59.67it/s]
Epoch 2 ended.
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9106
100%|██████████| 411/411 [00:06<00:00, 61.15it/s]
Epoch 3 ended.
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9273
100%|██████████| 411/411 [00:06<00:00, 64.01it/s]
Epoch 4 ended.
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9371
-- >> End of training phase << --
Evaluating...
-- >> Start of eval phase << --
-- Starting eval on experience 0 (Task 0) from test stream --
100%|██████████| 17/17 [00:00<00:00, 30.87it/s]
> Eval on experience 0 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000 = 0.3835
-- Starting eval on experience 1 (Task 0) from test stream --
100%|██████████|

In [None]:
import json

metrics = cl_strategy.evaluator.get_all_metrics()

with open("ewc_metrics.json", "w") as f:
    json.dump(metrics, f)
