In [1]:
# """
# EWC: 
# "Overcoming catastrophic forgetting in neural networks" by Kirkpatrick et. al. (2017).
# https://www.pnas.org/content/114/13/3521
# """
###################################################################################################################
# """"
# "Learning without Forgetting" by Li et. al. (2016).
# http://arxiv.org/abs/1606.09282
# Since experimental setup of the paper is quite outdated and not
# easily reproducible, this experiment is based on
# "Three scenarios for continual learning" by van de Ven et. al. (2018).
# https://arxiv.org/pdf/1904.07734.pdf

# Please, note that the performance of LwF on Permuted MNIST is below the one achieved
# by Naive with the same configuration. This is compatible with the results presented
# by van de Ven et. al. (2018).
# """
###################################################################################################################
# """
# "Continual Learning Through Synaptic Intelligence" by Zenke et. al. (2017).
# http://proceedings.mlr.press/v70/zenke17a.html
# """

In [None]:
# !pip install avalanche-lib

In [2]:
import avalanche as avl
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam, SGD
# from avalanche.evaluation import metrics as metrics

In [3]:
from types import SimpleNamespace
import torch
import numpy as np
import random

def set_seed(seed):
    if seed is None:
        return
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = False


def create_default_args(args_dict, additional_args=None):
    args = SimpleNamespace()
    for k, v in args_dict.items():
        args.__dict__[k] = v
    if additional_args is not None:
        for k, v in additional_args.items():
            args.__dict__[k] = v
    return args

In [12]:
def run(args):
    set_seed(args.seed)
    device = torch.device(f"cuda:0"
                          if torch.cuda.is_available() 
                          else "cpu")
    
    if args.benchmark == 'pmnist':
        benchmark = avl.benchmarks.PermutedMNIST(10)
        num_classes = 10
        input_size = 784
    
    elif args.benchmark == 'smnist':
        benchmark = avl.benchmarks.SplitMNIST(
            n_experiences=5, # continual learning; 5 experiences, 2 digits/classes per experience
            return_task_id=args.multitask, # set to True for multi-task experiments
        )
        num_classes = 10
        input_size = 784
    # add other benchmarks as needed
    
    if args.model == 'mlp':
        model = avl.models.SimpleMLP(
            num_classes=num_classes, 
            input_size=input_size, 
            hidden_size=args.hidden_size, 
            hidden_layers=args.hidden_layers, 
            drop_rate=args.dropout
        )
    # add other models as needed
        
    if args.multitask: 
        model = avl.models.as_multitask(model, 'classifier')
    
    criterion = CrossEntropyLoss()

    interactive_logger = avl.logging.InteractiveLogger()

    evaluation_plugin = avl.training.plugins.EvaluationPlugin(
        avl.evaluation.metrics.accuracy_metrics(epoch=True, experience=True, stream=True),
        avl.evaluation.metrics.forgetting_metrics(experience=True, stream=True),
#         avl.evaluation.metrics.bwt_metrics(experience=True, stream=True),
#         avl.evaluation.metrics.ram_usage_metrics(every=1, minibatch=True, epoch=True, experience=True, stream=True),
        avl.evaluation.metrics.timing_metrics(minibatch=False, epoch=True, epoch_running=True, experience=True, stream=True),
        loggers=[interactive_logger])
    # add other evaluation metrics as needed
    
    if args.strategy == 'ewc':
        cl_strategy = avl.training.EWC(
            model, 
            SGD(model.parameters(), lr=args.learning_rate), 
            criterion,
            ewc_lambda=args.ewc_lambda, 
            mode=args.ewc_mode, 
            decay_factor=args.ewc_decay,
            train_mb_size=args.train_mb_size, 
            train_epochs=args.epochs, 
            eval_mb_size=args.eval_mb_size,
            device=device, 
            evaluator=evaluation_plugin
        )
    
    elif args.strategy == 'lwf':
        cl_strategy = LwFCEPenalty(
            model, 
            Adam(model.parameters(), lr=args.learning_rate), 
            criterion,
            alpha=args.lwf_alpha,
            temperature=args.lwf_temperature,
            train_mb_size=args.train_mb_size, 
            train_epochs=args.epochs, 
            eval_mb_size=args.eval_mb_size,
            device=device, 
            evaluator=evaluation_plugin
        )
        
    elif args.strategy == 'si':
        cl_strategy = avl.training.SynapticIntelligence(
            model, 
            Adam(model.parameters(), lr=args.learning_rate), 
            criterion,
            si_lambda=args.si_lambda, 
            eps=args.si_eps,
            train_mb_size=args.train_mb_size, 
            train_epochs=args.epochs, 
            eval_mb_size=args.eval_mb_size,
            device=device, 
            evaluator=evaluation_plugin
        )
    # add other learning strategies as needed

    print("Starting experiment...")
    results = []
    for experience in benchmark.train_stream:
        print("Start training on experience ", experience.current_experience)
        cl_strategy.train(experience)
        print("Computing accuracy on the test set")
        results += [cl_strategy.eval(benchmark.test_stream[:])]

    return results

