In [2]:
"""**Define Performance Metrics**"""



import re

import torch.nn as nn





class BaseObject(nn.Module):

    def __init__(self, name=None):

        super().__init__()

        self._name = name



    @property

    def __name__(self):

        if self._name is None:

            name = self.__class__.__name__

            s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)

            return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()

        else:

            return self._name





class Metric(BaseObject):

    pass





class Loss(BaseObject):

    def __add__(self, other):

        if isinstance(other, Loss):

            return SumOfLosses(self, other)

        else:

            raise ValueError("Loss should be inherited from `Loss` class")



    def __radd__(self, other):

        return self.__add__(other)



    def __mul__(self, value):

        if isinstance(value, (int, float)):

            return MultipliedLoss(self, value)

        else:

            raise ValueError("Loss should be inherited from `BaseLoss` class")



    def __rmul__(self, other):

        return self.__mul__(other)





class SumOfLosses(Loss):

    def __init__(self, l1, l2):

        name = "{} + {}".format(l1.__name__, l2.__name__)

        super().__init__(name=name)

        self.l1 = l1

        self.l2 = l2



    def __call__(self, *inputs):

        return self.l1.forward(*inputs) + self.l2.forward(*inputs)





class MultipliedLoss(Loss):

    def __init__(self, loss, multiplier):



        # resolve name

        if len(loss.__name__.split("+")) > 1:

            name = "{} * ({})".format(multiplier, loss.__name__)

        else:

            name = "{} * {}".format(multiplier, loss.__name__)

        super().__init__(name=name)

        self.loss = loss

        self.multiplier = multiplier



    def __call__(self, *inputs):

        return self.multiplier * self.loss.forward(*inputs)

class Activation(nn.Module):



    def __init__(self, name, **params):



        super().__init__()



        if name is None or name == 'identity':

            self.activation = nn.Identity(**params)

        elif name == 'sigmoid':

            self.activation = nn.Sigmoid()

        elif name == 'softmax2d':

            self.activation = nn.Softmax(dim=1, **params)

        elif name == 'softmax':

            self.activation = nn.Softmax(**params)

        elif name == 'logsoftmax':

            self.activation = nn.LogSoftmax(**params)

        elif name == 'argmax':

            self.activation = ArgMax(**params)

        elif name == 'argmax2d':

            self.activation = ArgMax(dim=1, **params)

        elif callable(name):

            self.activation = name(**params)

        else:

            raise ValueError('Activation should be callable/sigmoid/softmax/logsoftmax/None; got {}'.format(name))



    def forward(self, x):

        return self.activation(x)



import torch





def _take_channels(*xs, ignore_channels=None):

    if ignore_channels is None:

        return xs

    else:

        channels = [channel for channel in range(xs[0].shape[1]) if channel not in ignore_channels]

        xs = [torch.index_select(x, dim=1, index=torch.tensor(channels).to(x.device)) for x in xs]

        return xs





def _threshold(x, threshold=None):

    if threshold is not None:

        return (x > threshold).type(x.dtype)

    else:

        return x





def iou(pr, gt, eps=1e-7, threshold=None, ignore_channels=None):

    """Calculate Intersection over Union between ground truth and prediction

    Args:

        pr (torch.Tensor): predicted tensor

        gt (torch.Tensor):  ground truth tensor

        eps (float): epsilon to avoid zero division

        threshold: threshold for outputs binarization

    Returns:

        float: IoU (Jaccard) score

    """



    pr = _threshold(pr, threshold=threshold)

    pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels)



    intersection = torch.sum(gt * pr)

    union = torch.sum(gt) + torch.sum(pr) - intersection + eps

    return (intersection + eps) / union





jaccard = iou





def f_score(pr, gt, beta=1, eps=1e-7, threshold=None, ignore_channels=None):

    """Calculate F-score between ground truth and prediction

    Args:

        pr (torch.Tensor): predicted tensor

        gt (torch.Tensor):  ground truth tensor

        beta (float): positive constant

        eps (float): epsilon to avoid zero division

        threshold: threshold for outputs binarization

    Returns:

        float: F score

    """



    pr = _threshold(pr, threshold=threshold)

    pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels)



    tp = torch.sum(gt * pr)

    fp = torch.sum(pr) - tp

    fn = torch.sum(gt) - tp



    score = ((1 + beta ** 2) * tp + eps) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + eps)



    return score





def accuracy(pr, gt, threshold=0.5, ignore_channels=None):

    """Calculate accuracy score between ground truth and prediction

    Args:

        pr (torch.Tensor): predicted tensor

        gt (torch.Tensor):  ground truth tensor

        eps (float): epsilon to avoid zero division

        threshold: threshold for outputs binarization

    Returns:

        float: precision score

    """

    pr = _threshold(pr, threshold=threshold)

    pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels)



    tp = torch.sum(gt == pr, dtype=pr.dtype)

    score = tp / gt.view(-1).shape[0]

    return score





