# Computation of an online softmax

Naive implementation of `softmax` computation requires to perform several passes on all data.
Data reading from global memory (GPU DRAM) becomes the operation bottleneck.

`softmax` `triton` tutorial assumes that whole vector is small enough to be loaded in SRAM.
It solves data access bottleneck but doesn't tell us what to do when the vector is too large?

In the case of `transformer` model, the `softmax` is applied to a matrix of shape `(sequence length, sequence length)`.

In their paper [Online normalizer calculation for softmax](https://arxiv.org/pdf/1805.02867.pdf), *M. Milakov & Al.* show an approach which makes parallelization possible by computing softmax in small blocks and incorporating progressively new knowledge from new data. It's a key part of the solution we will develop below.

FWIW, existing implementation of `online softmax`: https://github.com/jenkspt/online-softmax-jax (not tested)

## The original limitations

Softmax computation requires 2 elements known from whole vector:
- denominator is the sum of the exponential of each vector element;
- to avoid having overflow with `FP16` or `FP32` numbers, it's usual to substract the maximum value of the vector to each of its elements before applying the operation wise expentional operator.

## Problem setup


In [1]:
import numpy as np
from scipy.special import softmax

np.random.seed(456)

M, N, K = 12, 4, 16  # for simplification M and K are multiples of N
block_size_M, block_size_K = N, N

long_input_vec = np.random.random((M, K))
small_vec = np.random.random((K, N))

To avoid overflow in `softmax` computation it's usual to substract the maximul value of the vector.
Of course, this operation has no effect on the final result mathematically.
This is sometimes called `safe softmax` computation.

In [2]:
expected_softmax = softmax(long_input_vec, axis=1)
expected_attention = expected_softmax @ small_vec

input_safe = long_input_vec - np.max(long_input_vec, axis=1)[:, None]

softmax_numerator = np.exp(input_safe)
softmax_denominator = np.sum(softmax_numerator, axis=1)[:, None]
naive_softmax = softmax_numerator / softmax_denominator
matmul_result = naive_softmax @ small_vec

assert np.allclose(naive_softmax, expected_softmax)
assert np.allclose(matmul_result, expected_attention)

Direct implementation of the paper without vectorization

In [3]:
online_softmax_simple = np.zeros_like(long_input_vec)
rows, cols = long_input_vec.shape

for row in range(rows):
    max_row = 0.
    softmax_denominator = 0.
    for col in range(cols):
        val = long_input_vec[row, col]
        old_max_row = max_row
        max_row = max(old_max_row, val)
        softmax_denominator = softmax_denominator * np.exp(old_max_row - max_row) + np.exp(val - max_row)

    for index, j in enumerate(long_input_vec[row, :]):
        online_softmax_simple[row, index] = np.exp(j - max_row) / softmax_denominator

assert np.allclose(online_softmax_simple, expected_softmax)

In [4]:
online_softmax = np.zeros_like(long_input_vec)
online_attention = np.zeros((M, N))

for block_start_M in range(0, M, block_size_M):
    block_end_M = block_start_M + block_size_M
    # init some variables required at the row level
    # line 4, mi will store the row max (computed progressively, block after block)
    mi = np.full((block_size_M, 1), -np.inf)
    # line 4, li will store the denominator of the softmax
    li = np.zeros((block_size_M, 1))
    # Oi contains the matmum result
    Oi = online_attention[block_start_M:block_end_M, :]
    for block_start_K in range(0, K, block_size_K):
        block_end_K = block_start_K + block_size_K
        block = long_input_vec[block_start_M:block_end_M, block_start_K:block_end_K]
        # line 6
        Vj = small_vec[block_start_K:block_end_K, :]
        # line 10, the row max of the block
        mij_hat = np.max(block, axis=1)[:, None]
        # line 10
        pij_hat = np.exp(block - mij_hat)
        # line 10
        lij_hat = np.sum(pij_hat, axis=1)[:, None]

        # line 11, the row max so far
        mi_new = np.max(np.column_stack([mi, mij_hat]), axis=1)[:, None]

        # line 11
        li_new = np.exp(mi - mi_new) * li + np.exp(mij_hat - mi_new) * lij_hat

        # line 12
        Oi = (li * np.exp(mi - mi_new) * Oi / li_new) + (np.exp(mij_hat - mi_new) * pij_hat / li_new) @ Vj
        online_attention[block_start_M:block_end_M, :] = Oi

        # line 13
        mi = mi_new
        li = li_new


assert np.allclose(online_attention, expected_attention)
