In [12]:
import torch 
print('torch softmax example')
A = torch.randn(2, 6)
print(A)

torch softmax example
tensor([[-0.2735,  0.3391,  0.2336, -0.2055,  0.0382,  1.0012],
        [-0.0193,  0.2482,  0.1785, -0.7953,  0.5555,  0.8238]])


In [13]:

A_exp = torch.exp(A)
A_sum = torch.sum(A_exp, dim=1).unsqueeze(1)
P = A_exp / A_sum
print(P)

tensor([[0.0951, 0.1754, 0.1579, 0.1017, 0.1298, 0.3401],
        [0.1237, 0.1616, 0.1507, 0.0569, 0.2197, 0.2874]])


In [14]:
# safe softmax implementation
A_max, _ = torch.max(A, dim=1, keepdim=True)
A_exp = torch.exp(A - A_max)
A_sum = torch.sum(A_exp, dim=1, keepdim=True)
P = A_exp / A_sum
print(P)

tensor([[0.0951, 0.1754, 0.1579, 0.1017, 0.1298, 0.3401],
        [0.1237, 0.1616, 0.1507, 0.0569, 0.2197, 0.2874]])


# Online Softmax (3-pass vs 2-pass)

Below are Markdown versions of the algorithms in the image and PyTorch implementations.

## Algorithm: 3‑pass Online Softmax

**Idea:** Compute the global max in pass 1, the normalized sum of exponentials in pass 2, and the probabilities in pass 3.

Let the input vector be \(x_1, x_2, \dots, x_N\).

```
m_0 = -inf
d_0 = 0.0
for i = 1..N:
    m_i = max(m_{i-1}, x_i)         # pass 1: running max -> m_N is global max

for i = 1..N:
    d_i = d_{i-1} + exp(x_i - m_N)  # pass 2: sum of shifted exponentials -> d_N

for i = 1..N:
    a_i = exp(x_i - m_N) / d_N      # pass 3: probabilities
```

This is equivalent to the standard numerically-stable softmax that subtracts the global max.

## Algorithm: 2‑pass Online Softmax

**Idea:** Maintain a *running* max and a *rescaled* running sum in a single forward pass, then compute probabilities in pass 2.

```
m_0   = -inf
d'_0  = 0.0
for i = 1..N:
    m_i   = max(m_{i-1}, x_i)                         # update running max
    d'_i  = d'_{i-1} * exp(m_{i-1} - m_i) + exp(x_i - m_i)

for i = 1..N:
    a_i = exp(x_i - m_N) / d'_N
```

The factor `exp(m_{i-1} - m_i)` keeps the accumulated sum in the same scale after the max increases.

In [15]:
def online_softmax_3pass_vec(x: torch.Tensor) -> torch.Tensor:
    assert x.dim() == 1
    m = torch.tensor(-float("inf"), device=x.device, dtype=x.dtype)
    # pass 1: 全局最大
    for i in range(x.numel()):
        m = torch.maximum(m, x[i])
    # pass 2: 累加 exp(x - max)
    d = torch.tensor(0.0, device=x.device, dtype=x.dtype)
    for i in range(x.numel()):
        d = d + torch.exp(x[i] - m)
    # pass 3: 概率
    out = torch.empty_like(x)
    for i in range(x.numel()):
        out[i] = torch.exp(x[i] - m) / d
    return out

P=[]
for i in range(A.shape[0]):
    X = A[i]
    P_online = online_softmax_3pass_vec(X)
    P.append(P_online)

P = torch.stack(P)
print(P)

tensor([[0.0951, 0.1754, 0.1579, 0.1017, 0.1298, 0.3401],
        [0.1237, 0.1616, 0.1507, 0.0569, 0.2197, 0.2874]])


In [17]:
def online_softmax_2pass_vec(x: torch.Tensor) -> torch.Tensor:
    assert x.dim() == 1
    m = torch.tensor(-float("inf"), device=x.device, dtype=x.dtype)
    d = torch.tensor(0.0, device=x.device, dtype=x.dtype)
    # 单次前向：维护 running max 与重标定的 running sum
    for i in range(x.numel()):
        xi = x[i]
        m_new = torch.maximum(m, xi)
        d = d * torch.exp(m - m_new) + torch.exp(xi - m_new)
        m = m_new
    # 第二遍：用最终 m_N 和 d_N 出概率
    return torch.exp(x - m) / d

P=[]
for i in range(A.shape[0]):
    X = A[i]
    P_online = online_softmax_2pass_vec(X)
    P.append(P_online)

P = torch.stack(P)
print(P)

tensor([[0.0951, 0.1754, 0.1579, 0.1017, 0.1298, 0.3401],
        [0.1237, 0.1616, 0.1507, 0.0569, 0.2197, 0.2874]])


In [16]:
def online_softmax_3pass(x: torch.Tensor, dim: int = -1) -> torch.Tensor:
    x = x.transpose(dim, -1).contiguous()
    *prefix, N = x.shape
    y = torch.empty_like(x)

    m = torch.full(prefix, -float("inf"), device=x.device, dtype=x.dtype)
    for i in range(N):
        m = torch.maximum(m, x[..., i])

    d = torch.zeros(prefix, device=x.device, dtype=x.dtype)
    for i in range(N):
        d = d + torch.exp(x[..., i] - m)

    for i in range(N):
        y[..., i] = torch.exp(x[..., i] - m) / d
    return y.transpose(-1, dim)

def online_softmax_2pass(x: torch.Tensor, dim: int = -1) -> torch.Tensor:
    x = x.transpose(dim, -1).contiguous()
    *prefix, N = x.shape

    m = torch.full(prefix, -float("inf"), device=x.device, dtype=x.dtype)
    d = torch.zeros(prefix, device=x.device, dtype=x.dtype)
    for i in range(N):
        xi = x[..., i]
        m_new = torch.maximum(m, xi)
        d = d * torch.exp(m - m_new) + torch.exp(xi - m_new)
        m = m_new

    y = torch.exp(x - m.unsqueeze(-1)) / d.unsqueeze(-1)
    return y.transpose(-1, dim)

p1 = online_softmax_3pass(A, dim=1)
p2 = online_softmax_2pass(A, dim=1)

print(p1)
print(p2)

tensor([[0.0951, 0.1754, 0.1579, 0.1017, 0.1298, 0.3401],
        [0.1237, 0.1616, 0.1507, 0.0569, 0.2197, 0.2874]])
tensor([[0.0951, 0.1754, 0.1579, 0.1017, 0.1298, 0.3401],
        [0.1237, 0.1616, 0.1507, 0.0569, 0.2197, 0.2874]])
