In [None]:
import numpy as np
import torch
from capymoa.base import BatchClassifier
from capymoa.stream import Schema
from torch import Tensor


class NCM(BatchClassifier):
    _dtype = torch.float32

    def __init__(
        self,
        schema: Schema,
        device: torch.device | str = torch.device("cpu"),
    ):
        super().__init__(schema)
        n_classes = schema.get_num_classes()
        n_feats = schema.get_num_attributes()
        self._device = device
        self.sum = torch.zeros((n_classes, n_feats), device=device)
        self.count = torch.zeros((n_classes,), device=device, dtype=torch.int64)
        self.mean = torch.zeros((n_classes, n_feats), device=device)

    @torch.no_grad()
    def batch_train(self, x: np.ndarray, y: np.ndarray) -> None:
        x_ = torch.from_numpy(x).to(self._device, self._dtype)  # (batch_size, features)
        y_ = torch.from_numpy(y).to(self._device, self._dtype)  # (batch_size,)

        # Update mean and count
        for i in range(self.schema.get_num_classes()):
            mask = y_ == i
            self.sum[i] += x_[mask].sum(dim=0)
            self.count[i] += mask.sum()
            self.mean[i] = self.sum[i] / self.count[i] if self.count[i] > 0 else 0

    @torch.no_grad()
    def batch_predict_proba(self, x: np.ndarray) -> np.ndarray:
        assert x.ndim == 2, "Input must be a 2D array (batch_size, features)"
        x_ = torch.from_numpy(x).to(self._device, self._dtype)
        # Calculate distances to class means
        distances = torch.cdist(x_.unsqueeze(0), self.mean.unsqueeze(0)).squeeze(0)

        # Convert distances to pseudo-probabilities. Using the inverse weighted
        # distance method.
        inv_distances = 1 / (1 + distances)
        probabilities = inv_distances / inv_distances.sum(dim=1, keepdim=True)
        return probabilities.cpu().numpy()

https://fanf2.user.srcf.net/hermes/doc/antiforgery/stats.pdf

In [None]:
def ema_mean_update(
    batch: Tensor,
    lambda_: float,
    count: float,
    mean: Tensor,
):
    batch_size: int = batch.shape[0]
    new_count = lambda_ * count + batch_size
    delta = batch - mean
    new_mean = mean + delta.sum(dim=0) / new_count
    return new_count, new_mean


def ema_batch_update_scatter(
    batch: Tensor,
    lambda_: float,
    count: float,
    mean: Tensor,
    scatter_matrix: Tensor,
):
    batch_size: int = batch.shape[0]
    new_count = lambda_ * count + batch_size
    delta = batch - mean
    new_mean = mean + delta.sum(dim=0) / new_count
    new_scatter = lambda_ * scatter_matrix + delta.T @ (batch - new_mean)
    return new_count, new_mean, new_scatter


class SLDA(BatchClassifier):
    _dtype = torch.float32

    def __init__(
        self,
        schema: Schema,
        lambda_: float = 0.9,
        ridge: float = 1e-6,
        bias: bool = True,
        device: torch.device | str = torch.device("cpu"),
    ):
        super().__init__(schema)
        n_classes = schema.get_num_classes()
        n_feats = schema.get_num_attributes()
        self._device = device
        self.class_counts = torch.zeros((n_classes,), device=device, dtype=torch.int64)
        self.class_means = torch.zeros((n_classes, n_feats), device=device)

        self.count = 0
        self.bias = bias
        self.mean = torch.zeros(n_feats, device=device)
        self.scatter = torch.eye(n_feats, device=device)
        self.ridge = torch.eye(n_feats, device=device) * ridge
        self.lambda_ = lambda_

    @torch.no_grad()
    def batch_train(self, x: np.ndarray, y: np.ndarray) -> None:
        x_ = torch.from_numpy(x).to(self._device, self._dtype)  # (batch_size, features)
        y_ = torch.from_numpy(y).to(self._device, self._dtype)  # (batch_size,)

        # Update mean and count
        for i in range(self.schema.get_num_classes()):
            mask = y_ == i
            x_masked = x_[mask]
            if x_masked.size(0) == 0:
                continue
            self.class_counts[i], self.class_means[i] = ema_mean_update(
                x_masked, self.lambda_, self.class_counts[i].item(), self.class_means[i]
            )

        self.count, self.mean, self.scatter = ema_batch_update_scatter(
            x_, self.lambda_, self.count, self.mean, self.scatter
        )

    @torch.no_grad()
    def batch_predict_proba(self, x: np.ndarray) -> np.ndarray:
        x_ = torch.from_numpy(x).to(self._device, self._dtype)
        if self.count == 0:
            # Return uniform probabilities if no training has been done
            return np.full(
                (x_.shape[0], self.schema.get_num_classes()),
                1.0 / self.schema.get_num_classes(),
            )

        covariance = self.scatter / self.count + self.ridge
        weights: Tensor = torch.linalg.solve(covariance, self.class_means.T).T
        bias = -0.5 * (self.class_means @ weights.T).diagonal() + torch.log(
            self.class_counts / self.count
        )
        scores = x_ @ weights.T
        if self.bias:
            scores += bias
        proba = torch.softmax(scores, dim=1)
        return proba.cpu().numpy()

In [None]:
from capymoa.ocl.datasets import SplitCIFAR100ViT

stream = SplitCIFAR100ViT()

In [None]:
from capymoa.classifier import Finetune
from capymoa.ocl.ann import WNPerceptron
from capymoa.ocl.evaluation import ocl_train_eval_loop
from capymoa.ocl.strategy import ExperienceReplay
from torch.optim import Adam

results = []
for label, learner in [
    (
        "ER",
        ExperienceReplay(
            Finetune(stream.schema, WNPerceptron, lambda p: Adam(p, 0.01)), 1_000
        ),
    ),
    ("NCM", NCM(stream.schema, device="cuda")),
    (r"SLDA $\lambda=1.0$", SLDA(stream.schema, 1, device="cuda")),
    (r"SLDA $\lambda=0.9$", SLDA(stream.schema, 0.9, device="cuda")),
]:
    row = ocl_train_eval_loop(
        learner,
        stream.train_loaders(128),
        stream.test_loaders(128),
        progress_bar=True,
        continual_evaluations=10,
        eval_window_size=128 * 3,
    )
    results.append((label, row))

In [None]:
%load_ext autoreload
%autoreload 2
from matplotlib import pyplot as plt

from plot import figsize, plot_multiple

fig, ax = plt.subplots(figsize=figsize)
plot_multiple(
    results,
    ax,
    acc_online=True,
    acc_seen=True,
)
ax.set_title(f"{stream}")
ax.set_ylim(50, 100)

* $\lambda$ controls the exponential moving average of the covariance matrix.
* $\lambda=1$ is the least forgetful.
* $\lambda=0.9$ is more forgetful but adapts faster to changes in the data.
* **bias** is important for class imbalances. The no-bias version is less robust
  during the first few iterations.