# Regularization Methods

Most regularization methods have two components:
- a function that computes the regularization loss
- a function that updates the internal state (teacher model, weight importance, ...) after each experience

We are going to see this with two examples: Synaptic Intelligence and Learning without Forgetting.

First, let's install Avalanche. You can skip this step if you have installed it already.

In [None]:
!pip install avalanche-lib==0.3.1

## Setup Benchmark and Model
We use Permuted MNIST, a task-agnostic domain-incremental benchmark. 

In [4]:
from avalanche.benchmarks import PermutedMNIST
from avalanche.benchmarks.utils.data_loader import GroupBalancedDataLoader
from avalanche.models import SimpleMLP
from torch.optim import SGD
from torch.nn import CrossEntropyLoss

benchmark = PermutedMNIST(5, return_task_id=False)
# model
model = SimpleMLP()
optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = CrossEntropyLoss()

# Synaptic Intelligence

In [11]:
import warnings
from fnmatch import fnmatch
from typing import Sequence, Any, Set, List, Tuple, Dict, Union, TYPE_CHECKING

import numpy as np
import torch
from torch import Tensor
from torch.nn import Module
from torch.nn.modules.batchnorm import _NormBase

from avalanche.training.plugins.ewc import EwcDataType, ParamDict
from avalanche.training.plugins.strategy_plugin import SupervisedPlugin
from avalanche.training.utils import get_layers_and_params, ParamData
from avalanche.training.templates import SupervisedTemplate

SynDataType = Dict[str, Dict[str, Union[ParamData, Tensor]]]


