In [None]:
import torch

In [1]:
# Regression
def regression_heteroscedastic_loss(true, mean, log_var, metric): 
    '''
    ARGUMENTS:
    true: true values. Tensor (batch_size x number of outputs)
    mean: predictions. Tensor (batch_size x number of outputs)
    log_var: Logaritms of uncertainty estimates. Tensor (batch_size x number of outputs)
    metric: "mae" or "rmse"

    OUTPUTS:
    loss: Tensor (0)
    '''
    precision = torch.exp(-log_var)
    if metric == "mae":
        return torch.mean(torch.sum((2 * precision) ** .5 * torch.abs(true - mean) + log_var / 2, 1), 0)
    elif metric == "rmse" or not metric:   #default is rmse
        return torch.mean(torch.sum(precision * (true - mean) ** 2 + log_var, 1), 0)
    else:
        print("Metric has to be 'rmse' or 'mae'")

def regression_homoscedastic_loss(true, mean, metric):
    '''
    ARGUMENTS:
    true: true values. Tensor (batch_size x number of outputs)
    mean: predictions. Tensor (batch_size x number of outputs)
    metric: "mae" or "rmse"

    OUTPUTS:
    loss: Tensor (0)
    '''
    
    if metric == "mae":
        return torch.mean(torch.sum(torch.abs(true - mean), 1), 0)
    elif metric == "rmse" or not metric:   #default is rmse
        return torch.mean(torch.sum((true - mean) ** 2, 1), 0)
    else:
        print("Metric has to be 'rmse' or 'mae'")

In [1]:
# Classification
# Adapted from https://github.com/kyle-dorman/bayesian-neural-network-blogpost
def classif_heterosc_loss_dorman(pred, true, logvar, T):
    '''
    ARGUMENTS:
    true: true values. Tensor (batch_size x number of outputs)
    pred: predictions. Tensor (batch_size x number of outputs)
    log_var: Logaritms of uncertainty estimates. Tensor (batch_size x number of outputs)
    T: number of forward passes throught the softmax function: Integer

    OUTPUTS:
    loss: Tensor (0)
    '''
    CEL_undistort = torch.nn.CrossEntropyLoss(reduction='none')
    LSM = torch.nn.LogSoftmax(dim=1)
    softmax_pred = LSM(pred)
    NLLL = torch.nn.NLLLoss(reduction='none')
    CEL_loss_undistort = NLLL(softmax_pred, true.long())

    std = torch.squeeze(torch.sqrt(torch.exp(logvar)))
    variance_depressor = torch.mean(torch.exp(torch.square(std)) - torch.ones(std.size()).cuda(),0)
    dim = list(std.size())[0]
    std = torch.transpose(std.expand(dim,dim), 0, 1)
    cum_elu = 0
    for t in range(T):
        noise = torch.randn_like(pred)
        noisy_pred = pred + torch.matmul(std, noise)
        LSM_noisy = torch.nn.LogSoftmax(dim=1)
        softmax_noisy_pred = LSM_noisy(noisy_pred)
        NLLL_noisy = torch.nn.NLLLoss(reduction='none')
        CEN_t = NLLL_noisy(softmax_pred, true.long())

        difference = CEN_t  - CEL_loss_undistort
        ELU = torch.nn.ELU(alpha=1)
        elu = ELU(difference)
        cum_elu += elu
    loss = torch.mean(CEL_loss_undistort * (1+cum_elu / t),0) + variance_depressor
    return loss