**Cross entropy loss**

The cross-entropy loss (i.e. log loss), calculates the classification error when the corresponding prediction for each class is given as a value between 0 and 1. Cross-entropy loss increases as the predicted probability for the true calass decreases. In other words, predicting a low probability when the actual class label is 1 would result in a higher loss value. A perfect predictiob model would have a log loss of 0.
The cross entropy loss for a binary classification is calculated as:

Binary Cross Entropy Loss (BCELoss):

$$ 
BCELoss = −(𝑦log(𝑝)+(1−𝑦)log(1−𝑝))
$$

for the class number more than two (i.e. C_num>2): 

The multiclass cross-entropy loss:

$$
L = -\sum_{i=1}^{N} \sum_{j=1}^{C} y_{ij} \log(\hat{y}_{ij})
$$

Where:
- \( L \): The total loss.
- \( N \): Number of samples.
- \( C \): Number of classes.
- \( y_{ij} \): Ground truth indicator (1 if sample \( i \) belongs to class \( j \), 0 otherwise).
- ( \hat{y}_{ij} \): Predicted probability for sample \( i \) and class \( j \), typically output from a softmax function.

For a **batch-averaged** version of the loss:

$$
L = -\frac{1}{N} \sum_{i=1}^{N} \sum_{j=1}^{C} y_{ij} \log(\hat{y}_{ij})
$$


In [21]:
import torch
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print (x)
else:
    print ("MPS device not found.")

tensor([1.], device='mps:0')


In [25]:
import torch.nn as nn

# Sample predictions and targets
# Predictions from the model per class; here we have 4 classes and each value is the predicted value for that class e.g. 0.95 for class 0, 0.25 for class 1, 0.25 for class 2, and 0.25 for class 3
predictions = torch.tensor([[0.95, 0.25, 0.25, 0.25], [0.1, 0.2, 0.3, 0.94]], requires_grad=True)   
targets = torch.tensor([0, 3]) # Actual class labels e.g. 0 for class 0, 3 for class 3

# Define the loss function
loss_fn = nn.CrossEntropyLoss()

# Calculate the loss
loss = loss_fn(predictions, targets)

print(f'Cross-Entropy Loss: {loss.item()}')

# Backward pass to compute gradients
loss.backward()

Cross-Entropy Loss: 0.9012950658798218