class SynapticIntelligencePlugin(SupervisedPlugin):
    """Synaptic Intelligence plugin.

    This is the Synaptic Intelligence PyTorch implementation of the
    algorithm described in the paper
    "Continuous Learning in Single-Incremental-Task Scenarios"
    (https://arxiv.org/abs/1806.08568)

    The original implementation has been proposed in the paper
    "Continual Learning Through Synaptic Intelligence"
    (https://arxiv.org/abs/1703.04200).

    This plugin can be attached to existing strategies to achieve a
    regularization effect.

    This plugin will require the strategy `loss` field to be set before the
    `before_backward` callback is invoked. The loss Tensor will be updated to
    achieve the S.I. regularization effect.
    """

    def __init__(
        self,
        si_lambda: Union[float, Sequence[float]],
        eps: float = 0.0000001,
        excluded_parameters: Sequence["str"] = None,
        device: Any = "as_strategy",
    ):
        """Creates an instance of the Synaptic Intelligence plugin.

        :param si_lambda: Synaptic Intelligence lambda term.
            If list, one lambda for each experience. If the list has less
            elements than the number of experiences, last lambda will be
            used for the remaining experiences.
        :param eps: Synaptic Intelligence damping parameter.
        :param device: The device to use to run the S.I. experiences.
            Defaults to "as_strategy", which means that the `device` field of
            the strategy will be used. Using a different device may lead to a
            performance drop due to the required data transfer.
        """

        super().__init__()

        if excluded_parameters is None:
            excluded_parameters = []
        self.si_lambda = (
            si_lambda if isinstance(si_lambda, (list, tuple)) else [si_lambda]
        )
        self.eps: float = eps
        self.excluded_parameters: Set[str] = set(excluded_parameters)
        self.ewc_data: EwcDataType = (dict(), dict())
        """
        The first dictionary contains the params at loss minimum while the 
        second one contains the parameter importance.
        """

        self.syn_data: SynDataType = {
            "old_theta": dict(),
            "new_theta": dict(),
            "grad": dict(),
            "trajectory": dict(),
            "cum_trajectory": dict(),
        }

        self._device = device

    def before_training_exp(self, strategy: "SupervisedTemplate", **kwargs):
        super().before_training_exp(strategy, **kwargs)
        SynapticIntelligencePlugin.create_syn_data(
            strategy.model,
            self.ewc_data,
            self.syn_data,
            self.excluded_parameters,
        )

        SynapticIntelligencePlugin.init_batch(
            strategy.model,
            self.ewc_data,
            self.syn_data,
            self.excluded_parameters,
        )

    def before_backward(self, strategy: "SupervisedTemplate", **kwargs):
        super().before_backward(strategy, **kwargs)

        exp_id = strategy.clock.train_exp_counter
        try:
            si_lamb = self.si_lambda[exp_id]
        except IndexError:  # less than one lambda per experience, take last
            si_lamb = self.si_lambda[-1]

        syn_loss = SynapticIntelligencePlugin.compute_ewc_loss(
            strategy.model,
            self.ewc_data,
            self.excluded_parameters,
            lambd=si_lamb,
            device=self.device(strategy),
        )

        if syn_loss is not None:
            strategy.loss += syn_loss.to(strategy.device)

    def before_training_iteration(
        self, strategy: "SupervisedTemplate", **kwargs
    ):
        super().before_training_iteration(strategy, **kwargs)
        SynapticIntelligencePlugin.pre_update(
            strategy.model, self.syn_data, self.excluded_parameters
        )

    def after_training_iteration(
        self, strategy: "SupervisedTemplate", **kwargs
    ):
        super().after_training_iteration(strategy, **kwargs)
        SynapticIntelligencePlugin.post_update(
            strategy.model, self.syn_data, self.excluded_parameters
        )

    def after_training_exp(self, strategy: "SupervisedTemplate", **kwargs):
        super().after_training_exp(strategy, **kwargs)
        SynapticIntelligencePlugin.update_ewc_data(
            strategy.model,
            self.ewc_data,
            self.syn_data,
            0.001,
            self.excluded_parameters,
            1,
            eps=self.eps,
        )

    def device(self, strategy: "SupervisedTemplate"):
        if self._device == "as_strategy":
            return strategy.device

        return self._device

    @staticmethod
    @torch.no_grad()
    def create_syn_data(
        model: Module,
        ewc_data: EwcDataType,
        syn_data: SynDataType,
        excluded_parameters: Set[str],
    ):
        params = SynapticIntelligencePlugin.allowed_parameters(
            model, excluded_parameters
        )

        for param_name, param in params:
            if param_name not in ewc_data[0]:
                # new parameter
                ewc_data[0][param_name] = ParamData(
                    param_name, param.flatten().shape)
                ewc_data[1][param_name] = ParamData(
                    f"imp_{param_name}", param.flatten().shape)
                syn_data["old_theta"][param_name] = ParamData(
                    f"old_theta_{param_name}", param.flatten().shape)
                syn_data["new_theta"][param_name] = ParamData(
                    f"new_theta_{param_name}", param.flatten().shape)
                syn_data["grad"][param_name] = ParamData(
                    f"grad{param_name}", param.flatten().shape)
                syn_data["trajectory"][param_name] = ParamData(
                    f"trajectory_{param_name}", param.flatten().shape)
                syn_data["cum_trajectory"][param_name] = ParamData(
                    f"cum_trajectory_{param_name}", param.flatten().shape)
            elif ewc_data[0][param_name].shape != param.shape:
                # parameter expansion
                ewc_data[0][param_name].expand(param.flatten().shape)
                ewc_data[1][param_name].expand(param.flatten().shape)
                syn_data["old_theta"][param_name].expand(param.flatten().shape)
                syn_data["new_theta"][param_name].expand(param.flatten().shape)
                syn_data["grad"][param_name].expand(param.flatten().shape)
                syn_data["trajectory"][param_name]\
                    .expand(param.flatten().shape)
                syn_data["cum_trajectory"][param_name]\
                    .expand(param.flatten().shape)

    @staticmethod
    @torch.no_grad()
    def extract_weights(
        model: Module, target: ParamDict, excluded_parameters: Set[str]
    ):
        params = SynapticIntelligencePlugin.allowed_parameters(
            model, excluded_parameters
        )

        for name, param in params:
            target[name].data = param.detach().cpu().flatten()

    @staticmethod
    @torch.no_grad()
    def extract_grad(model, target: ParamDict, excluded_parameters: Set[str]):
        params = SynapticIntelligencePlugin.allowed_parameters(
            model, excluded_parameters
        )

        # Store the gradients into target
        for name, param in params:
            target[name].data = param.grad.detach().cpu().flatten()

    @staticmethod
    @torch.no_grad()
    def init_batch(
        model,
        ewc_data: EwcDataType,
        syn_data: SynDataType,
        excluded_parameters: Set[str],
    ):
        # Keep initial weights
        SynapticIntelligencePlugin.extract_weights(
            model, ewc_data[0], excluded_parameters
        )
        for param_name, param_trajectory in syn_data["trajectory"].items():
            param_trajectory.data.fill_(0.0)

    @staticmethod
    @torch.no_grad()
    def pre_update(model, syn_data: SynDataType, excluded_parameters: Set[str]):
        SynapticIntelligencePlugin.extract_weights(
            model, syn_data["old_theta"], excluded_parameters
        )

    @staticmethod
    @torch.no_grad()
    def post_update(
        model, syn_data: SynDataType, excluded_parameters: Set[str]
    ):
        SynapticIntelligencePlugin.extract_weights(
            model, syn_data["new_theta"], excluded_parameters
        )
        SynapticIntelligencePlugin.extract_grad(
            model, syn_data["grad"], excluded_parameters
        )

        for param_name in syn_data["trajectory"]:
            syn_data["trajectory"][param_name].data += syn_data["grad"][
                param_name
            ].data * (
                syn_data["new_theta"][param_name].data
                - syn_data["old_theta"][param_name].data
            )

    @staticmethod
    def compute_ewc_loss(
        model,
        ewc_data: EwcDataType,
        excluded_parameters: Set[str],
        device,
        lambd=0.0,
    ):
        params = SynapticIntelligencePlugin.allowed_parameters(
            model, excluded_parameters
        )

        loss = None
        for name, param in params:
            weights = param.to(device).flatten()  # Flat, not detached
            ewc_data0 = ewc_data[0][name].data.to(device)  # Flat, detached
            ewc_data1 = ewc_data[1][name].data.to(device)  # Flat, detached
            syn_loss: Tensor = torch.dot(
                ewc_data1, (weights - ewc_data0) ** 2
            ) * (lambd / 2)

            if loss is None:
                loss = syn_loss
            else:
                loss += syn_loss

        return loss

    @staticmethod
    @torch.no_grad()
    def update_ewc_data(
        net,
        ewc_data: EwcDataType,
        syn_data: SynDataType,
        clip_to: float,
        excluded_parameters: Set[str],
        c=0.0015,
        eps: float = 0.0000001,
    ):
        SynapticIntelligencePlugin.extract_weights(
            net, syn_data["new_theta"], excluded_parameters
        )

        for param_name in syn_data["cum_trajectory"]:
            syn_data["cum_trajectory"][param_name].data += (
                c
                * syn_data["trajectory"][param_name].data
                / (
                    np.square(
                        syn_data["new_theta"][param_name].data
                        - ewc_data[0][param_name].data
                    )
                    + eps
                )
            )

        for param_name in syn_data["cum_trajectory"]:
            ewc_data[1][param_name].data = torch.empty_like(
                syn_data["cum_trajectory"][param_name].data
            ).copy_(-syn_data["cum_trajectory"][param_name].data)

        # change sign here because the Ewc regularization
        # in Caffe (theta - thetaold) is inverted w.r.t. syn equation [4]
        # (thetaold - theta)
        for param_name in ewc_data[1]:
            ewc_data[1][param_name].data = torch.clamp(
                ewc_data[1][param_name].data, max=clip_to
            )
            ewc_data[0][param_name].data = \
                syn_data["new_theta"][param_name].data.clone()

    @staticmethod
    def explode_excluded_parameters(excluded: Set[str]) -> Set[str]:
        """
        Explodes a list of excluded parameters by adding a generic final ".*"
        wildcard at its end.

        :param excluded: The original set of excluded parameters.

        :return: The set of excluded parameters in which ".*" patterns have been
            added.
        """
        result = set()
        for x in excluded:
            result.add(x)
            if not x.endswith("*"):
                result.add(x + ".*")
        return result

    @staticmethod
    def not_excluded_parameters(
        model: Module, excluded_parameters: Set[str]
    ) -> Sequence[Tuple[str, Tensor]]:
        # Add wildcards ".*" to all excluded parameter names
        result: List[Tuple[str, Tensor]] = []
        excluded_parameters = (
            SynapticIntelligencePlugin.explode_excluded_parameters(
                excluded_parameters
            )
        )
        layers_params = get_layers_and_params(model)

        for lp in layers_params:
            if isinstance(lp.layer, _NormBase):
                # Exclude batch norm parameters
                excluded_parameters.add(lp.parameter_name)

        for name, param in model.named_parameters():
            accepted = True
            for exclusion_pattern in excluded_parameters:
                if fnmatch(name, exclusion_pattern):
                    accepted = False
                    break

            if accepted:
                result.append((name, param))

        return result

    @staticmethod
    def allowed_parameters(
        model: Module, excluded_parameters: Set[str]
    ) -> List[Tuple[str, Tensor]]:

        allow_list = SynapticIntelligencePlugin.not_excluded_parameters(
            model, excluded_parameters
        )

        result = []
        for name, param in allow_list:
            if param.requires_grad:
                result.append((name, param))

        return result


