# Regularisation

In this notebook, we implement L2, EWC, and SI based regularization for
continual learning.

In [None]:
from typing import Iterable

import torch
from capymoa.ocl.ann import WNPerceptron
from matplotlib import pyplot as plt
from torch import Tensor

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
plt.rcParams["font.family"] = "sans-serif"
plt.rcParams["font.sans-serif"] = ["Noto Sans"]
plt.rcParams["font.size"] = 9
pt = 1 / 72.27
figsize_169 = (455 * pt, 256 * pt)
figsize = (figsize_169[0], 0.45 * figsize_169[0])
cl_evals = 5

$$
L_{reg} = \lambda \sum_{i} (\theta_i - \theta_i^*)^2
$$

In [None]:
def l2_reg(params: Iterable[Tensor], anchor_params: Iterable[Tensor]) -> Tensor:
    l2 = torch.scalar_tensor(0.0)
    for param, anchor_param in zip(params, anchor_params, strict=True):
        assert param.shape == anchor_param.shape
        l2 += ((param - anchor_param) ** 2).sum()
    return l2

In [None]:
from typing import List, Optional

import numpy as np
import torch
from capymoa.base import BatchClassifier
from capymoa.instance import Instance
from capymoa.ocl.base import TaskBoundaryAware
from capymoa.stream._stream import Schema
from torch import nn


class L2RegularizedCL(BatchClassifier, TaskBoundaryAware):
    def __init__(
        self,
        schema: Schema,
        model: nn.Module,
        lambda_: float,
        batch_size: int = 128,
        random_seed: int = 1,
        lr: float = 0.01,
        device: torch.device = device,
    ) -> None:
        super().__init__(schema, batch_size, random_seed)
        self.lambda_ = lambda_
        self.device = device
        self.lr = lr
        self.model = model.to(self.device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        self.criterion = nn.CrossEntropyLoss()
        self.anchor_params: Optional[List[Tensor]] = None
        self.test_task_id = 0

    def batch_train(self, x: np.ndarray, y: np.ndarray) -> None:
        self.torch_batch_train(
            torch.from_numpy(x).float().to(self.device),
            torch.from_numpy(y).long().to(self.device),
        )

    def torch_batch_train(self, x: Tensor, y: Tensor):
        self.model.train()
        self.optimizer.zero_grad()

        # Forward pass
        y_hat = self.model(x)
        classify_loss = self.criterion(y_hat, y)
        l2 = self.regularize()
        # print(f"l2: {l2.item()} classify_loss: {classify_loss.item()}")

        # Backward pass and step
        loss = classify_loss + self.lambda_ * l2
        loss.backward()
        self.optimizer.step()

    def regularize(self) -> Tensor:
        if self.anchor_params is not None:
            return l2_reg(self.model.parameters(), self.anchor_params)
        else:
            return torch.scalar_tensor(0.0)

    @torch.no_grad()
    def predict_proba(self, instance: Instance) -> np.ndarray:
        x = torch.from_numpy(instance.x).float().to(self.device).view(1, -1)
        self.model.eval()

        y_hat = self.model(x)
        return torch.softmax(y_hat, dim=1).cpu().numpy()

    def set_train_task(self, train_task_id: int):
        # Adam maintains momentum, so we need to reinitialize the optimizer
        # when the task changes
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)

        # Anchoring the first task is detrimental
        if train_task_id == 0:
            self.anchor_params = None
        else:
            self.anchor_params = [
                param.clone().detach() for param in self.model.parameters()
            ]

    def __str__(self) -> str:
        return f"L2RegularizedCL(lambda_={self.lambda_}, lr={self.optimizer.defaults['lr']})"

In [None]:
from capymoa.ocl.datasets import SplitMNIST
from capymoa.ocl.evaluation import ocl_train_eval_loop

torch.manual_seed(0)

scenario = SplitMNIST()
schema = scenario.schema
l2_off = L2RegularizedCL(schema, WNPerceptron(schema, 128), lambda_=0)
l2_low = L2RegularizedCL(schema, WNPerceptron(schema, 128), lambda_=1)
l2_high = L2RegularizedCL(schema, WNPerceptron(schema, 128), lambda_=5)

l2_off_results = ocl_train_eval_loop(
    l2_off,
    scenario.train_streams,
    scenario.test_streams,
    progress_bar=True,
    continual_evaluations=cl_evals,
)
l2_low_results = ocl_train_eval_loop(
    l2_low,
    scenario.train_streams,
    scenario.test_streams,
    progress_bar=True,
    continual_evaluations=cl_evals,
)
l2_high_results = ocl_train_eval_loop(
    l2_high,
    scenario.train_streams,
    scenario.test_streams,
    progress_bar=True,
    continual_evaluations=cl_evals,
)

In [None]:
%load_ext autoreload
%autoreload 2
from matplotlib import pyplot as plt

from plot import plot_multiple

fig, ax = plt.subplots(figsize=(10, 5))
plot_multiple(
    [
        (r"Control $\lambda=0$", l2_off_results),
        (r"L2RegularizedCL $\lambda=1$", l2_low_results),
        (r"L2RegularizedCL $\lambda=5$", l2_high_results),
    ],
    ax,
    # acc_all=True,
    acc_seen=True,
    acc_online=True,
)

From the figure we can conclude, that L2 regularization is:

- slightly better than no regularization.
- not very effective for continual learning.
- can have a negative impact on online accuracy if the regularization
  strength is too high.

# EWC

