In [1]:
__name__ = "k1lib.callbacks"

In [2]:
#export
from .callbacks import Callback, Callbacks, Cbs
import k1lib, warnings
from typing import List, Callable
try: import torch; hasTorch = True
except: torch = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {})); hasTorch = False
__all__ = ["ConfusionMatrix"]

In [3]:
#export
@k1lib.patch(Cbs)
class ConfusionMatrix(Callback):
    " "
    categories:List[str]
    """String categories for displaying the matrix. You can set this
so that it displays what you want, in case this Callback is included
automatically."""
    matrix:torch.Tensor
    """The recorded confusion matrix."""
    def __init__(self, categories:List[str]=None, condF:Callable[["ConfusionMatrix"], bool]=lambda _: True):
        """Records what categories the network is confused the most. Expected
variables in :class:`~k1lib.Learner`:

- preds: long tensor with categories id of batch before checkpoint ``endLoss``.
  Auto-included in :class:`~k1lib.callbacks.lossFunctions.accuracy.AccF` and
  :class:`~k1lib.callbacks.lossFunctions.shorts.LossNLLCross`.

:param categories: optional list of category names
:param condF: takes in this cb's and returns whether to record at this
    particular `endLoss` checkpoint."""
        super().__init__(); self.categories = categories
        self.n = len(categories or []) or 2; self.condF = condF
        self.matrix = torch.zeros(self.n, self.n);
        self.wipeOnAdd = False # flag to wipe matrix on adding new data points
    def _adapt(self, idxs):
        """Adapts the internal matrix so that it supports new categories"""
        m = idxs.max().item() + 1
        if m > self.n: # +1 because max index = len() - 1
            matrix = torch.zeros(m, m)
            matrix[:self.n, :self.n] = self.matrix
            self.matrix = matrix; self.n = len(self.matrix)
        self.matrix = self.matrix.to(idxs.device); return idxs
    def startEpoch(self): self.wipeOnAdd = True
    def endLoss(self):
        if self.condF(self):
            if self.wipeOnAdd:
                self.matrix = torch.zeros(self.n, self.n);
                self.wipeOnAdd = False;
            yb = self._adapt(self.l.yb); preds = self._adapt(self.l.preds)
            self.matrix[yb, preds] += 1
    @property
    def goodMatrix(self) -> torch.Tensor:
        """Clears all inf, nans and whatnot from the matrix, then returns it."""
        n = self.n; m = self.matrix
        while m.hasNan() or m.hasInfs():
            n -= 1; m = m[:n, :n]
        if n != self.n: warnings.warn(f"Originally, the confusion matrix has {self.n} categories, now it has {n} only, after filtering, because there are some nans and infinite values.")
        if self.categories is not None:
            n = len(self.categories); m = m[:n, :n]
        return m/m.max(dim=1).values[:,None]
    def plot(self):
        """Plots everything"""
        k1lib.viz.confusionMatrix(self.goodMatrix, self.categories or list(range(self.n)))
    def __repr__(self):
        return f"""{super()._reprHead}, use...
- l.plot(): to plot everything
{super()._reprCan}"""

In [2]:
!../../export.py callbacks/confusionMatrix

Current dir: /home/kelvin/repos/labs/k1lib, /home/kelvin/repos/labs/k1lib/k1lib/callbacks/../../export.py
rm: cannot remove '__pycache__': No such file or directory
Found existing installation: k1lib 1.3.4.2
Uninstalling k1lib-1.3.4.2:
  Successfully uninstalled k1lib-1.3.4.2
running install
running bdist_egg
running egg_info
creating k1lib.egg-info
writing k1lib.egg-info/PKG-INFO
writing dependency_links to k1lib.egg-info/dependency_links.txt
writing requirements to k1lib.egg-info/requires.txt
writing top-level names to k1lib.egg-info/top_level.txt
writing manifest file 'k1lib.egg-info/SOURCES.txt'
reading manifest file 'k1lib.egg-info/SOURCES.txt'
adding license file 'LICENSE'
writing manifest file 'k1lib.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_py
creating build
creating build/lib
creating build/lib/k1lib
copying k1lib/_learner.py -> build/lib/k1lib
copying k1lib/fmt.py -> build/lib/k1lib
copying k1lib/_k1a.py ->