In [None]:
# """
# EWC: 
# "Overcoming catastrophic forgetting in neural networks" by Kirkpatrick et. al. (2017).
# https://www.pnas.org/content/114/13/3521
# """
###################################################################################################################
# """"
# "Learning without Forgetting" by Li et. al. (2016).
# http://arxiv.org/abs/1606.09282
# Since experimental setup of the paper is quite outdated and not
# easily reproducible, this experiment is based on
# "Three scenarios for continual learning" by van de Ven et. al. (2018).
# https://arxiv.org/pdf/1904.07734.pdf

# Please, note that the performance of LwF on Permuted MNIST is below the one achieved
# by Naive with the same configuration. This is compatible with the results presented
# by van de Ven et. al. (2018).
# """
###################################################################################################################
# """
# "Continual Learning Through Synaptic Intelligence" by Zenke et. al. (2017).
# http://proceedings.mlr.press/v70/zenke17a.html
# """

In [None]:
# !pip install avalanche-lib

In [None]:
import avalanche as avl
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam, SGD
# from avalanche.evaluation import metrics as metrics

In [None]:
from types import SimpleNamespace
import torch
import numpy as np
import random

def set_seed(seed):
    if seed is None:
        return
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = False


def create_default_args(args_dict, additional_args=None):
    args = SimpleNamespace()
    for k, v in args_dict.items():
        args.__dict__[k] = v
    if additional_args is not None:
        for k, v in additional_args.items():
            args.__dict__[k] = v
    return args

In [None]:
def run(args):
    set_seed(args.seed)
    device = torch.device(f"cuda:0"
                          if torch.cuda.is_available() 
                          else "cpu")
    
    if args.benchmark == 'pmnist':
        benchmark = avl.benchmarks.PermutedMNIST(10)
        num_classes = 10
        input_size = 784
    
    elif args.benchmark == 'smnist':
        benchmark = avl.benchmarks.SplitMNIST(
            n_experiences=5, # continual learning; 5 experiences, 2 digits/classes per experience
            return_task_id=args.multitask, # set to True for multi-task experiments
        )
        num_classes = 10
        input_size = 784
    # add other benchmarks as needed
    
    if args.model == 'mlp':
        model = avl.models.SimpleMLP(
            num_classes=num_classes, 
            input_size=input_size, 
            hidden_size=args.hidden_size, 
            hidden_layers=args.hidden_layers, 
            drop_rate=args.dropout
        )
    # add other models as needed
        
    if args.multitask: 
        model = avl.models.as_multitask(model, 'classifier')
    
    criterion = CrossEntropyLoss()

    interactive_logger = avl.logging.InteractiveLogger()

    evaluation_plugin = avl.training.plugins.EvaluationPlugin(
        avl.evaluation.metrics.accuracy_metrics(epoch=True, experience=True, stream=True),
        avl.evaluation.metrics.forgetting_metrics(experience=True, stream=True),
#         avl.evaluation.metrics.bwt_metrics(experience=True, stream=True),
#         avl.evaluation.metrics.ram_usage_metrics(every=1, minibatch=True, epoch=True, experience=True, stream=True),
        avl.evaluation.metrics.timing_metrics(minibatch=False, epoch=True, epoch_running=True, experience=True, stream=True),
        loggers=[interactive_logger])
    # add other evaluation metrics as needed
    
    if args.strategy == 'ewc':
        cl_strategy = avl.training.EWC(
            model, 
            SGD(model.parameters(), lr=args.learning_rate), 
            criterion,
            ewc_lambda=args.ewc_lambda, 
            mode=args.ewc_mode, 
            decay_factor=args.ewc_decay,
            train_mb_size=args.train_mb_size, 
            train_epochs=args.epochs, 
            eval_mb_size=args.eval_mb_size,
            device=device, 
            evaluator=evaluation_plugin
        )
    
    elif args.strategy == 'lwf':
        cl_strategy = LwFCEPenalty(
            model, 
            Adam(model.parameters(), lr=args.learning_rate), 
            criterion,
            alpha=args.lwf_alpha,
            temperature=args.lwf_temperature,
            train_mb_size=args.train_mb_size, 
            train_epochs=args.epochs, 
            eval_mb_size=args.eval_mb_size,
            device=device, 
            evaluator=evaluation_plugin
        )
        
    elif args.strategy == 'si':
        cl_strategy = avl.training.SynapticIntelligence(
            model, 
            Adam(model.parameters(), lr=args.learning_rate), 
            criterion,
            si_lambda=args.si_lambda, 
            eps=args.si_eps,
            train_mb_size=args.train_mb_size, 
            train_epochs=args.epochs, 
            eval_mb_size=args.eval_mb_size,
            device=device, 
            evaluator=evaluation_plugin
        )
    # add other learning strategies as needed

    print("Starting experiment...")
    results = []
    for experience in benchmark.train_stream:
        print("Start training on experience ", experience.current_experience)
        cl_strategy.train(experience)
        print("Computing accuracy on the test set")
        results += [cl_strategy.eval(benchmark.test_stream[:])]

    return results

