## Loss investigation
An investigation into the loss functions.

In [None]:
import torch
import matplotlib.pyplot as plt
from torch.nn import CrossEntropyLoss
from torchvision.ops import sigmoid_focal_loss

In [None]:
DTYPE = torch.float32
DEVICE = torch.device("cpu")

### Cross entropy loss

In [None]:
# Example model output (logits), shape: [batch_size, num_classes]
logits = torch.stack(
    (
        torch.linspace(-2, 2, 201, dtype=DTYPE, device=DEVICE),
        torch.linspace(2, -2, 201, dtype=DTYPE, device=DEVICE),
    ),
    dim=1,
)

# Ground truth labels, shape: [batch_size]
target = torch.zeros((201,), dtype=torch.long)  # class index

# Define the loss function
criterion = CrossEntropyLoss(reduction="none")

# Compute the loss
loss = criterion.forward(logits, target)

In [None]:
probs = logits.softmax(dim=1)

plt.plot(probs[:, 0].cpu().numpy(), loss.cpu().numpy())
plt.grid()
plt.xlabel("Class 0 Probability")
plt.ylabel("Loss")
plt.xlim([0, 1])
plt.ylim([0, 5])
plt.title("Cross entropy loss\nGround truth class 0")

In [None]:
plt.plot(probs[:, 1].cpu().numpy(), loss.cpu().numpy())
plt.grid()
plt.xlabel("Class 0 Probability")
plt.ylabel("Loss")
plt.xlim([0, 1])
plt.ylim([0, 5])
plt.title("Cross entropy loss\nGround truth class 1")

### Sigmoid focal loss

In [None]:
# Example model output (logits), shape: [batch_size, num_classes]
logits = torch.stack(
    (
        torch.linspace(-4, 4, 201, dtype=DTYPE, device=DEVICE),
        torch.linspace(4, -4, 201, dtype=DTYPE, device=DEVICE),
    ),
    dim=1,
)

# Targets, shape: [batch_size, num_classes]
targets = torch.stack(
    (
        torch.ones((201,), dtype=DTYPE, device=DEVICE),
        torch.zeros((201,), dtype=DTYPE, device=DEVICE),
    ),
    dim=1,
)

# Define the loss function
loss = sigmoid_focal_loss(logits, targets, -1, 0, reduction="none")

In [None]:
probs = logits.sigmoid()

plt.plot(probs[:, 0].cpu().numpy(), loss.cpu().numpy())
plt.grid()
plt.xlabel("Class 0 Probability")
plt.ylabel("Loss")
plt.xlim([0, 1])
plt.ylim([0, 5])
plt.title("Cross entropy loss\nGround truth class 0")

In [None]:
import torch

In [None]:
# Tensor you want to sort by
sort_by = torch.rand((6,))

# Tensor you want to reorder
data = torch.rand((6,))

# Get the sorting indices
_, indices = torch.sort(sort_by, descending=True)

# Use indices to sort the data tensor
sorted_data = data[indices]

print(sort_by)
print(indices)
print(data)
print(sorted_data)

In [None]:
indices