# Day 14: Cross-Entropy, Softmax & Classification Gradients

**Building LLMs from Scratch** — Following Andrej Karpathy's makemore lectures.

---

## 1. Introduction

**Softmax** converts logits (unnormalized scores) into a probability distribution over classes. **Cross-entropy** measures how well predicted probabilities match the true target. Together they form the standard loss for classification.

**The math:**
- Softmax: $p_i = \frac{e^{z_i}}{\sum_j e^{z_j}}$
- Cross-entropy (for target class $y$): $\mathcal{L} = -\log(p_y)$

**Numerical stability:** Raw $e^{z_i}$ overflows for large $z$. We subtract $\max(z)$ first: $p_i = \frac{e^{z_i - \max(z)}}{\sum_j e^{z_j - \max(z)}}$ — mathematically equivalent, numerically stable.

**The gradient:** When softmax is followed by cross-entropy, the gradient w.r.t. logits simplifies beautifully to $\frac{\partial \mathcal{L}}{\partial z_i} = p_i - \mathbb{1}[i = y]$. That is: **prediction minus truth**.

## 2. Naive Softmax

Implement `softmax_naive(logits)`: `exp(logits) / sum(exp(logits))`. Works for small values but overflows for large logits.

In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt


def softmax_naive(logits):
    """Naive softmax: exp(logits) / sum(exp(logits)). Overflows for large logits."""
    exp_logits = torch.exp(logits)
    return exp_logits / exp_logits.sum()


# Works for small values
logits_small = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
probs_small = softmax_naive(logits_small)
print("Small logits [1, 2, 3]:")
print(f"  softmax_naive = {probs_small.tolist()}")
print(f"  sum = {probs_small.sum().item():.6f}")

# Overflows for large values
logits_large = torch.tensor([1000.0, 1001.0, 1002.0], dtype=torch.float32)
print("\nLarge logits [1000, 1001, 1002]:")
try:
    probs_large = softmax_naive(logits_large)
    print(f"  softmax_naive = {probs_large.tolist()}")
except Exception as e:
    print(f"  ERROR: {e}")

## 3. Stable Softmax

Implement `softmax_stable(logits)`: subtract max first. `logits -= logits.max()` then exponentiate. Handles large values.

In [None]:
def softmax_stable(logits):
    """Numerically stable softmax: subtract max before exp."""
    logits = logits - logits.max()
    exp_logits = torch.exp(logits)
    return exp_logits / exp_logits.sum()


# Same result for small values
probs_stable_small = softmax_stable(logits_small)
print("Small logits [1, 2, 3]:")
print(f"  softmax_stable = {probs_stable_small.tolist()}")
print(f"  Match naive? {torch.allclose(probs_small, probs_stable_small)}")

# Now handles large values!
probs_stable_large = softmax_stable(logits_large)
print("\nLarge logits [1000, 1001, 1002]:")
print(f"  softmax_stable = {probs_stable_large.tolist()}")
print(f"  sum = {probs_stable_large.sum().item():.6f}")
print("  (After subtracting max, we get [-2, -1, 0] → same relative probs as [1, 2, 3])")

## 4. Cross-Entropy Loss

Implement manually: `loss = -log(softmax[target_class])`. Show equivalence with `F.cross_entropy`.

In [None]:
def cross_entropy_manual(logits, target):
    """Cross-entropy: -log(softmax[target])."""
    probs = softmax_stable(logits)
    return -torch.log(probs[target])


# Single example
logits = torch.tensor([2.0, 1.0, 0.1], dtype=torch.float32)
target = 0  # true class is index 0

loss_manual = cross_entropy_manual(logits, target)
loss_torch = F.cross_entropy(logits.unsqueeze(0), torch.tensor([target]))

print(f"Logits: {logits.tolist()}")
print(f"Target class: {target}")
print(f"Manual CE loss: {loss_manual.item():.6f}")
print(f"F.cross_entropy: {loss_torch.item():.6f}")
print(f"Match? {torch.allclose(loss_manual, loss_torch)}")

# Batch version
logits_batch = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.5, 0.3]], dtype=torch.float32)
targets_batch = torch.tensor([0, 1])

losses_manual = torch.stack([cross_entropy_manual(logits_batch[i], targets_batch[i]) for i in range(2)])
loss_batch_torch = F.cross_entropy(logits_batch, targets_batch)

