In [None]:
"""
In this script we will implement Dark Experience Replay ++ (DER++) from https://arxiv.org/pdf/2004.07211 .

DER++ is a simple reherasal method that uses a replay buffer to store past experiences.
In addition to the Cross-Entropy loss of Experience Replay, it also uses a reconstruction loss
to reconstruct the input data from the replay buffer.
"""

from argparse import ArgumentParser
from mammoth_lite import register_model, ContinualModel, load_runner, train, Buffer, add_rehearsal_args

In [None]:
from torch.nn import functional as F

@register_model('der')
class DarkExperienceReplay(ContinualModel):
    COMPATIBILITY = ['class-il', 'task-il']

    @staticmethod
    def get_parser(parser: ArgumentParser):
        add_rehearsal_args(parser)

        parser.add_argument('--alpha', type=float, required=True,
                            help='MSE distillation loss coefficient.')
        parser.add_argument('--beta', type=float, required=True,
                            help='CE loss coefficient.')
        
        return parser

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # initialize the replay buffer with the size defined in the command line arguments
        self.buffer = Buffer(buffer_size=self.args.buffer_size) 

    def observe(self, inputs, labels, not_aug_inputs, epoch=None):
        self.opt.zero_grad()

        outputs = self.net(inputs)
        loss = self.loss(outputs, labels)

        if len(self.buffer) > 0:
            buffer_inputs, buffer_labels, _ = self.buffer.get_data(
                size=self.args.minibatch_size, device=self.device)
            
            ce_buffer_outputs = self.net(buffer_inputs)
            ce_buffer_loss = self.loss(ce_buffer_outputs, buffer_labels)
            loss = loss + self.args.beta * ce_buffer_loss

            buffer_inputs, _, buffer_logits = self.buffer.get_data(
                size=self.args.minibatch_size, device=self.device)
            
            mse_buffer_outputs = self.net(buffer_inputs)
            mse_buffer_loss = F.mse_loss(mse_buffer_outputs, buffer_logits)
            loss = loss + self.args.alpha * mse_buffer_loss

        loss.backward()
        self.opt.step()
        
        self.buffer.add_data(not_aug_inputs, labels, outputs.detach())

        return loss.item()

In [None]:
"""
Now we can use the `load_runner` function to load our custom model.
"""
args = {
    'lr': 0.1, 
    'n_epochs': 1,
    'batch_size': 32,
    'buffer_size': 500, 
    'minibatch_size': 32, 
    'alpha': 0.3,
    'beta': 0.5,
    }

model, dataset = load_runner('der','seq-cifar10',args)
train(model, dataset)