In [5]:
# EWC experiments

In [6]:
args_ewc_pmnist = create_default_args({
    'benchmark': 'pmnist',
    'multitask': False,
    
    'model': 'mlp',
    'learning_rate': 0.001, 
    'train_mb_size': 128,
    'eval_mb_size': 128,
    'hidden_size': 500,
    'hidden_layers': 2, 
    'epochs': 5, 
    'dropout': 0,
    'seed': 0,
    
    'strategy': 'ewc',
    'ewc_lambda': 1, 
    'ewc_mode': 'separate', 
#     'ewc_mode': 'online', 
    'ewc_decay': None,
})
results_ewc_pmnist = run(args_ewc_pmnist)

Starting experiment...
Start training on experience  0
-- >> Start of training phase << --


  "No benchmark provided to the evaluation plugin. "


100%|█████████████████████████████████████████| 469/469 [00:38<00:00, 12.23it/s]
Epoch 0 ended.
	RunningTime_Epoch/train_phase/train_stream/Task000 = 0.0002
	Time_Epoch/train_phase/train_stream/Task000 = 38.3421
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.3730
100%|█████████████████████████████████████████| 469/469 [00:36<00:00, 13.02it/s]
Epoch 1 ended.
	RunningTime_Epoch/train_phase/train_stream/Task000 = 0.0002
	Time_Epoch/train_phase/train_stream/Task000 = 36.0096
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.6493
100%|█████████████████████████████████████████| 469/469 [00:36<00:00, 12.97it/s]
Epoch 2 ended.
	RunningTime_Epoch/train_phase/train_stream/Task000 = 0.0002
	Time_Epoch/train_phase/train_stream/Task000 = 36.1483
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.7214
100%|█████████████████████████████████████████| 469/469 [00:35<00:00, 13.06it/s]
Epoch 3 ended.
	RunningTime_Epoch/train_phase/train_stream/Task000 = 0.0002
	Time_Epoch/train_phase/train_str

100%|███████████████████████████████████████████| 79/79 [00:05<00:00, 14.39it/s]
> Eval on experience 4 (Task 4) from test stream ended.
	Time_Exp/eval_phase/test_stream/Task004/Exp004 = 5.4912
	Top1_Acc_Exp/eval_phase/test_stream/Task004/Exp004 = 0.1051
-- Starting eval on experience 5 (Task 5) from test stream --
100%|███████████████████████████████████████████| 79/79 [00:05<00:00, 14.29it/s]
> Eval on experience 5 (Task 5) from test stream ended.
	Time_Exp/eval_phase/test_stream/Task005/Exp005 = 5.5285
	Top1_Acc_Exp/eval_phase/test_stream/Task005/Exp005 = 0.1171
-- Starting eval on experience 6 (Task 6) from test stream --
100%|███████████████████████████████████████████| 79/79 [00:05<00:00, 14.49it/s]
> Eval on experience 6 (Task 6) from test stream ended.
	Time_Exp/eval_phase/test_stream/Task006/Exp006 = 5.4528
	Top1_Acc_Exp/eval_phase/test_stream/Task006/Exp006 = 0.1057
-- Starting eval on experience 7 (Task 7) from test stream --
100%|███████████████████████████████████████████|

100%|█████████████████████████████████████████| 469/469 [00:37<00:00, 12.64it/s]
Epoch 0 ended.
	RunningTime_Epoch/train_phase/train_stream/Task003 = 0.0002
	Time_Epoch/train_phase/train_stream/Task003 = 37.1006
	Top1_Acc_Epoch/train_phase/train_stream/Task003 = 0.6235
