In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SoftCrossEntropyLoss(nn.Module):
    """Computes the CrossEntropyLoss while accepting soft (float) targets

    Args:
        weight: a tensor of relative weights to assign to each class.
        size_average:
        reduce:

    Accepts:
        input: An [n, K_t] float tensor of prediction logits (not probabilities)
        target: An [n, K_t] float tensor of target probabilities
    """
    def __init__(self, weight=None, size_average=True, reduce=True):
        super().__init__()
        assert(weight is None or isinstance(weight, torch.FloatTensor))
        self.weight = weight
        self.reduce = reduce
        self.size_average = size_average and reduce

    def forward(self, input, target):
        N, K_t = input.shape
        total_loss = torch.tensor(0.0)
        cum_losses = torch.zeros(N)
        for y in range(K_t):
            cls_idx = torch.full((N,), y, dtype=torch.long)
            y_loss = F.cross_entropy(input, cls_idx, reduce=False)
            if self.weight is not None:
                y_loss = y_loss * self.weight[y]
            cum_losses += target[:, y] * y_loss
        if not self.reduce:
            return cum_losses
        elif self.size_average:
            return cum_losses.mean()
        else:
            return cum_losses.sum()
        
def hard_to_soft(Y_h, k):
    """Converts a 1D tensor of hard labels into a 2D tensor of soft labels

    Args:
        Y_h: an [N], or [N,1] tensor of hard (int) labels >= 1
        k: the target cardinality of the soft label matrix
    """
    Y_h = Y_h.squeeze()
    assert(Y_h.dim() == 1)
    assert((Y_h >= 1).all())
    N = Y_h.shape[0]
    Y_s = torch.zeros(N, k, dtype=torch.float)
    for i, j in enumerate(Y_h):
        Y_s[i, j-1] = 1.0
    return Y_s

### Minimum Example

In [20]:
input = torch.tensor([
    [-100.,  100., -100.],
    [-100.,  100., -100.]
])
target = torch.tensor([0,1], dtype=torch.long)
soft_target = hard_to_soft(target + 1, k=3)
weights = torch.tensor([1,2,1], dtype=torch.float)

ce_noreduce = nn.CrossEntropyLoss(weight=weights, reduce=False)
ce_reduce = nn.CrossEntropyLoss(weight=weights, reduce=True, size_average=True)
sce = SoftCrossEntropyLoss(weight=weights, reduce=True, size_average=True)

ce_loss = ce_noreduce(input, target)
ce_loss_reduced = ce_reduce(input, target)
sce_loss = sce(input, soft_target)

print(ce_loss_reduced)
print(ce_loss.mean())
print(sce_loss)

tensor(66.6667)
tensor(100.)
tensor(100.)


### Minimum Example 2

In [94]:
input = torch.tensor([
    [-1,  1, -1],
    [-1,  1, -1],
    [-1,  1, -1],
], dtype=torch.float) * 100
target = torch.tensor([0,0,1], dtype=torch.long)
weights = torch.tensor([1,10,1], dtype=torch.float)

ce_noweights = nn.CrossEntropyLoss(weight=None, reduce=False)
ce_weights = nn.CrossEntropyLoss(weight=weights, reduce=False)

loss1 = ce_noweights(input, target)
loss2 = ce_weights(input, target)

print(loss1.mean())
print(loss2.mean())

tensor(133.3333)
tensor(133.3333)


### Other experimentation

In [86]:
input = torch.tensor([
    [-1,  1, -1],
    [-1,  1, -1],
    [-1,  1, -1],
], dtype=torch.float) * 100
target = torch.tensor([0,0,1], dtype=torch.long)
correctness = target == torch.max(input, dim=1)[1].type(torch.long)
weights = torch.tensor([1,5,1], dtype=torch.float)

ce_noreduce = nn.CrossEntropyLoss(weight=weights, reduce=False)
ce_reduce = nn.CrossEntropyLoss(weight=weights, reduce=True, size_average=True)

loss1 = ce_noreduce(input, target)
loss2 = ce_reduce(input, target)

print(correctness)
print(loss1)
print(loss1.mean())
print(loss2)

tensor([ 0,  0,  1], dtype=torch.uint8)
tensor([ 200.,  200.,   -0.])
tensor(133.3333)
tensor(57.1429)
