In [6]:
import os
import numpy as np
import matplotlib.pyplot as plt
from nnfs.datasets import spiral_data
import numpy as np
import nnfs
import time
import torch
import torch.nn.functional as F

np.random.seed(42)
nnfs.init()

Categorical cross-entropy loss function

In [4]:
max(enumerate((1,2,3,10,4,5)), key=lambda x: x[1])

(3, 10)

In [7]:
def categorical_cross_entropy(y_true, y_pred):
    """
    Calculate categorical cross-entropy loss
    
    Args:
        y_true: Ground truth probabilities or one-hot encoded labels
        y_pred: Predicted probabilities from softmax
        
    Returns:
        Loss value
    """
    # Ensure numerical stability by adding a small epsilon to prevent log(0)
    epsilon = 1e-15
    y_pred = torch.clamp(y_pred, epsilon, 1 - epsilon)
    
    # Calculate cross-entropy loss
    loss = -torch.sum(y_true * torch.log(y_pred), dim=1)
    
    # Return mean loss across all samples
    return torch.mean(loss)

# Example usage
# Create sample data
batch_size = 3
num_classes = 4

# One-hot encoded ground truth
y_true = torch.zeros(batch_size, num_classes)
y_true[0, 1] = 1.0  # First sample belongs to class 1
y_true[1, 0] = 1.0  # Second sample belongs to class 0
y_true[2, 2] = 1.0  # Third sample belongs to class 2

# Predicted probabilities (output from softmax)
y_pred = torch.tensor([
    [0.1, 0.7, 0.1, 0.1],  # Prediction for first sample
    [0.8, 0.1, 0.05, 0.05], # Prediction for second sample
    [0.2, 0.2, 0.5, 0.1]   # Prediction for third sample
])

# Calculate loss
loss = categorical_cross_entropy(y_true, y_pred)
print(f"Categorical Cross-Entropy Loss: {loss.item()}")

# Compare with PyTorch's built-in function
# For torch.nn.functional.cross_entropy, we need class indices, not one-hot
targets = torch.tensor([1, 0, 2])  # Class indices
# Note: F.cross_entropy expects raw logits, not softmax probabilities
# For demonstration, we'll convert probabilities back to approximate logits
logits = torch.log(y_pred)
pytorch_loss = F.cross_entropy(logits, targets)
print(f"PyTorch's Cross-Entropy Loss: {pytorch_loss.item()}")

Categorical Cross-Entropy Loss: 0.42432188987731934
PyTorch's Cross-Entropy Loss: 0.42432186007499695