100%|█████████████████████████████████████████| 469/469 [00:37<00:00, 12.67it/s]
Epoch 1 ended.
	RunningTime_Epoch/train_phase/train_stream/Task003 = 0.0002
	Time_Epoch/train_phase/train_stream/Task003 = 37.0088
	Top1_Acc_Epoch/train_phase/train_stream/Task003 = 0.8072
100%|█████████████████████████████████████████| 469/469 [00:39<00:00, 11.90it/s]
Epoch 2 ended.
	RunningTime_Epoch/train_phase/train_stream/Task003 = 0.0002
	Time_Epoch/train_phase/train_stream/Task003 = 39.3996
	Top1_Acc_Epoch/train_phase/train_stream/Task003 = 0.8438
100%|█████████████████████████████████████████| 469/469 [00:36<00:00, 12.70it/s]
Epoch 3 ended.
	RunningTime_Epoch/train_phase/train_stream/Task003 = 0.0002
	Time_Epoch/train_phase/train_str

-- Starting eval on experience 3 (Task 3) from test stream --
100%|███████████████████████████████████████████| 79/79 [00:06<00:00, 13.12it/s]
> Eval on experience 3 (Task 3) from test stream ended.
	ExperienceForgetting/eval_phase/test_stream/Task003/Exp003 = 0.0278
	Time_Exp/eval_phase/test_stream/Task003/Exp003 = 6.0212
	Top1_Acc_Exp/eval_phase/test_stream/Task003/Exp003 = 0.8512
-- Starting eval on experience 4 (Task 4) from test stream --
100%|███████████████████████████████████████████| 79/79 [00:05<00:00, 13.83it/s]
> Eval on experience 4 (Task 4) from test stream ended.
	Time_Exp/eval_phase/test_stream/Task004/Exp004 = 5.7116
	Top1_Acc_Exp/eval_phase/test_stream/Task004/Exp004 = 0.8868
-- Starting eval on experience 5 (Task 5) from test stream --
100%|███████████████████████████████████████████| 79/79 [00:05<00:00, 13.62it/s]
> Eval on experience 5 (Task 5) from test stream ended.
	Time_Exp/eval_phase/test_stream/Task005/Exp005 = 5.7990
	Top1_Acc_Exp/eval_phase/test_stream/Task

	Top1_Acc_Stream/eval_phase/test_stream/Task001 = 0.8077
	Top1_Acc_Stream/eval_phase/test_stream/Task002 = 0.8399
	Top1_Acc_Stream/eval_phase/test_stream/Task003 = 0.8507
	Top1_Acc_Stream/eval_phase/test_stream/Task004 = 0.8658
	Top1_Acc_Stream/eval_phase/test_stream/Task005 = 0.8935
	Top1_Acc_Stream/eval_phase/test_stream/Task006 = 0.1122
	Top1_Acc_Stream/eval_phase/test_stream/Task007 = 0.1523
	Top1_Acc_Stream/eval_phase/test_stream/Task008 = 0.1051
	Top1_Acc_Stream/eval_phase/test_stream/Task009 = 0.1120
Start training on experience  6
-- >> Start of training phase << --
100%|█████████████████████████████████████████| 469/469 [00:39<00:00, 11.77it/s]
Epoch 0 ended.
	RunningTime_Epoch/train_phase/train_stream/Task006 = 0.0002
	Time_Epoch/train_phase/train_stream/Task006 = 39.8444
	Top1_Acc_Epoch/train_phase/train_stream/Task006 = 0.6395
100%|█████████████████████████████████████████| 469/469 [00:39<00:00, 11.94it/s]
Epoch 1 ended.
	RunningTime_Epoch/train_phase/train_stream/Task006 =

-- Starting eval on experience 1 (Task 1) from test stream --
100%|███████████████████████████████████████████| 79/79 [00:05<00:00, 15.04it/s]
> Eval on experience 1 (Task 1) from test stream ended.
	ExperienceForgetting/eval_phase/test_stream/Task001/Exp001 = 0.1231
	Time_Exp/eval_phase/test_stream/Task001/Exp001 = 5.2543
	Top1_Acc_Exp/eval_phase/test_stream/Task001/Exp001 = 0.7416
