In [1]:
import torch
import random
import torchvision
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torchvision import transforms
from torch.optim.lr_scheduler import MultiStepLR

from avalanche.models import IcarlNet
from avalanche.training.supervised import GEM, AGEM
from avalanche.logging import InteractiveLogger, WandBLogger
from avalanche.benchmarks.classic import SplitCIFAR10
from avalanche.benchmarks.datasets import CIFAR10
from avalanche.benchmarks.generators import nc_benchmark
from avalanche.benchmarks.utils import AvalancheDataset
from avalanche.training.plugins import EvaluationPlugin
from avalanche.training.plugins.lr_scheduling import LRSchedulerPlugin
from avalanche.evaluation.metrics import ExperienceAccuracy, ExperienceLoss, ExperienceForgetting, ExperienceCPUUsage, ExperienceMaxGPU, ExperienceMaxRAM, ExperienceTime, EpochAccuracy

In [2]:
seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# torch.cuda.manual_seed_all(seed) # if use multi-GPU
cudnn.deterministic = True  # 연산 처리 속도 감소 -> 모델과 코드를 배포해야 하는 연구 후반 단계에 사용
cudnn.benchmark = False

In [4]:
transforms_group = dict(
       train=(
       transforms.Compose(
              [
              transforms.ToTensor(),
              ]
       ),
       None,
       ),
       eval=(
       transforms.Compose(
              [
              transforms.ToTensor(),
              ]
       ),
       None,
       )
)

train_set = CIFAR10('/home/data/cifar10', train=True, download=True)
test_set = CIFAR10('/home/data/cifar10', train=False, download=True)

train_set = AvalancheDataset(train_set, transform_groups=transforms_group, initial_transform_group="train")
test_set = AvalancheDataset(test_set, transform_groups=transforms_group, initial_transform_group="eval")

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /home/data/cifar10/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting /home/data/cifar10/cifar-10-python.tar.gz to /home/data/cifar10
Files already downloaded and verified


In [6]:
interactive_logger = InteractiveLogger()
wandb_logger = WandBLogger(run_name="AGEM-CIFAR10")
eval_plugin = EvaluationPlugin(
    EpochAccuracy(),
    ExperienceAccuracy(),
    ExperienceLoss(),
    ExperienceForgetting(),
    ExperienceCPUUsage(),
    ExperienceMaxGPU(gpu_id=0),
    ExperienceMaxRAM(),
    ExperienceTime(),
    loggers=[interactive_logger, wandb_logger])

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…



In [9]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

num_class = 10
incremental = 5
fixed_class_order = [4, 1, 7, 5, 3, 9, 0, 8, 6, 2]

scenario = nc_benchmark(train_dataset=train_set,
                        test_dataset=test_set,
                        n_experiences=incremental,
                        task_labels=True,
                        seed=seed,
                        shuffle=False,
                        fixed_class_order=fixed_class_order,
                        )

# model = torchvision.models.resnet18(pretrained=False, num_classes=num_class)
model = torchvision.models.resnet18(pretrained=False)
model.to(device)

optimizer = optim.SGD(model.parameters(), lr=1e-1)
criterion = torch.nn.CrossEntropyLoss()

## A-GEM

In [None]:
train_batch = 256
eval_batch = 128
epoch = 70

strategies = AGEM(model, optimizer, criterion, patterns_per_exp=256, sample_size=256, train_epochs=epoch, device=device, train_mb_size=10, evaluator=eval_plugin)  # criterion = ICaRLLossPlugin()

In [None]:

for i, exp in enumerate(scenario.train_stream):
    eval_exps = [e for e in scenario.test_stream][: i + 1]
    strategies.train(exp)
    strategies.eval(eval_exps)

-- >> Start of training phase << --
1343it [00:37, 35.70it/s]                          
Epoch 0 ended.
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.8148
100%|██████████| 1000/1000 [00:18<00:00, 54.82it/s]
Epoch 1 ended.
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.8775
 70%|███████   | 703/1000 [00:13<00:05, 52.39it/s]