# Training

In [16]:
from avalanche.training import Naive

device='cpu'

cl_strategy = Naive(
    model, optimizer, criterion,
    train_mb_size=128,
    train_epochs=4,
    eval_mb_size=128,
    device=device,
    plugins=[SynapticIntelligencePlugin(si_lambda=0.0001)]
)

# TRAINING LOOP
print("Starting experiment...")
results = []
for experience in benchmark.train_stream:
    print("Start of experience: ", experience.current_experience)
    print("Current Classes: ", experience.classes_in_this_experience)

    cl_strategy.train(experience)
    print("Training completed")

    print("Computing accuracy on the whole test set")
    results.append(cl_strategy.eval(benchmark.test_stream))

Starting experiment...
Start of experience:  0
Current Classes:  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
-- >> Start of training phase << --
100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:21<00:00, 21.39it/s]
Epoch 0 ended.
	Loss_Epoch/train_phase/train_stream/Task000 = 0.8542
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.7657
100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:22<00:00, 21.16it/s]
Epoch 1 ended.
	Loss_Epoch/train_phase/train_stream/Task000 = 0.4157
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.8808
100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:23<00:00, 20.05it/s]
Epoch 2 ended.
	Loss_Epoch/train_phase/train_stream/Task000 = 0.3462
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.8985
100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:25<00:00, 18.62it/s]
Epoch 3

