This notebook breaks down how `cross_entropy` function (corresponding to `CrossEntropyLoss` used for classification) is implemented in pytorch, and how it is related to softmax, log_softmax, and nll (negative log-likelihood).

In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [20]:
batch_size, n_classes = 5, 3
x = torch.randn(batch_size, n_classes)
x.shape

torch.Size([5, 3])

In [21]:
x

tensor([[-1.3796, -1.4432,  0.8661],
        [-1.6321, -0.2053,  0.5196],
        [-0.8158,  1.2398,  1.0848],
        [ 1.5698, -0.2200, -0.1800],
        [-1.2378,  1.9793, -1.3160]])

In [22]:
target = torch.randint(n_classes, size=(batch_size,), dtype=torch.long)
target

tensor([0, 0, 1, 1, 1])

### `softmax` + `nl` (negative likelihood)

This version is most similar to the math formula, but not numerically stable.

In [23]:
def softmax(x): return x.exp() / (x.exp().sum(-1)).unsqueeze(-1)
def nl(input, target): return -input[range(target.shape[0]), target].log().mean()

pred = softmax(x)
loss=nl(pred, target)
loss

tensor(1.5794)

In [24]:
pred = softmax(x)
loss=nl(pred, target)
loss

tensor(1.5794)

### `log_softmax` + `nll` (negative log-likelihood)

https://pytorch.org/docs/stable/nn.html?highlight=logsoftmax#torch-nn-functional
>While mathematically equivalent to `log(softmax(x))`, doing these two operations separately is slower, and numerically unstable. This function uses an alternative formulation to compute the output and gradient correctly.

In [25]:
def log_softmax(x): return x - x.exp().sum(-1).log().unsqueeze(-1)
def nll(input, target): return -input[range(target.shape[0]), target].mean()

pred = log_softmax(x)
loss = nll(pred, target)
loss

tensor(1.5794)

### `F.log_softmax` + `F.nll_loss`

The above but in pytorch.

In [26]:
pred = F.log_softmax(x, dim=-1)
loss = F.nll_loss(pred, target)
loss

tensor(1.5794)

### `F.cross_entropy`

Pytorch's single cross_entropy function.

In [27]:
F.cross_entropy(x, target)

tensor(1.5794)

Reference:
- https://github.com/fastai/fastai_old