# Computation of an online softmax

Softmax computation requires to have whole vector in memory to perform computation.
This makes parallelization difficult with limited SRAM memory.

In their paper [Online normalizer calculation for softmax](https://arxiv.org/pdf/1805.02867.pdf), M. Milakov & Al. show an approach which makes parallelization possible by computing softmax in small blocks and incorporating progressively new knowledge from new data.

Existing implementation: https://github.com/jenkspt/online-softmax-jax (not tested)

## The original limitations

Softmax computation requires 2 elements known from whole vector:
- denominator is the sum of the exponential of each vector element;
- to avoid having overflow with `FP16` or `FP32` numbers, it's usual to substract the maximum value of the vector to each of its elements before applying the operation wise expentional operator.

## Problem setup


In [1]:
import numpy as np
from scipy.special import softmax

np.random.seed(456)

vec_len = 10
block_size = 2

data = np.random.random(vec_len)


In [2]:
data_small = data - np.max(data)
numerator = np.exp(data_small)
denominator = np.sum(numerator)
custom_softmax = numerator/denominator
assert np.allclose(custom_softmax, softmax(data))
print(custom_softmax)

[0.07322196 0.06720894 0.12500862 0.12815786 0.106737   0.1044651
 0.1384406  0.12197998 0.06843227 0.06634768]


Direct implementation of the paper without vectorization

In [3]:
import math
m = 0.
d = 0.
online_softmax = np.zeros_like(data)

for j in data:
    old_m = m
    m = max(old_m, j)
    d = d * math.exp(old_m-m) + math.exp(j-m)

for index, j in enumerate(data):
    online_softmax[index] = math.exp(j - m)/ d

assert np.allclose(online_softmax, softmax(data))
print("d", d)
print("m", m)

d 7.223314665650643
m 0.8857019031149136


In [4]:
block_max = -np.inf
normalizer = 0.

for block_start in range(0, vec_len, block_size):
    block_end = block_start + block_size
    block_data = data[block_start:block_end]
    previous_block_max = block_max
    block_max = max(np.max(block_data), block_max)
    normalizer = normalizer * np.exp(previous_block_max - block_max) + np.sum(np.exp(block_data - block_max))

assert block_max == np.max(data)
assert np.allclose(normalizer, d), f"{normalizer}, {d}"
print("d",  normalizer)
print("m", block_max)

d 7.223314665650642
m 0.8857019031149136


In [5]:
data_max = -np.inf
normalizer = 0.
l = 0.
softmax_result = np.zeros_like(data)

for block_start in range(0, vec_len, block_size):
    block_end = block_start + block_size

    block_data = data[block_start:block_end]

    old_max = data_max
    block_data_max = np.max(block_data)
    data_max = max(block_data_max, old_max)

    block_data_f = np.exp(block_data - block_data_max)
    block_data_l = np.sum(block_data_f)
    block_data_normalizer = np.exp(block_data_max - data_max)
    block_data_f_norm = block_data_f * block_data_normalizer

    previous_block_data_normalizer = np.exp(old_max - data_max)

    previous_l = l
    l = previous_l * previous_block_data_normalizer + block_data_l * block_data_normalizer

    softmax_result[block_start:block_end] = block_data_f_norm / l

    # we fix the past by fixing both the numerator (previous_block_data_normalizer), removing the effect of the previous denominator and replace by the new one
    softmax_result[:block_start] = softmax_result[:block_start] * previous_block_data_normalizer * previous_l / l

assert np.allclose(softmax_result, softmax(data))