In [None]:
import torch
from capymoa.base import BatchClassifier
from capymoa.ocl.base import TestTaskAware
from capymoa.stream import Schema
from torch import Tensor, nn

from plot import plot_multiple

In [None]:
# Wrapper for evaluation loop
from capymoa.base import Classifier
from capymoa.ocl.datasets import _BuiltInCIScenario
from capymoa.ocl.evaluation import OCLMetrics, ocl_train_eval_loop


def run(scenario: _BuiltInCIScenario, learner: Classifier) -> OCLMetrics:
    return ocl_train_eval_loop(
        learner,
        scenario.train_loaders(128),
        scenario.test_loaders(128),
        progress_bar=True,
        continual_evaluations=1,
        eval_window_size=128 * 3,
    )

In [None]:
from typing import Dict, List

import numpy as np
from torch.utils.data import DataLoader, TensorDataset


class GDumb(BatchClassifier, TestTaskAware):
    """
    GDumb: A Simple Approach for Online Class-Incremental Learning
    https://arxiv.org/abs/2005.12797
    """

    def __init__(
        self,
        schema: Schema,
        model: nn.Module,
        epochs: int,
        batch_size: int,
        coreset_size: int,
        lr: float = 0.001,
        device: str = "cpu",
    ):
        super().__init__(schema)
        self.schema = schema
        self.epochs = epochs
        self.batch_size = batch_size
        self.lr = lr
        self.device = torch.device(device)
        self.model = model.to(device)
        self.original_state_dict = model.state_dict()
        self.buffer: Dict[int, List[Tensor]] = {
            k: [] for k in range(schema.get_num_classes())
        }
        self.coreset_size = coreset_size
        self.loss_func = nn.CrossEntropyLoss()

    @property
    def count(self) -> int:
        return sum(self.class_counts)

    @property
    def class_counts(self) -> List[int]:
        return [len(v) for v in self.buffer.values()]

    def batch_train(self, x: Tensor, y: Tensor) -> None:
        for xi, yi in zip(x, y):
            yi = int(yi.item())

            # Room left in the coreset for this example
            if self.count < self.coreset_size:
                self.buffer[yi].append(xi.cpu())
            else:
                # Coreset is full, replace a random example from the majority class
                replace_class = int(np.argmax(self.class_counts))
                replace_idx = np.random.randint(len(self.buffer[replace_class]))
                del self.buffer[replace_class][replace_idx]
                self.buffer[yi].append(xi.cpu())

    def batch_predict_proba(self, x: Tensor) -> Tensor:
        return self.model(x).softmax(dim=1)

    def gdumb_fit(self) -> None:
        """
        Fit the model on the coreset.
        """
        # Assemble a dataset from the buffer
        x = torch.cat([torch.stack(v) for v in self.buffer.values() if len(v) > 0])
        y = torch.cat([torch.full((len(v),), k) for k, v in self.buffer.items()])
        dataset = TensorDataset(x, y)

        self.model.load_state_dict(self.original_state_dict)
        self.model.to(self.device)
        self.model.train()
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)

        loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)

        for epoch in range(self.epochs):
            for batch_x, batch_y in loader:
                batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
                optimizer.zero_grad()
                outputs = self.model(batch_x)
                loss = self.loss_func(outputs, batch_y)
                loss.backward()
                optimizer.step()

    def on_test_task(self, task_id: int) -> None:
        if task_id == 0:
            self.gdumb_fit()

In [None]:
from capymoa.ann import LeNet5
from capymoa.ocl.datasets import SplitFashionMNIST

scenario = SplitFashionMNIST()

In [None]:
def new_gdumb(n: int) -> GDumb:
    return GDumb(
        scenario.schema,
        model=LeNet5(10),
        epochs=10,
        batch_size=32,
        coreset_size=n,
        lr=0.01,
        device="cuda",
    )


results = {}
results["GDumb $n=100$"] = run(scenario, new_gdumb(100))
results["GDumb $n=1000$"] = run(scenario, new_gdumb(1000))
results["GDumb $n=10000$"] = run(scenario, new_gdumb(10000))

In [None]:
plot_multiple(results, acc_seen=True, acc_online=True)