-- Starting eval on experience 2 (Task 2) from test stream --
100%|███████████████████████████████████████████| 79/79 [00:05<00:00, 15.03it/s]
> Eval on experience 2 (Task 2) from test stream ended.
	ExperienceForgetting/eval_phase/test_stream/Task002/Exp002 = 0.1114
	Time_Exp/eval_phase/test_stream/Task002/Exp002 = 5.2577
	Top1_Acc_Exp/eval_phase/test_stream/Task002/Exp002 = 0.7683
-- Starting eval on experience 3 (Task 3) from test stream --
100%|███████████████████████████████████████████| 79/79 [00:05<00:00, 15.10it/s]
> Eval on experience 3 (Task 3) from test stream ended.
	ExperienceForgetting/eval_ph

100%|███████████████████████████████████████████| 79/79 [00:05<00:00, 15.27it/s]
> Eval on experience 7 (Task 7) from test stream ended.
	ExperienceForgetting/eval_phase/test_stream/Task007/Exp007 = 0.0110
	Time_Exp/eval_phase/test_stream/Task007/Exp007 = 5.1745
	Top1_Acc_Exp/eval_phase/test_stream/Task007/Exp007 = 0.8848
-- Starting eval on experience 8 (Task 8) from test stream --
100%|███████████████████████████████████████████| 79/79 [00:05<00:00, 15.12it/s]
> Eval on experience 8 (Task 8) from test stream ended.
	Time_Exp/eval_phase/test_stream/Task008/Exp008 = 5.2257
	Top1_Acc_Exp/eval_phase/test_stream/Task008/Exp008 = 0.8966
-- Starting eval on experience 9 (Task 9) from test stream --
100%|███████████████████████████████████████████| 79/79 [00:05<00:00, 15.24it/s]
> Eval on experience 9 (Task 9) from test stream ended.
	Time_Exp/eval_phase/test_stream/Task009/Exp009 = 5.1835
	Top1_Acc_Exp/eval_phase/test_stream/Task009/Exp009 = 0.1379
-- >> End of eval phase << --
	StreamForge

In [7]:
args_ewc_smnist = create_default_args({
    'benchmark': 'smnist',
    'multitask': True,
    
    'model': 'mlp',
    'learning_rate': 0.001, 
    'train_mb_size': 128,
    'eval_mb_size': 128,
    'hidden_size': 500,
    'hidden_layers': 2, 
    'epochs': 5, 
    'dropout': 0,
    'seed': 0,
    
    'strategy': 'ewc',
    'ewc_lambda': 1, 
    'ewc_mode': 'separate', 
#     'ewc_mode': 'online', 
    'ewc_decay': None,
})
results_ewc_smnist = run(args_ewc_smnist)

Starting experiment...
Start training on experience  0
-- >> Start of training phase << --
100%|███████████████████████████████████████████| 99/99 [00:03<00:00, 27.29it/s]
Epoch 0 ended.
	RunningTime_Epoch/train_phase/train_stream/Task000 = 0.0004
	Time_Epoch/train_phase/train_stream/Task000 = 3.6274
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.8166
100%|███████████████████████████████████████████| 99/99 [00:03<00:00, 27.82it/s]
Epoch 1 ended.
	RunningTime_Epoch/train_phase/train_stream/Task000 = 0.0004
	Time_Epoch/train_phase/train_stream/Task000 = 3.5592
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9746
100%|███████████████████████████████████████████| 99/99 [00:03<00:00, 27.16it/s]
Epoch 2 ended.
	RunningTime_Epoch/train_phase/train_stream/Task000 = 0.0004
	Time_Epoch/train_phase/train_stream/Task000 = 3.6450
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9816
100%|███████████████████████████████████████████| 99/99 [00:03<00:00, 27.27it/s]
Epoch 3 ended.
	Runni

	Top1_Acc_Epoch/train_phase/train_stream/Task002 = 0.9511
100%|███████████████████████████████████████████| 95/95 [00:03<00:00, 25.11it/s]
Epoch 4 ended.
	RunningTime_Epoch/train_phase/train_stream/Task002 = 0.0004
	Time_Epoch/train_phase/train_stream/Task002 = 3.7830
	Top1_Acc_Epoch/train_phase/train_stream/Task002 = 0.9549
