In [3]:
import numpy as np

# Ground truth labels (one-hot encoded)
y_true = np.array([
[1, 0, 0, 0, 0],  # Sample 1: Class 1
[0, 1, 0, 0, 0],  # Sample 2: Class 2
[0, 0, 1, 0, 0],  # Sample 3: Class 3
[0, 0, 0, 1, 0],  # Sample 4: Class 4
[0, 0, 0, 0, 1]  # Sample 5: Class 5

])

# Model predictions (logits)
logits = np.array([
[2.0, 1.0, 0.1, 0.5, 0.3],  # Sample 1
[0.5, 2.5, 0.2, 1.2, 0.8],  # Sample 2
[0.1, 0.2, 3.0, 0.7, 1.5],  # Sample 3
[1.3, 0.4, 0.6, 2.8, 0.9],  # Sample 4
[0.7, 1.1, 0.3, 0.4, 2.2]  # Sample 5

])

# Softmax function
def softmax(x):
    exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True))  # Numerical stability
    return exp_x / np.sum(exp_x, axis=-1, keepdims=True)

# Compute softmax probabilities
probs = softmax(logits)
print("Softmax Probabilities:\n", probs)

# Compute log(probs)
log_probs = np.log(probs)
print("\nLog of Softmax Probabilities:\n", log_probs)

# Compute y_true * log(probs)
y_true_log_probs = y_true * log_probs
print("\ny_true * log(probs):\n", y_true_log_probs)

# Compute CE loss for each sample
ce_loss = -np.sum(y_true_log_probs, axis=-1)
print("\nCE Loss per Sample:", ce_loss)

# Average CE loss for the dataset
avg_ce_loss = np.mean(ce_loss)
print("\nAverage CE Loss for the Dataset:", avg_ce_loss)

Softmax Probabilities:
 [[0.51995003 0.19127893 0.07776821 0.11601653 0.0949863 ]
 [0.08004171 0.5914327  0.05929636 0.16118422 0.10804501]
 [0.03823122 0.04225203 0.6948197  0.06966182 0.15503523]
 [0.14174015 0.05762724 0.07038607 0.63523527 0.09501126]
 [0.11926553 0.17792327 0.07994608 0.08835408 0.53451104]]

Log of Softmax Probabilities:
 [[-0.65402257 -1.65402257 -2.55402257 -2.15402257 -2.35402257]
 [-2.52520738 -0.52520738 -2.82520738 -1.82520738 -2.22520738]
 [-3.26410289 -3.16410289 -0.36410289 -2.66410289 -1.86410289]
 [-1.95375984 -2.85375984 -2.65375984 -0.45375984 -2.35375984]
 [-2.1264029  -1.7264029  -2.5264029  -2.4264029  -0.6264029 ]]

y_true * log(probs):
 [[-0.65402257 -0.         -0.         -0.         -0.        ]
 [-0.         -0.52520738 -0.         -0.         -0.        ]
 [-0.         -0.         -0.36410289 -0.         -0.        ]
 [-0.         -0.         -0.         -0.45375984 -0.        ]
 [-0.         -0.         -0.         -0.         -0.6264029 ]