In [None]:
"""
In this script we will implement Learning without Forgetting (LwF) from https://arxiv.org/abs/1606.09282 .
Specifically, we will implement the version Multi Class LwF (LwF.MC) from https://arxiv.org/pdf/1611.07725 , which 
often performs better than the original LwF.

LwF uses a distillation loss to retain knowledge from previous tasks while learning new ones.
It distills the output from a previous checkpint of the model (the teacher) into the current model (the student).
"""

from argparse import ArgumentParser
from mammoth_lite import register_model, ContinualModel, load_runner, train

In [None]:
from copy import deepcopy
import torch
from torch.nn import functional as F

@register_model('lwfmc')
class LearningWithoutForgetting(ContinualModel):
    COMPATIBILITY = ['class-il', 'task-il']

    def __init__(self, backbone, loss, args, transform, dataset=None):
        super().__init__(backbone, loss, args, transform, dataset=dataset)

        self.eye = torch.eye(self.dataset.N_CLASSES_PER_TASK *
                             self.dataset.N_TASKS).to(self.device)

        self.class_means = None
        self.old_net = None
        self.current_task = 0

    @staticmethod
    def get_parser(parser: ArgumentParser):
        parser.add_argument('--wd_reg', type=float, default=0.0,
                            help='Custom weight decay regularization coefficient.')

        return parser
    
    def end_task(self, dataset) -> None:
        self.old_net = deepcopy(self.net.eval()) # Save a copy of the current model to use as the teacher
        self.net.train()

        self.current_task += 1

    def observe(self, inputs, labels, not_aug_inputs, logits=None, epoch=None):
        """
        The LwF simply trains the model on the current task, while also distilling knowledge
        from the previous task(s) using the logits from the old model.
        """
        if self.current_task > 0:
            with torch.no_grad():
                # step 1: compute the predictions of the old model
                # NOTE: Use the sigmoid activation function to get probabilities
                logits = torch.sigmoid(self.old_net(inputs))
        self.opt.zero_grad()
        loss = self.get_loss(inputs, labels, self.current_task, logits)
        loss.backward()

        self.opt.step()

        return loss.item()

    def get_loss(self, inputs: torch.Tensor, labels: torch.Tensor,
                 task_idx: int, logits: torch.Tensor) -> torch.Tensor:
        previous_classes = task_idx * self.dataset.N_CLASSES_PER_TASK
        current_classes = (task_idx + 1) * self.dataset.N_CLASSES_PER_TASK

        # step 2: compute the outputs of the current model
        outputs = self.net(inputs)

        # step 2.1: it is better to limit the outputs to the current classes
        outputs = outputs[:, :current_classes]
        if task_idx == 0:
            # step 3: if this is the first task, we do not have any previous classes
            # so we can simply compute the **binary cross entropy** loss with the targets
            targets = self.eye[labels][:, :current_classes]
            loss = F.binary_cross_entropy_with_logits(outputs, targets)
            assert loss >= 0
        else:
            # step 4: if this is not the first task, we need to compute the loss
            # using the outputs of the old model (the teacher) and the targets of the current task
            
            # we need to concatenate the outputs of the old model with the outputs of the current model
            targets = self.eye[labels][:, previous_classes:current_classes]
            comb_targets = torch.cat((logits[:, :previous_classes], targets), dim=1)

            # we can now compute the loss using the binary cross entropy
            loss = F.binary_cross_entropy_with_logits(outputs, comb_targets)
            assert loss >= 0

        # step 5: if the user has specified a weight decay regularization coefficient,
        # we add the weight decay regularization term to the loss
        if self.args.wd_reg:
            loss += self.args.wd_reg * torch.sum(self.net.get_params() ** 2)

        return loss



In [None]:
"""
Now we can use the `load_runner` function to load our custom model.
"""
args = {
    'lr': 0.1, 
    'n_epochs': 1,
    'batch_size': 32,
    }

model, dataset = load_runner('lwfmc','seq-cifar10',args)
train(model, dataset)