In [None]:
%load_ext autoreload
%autoreload 2
import torch
from torch.optim import SGD
from torch.nn import CrossEntropyLoss
from avalanche.models import SimpleMLP, IncrementalClassifier
from avalanche.training.strategies import Naive, CWRStar, Replay, GDumb, Cumulative, LwF, GEM, AGEM, EWC, CoPE
from avalanche.benchmarks.classic import SplitMNIST
from avalanche.training.strategies import BaseStrategy
from avalanche.training.plugins import ReplayPlugin, EWCPlugin, GEMPlugin, GDumbPlugin
from torch.optim import SGD
from torch.nn import CrossEntropyLoss
from avalanche.benchmarks.classic import SplitMNIST, SplitCIFAR10, SplitCIFAR100
from avalanche.evaluation.metrics import forgetting_metrics, accuracy_metrics, \
    loss_metrics, timing_metrics, cpu_usage_metrics, confusion_matrix_metrics, disk_usage_metrics,ExperienceForgetting
from avalanche.models import SimpleMLP
from avalanche.logging import InteractiveLogger, TextLogger, TensorboardLogger
from avalanche.training.plugins import EvaluationPlugin
from avalanche.training.strategies import Naive
from pl_bolts.models.self_supervised import SwAV
from torch.optim.lr_scheduler import MultiStepLR
from torch.optim import SGD
from torchvision import transforms
from avalanche.training.strategies.icarl import ICaRL
import numpy as np
from avalanche.benchmarks.classic.ccifar100 import SplitCIFAR100
from avalanche.models import IcarlNet, make_icarl_net, initialize_icarl_net
from avalanche.training.plugins.lr_scheduling import LRSchedulerPlugin

from NeoCL.models.pretrained import PretrainedIncrementalClassifier, SSLIcarl
from NeoCL.plugins.sparse_ewc import SparseEWCPlugin
from NeoCL.strategies.utils import get_average_metric, create_default_args


# create strategy
def icarl_cifar100_augment_data(img):
    img = img.numpy()
    padded = np.pad(img, ((0, 0), (4, 4), (4, 4)), mode='constant')
    random_cropped = np.zeros(img.shape, dtype=np.float32)
    crop = np.random.randint(0, high=8 + 1, size=(2,))

    # Cropping and possible flipping
    if np.random.randint(2) > 0:
        random_cropped[:, :, :] = \
            padded[:, crop[0]:(crop[0]+32), crop[1]:(crop[1]+32)]
    else:
        random_cropped[:, :, :] = \
            padded[:, crop[0]:(crop[0]+32), crop[1]:(crop[1]+32)][:, :, ::-1]
    t = torch.tensor(random_cropped)
    return t
fixed_class_order = [87, 0, 52, 58, 44, 91, 68, 97, 51, 15,
                            94, 92, 10, 72, 49, 78, 61, 14, 8, 86,
                            84, 96, 18, 24, 32, 45, 88, 11, 4, 67,
                            69, 66, 77, 47, 79, 93, 29, 50, 57, 83,
                            17, 81, 41, 12, 37, 59, 25, 20, 80, 73,
                            1, 28, 6, 46, 62, 82, 53, 9, 31, 75,
                            38, 63, 33, 74, 27, 22, 36, 3, 16, 21,
                            60, 19, 70, 90, 89, 43, 5, 42, 65, 76,
                            40, 30, 23, 85, 2, 95, 56, 48, 71, 64,
                            98, 13, 99, 7, 34, 55, 54, 26, 35, 39]
# config (NOTE: memory_size==k)
args = create_default_args({'cuda': 0, 'batch_size': 128, 'nb_exp': 10,
                            'memory_size': 2000, 'epochs': 70, 'lr_base': 2.,
                            'lr_milestones': [49, 63], 'lr_factor': 5.,
                            'wght_decay': 0.00001, 'train_mb_size': 256,
                            'fixed_class_order': fixed_class_order, 'seed': 2222})
#
device = torch.device(f"cuda:{args.cuda}"
                      if torch.cuda.is_available() and
                         args.cuda >= 0 else "cpu")
weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/swav_imagenet/swav_imagenet.pth.tar'
encoder = SwAV.load_from_checkpoint(weight_path, strict=True)
model = SSLIcarl(encoder,embedding_size=2048,num_classes=100).to(device)
tb_logger = TensorboardLogger('../logs')
eval_plugin = EvaluationPlugin(
    accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True),
    loss_metrics(minibatch=True, epoch=True, experience=True, stream=True),
    ExperienceForgetting(),
    loggers=[tb_logger])
benchmark = SplitCIFAR100(n_experiences=args.nb_exp, seed=args.seed,
              fixed_class_order=args.fixed_class_order, dataset_root='/share/datasets/')


optim = SGD(model.parameters(), lr=args.lr_base,
            weight_decay=args.wght_decay, momentum=0.9)
sched = LRSchedulerPlugin(
    MultiStepLR(optim, args.lr_milestones, gamma=1.0 / args.lr_factor))

strategy = ICaRL(
    model.feature_extractor, model.classifier, optim,
    args.memory_size,
    buffer_transform=transforms.Compose([icarl_cifar100_augment_data]),
    fixed_memory=True, train_mb_size=args.batch_size,
    train_epochs=args.epochs, eval_mb_size=args.batch_size,
    plugins=[sched], device=device, evaluator=eval_plugin
)

# train on the selected scenario with the chosen strategy
print('Starting experiment...')
dict_iCaRL_aia = {}
for i, train_batch_info in enumerate(benchmark.train_stream):
    print("Start training on experience ", train_batch_info.current_experience)

    strategy.train(train_batch_info, num_workers=4)
    print("End training on experience ", train_batch_info.current_experience)
    print('Computing accuracy on the test set')
    res = strategy.eval(benchmark.test_stream[:i + 1], num_workers=4)
    dict_iCaRL_aia['Top1_Acc_Stream/Exp'+str(i)] = res['Top1_Acc_Stream/eval_phase/test_stream/Task000']
    avg_ia = get_average_metric(dict_iCaRL_aia)
    print("dict_iCaRL_aia= ", dict_iCaRL_aia)
    print(f"scifar100-batch=10 Average Incremental Accuracy: {avg_ia:.5f}")
    



Files already downloaded and verified
Files already downloaded and verified
Starting experiment...
Start training on experience  0
End training on experience  0
Computing accuracy on the test set
dict_iCaRL_aia=  {'Top1_Acc_Stream/Exp0': 0.923}
scifar100-batch=10 Average Incremental Accuracy: 0.92300
Start training on experience  1
End training on experience  1
Computing accuracy on the test set
dict_iCaRL_aia=  {'Top1_Acc_Stream/Exp0': 0.923, 'Top1_Acc_Stream/Exp1': 0.85}
scifar100-batch=10 Average Incremental Accuracy: 0.88650
Start training on experience  2
End training on experience  2
Computing accuracy on the test set
dict_iCaRL_aia=  {'Top1_Acc_Stream/Exp0': 0.923, 'Top1_Acc_Stream/Exp1': 0.85, 'Top1_Acc_Stream/Exp2': 0.7843333333333333}
scifar100-batch=10 Average Incremental Accuracy: 0.85244
Start training on experience  3


In [None]:
       # Dict to iCaRL Evaluation Protocol: Average Incremental Accuracy
        dict_iCaRL_aia = {}
        # ___________________________________________train and eval
        for i, exp in enumerate(benchmark.train_stream):
            strategy.train(exp, num_workers=4)
            res = strategy.eval(benchmark.test_stream[:i + 1], num_workers=4)
            dict_iCaRL_aia['Top1_Acc_Stream/Exp'+str(i)] = res['Top1_Acc_Stream/eval_phase/test_stream/Task000']

        avg_ia = get_average_metric(dict_iCaRL_aia)
        target_acc = get_target_result('iCaRL', 'scifar100')
        print("dict_iCaRL_aia= ", dict_iCaRL_aia)
        print(f"scifar100-batch=10 Average Incremental Accuracy: {avg_ia:.5f}")