In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## sumstom cross-entropy loss

$\text{loss}(x, class) = -\log\left(\frac{\exp(x[class])}{\sum_j \exp(x[j])}\right)
                       = -x[class] + \log\left(\sum_j \exp(x[j])\right)$

In [2]:
def cross_entropy_loss(input, target, use_logit=False):
    
    if use_logit:
        input = torch.softmax(input, dim=1)

    target = F.one_hot(target)

    output = - target * torch.log(input)

    output = torch.sum(output, dim=1)

    return torch.mean(output)


## cumstom binary-cross-entropy loss

$\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad$

$l_n = - w_n \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right],$

In [3]:
def binary_cross_entropy_loss(input, target, use_logit=False):
    """
    -(y_n * ln(x_n) + (1 - y_n) * ln(1-x_n))
    :param input:
    :param output:
    :param use_logit:
    :return:
    """
    if use_logit:
        input = torch.sigmoid(input)

    output = - ((target * torch.log(input)) + (1. - target) * torch.log(1. - input))

    return torch.mean(output)

In [4]:
torch.random.manual_seed(2020)

input = torch.randn(2, 2)

multi_label_target = torch.tensor([[1, 0],
                                   [0, 1]], dtype=torch.float32)
multi_class_target = torch.tensor([0, 1], dtype=torch.long)
print(input)

tensor([[ 1.2372, -0.9604],
        [ 1.5415, -0.4079]])


In [5]:
m_0 = nn.Sigmoid()
m_1 = nn.Softmax(dim=1)

bce_criterion = nn.BCELoss()
ce_criterion = nn.CrossEntropyLoss()

In [6]:
output_0 = bce_criterion(m_0(input), multi_label_target)
output_1 = binary_cross_entropy_loss(input, multi_label_target, use_logit=True)

output_2 = ce_criterion(input, multi_class_target)
output_3 = cross_entropy_loss(input, multi_class_target, use_logit=True)


print(output_0)
print(output_1)


tensor(0.8080)
tensor(0.8080)


In [7]:
output_2 = ce_criterion(input, multi_class_target)
output_3 = cross_entropy_loss(input, multi_class_target, use_logit=True)

print(output_2)
print(output_3)

tensor(1.0939)
tensor(1.0939)
