In [2]:
import pynanoflann
import torch
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
from sklearn.metrics import *

In [3]:
x=torch.rand((1,6,100))
y=torch.randint(0,6,(1,100))

In [4]:
def calculate_weight(ratio,alpha1=0.3,alpha2=0.3,alpha3=0.4):
    weight= 0.3+0.3*(1/(ratio+0.02))+0.4*(1/ratio)
    return weight

In [54]:
class IoULoss(nn.Module):
    def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.,
                 square=False):
        """
        paper: https://link.springer.com/chapter/10.1007/978-3-319-50835-1_22
        
        """
        super(IoULoss, self).__init__()

        self.square = square
        self.do_bg = do_bg
        self.batch_dice = batch_dice
        self.apply_nonlin = apply_nonlin
        self.smooth = smooth

    def forward(self, pred, true, loss_mask=None):
        batch=pred.size(0)
        iou=0
        
        pred=F.softmax(pred,dim=1)
        pred=pred.permute(0,2,1) # (B,N,C)
        for i in range(batch):
            x,y=pred[i],true[i]
            pred_label=torch.argmax(x,dim=1)
            n=len(torch.unique(y))
            #sprint(x.shape,y.shape)
            shp_x = x.shape

            if self.batch_dice:
                axes = [0] + list(range(2, len(shp_x)))
            else:
                axes = list(range(2, len(shp_x)))

            if self.apply_nonlin is not None:
                x = self.apply_nonlin(x)

            freq=torch.einsum("ch->h", F.one_hot(y)).type(torch.float32)
            ratio=freq/torch.sum(freq)
            w=calculate_weight(ratio)
            numerator=0
            denominator=0
            for j in range(n):
                #Find poisiton for class j
                pos_j=(true==i).nonzero()[:,1]
                pred_j=pred[:,pos_j]
                pred_label_j=pred_j.argmax(dim=-1) # The prediction Label
                # Find where correct predictions occurred
                tp_pos_j=((y==j)*(pred_label==j)).nonzero()
                tp_j=torch.sum(x[tp_pos_j[0]][:,j])
                fp_j=torch.sum(x[:,j])
                num_class=len((y==j).nonzero())
                weight_i=w[i]
                numerator+=tp_j*weight_i
                denominator+=(num_class+fp_j-tp_j)*weight_i

            batch_iou = numerator/denominator
        iou = batch_iou/batch

        return 1-iou
        

In [55]:
IoULoss()(x,y)

tensor(0.9999)

In [31]:
FPCELossV2()(x,y)

tensor([[5.2519, 1.1801, 1.3364, 1.2659, 1.3229, 1.3610],
        [1.2468, 4.8841, 1.1254, 1.2659, 1.1923, 1.3109],
        [1.2468, 1.2137, 4.2471, 1.1433, 1.3766, 1.1879],
        [1.1629, 1.2137, 1.3364, 9.8615, 1.2316, 1.2250],
        [1.2961, 1.2903, 1.2388, 1.2659, 8.1032, 1.1879],
        [1.2468, 1.2903, 1.1596, 1.3437, 1.1566, 8.4000]], dtype=torch.float64)


tensor(0.8732, dtype=torch.float64)

In [32]:
label=x.argmax(axis=1)

In [33]:
cm=confusion_matrix(y[0],label[0])

In [34]:
cm

array([[2, 2, 5, 2, 5, 6],
       [3, 3, 0, 2, 2, 5],
       [3, 3, 3, 0, 6, 2],
       [1, 3, 5, 0, 3, 3],
       [4, 5, 3, 2, 1, 2],
       [3, 5, 1, 3, 1, 1]])

In [36]:
cm/cm.sum()

  1/(cm/cm.sum())


array([[ 50.        ,  50.        ,  20.        ,  50.        ,
         20.        ,  16.66666667],
       [ 33.33333333,  33.33333333,          inf,  50.        ,
         50.        ,  20.        ],
       [ 33.33333333,  33.33333333,  33.33333333,          inf,
         16.66666667,  50.        ],
       [100.        ,  33.33333333,  20.        ,          inf,
         33.33333333,  33.33333333],
       [ 25.        ,  20.        ,  33.33333333,  50.        ,
        100.        ,  50.        ],
       [ 33.33333333,  20.        , 100.        ,  33.33333333,
        100.        , 100.        ]])

In [46]:
cm/(cm.sum(axis=1)).reshape(-1,1)

array([[0.2       , 0.2       , 0.13333333, 0.13333333, 0.2       ,
        0.13333333],
       [0.125     , 0.        , 0.375     , 0.        , 0.25      ,
        0.25      ],
       [0.05      , 0.15      , 0.15      , 0.25      , 0.25      ,
        0.15      ],
       [0.05882353, 0.29411765, 0.11764706, 0.17647059, 0.17647059,
        0.17647059],
       [0.07142857, 0.21428571, 0.21428571, 0.21428571, 0.21428571,
        0.07142857],
       [0.22222222, 0.22222222, 0.16666667, 0.22222222, 0.05555556,
        0.11111111]])

In [27]:
x[:,:,y[0]].shape

torch.Size([1, 10000, 10000])

In [4]:
x=np.array(x)

In [4]:
def findKNN(points,queries,num_neighbors):
    idx,dist=[],[]

    B=points.shape[0]

    for i in range(B):
        nn=pynanoflann.KDTree(n_neighbors=num_neighbors,metric="L2")
        pts=np.array(points[i])
        qs=np.array(queries[i])
        nn.fit(pts)
        tempDIST,tempIDX=nn.kneighbors(qs)
        idx.append(tempIDX.astype(np.int64))
        dist.append(tempDIST.astype(np.float32))
    
    idx=torch.from_numpy(np.array(idx))
    dist=torch.from_numpy(np.array(dist))
    return idx,dist

