In [None]:
from argparse import ArgumentParser
from mammoth_lite import register_model, ContinualModel, load_runner, train, add_rehearsal_args, Buffer, ContinualDataset, MammothBackbone

# Filling the buffer

In [None]:
import torch
import numpy as np

@torch.no_grad()
def fill_balanced_buffer(buffer: Buffer, dataset: ContinualDataset, t_idx: int) -> None:
    """
    Adds examples from the current task to the memory buffer.

    Args:
        buffer: the memory buffer
        dataset: the dataset from which take the examples
        t_idx: the task index
    """

    n_seen_classes = dataset.N_CLASSES_PER_TASK * (t_idx + 1)
    n_past_classes = dataset.N_CLASSES_PER_TASK * t_idx
    samples_per_class = buffer.buffer_size // n_seen_classes

    mask = dataset.train_loader.dataset.targets >= n_past_classes
    dataset.train_loader.dataset.targets = dataset.train_loader.dataset.targets[mask]
    dataset.train_loader.dataset.data = dataset.train_loader.dataset.data[mask]

    if t_idx > 0:
        # 1) First, subsample prior classes
        buf_data = buffer.examples
        buf_y = buffer.labels

        buffer.reset() # clear the buffer before filling it

        # Fill the buffer with samples from prior classes up to `samples_per_class`
        # This is done to ensure that the buffer has a balanced number of samples per class
        for _y in buf_y.unique():
            cls_idxs = (buf_y == _y)
            _buf_data_idx = buf_data[cls_idxs][:samples_per_class]
            buffer.add_data(examples=_buf_data_idx,
                            labels=buf_y[cls_idxs][:samples_per_class])

    examples, labels = dataset.train_loader.dataset.data, dataset.train_loader.dataset.targets

    # 2) Fill the buffer with samples from the current class
    for _y in np.unique(labels):
        cls_idxs = (labels == _y)
        _x, _y = examples[cls_idxs], labels[cls_idxs]

        # Add only up to `samples_per_class` examples per class
        buffer.add_data(
            examples=_x[:samples_per_class],
            labels=_y[:samples_per_class]
        )

    # NOTE: if this fails and you comment this it still works but the buffer will not be balanced
    assert len(buffer.examples) <= buffer.buffer_size, f"buffer overflowed its maximum size: {len(buffer)} > {buffer.buffer_size}"
    assert buffer.num_seen_examples <= buffer.buffer_size, f"buffer has been overfilled, there is probably an error: {buffer.num_seen_examples} > {buffer.buffer_size}"

## Bonus: fill the buffer with Herding

Herding selects the examples that are the closest from the average feature representation of the class.

In [None]:
import torch

