In [110]:
import torch
import numpy as np
from dice_loss import *
import torch.nn.functional as F
from sklearn.metrics import confusion_matrix,precision_score,recall_score

In [468]:
import torch
from ND_Crossentropy import CrossentropyND, TopKLoss, WeightedCrossEntropyLoss
from torch import nn
from torch.autograd import Variable
from torch import einsum
import numpy as np

def softmax_helper(x):
    # copy from: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/utilities/nd_softmax.py
    rpt = [1 for _ in range(len(x.size()))]
    rpt[1] = x.size(1)
    x_max = x.max(1, keepdim=True)[0].repeat(*rpt)
    e_x = torch.exp(x - x_max)
    return e_x / e_x.sum(1, keepdim=True).repeat(*rpt)

def sum_tensor(inp, axes, keepdim=False):
    # copy from: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/utilities/tensor_utilities.py
    axes = np.unique(axes).astype(int)
    if keepdim:
        for ax in axes:
            inp = inp.sum(int(ax), keepdim=True)
    else:
        for ax in sorted(axes, reverse=True):
            inp = inp.sum(int(ax))
    return inp

def get_tp_fp_fn(net_output, gt, axes=None, mask=None, square=False):
    """
    net_output must be (b, c, x, y(, z)))
    gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z))
    if mask is provided it must have shape (b, 1, x, y(, z)))
    :param net_output:
    :param gt:
    :param axes:
    :param mask: mask must be 1 for valid pixels and 0 for invalid pixels
    :param square: if True then fp, tp and fn will be squared before summation
    :return:
    """
    if axes is None:
        axes = tuple(range(2, len(net_output.size())))

    shp_x = net_output.shape
    shp_y = gt.shape

    with torch.no_grad():
        if len(shp_x) != len(shp_y):
            gt = gt.view((shp_y[0], 1, *shp_y[1:]))

        if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
            # if this is the case then gt is probably already a one hot encoding
            y_onehot = gt
        else:
            gt = gt.long()
            y_onehot = torch.zeros(shp_x)
            if net_output.device.type == "cuda":
                y_onehot = y_onehot.cuda(net_output.device.index)
            y_onehot.scatter_(1, gt, 1)
    tp = net_output * y_onehot
    fp = net_output * (1 - y_onehot)
    fn = (1 - net_output) * y_onehot

    if mask is not None:
        tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1)
        fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1)
        fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1)

    if square:
        tp = tp ** 2
        fp = fp ** 2
        fn = fn ** 2

    tp = sum_tensor(tp, axes, keepdim=False)
    fp = sum_tensor(fp, axes, keepdim=False)
    fn = sum_tensor(fn, axes, keepdim=False)

    return tp, fp, fn
    
class IoULoss(nn.Module):
    def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.,
                 square=False):
        """
        paper: https://link.springer.com/chapter/10.1007/978-3-319-50835-1_22
        
        """
        super(IoULoss, self).__init__()

        self.square = square
        self.do_bg = do_bg
        self.batch_dice = batch_dice
        self.apply_nonlin = apply_nonlin
        self.smooth = smooth

    def forward(self, pred, true, loss_mask=None):
        batch=pred.size(0)
        iou=0
        for i in range(batch):
            x,y=pred[i],true[i]
            shp_x = x.shape

            if self.batch_dice:
                axes = [0] + list(range(2, len(shp_x)))
            else:
                axes = list(range(2, len(shp_x)))

            if self.apply_nonlin is not None:
                x = self.apply_nonlin(x)

            tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square)
            tp=tp.sum(dim=0)
            fp=fp.sum(dim=0)
            fn=fn.sum(dim=0)


            batch_iou = (tp + self.smooth) / (tp + fp + fn + self.smooth)
             
            """if not self.do_bg:
                if self.batch_dice:
                    batch_iou=batch_iou[1:]
                else:
                    batch_iou = batch_iou[:, 1:]"""
        iou = batch_iou.mean()/batch

        return 1-iou,tp, fp, fn

In [469]:
pred=torch.randn(1,100,6)
pred=F.softmax(pred,dim=2)
true=torch.randint(0,6,(1,100))

In [458]:
true=torch.argmax(pred,dim=2)

In [459]:
pred=F.one_hot(true)

In [470]:
iou=IoULoss()
iouVal,tp,fp,fn=iou(pred,true)

tensor([0.1020, 0.1190, 0.0961, 0.1094, 0.0967, 0.1164])


In [467]:
iouVal

tensor(0.8711)

In [445]:
fn

tensor([ 9.4910, 16.9072, 16.2748, 13.9113, 15.8984,  9.7754])

In [446]:
cm=confusion_matrix(true[0],torch.argmax(pred,dim=2)[0])

In [447]:
precision_score(true[0],torch.argmax(pred,dim=2)[0],average=None)

array([0.05555556, 0.25      , 0.2       , 0.27777778, 0.2       ,
       0.05555556])

In [448]:
tp=np.diagonal(cm)
fp=-(1-1/precision_score(true[0],torch.argmax(pred,dim=2)[0],average=None))*tp
fn=-(1-1/recall_score(true[0],torch.argmax(pred,dim=2)[0],average=None))*tp


In [449]:
fn

array([10., 17., 16., 13., 16., 11.])

In [450]:
cm

array([[1, 3, 3, 2, 1, 1],
       [7, 4, 5, 2, 0, 3],
       [1, 4, 3, 3, 4, 4],
       [2, 2, 2, 5, 4, 3],
       [4, 2, 0, 4, 3, 6],
       [3, 1, 2, 2, 3, 1]])

In [321]:
np.mean((tp)/(tp+fp+fn))

1.0

In [422]:
cm=confusion_matrix(true[0],torch.argmax(pred,dim=2)[0])

In [423]:
cm

array([[3, 2, 2, 1, 4, 1],
       [0, 6, 1, 2, 4, 8],
       [1, 1, 1, 3, 3, 4],
       [3, 1, 3, 6, 4, 1],
       [2, 4, 0, 4, 3, 8],
       [3, 3, 2, 3, 2, 1]])