In [1]:
import torch
import matplotlib.pyplot as plt

In [None]:
def cross_entropy_dirichlet(prediction: torch.Tensor, target: torch.Tensor):
    alpha = prediction + 1
    diriclet_strength = alpha.sum(dim=1, keepdim=True)
    return (target * (torch.digamma(diriclet_strength) - torch.digamma(alpha))).sum(1)


def mse_dirichlet(prediction: torch.Tensor, target: torch.Tensor):
    return target.pow(2) - 2 * target * prediction + prediction.pow(2)


def KL_divergence_dirichlet(prediction: torch.Tensor, target: torch.Tensor):
    alpha = prediction + 1
    n_class = torch.tensor(prediction.size(1))
    approx_alpha = target + (1 - target) * alpha

    first_term = torch.lgamma(approx_alpha.sum(dim=1))
    first_term -= torch.lgamma(n_class) + torch.lgamma(approx_alpha).sum(dim=1)
    second_term = (
        (approx_alpha - 1)
        * (
            torch.digamma(approx_alpha)
            - torch.digamma(approx_alpha.sum(dim=1, keepdim=True))
        )
    ).sum(dim=1)
    return first_term + second_term


def overall_loss(
    prediction: torch.Tensor,
    target: torch.Tensor,
    lambda_t: torch.Tensor,
):
    prediction = prediction.relu()
    loss = cross_entropy_dirichlet(prediction, target)
    loss += lambda_t * KL_divergence_dirichlet(
        prediction,
        target,
    )
    return loss


prediction = torch.ones([2, 3, 4, 4])
target = torch.ones([2, 3, 4, 4])
overall_loss(prediction, target, torch.tensor(1))

# overall_loss(prediction, target).mean(dim=1).mean()

torch.Size([2, 4, 4])

In [3]:
torch.randn([2, 3, 4, 4]).mean([0, 2, 3])

tensor([ 0.2029, -0.3967, -0.1321])