In [1]:
import torch

In [29]:
class Metrics:
    def __init__(self, num_classes, metric='dice'):
        if num_classes<0:
            raise ValueError(f"num_classes {num_classes} must be an integer greated than 0.")
        if metric not in ['dice', 'iou', 'accuracy']:
            raise ValueError(f"metric ({metric}) must be one of 'dice', 'iou' or 'accuracy'.")
        
        self.num_classes = num_classes
        self.metric = metric
        self.confusion = torch.zeros(1, self.num_classes, self.num_classes, dtype=torch.int64)
        
        
    ################## Confusion Matrix
    def get_confusion_matrix(self, preds, targets):
        if preds is None and targets is None:
            return self.confusion
        elif preds is not None and targets is not None:
            if not preds.shape == targets.shape:
                raise ValueError(f"Size of prediction {list(preds.shape)} must match size of targets {list(targets.shape)}")
            
            conf = torch.zeros(targets.shape[0], self.num_classes, self.num_classes, dtype=torch.int64)
            for i in range(self.num_classes):
                for j in range(self.num_classes):
                    conf[:,i,j] = torch.logical_and(targets==i, preds==j).reshape(preds.shape[0], -1).sum(dim=-1)
            return conf
        else:
            raise ValueError('Either both preds and targets should be given or none of them.')
        
    def add_to_confusion_matrix(self, preds, targets):
        self.confusion += torch.sum(self.get_confusion_matrix(preds, targets), dim=0, keepdim=True)

    ################## Metrics
    def macro_per_class(self, preds=None, targets=None):
        conf = self.get_confusion_matrix(preds, targets)
        tp = torch.diagonal(conf, dim1=-2, dim2=-1)
        fp = torch.sum(conf, dim=-1) - tp
        fn = torch.sum(conf, dim=-2) - tp
        
        return torch.nan_to_num(self._get_score(tp,fp,fn), 1).mean(dim=0)

    def macro(self, preds=None, targets=None):
        return torch.mean(self.macro_per_class(preds, targets)[1:]).item()
    
    def micro_per_class(self, preds=None, targets=None):
        conf = self.get_confusion_matrix(preds, targets)
        tp = torch.diagonal(conf, dim1=-2, dim2=-1).sum(dim=0)
        fp = (torch.sum(conf, dim=-1) - tp).sum(dim=0)
        fn = (torch.sum(conf, dim=-2) - tp).sum(dim=0)
            
        return torch.nan_to_num(self._get_score(tp,fp,fn), 1)
    
    def micro(self, preds=None, targets=None):
        conf = self.get_confusion_matrix(preds, targets)
        tp = torch.diagonal(conf, dim1=-2, dim2=-1)
        fp = (torch.sum(conf, dim=-1) - tp)[:,1:].sum()
        fn = (torch.sum(conf, dim=-2) - tp)[:,1:].sum()
        tp = tp[:,1:].sum()
            
        return torch.nan_to_num(self._get_score(tp,fp,fn), 1).item()


    def _get_score(self, tp, fp, fn):
        if self.metric == 'dice':
            score = 2*tp / (2*tp + fp + fn)

        elif self.metric == 'iou':
            score = tp / (tp + fp + fn)
        
        elif self.metric == 'accuracy':
            score = tp / (tp + fp)
            
        return score

In [30]:
metrics = Metrics(num_classes=8)
for _ in range(14):
    preds = torch.randint(0,8,(1,1000,1000))
    targets = torch.randint(0,8,(1,1000,1000))
    metrics.add_to_confusion_matrix(preds, targets)

In [41]:
a = torch.tensor([[[10782481,      679,    40832,    61909,   228086,   120091,    21495,
               730],
         [   28594,      154,    18388,       33,     3295,    21984,        0,
                 0],
         [   65064,       19,   170986,       94,    21471,    41892,        8,
                 0],
         [  146078,       40,     7141,    10607,    51881,     4913,        0,
                 0],
         [  750121,      506,    36937,    10215,   448677,    36089,        8,
                 0],
         [  339919,      173,    50220,      597,    25641,   230601,      645,
                 0],
         [  114539,        9,     2121,        4,      208,    44541,    38076,
                 0],
         [   12433,        1,     1100,       47,       86,     6680,      861,
                 0]]])

tp = torch.diagonal(a, dim1=-2, dim2=-1)
fp = (torch.sum(a, dim=-1) - tp)[:,1:].sum()
fn = (torch.sum(a, dim=-2) - tp)[:,1:].sum()
#tp = tp[:,1:].sum()
tp

tensor([[10782481,      154,   170986,    10607,   448677,   230601,    38076,
                0]])