-- >> End of training phase << --
Computing accuracy on the test set
-- >> Start of eval phase << --
-- Starting eval on experience 0 (Task 0) from test stream --
100%|███████████████████████████████████████████| 17/17 [00:00<00:00, 31.49it/s]
> Eval on experience 0 (Task 0) from test stream ended.
	ExperienceForgetting/eval_phase/test_stream/Task000/Exp000 = -0.0047
	Time_Exp/eval_phase/test_stream/Task000/Exp000 = 0.5411
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000 = 0.9934
-- Starting eval on experience 1 (Task 1) from test stream --
100%|███████████████████████████████████████████| 15/15 [00:00<00:00, 30.11it/s]
> Eval on experience 1 (Task 1) from test 

-- Starting eval on experience 1 (Task 1) from test stream --
100%|███████████████████████████████████████████| 15/15 [00:00<00:00, 29.79it/s]
> Eval on experience 1 (Task 1) from test stream ended.
	ExperienceForgetting/eval_phase/test_stream/Task001/Exp001 = 0.0198
	Time_Exp/eval_phase/test_stream/Task001/Exp001 = 0.5037
	Top1_Acc_Exp/eval_phase/test_stream/Task001/Exp001 = 0.9536
-- Starting eval on experience 2 (Task 2) from test stream --
100%|███████████████████████████████████████████| 16/16 [00:00<00:00, 29.71it/s]
> Eval on experience 2 (Task 2) from test stream ended.
	ExperienceForgetting/eval_phase/test_stream/Task002/Exp002 = 0.0005
	Time_Exp/eval_phase/test_stream/Task002/Exp002 = 0.5379
	Top1_Acc_Exp/eval_phase/test_stream/Task002/Exp002 = 0.9658
-- Starting eval on experience 3 (Task 3) from test stream --
100%|███████████████████████████████████████████| 16/16 [00:00<00:00, 30.67it/s]
> Eval on experience 3 (Task 3) from test stream ended.
	ExperienceForgetting/eval_ph

In [8]:
# LWF experiments

In [9]:
class LwFCEPenalty(avl.training.LwF):
    """This wrapper around LwF computes the total loss
    by diminishing the cross-entropy contribution over time,
    as per the paper
    "Three scenarios for continual learning" by van de Ven et. al. (2018).
    https://arxiv.org/pdf/1904.07734.pdf
    The loss is L_tot = (1/n_exp_so_far) * L_cross_entropy +
                        alpha[current_exp] * L_distillation
    """
    def _before_backward(self, **kwargs):
        self.loss *= float(1/(self.clock.train_exp_counter+1))
        super()._before_backward(**kwargs)

In [None]:
args_lwf_pmnist = create_default_args({
    'benchmark': 'pmnist',
    'multitask': False,
    
    'model': 'mlp',
    'learning_rate': 0.001, 
    'train_mb_size': 128,
    'eval_mb_size': 128,
    'hidden_size': 500,
    'hidden_layers': 2, 
    'epochs': 5, 
    'dropout': 0,
    'seed': 0,
    
    'strategy': 'lwf',
    'lwf_alpha': [0.]+[1-(1./float(i)) for i in range(2, 11)], # Penalty hyperparameter for LwF. It can be either a list with 
        # multiple elements (one alpha per experience) or a list of one element (same alpha for all experiences).
    'lwf_temperature': 2, # Temperature for softmax used in distillation
})

results_lwf_pmnist = run(args_lwf_pmnist)

Starting experiment...
Start training on experience  0
-- >> Start of training phase << --
100%|█████████████████████████████████████████| 469/469 [00:34<00:00, 13.42it/s]
Epoch 0 ended.
	RunningTime_Epoch/train_phase/train_stream/Task000 = 0.0002
	Time_Epoch/train_phase/train_stream/Task000 = 34.9426
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9335
100%|█████████████████████████████████████████| 469/469 [00:35<00:00, 13.34it/s]
Epoch 1 ended.
	RunningTime_Epoch/train_phase/train_stream/Task000 = 0.0002
	Time_Epoch/train_phase/train_stream/Task000 = 35.1577
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9736
100%|█████████████████████████████████████████| 469/469 [00:35<00:00, 13.38it/s]
Epoch 2 ended.
	RunningTime_Epoch/train_phase/train_stream/Task000 = 0.0002
	Time_Epoch/train_phase/train_stream/Task000 = 35.0638
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9817
100%|█████████████████████████████████████████| 469/469 [00:34<00:00, 13.54it/s]
Epoch 3 ended.
	Ru

	Top1_Acc_Exp/eval_phase/test_stream/Task003/Exp003 = 0.0610
