In [2]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [7]:
import torch
from torch import Tensor, nn
import torch.nn.functional as F

## Loss functions
Loss functions provide a quantitative measure of the current performance of a neural network. There are many to choose from, but the most appropriate will often depend on your task and the form of the targets and outputs. Similar to layers in a neural network, losses are either offered as classes (inheriting from `nn.Module`) or as functions in `nn.functional`.

PyTorch provides implementations for many common losses (https://pytorch.org/docs/stable/nn.html#loss-functions), and more advanced ones can be written by the user.

In general, PyTorch losses will:
- Take an `input` argument of predictions and a `target` argument of true values. In general, the first dimension is expected to be a batch dimension.
- Have a *reduction* method, which determines how the final value is produced. The loss of each item in the batch will first be computed in isolation, then either these can be returned as an (N,) tensor (`reduction='none'`), or they can be reduced to the mean (`reduction='mean'` default) or the sum (`reduction='sum'`).

The losses in PyTorch make strong assumptions on the inputs and targets (shapes, normalisation, log-space, logits, etc.), and often this isn't indicated in the name, so it is best check the docs to see what exactly is expected.

Additionally, most losses have a `weight`, the effect of which varies between loss function and doesn't always behave as expected (to a HEP person). Additionally, they must be provided during initialisation, rather than vary per batch. If decent weight handling is required, write your own inheriting losses, or see mine: https://github.com/GilesStrong/lumin/blob/master/lumin/nn/losses/basic_weighted.py

Below will be a few common losses.

### Binary classification
For classification tasks with only two classes, the DNN can have a single output with a sigmoid output activation. The binary cross entropy function can then be used to quantify performance.

In [21]:
logit = torch.rand(10,1)  # pre-activation values of the DNN output, for a batch size of 10
targs = torch.randint(0,2, size=(10,1)).float()  # random binary targets

In [22]:
loss_fn = nn.BCELoss()

In [23]:
loss_fn(torch.sigmoid(logit), targs)

tensor(0.7767)

This is the mean binary cross-entropy for our batch. We could instead get the raw BCE per element:

In [24]:
loss_fn = nn.BCELoss(reduction='none')

In [25]:
loss_fn(torch.sigmoid(logit), targs)

tensor([[0.3986],
        [0.7126],
        [0.5999],
        [0.8604],
        [1.1029],
        [0.9895],
        [0.8462],
        [1.0604],
        [0.4637],
        [0.7328]])

In the above, we took the logits and applied a sigmoid activation to them, which involves taking the exponential of the logits. The BCE then compute the natural log of the predictions. One can save time and numerical precision, by instead computing the BCE directly from the logits:

In [26]:
loss_fn = nn.BCEWithLogitsLoss()

In [27]:
loss_fn(logit, targs)

tensor(0.7767)

### Multi-label classification
This is similar to binary classification, except now we are predicting which non-mutually-exclusive Boolean properties the inputs have. Again we can use sigmoids for each of the targets, and BCE for the loss.

In [28]:
logit = torch.rand(10,5)  # pre-activation values of the DNN output, for a batch size of 10 for 5 labels
targs = torch.randint(0,2, size=(10,5)).float()  # random binary targets

In [29]:
loss_fn = nn.BCELoss()

In [30]:
loss_fn(torch.sigmoid(logit), targs)

tensor(0.7102)

In [31]:
loss_fn = nn.BCELoss(reduction='none')

In [33]:
loss = loss_fn(torch.sigmoid(logit), targs)
loss  # reduction none, now gives the BCE per lable per item in the batch

tensor([[1.1361, 0.5336, 0.3795, 0.5929, 0.4899],
        [0.6156, 0.7693, 0.6194, 0.7785, 0.8370],
        [1.1894, 0.6137, 0.7788, 0.5598, 0.7460],
        [0.8494, 0.6064, 0.4181, 0.4173, 0.4461],
        [0.6290, 0.8533, 1.2360, 0.4041, 0.7264],
        [1.0350, 0.6776, 1.1268, 0.5273, 0.4180],
        [0.4928, 0.5614, 0.4063, 0.4067, 0.7497],
        [0.3412, 0.8677, 0.4283, 0.9006, 1.1542],
        [0.6228, 1.1422, 0.4801, 0.4836, 1.1234],
        [1.2205, 0.4430, 0.8429, 1.0939, 0.7363]])

In [37]:
loss.mean(-1, keepdim=True)  # we can get the mean loss per item ourselves, though

tensor([[0.6264],
        [0.7240],
        [0.7775],
        [0.5475],
        [0.7698],
        [0.7569],
        [0.5234],
        [0.7384],
        [0.7704],
        [0.8673]])

### Multi-class classification
Extending binary classification to the case where items belong to one and only one class, and there are more than two classes. The loss here is the categorical cross-entropy, which works by comparing the predicted probabilities that an item belongs to each of the classes to the true class it belongs to. This requires that per item, the logits are normalised to one: the softmax activation will perform this normalisation. **However** none of the pyTorch CCE losses actually expect a softmaxed input...

In [46]:
logit = torch.rand(10,5)  # pre-activation values of the DNN output, for a batch size of 10 for 5 classes
targs = torch.randint(0,5, size=(10,))  # random targets for five classes

In [47]:
loss_fn = nn.CrossEntropyLoss()

In [48]:
loss_fn(logit, targs)  # Unlike BCELoss, the CrossEntropyLoss expects the logits. Really this should be called CrossEntropyWithLogitsLoss, but hey ho

tensor(1.7859)

Alternative, if you do want to have a softmax output, there is the negative log likelihood loss, which expects... the log of the softmaxed outputs.

In [49]:
loss_fn = nn.NLLLoss()

In [57]:
loss_fn(F.softmax(logit, dim=-1).log(), targs)  # the dim=-1 indicates to normalise over the last dimension

tensor(1.7859)

Alternatively, we can use the logsoftmax activation function:

In [58]:
loss_fn(F.log_softmax(logit, dim=-1), targs)

tensor(1.7859)

#### Multi-d multi-class classification
If predicting the class of 2D data, or higher, the expected tensor shape for:
 - inputs is (batch, class, x, y,...)
 - targets is (batch, x, y,...)

In [65]:
logit = torch.rand(10,5,2,3,4)  # pre-activation values of the DNN output, for a batch size of 10 for 5 classes over a cuboid
targs = torch.randint(0,5, size=(10,2,3,4))  # random targets for five classes

In [66]:
loss_fn = nn.CrossEntropyLoss()

In [67]:
loss_fn(logit, targs)  # Unlike BCELoss, the CrossEntropyLoss expects the logits. Really this should be called CrossEntropyWithLogitsLoss, but hey ho

tensor(1.6404)

In [68]:
loss_fn = nn.NLLLoss()

In [70]:
loss_fn(F.softmax(logit, dim=1).log(), targs)  # remember to normalise over the class dimension

tensor(1.6404)

### Regression
Regression problems involve predicting float targets. Typically no output activation is used, such that outputs linear map to [-inf,inf]. In such problems, the loss should scale with the error on the prediction. Common choices are:
- squared error (p-t)**2
- absolute error |p-t|

In [71]:
logit = torch.rand(10,1)  # Outputs of the DNN output
targs = torch.rand(10,1)  # random targets values

In [72]:
loss_fn = nn.MSELoss()  # Mean square error

In [73]:
loss_fn(logit, targs)

tensor(0.1194)

In [74]:
loss_fn = nn.L1Loss()  # L1 loss is the absolute error

In [75]:
loss_fn(logit, targs)

tensor(0.2931)

## Functional losses
As mentioned, function versions of the losses exist, too, e.g.:

In [76]:
logit = torch.rand(10,1)  # Outputs of the DNN output
targs = torch.rand(10,1)  # random targets values

In [77]:
F.mse_loss(logit, targs)

tensor(0.0680)

## Custom loss function
Class-based losses inherit from `nn.Module` so making our own is quite easy. We can even inherit from existing losses that are close to what we want.
Let's make a loss that takes the squared-error on predictions and then divides it by the target:

In [78]:
class FractionalMSE(nn.MSELoss):  # Inherit from the basic MSELoss
    def __init__(self):
        super().__init__(reduction='none')  # Set the reduction to none such that the se shape matches the targets
        
    def forward(self, input, target):
        se = super().forward(input, target)  # Compute the MSE 
        fse = se/target
        return torch.mean(fse)  # return the mean fractional squared error

In [83]:
logit = torch.rand(10,2)  # Outputs of the DNN output
targs = torch.rand(10,2)  # random targets values

In [84]:
loss_fn = FractionalMSE()

In [85]:
loss_fn(logit, targs)

tensor(0.3215)