KeyboardInterrupt: 

# Learning without Forgetting

In [18]:
import copy
from collections import defaultdict
from typing import List

import torch
import torch.nn.functional as F

from avalanche.models import MultiTaskModule, avalanche_forward
from avalanche.training.regularization import RegularizationMethod


def cross_entropy_with_oh_targets(outputs, targets, eps=1e-5):
    """ Calculates cross-entropy with temperature scaling, 
    targets can also be soft targets but they must sum to 1 """
    outputs = torch.nn.functional.softmax(outputs, dim=1)
    ce = -(targets * outputs.log()).sum(1)
    ce = ce.mean()
    return ce


class LearningWithoutForgetting(RegularizationMethod):
    """Learning Without Forgetting.
    The method applies knowledge distilllation to mitigate forgetting.
    The teacher is the model checkpoint after the last experience.
    """

    def __init__(self, alpha=1, temperature=2):
        """
        :param alpha: distillation hyperparameter. It can be either a float
                number or a list containing alpha for each experience.
        :param temperature: softmax temperature for distillation
        """
        self.alpha = alpha
        self.temperature = temperature
        self.prev_model = None
        self.expcount = 0
        # count number of experiences (used to increase alpha)
        self.prev_classes_by_task = defaultdict(set)
        """ In Avalanche, targets of different experiences are not ordered. 
        As a result, some units may be allocated even though their 
        corresponding class has never been seen by the model.
        Knowledge distillation uses only units corresponding
        to old classes. 
        """

    def _distillation_loss(self, out, prev_out, active_units):
        """Compute distillation loss between output of the current model and
        and output of the previous (saved) model.
        """
        # we compute the loss only on the previously active units.
        au = list(active_units)

        # some people use the crossentropy instead of the KL
        # They are equivalent. We compute 
        # kl_div(log_p_curr, p_prev) = p_prev * (log (p_prev / p_curr)) = 
        #   p_prev * log(p_prev) - p_prev * log(p_curr).
        # Now, the first term is constant (we don't optimize the teacher), 
        # so optimizing the crossentropy and kl-div are equivalent.
        log_p = torch.log_softmax(out[:, au] / self.temperature, dim=1)
        q = torch.softmax(prev_out[:, au] / self.temperature, dim=1)
        res = torch.nn.functional.kl_div(log_p, q, reduction="batchmean")
        return res

    def _lwf_penalty(self, out, x, curr_model):
        """
        Compute weighted distillation loss.
        """
        if self.prev_model is None:
            return 0
        else:
            if isinstance(self.prev_model, MultiTaskModule):
                # output from previous output heads.
                with torch.no_grad():
                    y_prev = avalanche_forward(self.prev_model, x, None)
                y_prev = {k: v for k, v in y_prev.items()}
                # in a multitask scenario we need to compute the output
                # from all the heads, so we need to call forward again.
                # TODO: can we avoid this?
                y_curr = avalanche_forward(curr_model, x, None)
                y_curr = {k: v for k, v in y_curr.items()}
            else:  # no task labels. Single task LwF
                with torch.no_grad():
                    y_prev = {0: self.prev_model(x)}
                y_curr = {0: out}

            dist_loss = 0
            for task_id in y_prev.keys():
                # compute kd only for previous heads and only for seen units.
                if task_id in self.prev_classes_by_task:
                    yp = y_prev[task_id]
                    yc = y_curr[task_id]
                    au = self.prev_classes_by_task[task_id]
                    dist_loss += self._distillation_loss(yc, yp, au)
            return dist_loss

    def __call__(self, mb_x, mb_pred, model):
        """
        Add distillation loss
        """
        alpha = (
            self.alpha[self.expcount]
            if isinstance(self.alpha, (list, tuple))
            else self.alpha
        )
        return alpha * self._lwf_penalty(mb_pred, mb_x, model)

    def update(self, experience, model):
        """Save a copy of the model after each experience and
        update self.prev_classes to include the newly learned classes.
        :param experience: current experience
        :param model: current model
        """

        self.expcount += 1
        self.prev_model = copy.deepcopy(model)
        task_ids = experience.dataset.targets_task_labels.uniques

        for task_id in task_ids:
            task_data = experience.dataset.task_set[task_id]
            pc = set(task_data.targets.uniques)

            if task_id not in self.prev_classes_by_task:
                self.prev_classes_by_task[task_id] = pc
            else:
                self.prev_classes_by_task[task_id] = self.prev_classes_by_task[
                    task_id
                ].union(pc)