-- Starting eval on experience 4 (Task 4) from test stream --
100%|███████████████████████████████████████████| 79/79 [00:05<00:00, 14.89it/s]
> Eval on experience 4 (Task 4) from test stream ended.
	Time_Exp/eval_phase/test_stream/Task004/Exp004 = 5.3053
	Top1_Acc_Exp/eval_phase/test_stream/Task004/Exp004 = 0.1131
-- Starting eval on experience 5 (Task 5) from test stream --
100%|███████████████████████████████████████████| 79/79 [00:05<00:00, 14.86it/s]
> Eval on experience 5 (Task 5) from test stream ended.
	Time_Exp/eval_phase/test_stream/Task005/Exp005 = 5.3157
	Top1_Acc_Exp/eval_phase/test_stream/Task005/Exp005 = 0.1388
-- Starting eval on experience 6 (Task 6) from test stream --
100%|███████████████████████████████████████████| 79/79 [00:05<00:00, 14.86it/s]
> Eval on experience 6 (Task 6) from test stream ended.
	Time_Exp/eval_phase/test_stream/Task006/Exp006 = 5.3180
	Top1_Acc_Exp/eval_phase/test_stream/Task006/Exp0

Start training on experience  3
-- >> Start of training phase << --
100%|█████████████████████████████████████████| 469/469 [00:35<00:00, 13.32it/s]
Epoch 0 ended.
	RunningTime_Epoch/train_phase/train_stream/Task003 = 0.0002
	Time_Epoch/train_phase/train_stream/Task003 = 35.2220
	Top1_Acc_Epoch/train_phase/train_stream/Task003 = 0.9060
100%|█████████████████████████████████████████| 469/469 [00:35<00:00, 13.40it/s]
Epoch 1 ended.
	RunningTime_Epoch/train_phase/train_stream/Task003 = 0.0002
	Time_Epoch/train_phase/train_stream/Task003 = 35.0050
	Top1_Acc_Epoch/train_phase/train_stream/Task003 = 0.9670
100%|█████████████████████████████████████████| 469/469 [00:34<00:00, 13.52it/s]
Epoch 2 ended.
	RunningTime_Epoch/train_phase/train_stream/Task003 = 0.0002
	Time_Epoch/train_phase/train_stream/Task003 = 34.6999
	Top1_Acc_Epoch/train_phase/train_stream/Task003 = 0.9792
 63%|█████████████████████████▊               | 295/469 [00:22<00:13, 13.15it/s]

In [None]:
args_lwf_smnist = create_default_args({
    'benchmark': 'smnist',
    'multitask': True,
    
    'model': 'mlp',
    'learning_rate': 0.001, 
    'train_mb_size': 128,
    'eval_mb_size': 128,
    'hidden_size': 500,
    'hidden_layers': 2, 
    'epochs': 5, 
    'dropout': 0,
    'seed': 0,
    
    'strategy': 'lwf',
    'lwf_alpha': [0.]+[1-(1./float(i)) for i in range(2, 6)], # Penalty hyperparameter for LwF. It can be either a list with 
        # multiple elements (one alpha per experience) or a list of one element (same alpha for all experiences).
    'lwf_temperature': 2, # Temperature for softmax used in distillation
})

results_lwf_smnist = run(args_lwf_smnist)

In [None]:
# SI experiments

In [None]:
args_si_pmnist = create_default_args({
    'benchmark': 'pmnist',
    'multitask': False,
    
    'model': 'mlp',
    'learning_rate': 0.001, 
    'train_mb_size': 128,
    'eval_mb_size': 128,
    'hidden_size': 500,
    'hidden_layers': 2, 
    'epochs': 5, 
    'dropout': 0,
    'seed': 0,
    
    'strategy': 'si',
    'si_lambda': 0.1, 
    'si_eps': 0.1,     
})

results_si_pmnist = run(args_si_pmnist)

In [None]:
args_si_smnist = create_default_args({
    'benchmark': 'smnist',
    'multitask': True,
    
    'model': 'mlp',
    'learning_rate': 0.001, 
    'train_mb_size': 128,
    'eval_mb_size': 128,
    'hidden_size': 500,
    'hidden_layers': 2, 
    'epochs': 5, 
    'dropout': 0,
    'seed': 0,
    
    'strategy': 'si',
    'si_lambda': 0.1, 
    'si_eps': 0.1,     
})

results_si_smnist = run(args_si_smnist)