In [None]:
# EWC experiments

In [None]:
args_ewc_pmnist = create_default_args({
    'benchmark': 'pmnist',
    'multitask': False,
    
    'model': 'mlp',
    'learning_rate': 0.001, 
    'train_mb_size': 128,
    'eval_mb_size': 128,
    'hidden_size': 500,
    'hidden_layers': 2, 
    'epochs': 1, 
    'dropout': 0,
    'seed': 0,
    
    'strategy': 'ewc',
    'ewc_lambda': 1, 
    'ewc_mode': 'separate', 
#     'ewc_mode': 'online', 
    'ewc_decay': None,
})
results_ewc_pmnist = run(args_ewc_pmnist)

In [None]:
args_ewc_smnist = create_default_args({
    'benchmark': 'smnist',
    'multitask': True,
    
    'model': 'mlp',
    'learning_rate': 0.001, 
    'train_mb_size': 128,
    'eval_mb_size': 128,
    'hidden_size': 500,
    'hidden_layers': 2, 
    'epochs': 1, 
    'dropout': 0,
    'seed': 0,
    
    'strategy': 'ewc',
    'ewc_lambda': 1, 
    'ewc_mode': 'separate', 
#     'ewc_mode': 'online', 
    'ewc_decay': None,
})
results_ewc_smnist = run(args_ewc_smnist)

In [None]:
# LWF experiments

In [None]:
class LwFCEPenalty(avl.training.LwF):
    """This wrapper around LwF computes the total loss
    by diminishing the cross-entropy contribution over time,
    as per the paper
    "Three scenarios for continual learning" by van de Ven et. al. (2018).
    https://arxiv.org/pdf/1904.07734.pdf
    The loss is L_tot = (1/n_exp_so_far) * L_cross_entropy +
                        alpha[current_exp] * L_distillation
    """
    def _before_backward(self, **kwargs):
        self.loss *= float(1/(self.clock.train_exp_counter+1))
        super()._before_backward(**kwargs)

In [None]:
args_lwf_pmnist = create_default_args({
    'benchmark': 'pmnist',
    'multitask': False,
    
    'model': 'mlp',
    'learning_rate': 0.001, 
    'train_mb_size': 128,
    'eval_mb_size': 128,
    'hidden_size': 500,
    'hidden_layers': 2, 
    'epochs': 1, 
    'dropout': 0,
    'seed': 0,
    
    'strategy': 'lwf',
    'lwf_alpha': [0.]+[1-(1./float(i)) for i in range(2, 11)], # Penalty hyperparameter for LwF. It can be either a list with 
        # multiple elements (one alpha per experience) or a list of one element (same alpha for all experiences).
    'lwf_temperature': 2, # Temperature for softmax used in distillation
})

results_lwf_pmnist = run(args_lwf_pmnist)

In [None]:
args_lwf_smnist = create_default_args({
    'benchmark': 'smnist',
    'multitask': True,
    
    'model': 'mlp',
    'learning_rate': 0.001, 
    'train_mb_size': 128,
    'eval_mb_size': 128,
    'hidden_size': 500,
    'hidden_layers': 2, 
    'epochs': 1, 
    'dropout': 0,
    'seed': 0,
    
    'strategy': 'lwf',
    'lwf_alpha': [0.]+[1-(1./float(i)) for i in range(2, 6)], # Penalty hyperparameter for LwF. It can be either a list with 
        # multiple elements (one alpha per experience) or a list of one element (same alpha for all experiences).
    'lwf_temperature': 2, # Temperature for softmax used in distillation
})

results_lwf_smnist = run(args_lwf_smnist)

In [None]:
# SI experiments

In [None]:
args_si_pmnist = create_default_args({
    'benchmark': 'pmnist',
    'multitask': False,
    
    'model': 'mlp',
    'learning_rate': 0.001, 
    'train_mb_size': 128,
    'eval_mb_size': 128,
    'hidden_size': 500,
    'hidden_layers': 2, 
    'epochs': 1, 
    'dropout': 0,
    'seed': 0,
    
    'strategy': 'si',
    'si_lambda': 0.1, 
    'si_eps': 0.1,     
})

results_si_pmnist = run(args_si_pmnist)

In [None]:
args_si_smnist = create_default_args({
    'benchmark': 'smnist',
    'multitask': True,
    
    'model': 'mlp',
    'learning_rate': 0.001, 
    'train_mb_size': 128,
    'eval_mb_size': 128,
    'hidden_size': 500,
    'hidden_layers': 2, 
    'epochs': 1, 
    'dropout': 0,
    'seed': 0,
    
    'strategy': 'si',
    'si_lambda': 0.1, 
    'si_eps': 0.1,     
})

results_si_smnist = run(args_si_smnist)