# Computation of an online softmax

Naive implementation of `softmax` computation requires to perform several passes on the whole input vector.
Vector loading from global memory (GPU DRAM) from each step is by far the operation bottleneck.

`softmax` `triton` tutorial assumes that the whole vector is small enough to be loaded in shared memory (`SRAM`).
That assumption fixes the issue but do not help us when it doesn't stand, aka what to do when the vector is too large for the `SRAM`?

In the case of `transformer` model, the `softmax` is applied to each row of a matrix of shape `(sequence length, sequence length)`.

We will start with a naive approach and optimize it until we reach the flash attention approach.


## Problem setup

We name axis as in GEMM: M, N and K.
For more information: https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html

In [1]:
import numpy as np
from scipy.special import softmax

np.random.seed(456)

M, N, K = 12, 4, 16  # for simplification M and K are multiples of N
block_size_M, block_size_K = N, N

long_input_vec = np.random.random((M, K))
small_vec = np.random.random((K, N))

## Softmax computation

### Safe softmax

To avoid `FP16` or `FP32` overflow in `softmax` computation, it's usual to subtract to input vector its maximum value.
This operation has no effect on the final output outside numerical stability.
This is sometimes called `safe softmax` computation.

### Memory bottleneck

Computation of `safe softmax` on `PyTorch` requires multiple passes on the whole input vector if done manually:

* one pass to find the maximum value
* one pass to apply exponential operation to each value (numerator) and sum them (denominator)
* one pass to perform the division `numerator / denominator`

> because of the eager execution model, on `PyTorch` step 2 requires 2 passes.

In [2]:
expected_softmax = softmax(long_input_vec, axis=1)
expected_attention = expected_softmax @ small_vec

# 1st read
row_max = np.max(long_input_vec, axis=1)[:, None]
# 2nd read
input_safe = long_input_vec - row_max
softmax_numerator = np.exp(input_safe)
# 3rd read
softmax_denominator = np.sum(softmax_numerator, axis=1)[:, None]
# 4th read
naive_softmax = softmax_numerator / softmax_denominator
# final matmul (another read / write)
matmul_result = naive_softmax @ small_vec

assert np.allclose(naive_softmax, expected_softmax)
assert np.allclose(matmul_result, expected_attention)

## Online softmax

In their paper [Online normalizer calculation for softmax](https://arxiv.org/pdf/1805.02867.pdf), *M. Milakov & Al.* show an approach which makes parallelization possible by computing `softmax` progressively.
Basically, we load the input vector in small blocks (adapted to the size of the `SRAM`) and compute 2 statistics in a single pass:

* the maximum value
* the denominator

The achievement lies in the fact that you are supposed to know the maximum value of the vector to compute the denominator.
At each step, our knowledge of the maximum value may evolve (we may meet a value bigger than our precedent maximum).
When it happens, we just adjust the result of our computation of the precedent step.

The adjustment procedure is based on rules of exponentiation: when multiplying a base raised to one exponent by the same base raised to another exponent, the exponents add.

In [3]:
online_softmax_simple = np.zeros_like(long_input_vec)
rows, cols = long_input_vec.shape

for row in range(rows):
    max_row = 0.
    softmax_denominator = 0.
    for col in range(cols):
        val = long_input_vec[row, col]
        old_max_row = max_row
        max_row = max(old_max_row, val)
        # np.exp(old_max_row - max_row) is the adjustment factor of our precedent softmax_denominator,
        # after this multiplication it's like we had substracted max_row to all values instead of old_max_row
        softmax_denominator = softmax_denominator * np.exp(old_max_row - max_row) + np.exp(val - max_row)

    # leverage our 2 statistics
    online_softmax_simple[row, :] = np.exp(long_input_vec[row, :] - max_row) / softmax_denominator

assert np.allclose(online_softmax_simple, expected_softmax)

## Flash attention trick

`Online softmax` limits the computation to 2 passes.
In [`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`](https://arxiv.org/pdf/2205.14135.pdf), there is only one pass, the `softmax` itself is never materialized, which is the key to the low memory footprint.

The trick is that in attention mechanism, the `softmax` is followed by a `matmul` with the `V` matrix.
In `V` matrix, the number of columns is low (<= 128), and so the output will have few columns too.

Therefore, we do not need to store the whole softmax in memory, we can progressively compute it, do the `matmul`, and adjust the output of the `matmul` (multiplication being associative and commutative).

Because we can leverage the same approach as explained in the chained matmul tutorial: we keep the output of the `matmul` in `SRAM`, avoiding a read / write from global memory for each block. That's the trick.

In [4]:
online_softmax = np.zeros_like(long_input_vec)
online_attention = np.zeros((M, N))

for block_start_M in range(0, M, block_size_M):
    block_end_M = block_start_M + block_size_M
    # init some variables required at the row level
    # line 4, mi will store the row max (computed progressively, block after block)
    mi = np.full((block_size_M, 1), -np.inf)
    # line 4, li will store the denominator of the softmax
    li = np.zeros((block_size_M, 1))
    # load from global memory, Oi contains the matmum result
    Oi = online_attention[block_start_M:block_end_M, :]
    for block_start_K in range(0, K, block_size_K):
        block_end_K = block_start_K + block_size_K
        # load a block from input tensor
        block = long_input_vec[block_start_M:block_end_M, block_start_K:block_end_K]
        # line 6, load a block from matmul input tensor
        Vj = small_vec[block_start_K:block_end_K, :]
        # line 10, find row max of the block (and only the block)
        mij_hat = np.max(block, axis=1)[:, None]
        # line 10, compute the softmax numerator like if we only had the data from this block (and nothing before and after)
        pij_hat = np.exp(block - mij_hat)
        # line 10, compute the denominator like if we only had the data from this block (and nothing before and after)
        lij_hat = np.sum(pij_hat, axis=1)[:, None]

        # line 11, find row max regarding the current block and all the previous ones we have visited
        mi_new = np.max(np.column_stack([mi, mij_hat]), axis=1)[:, None]

        # line 11, adjusting factor leveraging the rule of exponentiation
        li_new = np.exp(mi - mi_new) * li + np.exp(mij_hat - mi_new) * lij_hat

        # line 12, first part before the "+" is the adjustment of the past blocks
        # second part after the plus is the incorporation of the information from the current block and the matmul (done in steps)
        Oi = (li * np.exp(mi - mi_new) * Oi / li_new) + (np.exp(mij_hat - mi_new) * pij_hat / li_new) @ Vj

        # line 13
        mi = mi_new
        li = li_new
    # save to global memory
    online_attention[block_start_M:block_end_M, :] = Oi


assert np.allclose(online_attention, expected_attention)