$$
L_{EWC} = \frac{\lambda}{2} \sum_{i} F_i (\theta_i - \theta_i^*)^2
$$
where $F_i$ is the diagonal of the Fisher information matrix. The fisher
information matrix measures how important a parameter is for the task.

* Kirkpatrick, James, Razvan Pascanu, Neil Rabinowitz, Joel Veness, Guillaume Desjardins, Andrei A. Rusu, Kieran Milan, et al. “Overcoming Catastrophic Forgetting in Neural Networks.” Proceedings of the National Academy of Sciences 114, no. 13 (March 28, 2017): 3521–26. https://doi.org/10.1073/pnas.1611835114.


In [None]:
def weighted_l2_reg(
    params: Iterable[Tensor],
    anchor_params: Iterable[Tensor],
    fisher_diags: Iterable[Tensor],
) -> Tensor:
    l2 = torch.scalar_tensor(0.0)
    for param, anchor_param, fisher_diag in zip(
        params, anchor_params, fisher_diags, strict=True
    ):
        assert param.shape == anchor_param.shape
        l2 += (fisher_diag * (param - anchor_param) ** 2).sum()
    return l2 / 2

In [None]:
from capymoa.stream import Schema
from torch import Tensor
from torch.nn.modules import Module


class EWC(L2RegularizedCL):
    def __init__(
        self,
        schema: Schema,
        model: Module,
        lambda_: float,
        fim_buffer: int = 1000,
        batch_size: int = 128,
        random_seed: int = 1,
        lr: float = 0.01,
        device: torch.device = device,
    ) -> None:
        super().__init__(schema, model, lambda_, batch_size, random_seed, lr, device)
        self._buffer_x: Tensor = torch.zeros(
            (fim_buffer, schema.get_num_attributes()), device=device
        )
        self._buffer_y: Tensor = torch.zeros((fim_buffer,), device=device).long()
        self._buffer_index = 0
        self._buffer_size = fim_buffer
        self.fim_diags: Optional[List[Tensor]] = None

    def torch_batch_train(self, x: Tensor, y: Tensor):
        super().torch_batch_train(x, y)
        # Use the buffer as a ring buffer
        for i in range(x.shape[0]):
            self._buffer_x[self._buffer_index] = x[i]
            self._buffer_y[self._buffer_index] = y[i]
            self._buffer_index += 1
            if self._buffer_index >= self._buffer_size:
                self._buffer_index = 0

    @torch.enable_grad()
    def compute_fisher(self) -> List[Tensor]:
        self.model.eval()
        # dataset = TensorDataset(self._buffer_x, self._buffer_y)
        # dataloader = torch.utils.data.DataLoader(
        #     dataset, batch_size=self.batch_size, shuffle=False
        # )
        # fim_diag = FIM(
        #     self.model,
        #     dataloader,
        #     representation=PMatDiag,
        # )
        # fisher_diagonals = [
        #     torch.zeros_like(param) for param in self.model.parameters()
        # ]
        # vector_to_parameters(fim_diag.get_diag(), fisher_diagonals)
        # return fisher_diagonals

        fisher_diags = [torch.zeros_like(param) for param in self.model.parameters()]

        for x, y in zip(self._buffer_x, self._buffer_y):
            x = x.unsqueeze(0)
            y = y.unsqueeze(0)
            self.model.zero_grad()
            y_hat = self.model(x)
            loss = self.criterion(y_hat, y)
            loss.backward()

            for fisher_diag, param in zip(
                fisher_diags, self.model.parameters(), strict=True
            ):
                if param.grad is None:
                    continue
                fisher_diag.add_(param.grad.detach() ** 2)

        for fisher_diag in fisher_diags:
            fisher_diag.div_(self._buffer_size)
        return fisher_diags

    def set_train_task(self, train_task_id: int):
        super().set_train_task(train_task_id)
        if train_task_id == 0:
            return
        super().set_train_task(train_task_id)
        print(f"Computing Fisher information matrix for task {train_task_id}")
        fim_diags = self.compute_fisher()
        if self.fim_diags is None:
            self.fim_diags = fim_diags
        else:
            for i in range(len(self.fim_diags)):
                self.fim_diags[i] += fim_diags[i]

    def regularize(self) -> Tensor:
        if self.anchor_params is not None and self.fim_diags is not None:
            return weighted_l2_reg(
                self.model.parameters(), self.anchor_params, self.fim_diags
            )
        else:
            return torch.scalar_tensor(0.0)

In [None]:
torch.manual_seed(0)
ewc = EWC(schema, WNPerceptron(schema, 128), lambda_=10_000, lr=0.001)
ewc_results = ocl_train_eval_loop(
    ewc,
    scenario.train_streams,
    scenario.test_streams,
    progress_bar=True,
    continual_evaluations=cl_evals,
)

In [None]:
%load_ext autoreload
%autoreload 2
from matplotlib import pyplot as plt

from plot import plot_multiple

fig, ax = plt.subplots(figsize=figsize, layout="constrained")
plot_multiple(
    [
        (r"Control $\lambda=0$", l2_off_results),
        (r"L2 $\lambda=1$", l2_low_results),
        (r"L2 $\lambda=5$", l2_high_results),
        (r"EWC $\lambda=1$", ewc_results),
    ],
    ax,
    # acc_all=True,
    acc_seen=True,
    acc_online=True,
)
ax.set_title("SplitMNIST10/5")
plt.savefig("fig/regularization.pdf", bbox_inches="tight")