In [None]:
"""
In this script we will implement Dark Experience Replay ++ (DER++) from https://arxiv.org/pdf/2004.07211 .

DER++ is a simple reherasal method that uses a replay buffer to store past experiences.
In addition to the Cross-Entropy loss of Experience Replay, it also uses a reconstruction loss
to reconstruct the input data from the replay buffer.
"""

from argparse import ArgumentParser
from mammoth_lite import register_model, ContinualModel, load_runner, train, Buffer, add_rehearsal_args

In [None]:
from torch.nn import functional as F

@register_model('der')  # Register this model with the name 'experience-replay'
class DarkExperienceReplay(ContinualModel):
    COMPATIBILITY = ['class-il', 'task-il']

    @staticmethod
    def get_parser(parser: ArgumentParser):
        add_rehearsal_args(parser)

        # TODO: add the alpha and beta arguments for the MSE and CE losses
        
        return parser

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # initialize the replay buffer with the size defined in the command line arguments
        self.buffer = Buffer(buffer_size=self.args.buffer_size) 

    def observe(self, inputs, labels, not_aug_inputs, epoch=None):
        """
        This essentially implements Eq 6 of https://arxiv.org/pdf/2004.07211
        """
        self.opt.zero_grad()

        outputs = self.net(inputs)
        loss = self.loss(outputs, labels)

        if len(self.buffer) > 0:
            buffer_inputs, buffer_labels, _ = ... # TODO: sample from the buffer
            
            ce_buffer_loss = ... # TODO: compute the cross-entropy loss
            loss += ... # TODO: update the loss with the cross-entropy buffer loss

            buffer_inputs, _, buffer_logits = ... # sample another batch from the buffer
            
            mse_buffer_loss = ... # TODO: compute the MSE loss between the model outputs and buffer logits
            loss += ... # TODO: update the loss with the MSE buffer loss

        loss.backward()
        self.opt.step()
        
        # TODO: update the buffer with the current inputs, labels, and LOGITS!

        return loss.item()

In [None]:
args = {
    'lr': 0.1, 
    'n_epochs': 1,
    'batch_size': 32,
    'buffer_size': 500, 
    'minibatch_size': 32, 
    'alpha': 0.3,
    'beta': 0.5,
    }

model, dataset = load_runner('der','seq-cifar10',args)
train(model, dataset)