print(f"\nBatch: manual mean = {losses_manual.mean().item():.6f}, F.cross_entropy = {loss_batch_torch.item():.6f}")
print(f"Match? {torch.allclose(losses_manual.mean(), loss_batch_torch)}")

## 5. Temperature Scaling

Implement `softmax(logits / T)` for different T. Low T = peaked/confident, high T = uniform/uncertain.

In [None]:
def softmax_temperature(logits, T=1.0):
    """Softmax with temperature: softmax(logits / T)."""
    return softmax_stable(logits / T)


logits = torch.tensor([2.0, 1.0, 0.5, 0.1], dtype=torch.float32)
temperatures = [0.1, 0.5, 1.0, 2.0, 5.0]

fig, axes = plt.subplots(1, 5, figsize=(14, 4), sharey=True)

for ax, T in zip(axes, temperatures):
    probs = softmax_temperature(logits, T)
    bars = ax.bar(range(len(probs)), probs.tolist(), color='steelblue', edgecolor='black')
    ax.set_xlabel('Class')
    ax.set_title(f'T = {T}')
    ax.set_xticks(range(len(probs)))
    if T == 0.1:
        ax.set_ylabel('Probability')

plt.suptitle('Temperature Scaling: Low T = peaked, High T = uniform')
plt.tight_layout()
plt.show()

print("Low T (0.1): nearly one-hot, very confident")
print("High T (5.0): nearly uniform, uncertain")

## 6. The Gradient of Cross-Entropy + Softmax

Derive and implement: `dlogits = softmax(logits) - one_hot(target)`. Verify against PyTorch autograd.

In [None]:
def cross_entropy_softmax_gradient(logits, target):
    """
    Gradient of cross-entropy + softmax w.r.t. logits:
    dlogits = softmax(logits) - one_hot(target)
    """
    probs = softmax_stable(logits)
    one_hot = torch.zeros_like(logits)
    one_hot[target] = 1.0
    return probs - one_hot


# Manual gradient vs autograd
logits = torch.tensor([2.0, 1.0, 0.1], dtype=torch.float32, requires_grad=True)
target = 0

loss = F.cross_entropy(logits.unsqueeze(0), torch.tensor([target]))
loss.backward()

grad_autograd = logits.grad
grad_manual = cross_entropy_softmax_gradient(logits.detach(), target)

print("Logits:", logits.detach().tolist())
print("Target:", target)
print("Gradient (manual):", grad_manual.tolist())
print("Gradient (autograd):", grad_autograd.tolist())
print("Match?", torch.allclose(grad_manual, grad_autograd))

# Batch version
logits_batch = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.5, 0.3]], dtype=torch.float32, requires_grad=True)
targets_batch = torch.tensor([0, 1])

loss_batch = F.cross_entropy(logits_batch, targets_batch)
loss_batch.backward()

grad_autograd_batch = logits_batch.grad
grad_manual_batch = torch.stack([
    cross_entropy_softmax_gradient(logits_batch[i].detach(), targets_batch[i].item())
    for i in range(2)
])

print("\nBatch gradients match?", torch.allclose(grad_manual_batch, grad_autograd_batch))

## 7. Why This Gradient is Beautiful

The gradient of cross-entropy + softmax w.r.t. logits is simply:

$$\frac{\partial \mathcal{L}}{\partial z_i} = p_i - \mathbb{1}[i = y]$$

That is: **prediction minus truth**.

- For the **target class** $y$: gradient is $p_y - 1$ (negative). Gradient descent *increases* logit $z_y$ → higher $p_y$ → correct.
- For **non-target classes** $i \neq y$: gradient is $p_i - 0 = p_i$ (positive). Gradient descent *decreases* logits $z_i$ → lower $p_i$ → correct.

The magnitude of the error is proportional to the prediction: if we're very wrong ($p_y$ small), we get a large gradient; if we're correct ($p_y \approx 1$), we get a tiny gradient. This is why neural nets learn — they push predictions toward targets proportionally to the error.

---

**Building LLMs from Scratch** — [Day 14: Cross-Entropy, Softmax & Classification Gradients](https://omkarray.com/llm-day14.html) | [← Prev](llm_day13.ipynb) | [Next →](llm_day15.ipynb)