In [36]:
import torch

def get_softmax_uncertaintiy_distance(pred):
    """
    pred: BS,C,H,W before softmax is applied! 
    
    distance between sec and first element
    # 1 if fully uncertain -> 2th best pixel estimate = 1th best pixel estimate, for all pixels
    # 0 if absolutly confident for all pixels. One class probability 1, 2th best class probability 0, for all pixels
    """
    BS,C,H,W = pred.shape
    argm1 = torch.argmax(pred, 1)
    soft1 = torch.nn.functional.softmax(pred, dim=1)

    onehot_argm1 = torch.nn.functional.one_hot(argm1, num_classes=C).permute(0,3,1,2).type(torch.bool)
    ten2 = pred.clone()
    ten2[ onehot_argm1 ] = 0

    argm2 = torch.argmax(ten2, 1)
    onehot_argm2 = torch.nn.functional.one_hot(argm2, num_classes=C).permute(0,3,1,2).type(torch.bool)
    res = [] 

    soft1 = soft1.permute(0,2,3,1)
    onehot_argm1 = onehot_argm1.permute(0,2,3,1)
    onehot_argm2 = onehot_argm2.permute(0,2,3,1)

    for b in range(BS):
        res_ = soft1[b][onehot_argm1[b]] - soft1[b][onehot_argm2[b]]
        res.append( res_.mean() )
        # print(res_.min(),res_.max())

    return torch.tensor(res, dtype=pred.dtype, device=pred.device)


def test():
    BS,C,H,W = 16,40,300,320
    pred = torch.rand( ( BS,C,H,W) )
    res = get_softmax_uncertaintiy_distance(pred)
    print(res, "should be very low")
    
    pred = torch.rand( ( BS,C,H,W) ) /1000
    pred[:,0,:,:] = 10
    res = get_softmax_uncertaintiy_distance(pred)
    print(res, 'should be nearly 0')
    
    pred[:,1,:,:] = 8
    res = get_softmax_uncertaintiy_distance(pred)
    print(res, 'should be betweem 0-1')
    

test()
    

tensor([0.0009, 0.0009, 0.0009, 0.0009, 0.0009, 0.0009, 0.0009, 0.0009, 0.0009,
        0.0009, 0.0009, 0.0009, 0.0009, 0.0009, 0.0009, 0.0009]) should be very low
tensor([0.9982, 0.9982, 0.9982, 0.9982, 0.9982, 0.9982, 0.9982, 0.9982, 0.9982,
        0.9982, 0.9982, 0.9982, 0.9982, 0.9982, 0.9982, 0.9982]) should be nearly 0
tensor([0.7604, 0.7604, 0.7604, 0.7604, 0.7604, 0.7604, 0.7604, 0.7604, 0.7604,
        0.7604, 0.7604, 0.7604, 0.7604, 0.7604, 0.7604, 0.7604]) should be betweem 0-1


In [46]:
ten = torch.rand( ( BS,C,H,W) )
def get_softmax_uncertaintiy_max(pred):
    """
    pred: BS,C,H,W before softmax is applied! 
    
    (1 - max( softmax(pred))) mean over batch size
    
    # 1 if fully uncertain
    # 0 if absolutly confident for all pixels
    """
    BS,C,H,W = pred.shape
    
    argm1 = torch.argmax(pred, 1)
    soft1 = torch.nn.functional.softmax(pred, dim=1)
    onehot_argm1 = torch.nn.functional.one_hot(argm1, num_classes=C).permute(0,3,1,2).type(torch.bool)
    
    soft1 = soft1.permute(0,2,3,1)
    onehot_argm1 = onehot_argm1.permute(0,2,3,1)
    
    res = []
    for b in range(BS):
        res_ = soft1[b][onehot_argm1[b]]
        res.append( torch.mean(res_)  )
#     print(res)
    return torch.tensor(res, dtype=pred.dtype, device=pred.device)

def test():
    BS,C,H,W = 16,40,300,320
    pred = torch.rand( ( BS,C,H,W) )
    res = get_softmax_uncertaintiy_max(pred)
    print(res, "should be very low")
    
    pred = torch.rand( ( BS,C,H,W) ) /1000
    pred[:,0,:,:] = 10
    res = get_softmax_uncertaintiy_max(pred)
    print(res, 'should be nearly 0')
    
    pred[:,1,:,:] = 8
    res = get_softmax_uncertaintiy_max(pred)
    print(res, 'should be betweem 0-1')
    

test()

tensor([0.0387, 0.0387, 0.0387, 0.0387, 0.0387, 0.0387, 0.0387, 0.0387, 0.0387,
        0.0387, 0.0387, 0.0387, 0.0387, 0.0387, 0.0387, 0.0387]) should be very low
tensor([0.9982, 0.9982, 0.9982, 0.9982, 0.9982, 0.9982, 0.9982, 0.9982, 0.9982,
        0.9982, 0.9982, 0.9982, 0.9982, 0.9982, 0.9982, 0.9982]) should be nearly 0
tensor([0.8795, 0.8795, 0.8795, 0.8795, 0.8795, 0.8795, 0.8795, 0.8795, 0.8795,
        0.8795, 0.8795, 0.8795, 0.8795, 0.8795, 0.8795, 0.8795]) should be betweem 0-1


In [None]:
onehot_sel = torch.nn.functional.one_hot(sel, num_classes=C).permute(0,3,1,2)
ten2[ onehot_sel ] = 0

#ten2[soft] = 0

In [56]:
ten = torch.rand( ( 100 ) )
res = torch.topk( ten, 2 )
res.values

res.indices.shape
print( res.indices )
#torch.topk(input, buffer_size, dim=None, largest=True, sorted=True, *, out=None)

tensor([70, 98])
