# Repeated Augmented Rehearsal (RAR)

RAR is simple baseline that performs multiple optimization steps per-batch with augmented
rehearsal (1).


---

1.  Zhang, Yaqian, Bernhard Pfahringer, Eibe Frank, Albert Bifet, Nick Jin Sean Lim, and
    Yunzhe Jia. “A Simple but Strong Baseline for Online Continual Learning: Repeated
    Augmented Rehearsal.” In Advances in Neural Information Processing Systems 35:
    Annual Conference on Neural Information Processing Systems 2022, NeurIPS 2022, New
    Orleans, LA, USA, November 28 - December 9, 2022, edited by Sanmi Koyejo, S.
    Mohamed, A. Agarwal, Danielle Belgrave, K. Cho, and A. Oh, 2022.
    https://doi.org/10.5555/3600270.3601344.


In [None]:
import torch
from capymoa.base import BatchClassifier
from capymoa.ocl.strategy._experience_replay import _ReservoirSampler
from torch import Tensor, nn

In [None]:
class RAR(BatchClassifier):
    """Repeated Augmented Rehearsal"""

    def __init__(
        self,
        learner: BatchClassifier,
        coreset_size: int,
        augment: nn.Module = nn.Identity(),
        seed: int = 0,
        repeats: int = 1,
        device: str = "cpu",
    ) -> None:
        super().__init__(learner.schema)
        rng = torch.Generator().manual_seed(seed)
        num_features = learner.schema.get_num_attributes()

        self.device = torch.device(device)
        self.learner = learner
        self.augment = augment.to(self.device)
        self.repeats = repeats
        self.coreset = _ReservoirSampler(coreset_size, num_features, rng=rng)

    def train_step(self, x_fresh: Tensor, y_fresh: Tensor) -> None:
        # Sample from reservoir and augment the data
        if self.coreset.is_empty:
            x = self.augment(x_fresh)
            y = y_fresh
        else:
            n = x_fresh.shape[0]
            x_replay, y_replay = self.coreset.sample_n(n)
            x_replay = x_replay.to(self.device, self.x_dtype)
            y_replay = y_replay.to(self.device, self.y_dtype)

            x = torch.cat((x_fresh, x_replay), dim=0).to(self.device, self.x_dtype)
            y = torch.cat((y_fresh, y_replay), dim=0).to(self.device, self.y_dtype)
            x: Tensor = self.augment(x)

        # Train the learner
        x = x.to(self.learner.device, self.learner.x_dtype)
        y = y.to(self.learner.device, self.learner.y_dtype)
        self.learner.batch_train(x, y)

    def batch_train(self, x: Tensor, y: Tensor) -> None:
        for i in range(self.repeats):
            self.train_step(x, y)
        self.coreset.update(x, y)

    @torch.no_grad()
    def batch_predict_proba(self, x: Tensor) -> Tensor:
        x = x.to(self.learner.device, self.learner.x_dtype)
        return self.learner.batch_predict_proba(x)

In [None]:
import kornia.augmentation as K
from capymoa.ann import LeNet5
from capymoa.ann.util import apply_weight_norm
from capymoa.classifier import Finetune
from capymoa.ocl.datasets import SplitFashionMNIST
from capymoa.ocl.evaluation import ocl_train_eval_loop

scenario = SplitFashionMNIST()


def new_rar(wn: bool = False) -> RAR:
    model = LeNet5(10, (1, 28, 28))

    if wn:
        model.fc3 = apply_weight_norm(model.fc3)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    return RAR(
        Finetune(scenario.schema, model, optimizer, device="cuda"),
        augment=nn.Sequential(
            K.RandomHorizontalFlip(p=0.5, keepdim=True),
        ),
        device="cuda",
        coreset_size=1_000,
        repeats=5,
    )


results = {}
results["RAR"] = ocl_train_eval_loop(
    new_rar(), scenario.train_loaders(64), scenario.test_loaders(64), progress_bar=True
)
results["RAR_wn"] = ocl_train_eval_loop(
    new_rar(wn=True),
    scenario.train_loaders(64),
    scenario.test_loaders(64),
    progress_bar=True,
)

In [None]:
from plot import plot_multiple

plot_multiple(results, acc_online=True, acc_seen=True)