# some testing

In [1]:
import torch

In [2]:
def get_one_hot_target(K, target):
    one_hot = torch.zeros((target.shape[0], K, *target.shape[1:])).to(target.device)
    one_hot[:,0] = 1 - target
    one_hot[:,1] = target
    
    return one_hot

In [3]:
a = torch.randn(12, 5) > 0.5

In [5]:
a = a.type(torch.float32)

In [7]:
a

tensor([[0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 1., 1., 1.],
        [0., 1., 1., 0., 1.],
        [1., 0., 1., 0., 0.],
        [1., 1., 0., 0., 1.],
        [0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0.],
        [1., 0., 1., 1., 1.],
        [1., 1., 0., 0., 1.],
        [0., 0., 0., 1., 1.]])

In [8]:
get_one_hot_target(2, a)

tensor([[[1., 0., 1., 1., 1.],
         [0., 1., 0., 0., 0.]],

        [[1., 1., 0., 1., 1.],
         [0., 0., 1., 0., 0.]],

        [[1., 0., 1., 1., 1.],
         [0., 1., 0., 0., 0.]],

        [[1., 0., 0., 0., 0.],
         [0., 1., 1., 1., 1.]],

        [[1., 0., 0., 1., 0.],
         [0., 1., 1., 0., 1.]],

        [[0., 1., 0., 1., 1.],
         [1., 0., 1., 0., 0.]],

        [[0., 0., 1., 1., 0.],
         [1., 1., 0., 0., 1.]],

        [[1., 1., 1., 0., 1.],
         [0., 0., 0., 1., 0.]],

        [[1., 0., 1., 1., 1.],
         [0., 1., 0., 0., 0.]],

        [[0., 1., 0., 0., 0.],
         [1., 0., 1., 1., 1.]],

        [[0., 0., 1., 1., 0.],
         [1., 1., 0., 0., 1.]],

        [[1., 1., 1., 0., 0.],
         [0., 0., 0., 1., 1.]]])

In [11]:
def relu_evidence(logits):
    return torch.nn.functional.relu(logits)

def exp_evidence(logits):
    return logits.clamp(-10, 10).exp()

def softplus_evidence(logits):
    return torch.nn.functional.softplus(logits)


def get_S(alpha):
    # evidence is shape [b, c, <dims>], we want an S per pixel, so reduce on dim 1
    S = alpha.sum(dim = 1).unsqueeze(1)
    return S

def get_bk(evidence, S):
    return evidence / S

def get_uncert(K, S):
    return K / S

def get_alpha(evidence):
    return (evidence + 1.)**2

def get_one_hot_target(K, target):
    one_hot = torch.zeros((target.shape[0], K, *target.shape[1:])).to(target.device)
    one_hot[:,0] = 1 - target
    one_hot[:,1] = target
    
    return one_hot

def get_mean_p_hat(alpha, S):
    return alpha / S

######
def digamma(values):
    return torch.digamma(values).clamp(-100,100)

def get_alpha_modified(alpha, one_hot_target):
    return one_hot_target + ((1 - one_hot_target) * alpha)

In [95]:
def dice_bayes_risk(K, alpha, one_hot_target, S, empty_slice_weight):
    bs = alpha.shape[0]
    alpha = alpha.view(bs, K, -1)
    one_hot_target = one_hot_target.view(bs, K, -1)
    S = S.view(bs, 1, -1)
    #print(one_hot_target.shape, alpha.shape, S.shape)
    numerator = torch.sum(one_hot_target * alpha / S, dim=2)
    denominator = torch.sum(one_hot_target ** 2 + (alpha/S)**2 + (alpha*(S-alpha)/((S**2)*(S+1))), dim=2)
    
    if empty_slice_weight == 1:
        dice = 1 - (2/K) * ((numerator/denominator).sum(dim=1))
        #print(dice.shape)
        return dice.mean()
    
    else:
        # finding the empties
        locs = torch.sum(one_hot_target[:,1], dim=1) == 0
        #print(torch.sum(one_hot_target[:,1], dim=(-2, -1)), locs)
        wheres = torch.where(locs)[0]
        combined = (numerator/denominator)
        combined[wheres] *= empty_slice_weight
        #print(wheres)
        ratio = ((one_hot_target.shape[0] - wheres.shape[0]) + (wheres.shape[0] * empty_slice_weight))
        print(ratio)
        
        dice_frac = (2/K) * combined.sum(dim=1)

        return  (1 - dice_frac.sum()/ratio)

In [106]:
logits = torch.randn(12, 2, 225, 225)
target = (torch.randn(12, 225, 225) > 0.5).type(torch.long)
target[0] = 0

evidence = softplus_evidence(logits)
alpha = get_alpha(evidence)
S = get_S(alpha)
K = alpha.shape[1]
one_hot = get_one_hot_target(K, target)
mean_p_hat = get_mean_p_hat(alpha, S)
alpha_modified = get_alpha_modified(alpha, one_hot)

In [107]:
dice_bayes_risk(K, alpha, one_hot, S, empty_slice_weight=1)

tensor(0.4392)

In [112]:
((logits[:,1] > logits[:,0]) == target).type(torch.float32).mean()

tensor(0.4998)