In [7]:

import numpy as np

# Ground truth labels (one-hot encoded)
y_true = np.array([
    [1, 0, 0, 0, 0],  # Sample 1: Class 1
    [0, 1, 0, 0, 0],  # Sample 2: Class 2
    [0, 0, 1, 0, 0],  # Sample 3: Class 3
    [0, 0, 0, 1, 0],  # Sample 4: Class 4
    [0, 0, 0, 0, 1]   # Sample 5: Class 5
])

# Model predictions (logits)
logits = np.array([
    [2.0, 1.0, 0.1, 3.5, 0.8],  # Sample 1
    [0.5, 2.5, 0.2, 1.7, 2.3],  # Sample 2
    [0.1, 0.2, 3.0, 4.1, 0.9],  # Sample 3
    [1.5, 2.2, 0.7, 3.8, 1.0],  # Sample 4
    [3.2, 1.1, 0.4, 2.6, 0.5]   # Sample 5
])

# Softmax function
def softmax(x):
    exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True))  # Numerical stability
    return exp_x / np.sum(exp_x, axis=-1, keepdims=True)

# Compute softmax probabilities
probs = softmax(logits)
print("Softmax Probabilities:\n", probs)

# Compute log(probs)
log_probs = np.log(probs)
print("\nLog of Softmax Probabilities:\n", log_probs)

# Compute y_true * log(probs)
y_true_log_probs = y_true * log_probs
print("\ny_true * log(probs):\n", y_true_log_probs)

# Compute CE loss for each sample
ce_loss = -np.sum(y_true_log_probs, axis=-1)
print("\nCE Loss per Sample:", ce_loss)

# Average CE loss for the dataset
avg_ce_loss = np.mean(ce_loss)
print("\nAverage CE Loss for the Dataset:", avg_ce_loss)

Softmax Probabilities:
 [[0.15872181 0.05839049 0.0237398  0.71134181 0.04780609]
 [0.05405511 0.39941624 0.04004501 0.17946928 0.32701436]
 [0.01296966 0.01433369 0.23571254 0.70811959 0.02886452]
 [0.07120583 0.14339092 0.03199484 0.71021989 0.04318852]
 [0.55577674 0.06805843 0.03379682 0.30501674 0.03735126]]

Log of Softmax Probabilities:
 [[-1.84060223 -2.84060223 -3.74060223 -0.34060223 -3.04060223]
 [-2.9177512  -0.9177512  -3.2177512  -1.7177512  -1.1177512 ]
 [-4.34514228 -4.24514228 -1.44514228 -0.34514228 -3.54514228]
 [-2.64218065 -1.94218065 -3.44218065 -0.34218065 -3.14218065]
 [-0.58738861 -2.68738861 -3.38738861 -1.18738861 -3.28738861]]

y_true * log(probs):
 [[-1.84060223 -0.         -0.         -0.         -0.        ]
 [-0.         -0.9177512  -0.         -0.         -0.        ]
 [-0.         -0.         -1.44514228 -0.         -0.        ]
 [-0.         -0.         -0.         -0.34218065 -0.        ]
 [-0.         -0.         -0.         -0.         -3.28738861]