# Softmax, Stable Softmax and Log-Softmax

### What & Why (brief)

Softmax turns a vector of scores **z** into probabilities:

$$
\text{softmax}(z_i) = \frac{e^{z_i}}{\sum_j e^{z_j}}
$$

---

**Numerical issue:**  
\( e^z \) explodes for large \( z \).

**Fix:** subtract the max (doesn’t change the probabilities):

$$
\text{softmax}(z_i) = \frac{e^{z_i - \max(z)}}{\sum_j e^{z_j - \max(z)}}
$$

---

**Log-Softmax:**  

$$
\log \text{softmax}(z) = z - \log \left(\sum_j e^{z_j}\right)
$$

(computed with *log-sum-exp* for stability).


In [1]:
import numpy as np 

def softmax_naive(x):
    exp_x = np.exp(x)
    print(x)
    return exp_x / np.sum(exp_x)

def stable_softmax(x, axis= -1):
    x_shifted = x - np.max(x, axis=axis, keepdims=True) #shift for Stability
    exp_x = np.exp(x_shifted)
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)

def log_softmax(x, axis= -1):
    x_shifted = x - np.max(x, axis=axis, keepdims=True)
    return x_shifted - np.log(np.sum(np.exp(x_shifted),axis=axis, keepdims=True))

In [2]:
x = np.array([1.0, 2.0, 3.0])

print(f'Softmax Naive: {softmax_naive(x)}\n')
print(f'Stable Softmax: {stable_softmax(x)}\n')
print(f'Log Softmax: {log_softmax(x)}')

[1. 2. 3.]
Softmax Naive: [0.09003057 0.24472847 0.66524096]

Stable Softmax: [0.09003057 0.24472847 0.66524096]

Log Softmax: [-2.40760596 -1.40760596 -0.40760596]


In [3]:
x_big = np.array([1000.0, 1001.0, 1002.0])

print("Softmax naive:", softmax_naive(x_big))   # will overflow -> NaN
print("Softmax stable:", stable_softmax(x_big)) # works fine
print("Log-Softmax:", log_softmax(x_big))       # works fine

[1000. 1001. 1002.]
Softmax naive: [nan nan nan]
Softmax stable: [0.09003057 0.24472847 0.66524096]
Log-Softmax: [-2.40760596 -1.40760596 -0.40760596]


  exp_x = np.exp(x)
  return exp_x / np.sum(exp_x)


# Pytorch Version

In [4]:
import torch

def softmax_naive(x, dim=-1):
    exp_x = torch.exp(x)
    return exp_x / torch.sum(exp_x, dim=dim, keepdim=True)

def softmax_stable(x, dim=-1):
    x_shifted = x - torch.max(x, dim=dim, keepdim=True).values
    exp_x = torch.exp(x_shifted)
    return exp_x / torch.sum(exp_x, dim=dim, keepdim=True)

def log_softmax(x, dim=-1):
    x_shifted = x - torch.max(x, dim=dim, keepdim=True).values
    log_sum_exp = torch.log(torch.sum(torch.exp(x_shifted), dim=dim, keepdim=True))
    return x_shifted - log_sum_exp

In [5]:
# Example tensor
t = torch.tensor([[1.0, 2.0, 3.0],
                  [1000.0, 1001.0, 1002.0]])

print("Naive Softmax:\n", softmax_naive(t, dim=1))     # may overflow for large values
print("Stable Softmax:\n", softmax_stable(t, dim=1))   # safe
print("Log-Softmax:\n", log_softmax(t, dim=1))         # safe

# Compare to PyTorch built-ins
print("torch.softmax:\n", torch.softmax(t, dim=1))
print("torch.log_softmax:\n", torch.log_softmax(t, dim=1))

Naive Softmax:
 tensor([[0.0900, 0.2447, 0.6652],
        [   nan,    nan,    nan]])
Stable Softmax:
 tensor([[0.0900, 0.2447, 0.6652],
        [0.0900, 0.2447, 0.6652]])
Log-Softmax:
 tensor([[-2.4076, -1.4076, -0.4076],
        [-2.4076, -1.4076, -0.4076]])
torch.softmax:
 tensor([[0.0900, 0.2447, 0.6652],
        [0.0900, 0.2447, 0.6652]])
torch.log_softmax:
 tensor([[-2.4076, -1.4076, -0.4076],
        [-2.4076, -1.4076, -0.4076]])


## Using `CrossEntropyLoss` in PyTorch

In neural networks, we usually output **raw logits** from the model and pass them directly into 
`torch.nn.CrossEntropyLoss`. 

This is because `CrossEntropyLoss` internally handles:
1. Applying **log-softmax** (stable).
2. Picking the log-probability of the correct class for each sample.
3. Averaging (or summing) the negative log-likelihoods → cross-entropy.

👉 So, we don’t need to manually apply `torch.softmax` or `torch.log_softmax`. 


In [6]:
import torch
import torch.nn as nn

# Dummy logits (batch_size=2, num_classes=3)
logits = torch.tensor([[2.0, 1.0, 0.1],
                       [0.5, 2.5, 0.3]])

# Dummy labels (true class indices)
labels = torch.tensor([0, 2])

# Define criterion
criterion = nn.CrossEntropyLoss()

# Compute loss
loss = criterion(logits, labels)
print("Cross-Entropy Loss:", loss.item())


Cross-Entropy Loss: 1.4185397624969482


## What happens inside `CrossEntropyLoss`

- **Step 1:** Applies `log_softmax(logits)` (with stability tricks).  
- **Step 2:** Selects the log-probability of the correct class for each sample.  
- **Step 3:** Averages (or sums) the negative log-probabilities.  

---

### The formula

$$
\text{Loss} = -\frac{1}{N} \sum_{i=1}^N \log 
\left( \frac{e^{z_{i,y_i}}}{\sum_j e^{z_{i,j}}} \right)
$$

where:  

- \( z_{i,j} \) = logit for sample *i*, class *j*  
- \( y_i \) = true label for sample *i*