## Training

we use a custom loop here instead of Avalanche `Naive`

In [20]:
from avalanche.models.utils import avalanche_model_adaptation
from avalanche.models.dynamic_optimizers import reset_optimizer
from torch.utils.data import DataLoader

model = SimpleMLP()
optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = CrossEntropyLoss()
lwf = LearningWithoutForgetting(alpha=1, temperature=2)

device='cpu'
num_epochs = 2

for exp in benchmark.train_stream:
    print(f"Experience ({exp.current_experience})")
    model.train()
    avalanche_model_adaptation(model, exp)
    reset_optimizer(optimizer, model)   
    dataset = exp.dataset
    dataset = dataset.train()
    
    for epoch in range(num_epochs):
        dl = DataLoader(dataset, batch_size=128)
        for x, y, t in dl:
          x, y, t = x.to(device), y.to(device), t.to(device)

          optimizer.zero_grad()
          output = model(x)
          loss = F.cross_entropy(output, y)
          # ADD LWF loss
          loss += lwf(x, output, model)
          loss.backward()
          optimizer.step()
        print('Train Epoch: {} \tLoss: {:.6f}'.format(epoch, loss.item()))
    
    # UPDATE LWF TEACHER
    lwf.update(exp, model)


Experience (0)
Train Epoch: 0 	Loss: 0.528850
Train Epoch: 1 	Loss: 0.409638
Experience (1)
Train Epoch: 0 	Loss: 0.966667
Train Epoch: 1 	Loss: 0.897288
Experience (2)
Train Epoch: 0 	Loss: 0.986454
Train Epoch: 1 	Loss: 0.922763
Experience (3)
Train Epoch: 0 	Loss: 0.933775
Train Epoch: 1 	Loss: 0.869438
Experience (4)
Train Epoch: 0 	Loss: 0.972159
Train Epoch: 1 	Loss: 0.892220


# Exercises
- test LwF in a task-aware scenario
- test "Stable SGD": use a high learning rate on the first experience and decrease over time lr and batch size for the following experiences