# Metrics

> Definition of metrics used during training.

In [None]:
#| default_exp pipeline.metrics

In [None]:
#| export
from genQC.imports import *

In [None]:
#| export
class Metric(abc.ABC):
    """Base metric class."""   
    def __init__(self, name: str, device): 
        self.name   = name
        self.device = torch.device(device)
        self.reset_state()        
    def __repr__(self): return f"{self.name}={self.result()}"
    def update_state(self, inp, tar=None): self.empty=False
    def reset_state(self): self.empty=True        
    
    @abc.abstractmethod
    def _eval(self, inp, tar): pass
        
    @abc.abstractmethod
    def result(self): pass 

In [None]:
#| export
class Mean(Metric):
    """Mean metric, used for loss."""
    
    def __init__(self, name: str, device): 
        super().__init__(name, device)  
        
    @torch.inference_mode()
    def update_state(self, inp: torch.Tensor, tar: torch.Tensor = None, weight: float = 1):
        super().update_state(inp, tar)         
        val = self._eval(inp, tar)        
        self.weighted_sum += torch.sum(val * weight)
        self.weight       += weight * torch.numel(val)  
        
    @torch.inference_mode()
    def reset_state(self): 
        super().reset_state()
        self.weighted_sum = torch.tensor(0.0, device=self.device)
        self.weight       = torch.tensor(0.0, device=self.device)   
    
    def _eval(self, inp, tar): 
        return inp  
        
    @torch.inference_mode()
    def result(self): 
        return (self.weighted_sum/self.weight).cpu()

In [None]:
#| export
class Accuracy(Mean):   
    """Accuracy metric."""
    @torch.inference_mode()
    def _eval(self, inp, tar): return (inp==tar).float().mean()

Example usage:

In [None]:
a = Accuracy("mean", "cpu")
print(a, a.empty)

a.update_state(torch.Tensor([3,2,2,1]), torch.Tensor([1,2,2,1]))
print(a, a.empty)

a.update_state(torch.Tensor([1,2,2,3]), torch.Tensor([1,2,2,3]))
print(a, a.empty)

a.reset_state()
print(a, a.empty)

mean=nan True
mean=0.75 False
mean=0.875 False
mean=nan True


# Export -

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()