In [1]:
import torch
import torch.nn.functional as F

torch.manual_seed(1337)

# ----------------------------------------------------------------------
# Cross Entropy Indices Example
num_classes = 3
num_examples = 5
input = torch.randn(size=(num_examples, num_classes)) # (5, 3)
print(input)

# Softmax
## s[i] denotes the normalized score of each class j for ith example
s = F.softmax(input, dim=-1) # (5, 3)
print(f"softmax: {s}")

manual_softmax = torch.exp(input) / torch.sum(torch.exp(input), dim=-1, keepdim=True)
assert torch.allclose(s, manual_softmax)

# Log Softmax
assert torch.allclose(F.log_softmax(input, dim=-1), torch.log(F.softmax(input, dim=-1)))
# Log Softmax are transition invariant
assert torch.allclose(F.log_softmax(input, dim=-1), F.log_softmax(F.log_softmax(input, dim=-1), dim=-1))

# Cross Entropy
## target[i] means the correct class in ith example 
target = torch.randint(num_classes, size=(num_examples,), dtype=torch.int64) # (5,)
print(f"target: {target}")

## Sum
cross_entropy = F.cross_entropy(input, target, reduction='sum')
print(f"cross_entropy: {cross_entropy}")

### In PyTorch, cross entropy == log softmax + nll
### There is no log in nll
half_manual_cross_entropy = F.nll_loss(F.log_softmax(input, dim=-1), target, reduction="sum")

### Definiion: Sum of log probability
### Sum of log probability of classification result a[i, j] which matches target i over all i examples
manual_cross_entropy = (-torch.sum(F.log_softmax(input, dim=-1)[torch.arange(0, num_examples), target]))
assert torch.allclose(cross_entropy, half_manual_cross_entropy)
assert torch.allclose(cross_entropy, manual_cross_entropy)

## Mean (default)
print(f"mean Cross Entropy (pytorch): {F.cross_entropy(input, target, reduction='mean')}")
print(f"mean Cross Entropy (log softmax + nll): {F.nll_loss(F.log_softmax(input, dim=-1), target, reduction='mean')}")
print(f"mean Cross Entropy (manual): {-torch.mean(F.log_softmax(input, dim=-1)[torch.arange(0, num_examples), target])}")



tensor([[-2.0260, -2.0655, -1.2054],
        [-0.9122, -1.2502,  0.8032],
        [-0.2071,  0.0544,  0.1378],
        [-0.3889,  0.5133,  0.3319],
        [ 0.6300,  0.5815, -0.0282]])
softmax: tensor([[0.2362, 0.2271, 0.5367],
        [0.1375, 0.0981, 0.7644],
        [0.2695, 0.3500, 0.3805],
        [0.1811, 0.4465, 0.3724],
        [0.4048, 0.3856, 0.2096]])
target: tensor([0, 2, 0, 2, 1])
cross_entropy: 4.963563919067383
mean Cross Entropy (pytorch): 0.9927127957344055
mean Cross Entropy (log softmax + nll): 0.9927127957344055
mean Cross Entropy (manual): 0.9927127957344055


In [2]:
import torch
import torch.nn.functional as F

torch.manual_seed(1337)

# -------------------------------------------------------------------------
# Cross Entropy Prob Example
input = torch.randn(3, 5)
target = torch.randn(3, 5).softmax(dim=-1)

## Sum
cross_entropy = F.cross_entropy(input, target, reduction='sum')
print(cross_entropy)
### Manual
print(-torch.sum(torch.sum(F.log_softmax(input, dim=-1) * target, dim=-1)))

## Mean (default)
print(F.cross_entropy(input, target, reduction='mean'))
### Manual
print(-torch.mean(torch.sum(F.log_softmax(input, dim=-1) * target, dim=-1)))


tensor(4.5964)
tensor(4.5964)
tensor(1.5321)
tensor(1.5321)


In [4]:
import torch
import torch.nn.functional as F
import torch.nn as nn

torch.manual_seed(1337)

# ----------------------------------------------------------------------
# Cross Entropy Indices Example
num_classes = 3
num_examples = 5
input = torch.randn(size=(num_examples, num_classes))  # (5, 3)
print(input)

# Cross Entropy
## target[i] means the correct class in ith example
target = torch.randint(num_classes, size=(num_examples,), dtype=torch.int64)  # (5,)
print(f"target: {target}")


loss = nn.CrossEntropyLoss()
assert torch.allclose(loss(input=input, target=target), F.cross_entropy(input=input, target=target))


tensor([[-2.0260, -2.0655, -1.2054],
        [-0.9122, -1.2502,  0.8032],
        [-0.2071,  0.0544,  0.1378],
        [-0.3889,  0.5133,  0.3319],
        [ 0.6300,  0.5815, -0.0282]])
target: tensor([0, 2, 0, 2, 1])