@torch.no_grad()
def fill_buffer_with_herding(buffer: Buffer, dataset: ContinualDataset, t_idx: int, net: MammothBackbone) -> None:
    """
    Adds examples from the current task to the memory buffer **with Herding**.

    Args:
        buffer: the memory buffer
        dataset: the dataset from which take the examples
        t_idx: the task index
        net: the model instance
    """
    mode = net.training
    net.eval()
    device = next(net.parameters()).device

    n_seen_classes = dataset.N_CLASSES_PER_TASK * (t_idx + 1)
    n_past_classes = dataset.N_CLASSES_PER_TASK * t_idx
    samples_per_class = buffer.buffer_size // n_seen_classes

    mask = dataset.train_loader.dataset.targets >= n_past_classes
    dataset.train_loader.dataset.targets = dataset.train_loader.dataset.targets[mask]
    dataset.train_loader.dataset.data = dataset.train_loader.dataset.data[mask]

    if t_idx > 0:
        # 1) First, subsample prior classes
        buf_data = buffer.examples
        buf_y = buffer.labels

        buffer.reset() # clear the buffer before filling it

        # Fill the buffer with samples from prior classes up to `samples_per_class`
        # This is done to ensure that the buffer has a balanced number of samples per class
        for _y in buf_y.unique():
            cls_idxs = (buf_y == _y)
            _buf_data_idx = buf_data[cls_idxs][:samples_per_class]
            buffer.add_data(examples=_buf_data_idx,
                            labels=buf_y[cls_idxs][:samples_per_class])
            
    norm_trans = dataset.get_normalization_transform()

    # 2 Extract all features
    examples, labels, features = [], [], []
    for data in dataset.train_loader:
        x, y, not_norm_x = data[0], data[1], data[2]
        if not x.size(0):
            continue
        examples.append(not_norm_x.cpu())
        labels.append(y.cpu())

        # Compute the features for the current batch
        feats = net(norm_trans(not_norm_x.to(device)), returnt='features')
        features.append(feats.cpu())

    examples, labels, features = torch.cat(examples), torch.cat(labels), torch.cat(features)

    # 3 Fill the buffer with samples from the current class using Herding
    # Herding is a greedy method to select samples that are representative of the class
    # It selects samples that minimize the distance to the mean feature of the class
    for _y in labels.unique():
        cls_idxs = (labels == _y)
        _x, _y, feats = examples[cls_idxs], labels[cls_idxs], features[cls_idxs]

        # Herding step 1: starting from the mean feature of the class
        mean_class_feat = feats.mean(0, keepdim=True)

        # Herding step 2: ... and an empty running sum
        running_sum = torch.zeros_like(mean_class_feat)
        i = 0

        while i < samples_per_class and i < feats.shape[0]:

            # Herding step 3: Compute the cost as the distance to the mean feature
            # The cost defines which sample we should add to the buffer such that
            # the running mean of the features is as close as possible to the mean feature
            running_mean = (feats + running_sum) / (i + 1)

            # Compute the cost as the L2 norm between the mean feature and the running mean
            cost = (mean_class_feat - running_mean).norm(2, 1)

            # Herding step 4: Select the sample with the minimum cost
            idx_min = cost.argmin().item()

            # Herding step 5: Add the sample to the buffer
            buffer.add_data(
                examples=_x[idx_min:idx_min + 1],
                labels=_y[idx_min:idx_min + 1]
            )

            running_sum += feats[idx_min:idx_min + 1]
            feats[idx_min] = feats[idx_min] + 1e6
            i += 1

    assert len(buffer.examples) <= buffer.buffer_size, f"buffer overflowed its maximum size: {len(buffer)} > {buffer.buffer_size}"
    assert buffer.num_seen_examples <= buffer.buffer_size, f"buffer has been overfilled, there is probably an error: {buffer.num_seen_examples} > {buffer.buffer_size}"

    net.train(mode)


In [None]:
def icarl_replay(self: ContinualModel, dataset: ContinualDataset, current_task: int):
    """
    Merge the replay buffer with the current task data.

    Args:
        self: the model instance
        dataset: the dataset
        current_task: the current task index
    """

    if current_task > 0:
        data_concatenate = torch.cat if isinstance(dataset.train_loader.dataset.data, torch.Tensor) else np.concatenate
        def refold_transform(x):
            return (x.cpu() * 255).permute([0, 2, 3, 1]).numpy().astype(np.uint8)
        
        # REDUCE AND MERGE TRAINING SET
        dataset.train_loader.dataset.targets = np.concatenate([
            dataset.train_loader.dataset.targets,
            self.buffer.labels.cpu().numpy()[:len(self.buffer)]
        ])
        dataset.train_loader.dataset.data = data_concatenate([
            dataset.train_loader.dataset.data,
            refold_transform(self.buffer.examples[:len(self.buffer)])
        ])

# iCaRL model

In [None]:
from copy import deepcopy
import torch
from torch.nn import functional as F