def precision(pr, gt, eps=1e-7, threshold=None, ignore_channels=None):

    """Calculate precision score between ground truth and prediction

    Args:

        pr (torch.Tensor): predicted tensor

        gt (torch.Tensor):  ground truth tensor

        eps (float): epsilon to avoid zero division

        threshold: threshold for outputs binarization

    Returns:

        float: precision score

    """



    pr = _threshold(pr, threshold=threshold)

    pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels)



    tp = torch.sum(gt * pr)

    fp = torch.sum(pr) - tp



    score = (tp + eps) / (tp + fp + eps)



    return score





def recall(pr, gt, eps=1e-7, threshold=None, ignore_channels=None):

    """Calculate Recall between ground truth and prediction

    Args:

        pr (torch.Tensor): A list of predicted elements

        gt (torch.Tensor):  A list of elements that are to be predicted

        eps (float): epsilon to avoid zero division

        threshold: threshold for outputs binarization

    Returns:

        float: recall score

    """



    pr = _threshold(pr, threshold=threshold)

    pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels)



    tp = torch.sum(gt * pr)

    fn = torch.sum(gt) - tp



    score = (tp + eps) / (tp + fn + eps)



    return score



import torch.nn as nn



class JaccardLoss(Loss):

    def __init__(self, eps=1.0, activation=None, ignore_channels=None, **kwargs):

        super().__init__(**kwargs)

        self.eps = eps

        self.activation = Activation(activation)

        self.ignore_channels = ignore_channels



    def forward(self, y_pr, y_gt):

        y_pr = self.activation(y_pr)

        return 1 - jaccard(

            y_pr,

            y_gt,

            eps=self.eps,

            threshold=None,

            ignore_channels=self.ignore_channels,

        )





class DiceLoss(Loss):

    def __init__(self, eps=1.0, beta=1.0, activation=None, ignore_channels=None, **kwargs):

        super().__init__(**kwargs)

        self.eps = eps

        self.beta = beta

        self.activation = Activation(activation)

        self.ignore_channels = ignore_channels



    def forward(self, y_pr, y_gt):

        y_pr = self.activation(y_pr)

        return 1 - f_score(

            y_pr,

            y_gt,

            beta=self.beta,

            eps=self.eps,

            threshold=None,

            ignore_channels=self.ignore_channels,

        )





class L1Loss(nn.L1Loss, Loss):

    pass





class MSELoss(nn.MSELoss, Loss):

    pass





class CrossEntropyLoss(nn.CrossEntropyLoss, Loss):

    pass





class NLLLoss(nn.NLLLoss, Loss):

    pass





class BCELoss(nn.BCELoss, Loss):

    pass





class BCEWithLogitsLoss(nn.BCEWithLogitsLoss, Loss):

    pass



class IoU(Metric):

    __name__ = "iou_score"



    def __init__(self, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs):

        super().__init__(**kwargs)

        self.eps = eps

        self.threshold = threshold

        self.activation = Activation(activation)

        self.ignore_channels = ignore_channels



    def forward(self, y_pr, y_gt):

        y_pr = self.activation(y_pr)

        return iou(

            y_pr,

            y_gt,

            eps=self.eps,

            threshold=self.threshold,

            ignore_channels=self.ignore_channels,

        )





class Fscore(Metric):

    def __init__(self, beta=1, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs):

        super().__init__(**kwargs)

        self.eps = eps

        self.beta = beta

        self.threshold = threshold

        self.activation = Activation(activation)

        self.ignore_channels = ignore_channels



    def forward(self, y_pr, y_gt):

        y_pr = self.activation(y_pr)

        return f_score(

            y_pr,

            y_gt,

            eps=self.eps,

            beta=self.beta,

            threshold=self.threshold,

            ignore_channels=self.ignore_channels,

        )





class Accuracy(Metric):

    def __init__(self, threshold=0.5, activation=None, ignore_channels=None, **kwargs):

        super().__init__(**kwargs)

        self.threshold = threshold

        self.activation = Activation(activation)

        self.ignore_channels = ignore_channels



    def forward(self, y_pr, y_gt):

        y_pr = self.activation(y_pr)

        return accuracy(

            y_pr,

            y_gt,

            threshold=self.threshold,

            ignore_channels=self.ignore_channels,

        )





class Recall(Metric):

    def __init__(self, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs):

        super().__init__(**kwargs)

        self.eps = eps

        self.threshold = threshold

        self.activation = Activation(activation)

        self.ignore_channels = ignore_channels



    def forward(self, y_pr, y_gt):

        y_pr = self.activation(y_pr)

        return recall(

            y_pr,

            y_gt,

            eps=self.eps,

            threshold=self.threshold,

            ignore_channels=self.ignore_channels,

        )





class Precision(Metric):

    def __init__(self, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs):

        super().__init__(**kwargs)

        self.eps = eps

        self.threshold = threshold

        self.activation = Activation(activation)

        self.ignore_channels = ignore_channels



    def forward(self, y_pr, y_gt):

        y_pr = self.activation(y_pr)

        return precision(

            y_pr,

            y_gt,

            eps=self.eps,

            threshold=self.threshold,

            ignore_channels=self.ignore_channels,

        )



metrics = [

    IoU(threshold=0.5),

    Accuracy(threshold=0.5),

    Fscore(threshold=0.5),

    Recall(threshold=0.5),

    Precision(threshold=0.5),

]