In [53]:
#### Accuracy functions ####

# Function to normalize a single image
def normalize_2(x):
    # Input: [1, height, width]
    x_shape = x.shape
    x = x.view(1,-1)
    x = x - torch.min(x)
    x = x / (torch.max(x) + 1e-8)
    if torch.sum(torch.isnan(x))>0:
        print("nan of tmp",torch.sum(torch.isnan(x)))
    return x.view(x_shape)

# Function to calculate balanced accuracy
def balanced_accuracy( logits, y_hot, Do_print=None):
    # Input:
    # logits: [batch_size, num_samples, no_classes]
    # y_hot: [batch_size, num_samples, no_classes]
    
    logits = logits.cpu().detach()
    y_hot = y_hot.cpu().detach()
    
    logits = logits.view(-1,2)
    y_hot = y_hot.view(-1,2)
    
    TP = torch.sum(y_hot[:,0]*torch.round(logits[:,0]))   # True posistive
    FP = torch.sum(y_hot[:,1]*torch.round(logits[:,0]))   # False positive
    FN = torch.sum(y_hot[:,0]*torch.round(logits[:,1]))   # False negative
    TN = torch.sum(y_hot[:,1]*torch.round(logits[:,1]))   # True negative

    P = TP + FN
    N = FP + TN
    
    if Do_print != None:
        print("TP: ",TP)
        print("FP: ",FP)
        print("TN: ",TN)
        print("FN: ",FN)
        print("P: ",P)
        print("N: ",N)

    acc = torch.sum(TP/P + TN/N)/2
        
    return acc

# Function to calculate balanced binary crossentropy
def balanced_binary_cross_entropy( logits, y_hot):
    # Input:
    # logits: [batch_size, num_samples, no_classes]
    # y_hot: [batch_size, num_samples, no_classes]
    
    if cuda:
        classWeight = torch.FloatTensor([torch.sum(y_hot[:,1,0])/torch.sum(y_hot[:,1,]),\
                                     torch.sum(y_hot[:,1,1])/torch.sum(y_hot[:,1,])]).cuda(device=0)
    else:
        classWeight = torch.FloatTensor([torch.sum(y_hot[:,1,0])/torch.sum(y_hot[:,1,]),\
                                     torch.sum(y_hot[:,1,1])/torch.sum(y_hot[:,1,])])
        
    class_loss_0 = 0
    class_loss_1 = 0
    
    for i in range(0, batch_size):
        tmp = torch.mean(y_hot, dim = 1)[i][0] 
        if tmp == 0:
            class_loss_0 += torch.nn.functional.binary_cross_entropy(logits[i], y_hot[i])
        else:
            class_loss_1 += torch.nn.functional.binary_cross_entropy(logits[i], y_hot[i])
        
    #bal_binary_cross_entropy = classWeight[0]*class_loss_0 + classWeight[1]*class_loss_1
    bal_binary_cross_entropy = (0.5/classWeight[0])*class_loss_0 + (0.5/classWeight[1])*class_loss_1
        
    return bal_binary_cross_entropy/batch_size

def balanced_accuracy_test( logits, y_hot, Do_print=None):
    # Input:
    # logits: [batch_size, num_samples, no_classes]
    # y_hot: [batch_size, num_samples, no_classes]
    
    logits = logits.cpu().detach()
    y_hot = y_hot.cpu().detach()
    
    logits = logits.view(-1,2)
    y_hot = y_hot.view(-1,2)
    
    TP = torch.sum(y_hot[:,0]*torch.round(logits[:,0]))   # True posistive
    FP = torch.sum(y_hot[:,1]*torch.round(logits[:,0]))   # False positive
    FN = torch.sum(y_hot[:,0]*torch.round(logits[:,1]))   # False negative
    TN = torch.sum(y_hot[:,1]*torch.round(logits[:,1]))   # True negative

    P = TP + FN
    N = FP + TN
    
    if Do_print != None:
        print("TP: ",TP)
        print("FP: ",FP)
        print("TN: ",TN)
        print("FN: ",FN)
        print("P: ",P)
        print("N: ",N)

    acc = torch.sum(TP/P + TN/N)/2
        
    return acc, TP, FP, FN, TN, P, N