@register_model('icarl')
class iCaRL(ContinualModel):
    COMPATIBILITY = ['class-il', 'task-il']

    @staticmethod
    def get_parser(parser) -> ArgumentParser:
        add_rehearsal_args(parser)
        parser.add_argument('--opt_wd', type=float, default=1e-5,
                            help='Optimizer weight decay')
        parser.add_argument('--use_herding', type=int, default=1, choices=[0, 1],
                            help='Use herding to fill the buffer')
        return parser

    def __init__(self, backbone, loss, args, transform, dataset=None):
        super().__init__(backbone, loss, args, transform, dataset=dataset)

        # Instantiate buffer
        self.buffer = Buffer(self.args.buffer_size)

        self.eye = torch.eye(self.dataset.N_CLASSES_PER_TASK *
                             self.dataset.N_TASKS).to(self.device)

        self.class_means = None
        self.old_net = None
        self.current_task = 0

    def forward(self, x):
        if self.class_means is None:
            with torch.no_grad():
                self.compute_class_means()
                self.class_means = self.class_means.squeeze()

        # Compute the features
        feats = self.net(x, returnt='features')

        feats = feats.view(feats.size(0), -1)

        # Compute the nearest-mean-of-exemplars prediction (Eq. 2 of the iCaRL paper)
        pred = (self.class_means.unsqueeze(0) - feats.unsqueeze(1)).pow(2).sum(2)
        return -pred
    
    @torch.no_grad()
    def compute_class_means(self) -> None:
        """
        Computes a vector representing mean features for each class.
        """
        was_training = self.net.training
        self.net.eval()
        transform = self.dataset.get_normalization_transform()
        class_means = []
        examples, labels = self.buffer.examples, self.buffer.labels
        for _y in labels.unique():
            x_buf = torch.stack(
                [examples[i]
                 for i in range(0, len(examples))
                 if labels[i].cpu() == _y]
            )

            all_features = []
            while len(x_buf):
                batch = x_buf[:self.args.batch_size]
                x_buf = x_buf[self.args.batch_size:]
                
                # Apply the normalization transform
                batch = torch.stack([transform(x) for x in batch.cpu()]).to(self.device)

                # Compute the features for the current batch
                feats = self.net(batch, returnt='features')

                all_features.append(feats)
            
            # Concatenate all features and compute the mean
            all_features = torch.cat(all_features).mean(0)

            class_means.append(all_features.flatten())
        self.class_means = torch.stack(class_means)
        self.net.train(was_training)

    def end_task(self, dataset) -> None:
        # Save the current model as the old model
        self.net.eval()
        self.old_net = deepcopy(self.net)

        # Fill the buffer with examples from the current task
        with torch.no_grad():
            if self.args.use_herding:
                fill_buffer_with_herding(self.buffer, dataset, self.current_task, net=self.net)
            else:
                fill_balanced_buffer(self.buffer, dataset, self.current_task)
        self.class_means = None

        self.current_task += 1

    def begin_task(self, dataset):
        # Concatenate the buffer with the current task data
        icarl_replay(self, dataset, self.current_task)
        self.net.train()

    def get_loss(self, inputs: torch.Tensor, labels: torch.Tensor,
                 task_idx: int, logits: torch.Tensor) -> torch.Tensor:
        """
        This is pretty much the same as LwF.MC
        """
        previous_classes = task_idx * self.dataset.N_CLASSES_PER_TASK

        # Compute the outputs of the current model
        outputs = self.net(inputs)
        if task_idx == 0:
            # If this is the first task, we do not have any previous classes
            targets = self.eye[labels]
            # Compute the loss as binary cross-entropy
            loss = F.binary_cross_entropy_with_logits(outputs, targets)
            assert loss >= 0
        else:
            # If this is not the first task, we have previous classes
            targets = self.eye[labels]
            # We concatenate the logits of the previous classes with the targets of the current task
            comb_targets = torch.cat((logits[:, :previous_classes], targets[:, previous_classes:]), dim=1)
            # Compute the loss as binary cross-entropy
            loss = F.binary_cross_entropy_with_logits(outputs, comb_targets)
            assert loss >= 0

        return loss

    def observe(self, inputs, labels, not_aug_inputs, logits=None, epoch=None):
        self.class_means = None
        if self.current_task > 0:
            with torch.no_grad():
                # Compute the output for the old model
                logits = torch.sigmoid(self.old_net(inputs))

        self.opt.zero_grad()
        loss = self.get_loss(inputs, labels, self.current_task, logits)

        # Add weight decay to the loss
        if self.args.opt_wd > 0:
            loss = loss + torch.sum(self.net.get_params() ** 2) * self.args.opt_wd
        loss.backward()

        self.opt.step()

        return loss.item()


In [None]:
"""
Now we can use the `load_runner` function to load our custom model.
"""
args = {
    'lr': 0.1, 
    'n_epochs': 1,
    'batch_size': 32,
    'buffer_size': 500,
    'opt_wd': 1e-5,
    'use_herding': 1,
    }

model, dataset = load_runner('icarl','seq-cifar10',args)
train(model, dataset)