# Cross-Entropy from Logits (Stable) + Gradients

## Multiclass Cross-Entropy (CE) Loss

**1) Softmax (convert logits to probabilities)**  
$$
p_i = \frac{e^{z_i}}{\sum_{j=1}^{C} e^{z_j}}
$$  
Turns raw scores (*logits*) into probabilities that sum to 1.

---

**2) Cross-Entropy (one-hot labels)**  
$$
CE(p, y) = -\sum_{i=1}^{C} y_i \log p_i = -\log p_{y^*}
$$  
Penalty is large if the true class \(y^*\) gets low probability.

---

**3) Stable CE directly from logits (log-sum-exp trick)**  
$$
CE(z, y^*) = \log \left( \sum_{j=1}^{C} e^{z_j} \right) - z_{y^*}
$$  
For numerical stability:  
$$
LSE(z) = m + \log \left( \sum_{j=1}^{C} e^{z_j - m} \right), \quad m = \max_j z_j
$$  
$$
CE(z, y^*) = LSE(z) - z_{y^*}
$$

---

**4) Batch Loss (N samples)**  
$$
\mathcal{L} = \frac{1}{N} \sum_{n=1}^{N} \big(-\log p^{(n)}_{y^*(n)}\big)
$$

---

**5) Gradient w.r.t logits**  
$$
\frac{\partial \mathcal{L}}{\partial z_i} = p_i - y_i
$$  
For a batch (mean reduction):  
$$
\frac{\partial \mathcal{L}}{\partial z_{n,i}} = \frac{1}{N}\big(p_{n,i} - y_{n,i}\big)
$$


In [1]:
import numpy as np 

def softmax_logits(logits, axis= -1):
    z = logits - np.max(logits, axis=axis, keepdims=True)
    exp_z = np.exp(z)
    return exp_z / np.sum(exp_z, axis=axis, keepdims=True)

In [2]:
def cross_entropy_from_probs(p, y, eps=1e-7, reduction='mean'):
    
    p_clamped = np.clip(p, eps, 1.0) #avoid log(0)
    loss = -(np.sum(y * np.log(p_clamped),axis= -1))
    if reduction == 'mean':
        return loss.mean()
    elif reduction == 'sum':
        return loss.sum()
    else:
        return loss

In [14]:
def cross_entropy_from_logits(z, y, reduction='mean'):
    max_z = np.max(logits, axis=-1, keepdims=True)
    lse = np.log(np.sum(np.exp(logits - max_z), axis=-1, keepdims=True)) + max_z
    # CE = LSE - z_true
    z_true = np.sum(y * logits, axis=-1, keepdims=True)
    loss = (lse - z_true).squeeze(-1)

    if reduction == 'mean':
        return loss.mean()
    elif reduction == 'sum':
        return loss.sum()
    else:
        return loss

In [4]:
def grad_logits_cross_entropy(logits, y, reduction= 'mean'):
    p = softmax_logits(logits, axis=-1)
    grad = p - y
    if reduction == 'mean':
        N = logits.shape[0]
        grad = grad / N
    return grad

In [22]:
logits = np.array([[2.0, 0.5, -1.0, 0.0],
                   [0.3, 1.2, 0.1, -0.5],
                   [3.0, 2.0, 0.0, -1.0]])

y_idx = np.array([0, 1, 0])

# One-hot labels
y = np.eye(4)[y_idx]
print(y)

[[1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [1. 0. 0. 0.]]


In [19]:
# Probabilities and loss
p = softmax_logits(logits)
loss_probs = cross_entropy_from_probs(p, y, reduction="mean")
loss_logits = cross_entropy_from_logits(logits, y, reduction="mean")

print("probs (rows sum to 1):\n", p)
print("\nCE from probs:", loss_probs)
print("\nCE from logits (stable):", loss_logits)

probs (rows sum to 1):
 [[0.71009992 0.15844471 0.03535379 0.09610157]
 [0.21152101 0.52025773 0.17317875 0.09504251]
 [0.69638749 0.25618664 0.03467109 0.01275478]]

CE from probs: 0.45254319508077606

CE from logits (stable): 0.45254319508077595


In [20]:
# Gradient wrt logits
g = grad_logits_cross_entropy(logits, y, reduction="mean")
print("grad shape:", g.shape)
print("row sums of grad (mean reduction):", np.sum(g, axis=1))

grad shape: (3, 4)
row sums of grad (mean reduction): [2.77555756e-17 0.00000000e+00 1.56125113e-17]