In [None]:
nn=pynanoflann.KDTree(n_neighbors=16,metric="L2",radius=10)
nn.fit(x[0])
idx,dist=nn.kneighbors(x[0])

In [5]:
idx,dist=findKNN(x,x,16)

In [38]:
x=np.array([[195830.,   2695.,   6747.,    704.,    677.,    540.],
       [  1365.,   4010.,    120.,      0.,      0.,    133.],
       [  2419.,    189.,  21351.,      0.,      0.,      0.],
       [   564.,      0.,      0.,    296.,      0.,      0.],
       [   302.,      0.,      0.,      0.,    812.,      0.],
       [   292.,     25.,      0.,      0.,      0.,    732.]])

In [39]:
x.sum()

239803.0

In [40]:
ratio=x/(x.sum(axis=1)).reshape(-1,1)

In [41]:
ratio

array([[0.94515741, 0.0130072 , 0.03256384, 0.0033978 , 0.00326748,
        0.00260627],
       [0.24253731, 0.71250888, 0.02132196, 0.        , 0.        ,
        0.02363184],
       [0.10096415, 0.00788848, 0.89114738, 0.        , 0.        ,
        0.        ],
       [0.65581395, 0.        , 0.        , 0.34418605, 0.        ,
        0.        ],
       [0.27109515, 0.        , 0.        , 0.        , 0.72890485,
        0.        ],
       [0.27836034, 0.02383222, 0.        , 0.        , 0.        ,
        0.69780744]])

In [9]:
1/(ratio+0.9)

array([[0.54195918, 1.09528162, 1.07231265, 1.10693208, 1.10709177,
        1.10790279],
       [0.87524494, 0.62015162, 1.0853969 , 1.11111111, 1.11111111,
        1.08268247],
       [0.99903678, 1.10145687, 0.55830135, 1.11111111, 1.11111111,
        1.11111111],
       [0.64275037, 1.11111111, 1.11111111, 0.80373832, 1.11111111,
        1.11111111],
       [0.85390158, 1.11111111, 1.11111111, 1.11111111, 0.6139094 ,
        1.11111111],
       [0.84863684, 1.08244763, 1.11111111, 1.11111111, 1.11111111,
        0.62585765]])

In [42]:
w=calculate_weight(ratio)
w

  weight= 0.3+0.3*(1/(ratio+0.02))+0.4*(1/ratio)


array([[  1.03404005,  40.14113489,  18.29090849, 130.84501473,
        135.61184547, 167.0469534 ],
       [  3.09192548,   1.27094786,  26.32006192,          inf,
                 inf,  24.10202845],
       [  6.7418761 ,  61.76400962,   1.07811473,          inf,
                 inf,          inf],
       [  1.35383823,          inf,          inf,   2.28591695,
                 inf,          inf],
       [  2.80608754,          inf,          inf,          inf,
          1.24935341,          inf],
       [  2.74248186,  23.92828012,          inf,          inf,
                 inf,   1.29116343]])

In [43]:
w[w==np.inf]=1

In [44]:
w.astype(np.int16)

array([[  1,  40,  18, 130, 135, 167],
       [  3,   1,  26,   1,   1,  24],
       [  6,  61,   1,   1,   1,   1],
       [  1,   1,   1,   2,   1,   1],
       [  2,   1,   1,   1,   1,   1],
       [  2,  23,   1,   1,   1,   1]], dtype=int16)

In [112]:
calculate_weight(1-ratio)

array([[11.60201509,  1.00318817,  1.01728109,  0.99646443,  0.99637413,
         0.99591632],
       [ 1.21394943,  2.66698529,  1.00911171,  0.99411765,  0.99411765,
         1.01077505],
       [ 1.07135013,  0.9995905 ,  6.30293456,  0.99411765,  0.99411765,
         0.99411765],
       [ 2.28591695,  0.99411765,  0.99411765,  1.35383823,  0.99411765,
         0.99411765],
       [ 1.24935341,  0.99411765,  0.99411765,  0.99411765,  2.80608754,
         0.99411765],
       [ 1.25880238,  1.01091971,  0.99411765,  0.99411765,  0.99411765,
         2.55477949]])

In [99]:
1/ratio

array([[1.05805005e+00, 7.68542285e+01, 3.07052460e+01, 2.93899291e+02,
        3.05603245e+02, 3.82992606e+02],
       [4.12445095e+00, 1.40463725e+00, 4.65619835e+01, 5.63400000e+03,
        5.63400000e+03, 4.20447761e+01],
       [9.90289256e+00, 1.26131579e+02, 1.12237729e+00, 2.39650000e+04,
        2.39650000e+04, 2.39650000e+04],
       [1.53274336e+00, 8.66000000e+02, 8.66000000e+02, 2.91582492e+00,
        8.66000000e+02, 8.66000000e+02],
       [3.69636964e+00, 1.12000000e+03, 1.12000000e+03, 1.12000000e+03,
        1.37761378e+00, 1.12000000e+03],
       [3.60068259e+00, 4.05769231e+01, 1.05500000e+03, 1.05500000e+03,
        1.05500000e+03, 1.43929059e+00]])

In [100]:
np.round(ratio,3)

array([[0.945, 0.013, 0.033, 0.003, 0.003, 0.003],
       [0.242, 0.712, 0.021, 0.   , 0.   , 0.024],
       [0.101, 0.008, 0.891, 0.   , 0.   , 0.   ],
       [0.652, 0.001, 0.001, 0.343, 0.001, 0.001],
       [0.271, 0.001, 0.001, 0.001, 0.726, 0.001],
       [0.278, 0.025, 0.001, 0.001, 0.001, 0.695]])