# Online softmax

> **note**: this tutorial requires to be familiar with tiled `matmul`. A dedicated tutorial is available in `tutorial` folder of this repository.

Naive implementation of `softmax` computation requires to perform several passes on the whole input vector.
Vector loading from global memory (`GM`, aka the GPU DRAM) for each pass is by far the operation bottleneck (compared to the computation part).

`softmax` `triton` tutorial avoid multiple read/write operations on `GM` by assuming that the whole input vector is small enough to be loaded in shared memory (`SRAM`).

Below, we describe an approach when this assumption doesn't stand, aka when the vector is too large for the `SRAM`. In the case of `transformer` model, the `softmax` is applied to each row of a matrix of shape `(sequence length, sequence length)`, and the `SRAM` limit for an `fp16` vector is around 128 tokens.

We will start the tutorial with a naive approach and optimize it.

## Problem setup



In [1]:
import torch

torch.random.manual_seed(456)

nb_rows, nb_cols = 4, 16

long_input_vec: torch.Tensor = torch.rand((nb_rows, nb_cols))


## Softmax computation

### Safe softmax

To avoid `FP16` or `FP32` overflow in `softmax` computation, it's usual to subtract to input vector its maximum value.
This operation has no effect on the final output outside numerical stability.
This is sometimes called `safe softmax` computation.

### Memory bottleneck

Computation of `safe softmax` on `PyTorch` requires multiple passes on the whole input vector if done manually:

* one pass to find the maximum value
* one pass to apply exponential operation to each value (numerator) and sum them (denominator)
* one pass to perform the division `numerator / denominator`

> because of the eager execution model, on `PyTorch` step 2 requires 2 passes.

In [2]:
# torch softmax as a reference
expected_softmax = torch.softmax(long_input_vec, dim=1)

# 1st read, torch max output both indexes and values, we only want the values
# we transpose it to get a vertical tensor
row_max = torch.max(long_input_vec, dim=1).values[:, None]
print("input row max\n", row_max)
# 2nd read
input_safe = long_input_vec - row_max
print("Below we reduce values amplitude, that's the safe part of safe softmax")
print("original 1st row input:\n", long_input_vec[0,:], "safe softmax input 1st row:\n", input_safe[0,:])

softmax_numerator = torch.exp(input_safe)
# 3rd read
softmax_denominator = torch.sum(softmax_numerator, dim=1)[:, None]
# 4th read
naive_softmax = softmax_numerator / softmax_denominator

assert torch.allclose(naive_softmax, expected_softmax)

input row max
 tensor([[0.9820],
        [0.8412],
        [0.9198],
        [0.9778]])
Below we reduce values amplitude, that's the safe part of safe softmax
original 1st row input:
 tensor([0.6815, 0.0039, 0.7451, 0.7946, 0.6127, 0.6803, 0.9820, 0.0019, 0.1609,
        0.5916, 0.6531, 0.8855, 0.7397, 0.0681, 0.3341, 0.3200]) safe softmax input 1st row:
 tensor([-0.3005, -0.9780, -0.2369, -0.1874, -0.3693, -0.3017,  0.0000, -0.9800,
        -0.8211, -0.3904, -0.3289, -0.0965, -0.2423, -0.9139, -0.6479, -0.6620])


## Online softmax

In their paper [Online normalizer calculation for softmax](https://arxiv.org/pdf/1805.02867.pdf), *M. Milakov & Al.* show an approach which makes parallelization possible by computing `softmax` progressively.
Basically, we load the input vector in small blocks (adapted to the size of the `SRAM`) and compute 2 statistics in a single pass:

* the maximum value
* the denominator

The achievement lies in the fact that you are supposed to know the maximum value of the vector to compute the denominator.
At each step, our knowledge of the maximum value may evolve (we may meet a value bigger than our precedent maximum).
When it happens, we just adjust the result of our computation of the precedent step.

The adjustment procedure is based on rules of exponentiation: when multiplying a base raised to one exponent by the same base raised to another exponent, the exponents add.

In [3]:
online_softmax_simple = torch.zeros_like(long_input_vec)
rows, cols = long_input_vec.shape

for row in range(rows):
    max_row = 0.
    softmax_denominator = 0.
    for col in range(cols):  # scalar level iteration
        val = long_input_vec[row, col]
        old_max_row = max_row
        max_row = max(old_max_row, val)
        # np.exp(old_max_row - max_row) is the adjustment factor of our precedent softmax_denominator,
        # after this multiplication it's like we had substracted max_row to all values instead of old_max_row
        softmax_denominator = softmax_denominator * torch.exp(old_max_row - max_row) + torch.exp(val - max_row)

    # leverage our 2 statistics
    online_softmax_simple[row, :] = torch.exp(long_input_vec[row, :] - max_row) / softmax_denominator

assert torch.allclose(online_softmax_simple, expected_softmax)

The very same procedure can be extended to manage blocks without any formula change, instead of moving one scalar at a time you just take a whole vector.