In [None]:
from copy import deepcopy
from typing import Optional

import numpy as np
import torch
from capymoa.ann import Perceptron
from capymoa.base import BatchClassifier
from capymoa.classifier import Finetune
from capymoa.instance import Instance
from capymoa.ocl.ann import WNPerceptron
from capymoa.ocl.base import TaskBoundaryAware
from capymoa.ocl.datasets import SplitMNIST
from capymoa.ocl.evaluation import ocl_train_eval_loop
from capymoa.ocl.strategy import ExperienceReplay
from capymoa.stream import Schema
from matplotlib import pyplot as plt
from torch import Tensor, nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
plt.rcParams["font.family"] = "sans-serif"
plt.rcParams["font.sans-serif"] = ["Noto Sans"]
plt.rcParams["font.size"] = 9
pt = 1 / 72.27
figsize_169 = (455 * pt, 256 * pt)
figsize = (figsize_169[0], 0.45 * figsize_169[0])
scenario = SplitMNIST()
schema = scenario.schema
print(device)

In [None]:
from torch.nn.functional import kl_div, log_softmax


def hinton_kdistill_loss(
    teacher_logits: Tensor, student_logits: Tensor, temperature: float
) -> Tensor:
    r"""Knowledge distillation loss function from Hinton et al. (2015) [1]_ [3]_.

    A type of response based distillation that forces the student to mimic the
    output class probabilities of the teacher model. The loss is calculated as:

    .. math::
        L_{KD} =  t^2 \times KL(
            \text{softmax}(\mathbf{a}{t}),
            \text{softmax}(\mathbf{b}{t}))

    where:

    * :math:`t` is the temperature for the softmax function.
    * :math:`\mathbf{a}` is the output (logits) from the teacher model.
    * :math:`\mathbf{b}` is the output (logits) from the student model.
    * :math:`\mathbf{x}` is the input to the model.

    The original paper uses the cross-entropy loss between the soft targets and
    soft predictions. Cross entropy loss can be defined as the Kullback-Leibler
    (KL) divergence plus the entropy of the target distribution [2]_. The entropy
    of the target distribution is irrelevant to optimisation since it is a
    constant. It is removed in this implementation, which is nice since the loss
    will equal zero when the student model matches the teacher model.


    .. [1] Geoffrey Hinton, Oriol Vinyals, Jeff Dean  (2015) Distilling the
        Knowledge in a Neural Network

    .. [2] https://en.wikipedia.org/wiki/Cross-entropy

    .. [3] https://intellabs.github.io/distiller/knowledge_distillation.html

    :param teacher_logits: The output from the teacher model in shape
        ``(batch_size, n_classes)``
    :param student_logits: The output from the student model in shape
        ``(batch_size, n_classes)``
    :param temperature: The temperature for the softmax function
    :return: The knowledge distillation loss.
    """
    if teacher_logits.ndim != 2 or student_logits.ndim != 2:
        raise ValueError(
            "Teacher and student logits must have two dimensions, "
            f" but got {teacher_logits.ndim} and {student_logits.ndim}"
        )
    if teacher_logits.shape != student_logits.shape:
        raise ValueError(
            "Teacher and student logits must have the same shape, "
            f" but got {teacher_logits.shape} and {student_logits.shape}"
        )
    # Calculate the soft targets and soft predictions
    return (
        kl_div(
            log_softmax(student_logits / temperature, dim=1),  # Soft predictions
            log_softmax(teacher_logits / temperature, dim=1),  # Soft targets
            log_target=True,
            reduction="batchmean",  # Mathematically correct unlike the default
        )
        * temperature**2
    )

In [None]:
class LWF(BatchClassifier, TaskBoundaryAware):
    def __init__(
        self,
        schema: Schema,
        model: nn.Module,
        lambda_: float,
        batch_size: int = 128,
        random_seed: int = 1,
        lr: float = 0.01,
        device: torch.device = device,
    ) -> None:
        super().__init__(schema, batch_size, random_seed)
        self.lambda_ = lambda_
        self.device = device
        self.lr = lr
        self.model = model.to(self.device)
        self.teacher: Optional[nn.Module] = None
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        self.criterion = nn.CrossEntropyLoss()

    def batch_train(self, x: np.ndarray, y: np.ndarray) -> None:
        self.torch_batch_train(
            torch.from_numpy(x).float().to(self.device),
            torch.from_numpy(y).long().to(self.device),
        )

    def torch_batch_train(self, x: Tensor, y: Tensor):
        self.model.train()
        self.optimizer.zero_grad()
        y_hat = self.model(x)

        ce_loss = self.criterion(y_hat, y)
        kd_loss = self.kd_loss(x, y_hat)
        loss = ce_loss + self.lambda_ * kd_loss
        loss.backward()
        self.optimizer.step()

    def kd_loss(self, x: Tensor, student_logits: Tensor) -> Tensor:
        if self.teacher is None:
            return torch.scalar_tensor(0.0, device=self.device)
        with torch.no_grad():
            teacher_logits = self.teacher(x)
        return hinton_kdistill_loss(
            teacher_logits=teacher_logits,
            student_logits=student_logits,
            temperature=2.0,
        )

    def set_train_task(self, train_task_id: int):
        # Adam maintains momentum, so we need to reinitialize the optimizer
        # when the task changes
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        if train_task_id == 0:
            return
        self.teacher = deepcopy(self.model)
        self.teacher.eval()

    @torch.no_grad()
    def predict_proba(self, instance: Instance) -> np.ndarray:
        x = torch.from_numpy(instance.x).float().to(self.device).view(1, -1)
        self.model.eval()

        y_hat = self.model(x)
        return torch.softmax(y_hat, dim=1).cpu().numpy()

    def __str__(self) -> str:
        return f"LWF(lambda_={self.lambda_}, model={self.model})"

In [None]:
torch.manual_seed(0)
lwf = LWF(schema, Perceptron(schema, 128), lambda_=0.5, lr=1e-3)
learner = ExperienceReplay(lwf, 100)
lwf_results = ocl_train_eval_loop(
    learner, scenario.train_streams, scenario.test_streams, progress_bar=True
)

er = ExperienceReplay(
    Finetune(schema, WNPerceptron(schema, 128), 100, device=device), 100
)
er_results = ocl_train_eval_loop(
    er, scenario.train_streams, scenario.test_streams, progress_bar=True
)

In [None]:
# for lambda_ in [0.1, 0.5, 1.0]:
#     torch.manual_seed(0)
#     lwf = LWF(schema, Percep(schema, 128), lambda_=lambda_, lr=1e-3)
#     learner = ExperienceReplay(lwf, 100)
#     lwf_results = ocl_train_eval_loop(learner, scenario.train_streams, scenario.test_streams, progress_bar=False)
#     print(lambda_, lwf_results.accuracy_all_avg)

In [None]:
%load_ext autoreload
%autoreload 2
from matplotlib import pyplot as plt

from plot import plot_multiple

fig, ax = plt.subplots(figsize=figsize, layout="constrained")
plot_multiple(
    [
        (r"LWF n=100", lwf_results),
        (r"ER n=100", er_results),
    ],
    ax,
    # acc_all=True,
    acc_seen=True,
    acc_online=True,
)
ax.set_title("SplitMNIST10/5")
plt.savefig("fig/lwf.pdf", bbox_inches="tight")