In [None]:
import timm

In [None]:
from capymoa.stream.preprocessing.pipeline import ClassifierPipeline, Transformer
from capymoa.instance import LabeledInstance, Instance
from capymoa.stream import Schema
import torch


class PyTorchTransformer(Transformer):
    def __init__(self, schema: Schema):
        self.device = torch.device("cuda")
        self.shape = (1, 3, 32, 32)

        self.model = timm.create_model(
            "vit_base_patch16_224.augreg2_in21k_ft_in1k", num_classes=0
        )
        data_config = timm.data.resolve_model_data_config(self.model)
        self.transforms = timm.data.create_transform(**data_config, is_training=False)
        self.model = self.model.eval().to("cuda")
        self.schema = Schema.from_basic_classify(
            768, schema.get_num_classes(), f"ViT({schema.dataset_name})"
        )

    @torch.no_grad()
    def transform_instance(self, instance: LabeledInstance) -> LabeledInstance:
        x = torch.from_numpy(instance.x).view(self.shape).to(self.device)
        x = self.model(self.transforms(x)).cpu().detach().numpy()
        return LabeledInstance.from_array(self.schema, x, instance.y_index)

    def get_schema(self) -> Schema:
        return self.schema

    def restart(self):
        pass

In [None]:
from capymoa.base import BatchClassifier
import numpy as np
import torch


class NCM(BatchClassifier):
    _dtype = torch.float32

    def __init__(
        self,
        schema: Schema,
        batch_size: int = 1,
        device: torch.device | str = torch.device("cpu"),
    ):
        super().__init__(schema, batch_size=batch_size)

        self._device = device
        #: Sum of features (num_classes, features)
        self.sum = torch.zeros(
            (self.schema.get_num_classes(), self.schema.get_num_attributes()),
            device=device,
        )
        #: Number of instances (num_classes,)
        self.count = torch.zeros(
            (self.schema.get_num_classes(),), device=device, dtype=torch.int64
        )
        #: Cached mean calculated from sum and count
        self.mean = torch.zeros(
            (self.schema.get_num_classes(), self.schema.get_num_attributes()),
            device=device,
        )

    @torch.no_grad()
    def batch_train(self, x: np.ndarray, y: np.ndarray) -> None:
        x_ = torch.from_numpy(x).to(self._device, self._dtype)  # (batch_size, features)
        y_ = torch.from_numpy(y).to(self._device, self._dtype)  # (batch_size,)

        # Update mean and count
        for i in range(self.schema.get_num_classes()):
            mask = y_ == i
            self.sum[i] += x_[mask].sum(dim=0)
            self.count[i] += mask.sum()
            self.mean[i] = self.sum[i] / self.count[i] if self.count[i] > 0 else 0

    @torch.no_grad()
    def predict_proba(self, instance: Instance) -> np.ndarray:
        x = torch.from_numpy(instance.x).to(self._device, self._dtype)
        distances = torch.cdist(x.unsqueeze(0), self.mean.unsqueeze(0)).squeeze(0)
        normed_distances = (distances / distances.sum()).cpu().numpy()
        return normed_distances

    def __str__(self):
        return "NCM"

In [None]:
from capymoa.datasets import ElectricityTiny
from capymoa.evaluation import prequential_evaluation

stream = ElectricityTiny()
learner = NCM(stream.get_schema())

results = prequential_evaluation(stream, learner)
results.accuracy()

In [None]:
results

In [None]:
from capymoa.ocl.datasets import SplitCIFAR10

scenario = SplitCIFAR10()

In [None]:
transformer = PyTorchTransformer(scenario.schema)
instance = transformer.transform_instance(next(scenario.train_streams[0]))