# Online Softmax

github: xiaodongguaAIGC

- softmax
- Safe Softmax
- online softmax
- block online softmax
- multi block online softmax
- batch online softmax
- multi block batch online softmax

In [1]:
import torch
import torch.nn.functional as F

In [2]:
X = torch.tensor([-0.3, 0.2, 0.5, 0.7, 0.1, 0.8])

## Softmax By Torch

In [4]:
X_softmax = F.softmax(X, dim = 0)
print(X_softmax)

tensor([0.0827, 0.1364, 0.1841, 0.2249, 0.1234, 0.2485])


## Softmax By Handwrite

In [5]:
X_exp_sum = X.exp().sum()
X_softmax_hand = torch.exp(X) / X_exp_sum
print(X_softmax_hand)

tensor([0.0827, 0.1364, 0.1841, 0.2249, 0.1234, 0.2485])


## Safe Softmax By Handwrite

In [6]:
X_max = X.max()
X_exp_sum_sub_max = torch.exp(X-X_max).sum()
X_safe_softmax_hand = torch.exp(X - X_max) / X_exp_sum_sub_max
print(X_safe_softmax_hand)

tensor([0.0827, 0.1364, 0.1841, 0.2249, 0.1234, 0.2485])


## Online Softmax

In [None]:
X_pre = X[:-1] #前N个数据
print('input x')
print(X)
print(X_pre)
print(X[-1])

# we calculative t-1 time Online Softmax
X_max_pre = X_pre.max()
X_sum_pre = torch.exp(X_pre - X_max_pre).sum() #l_(N)

# we calculative t time Online Softmax
X_max_cur = torch.max(X_max_pre, X[-1]) # X[-1] is new data。更新全局最大值
X_sum_cur = X_sum_pre * torch.exp(X_max_pre - X_max_cur) + torch.exp(X[-1] - X_max_cur) #l_(N+1)=l_N*exp(max_pre-max_cur) + exp(x_new-max_cur)，更新累加和

# final we calculative online softmax
X_online_softmax = torch.exp(X - X_max_cur) / X_sum_cur # 更新softmax
print('online softmax result: ', X_online_softmax)

input x
tensor([-0.3000,  0.2000,  0.5000,  0.7000,  0.1000,  0.8000])
tensor([-0.3000,  0.2000,  0.5000,  0.7000,  0.1000])
tensor(0.8000)
online softmax result:  tensor([0.0827, 0.1364, 0.1841, 0.2249, 0.1234, 0.2485])


## Block Online Softmax

In [8]:
X_block = torch.split(X, split_size_or_sections = 3 , dim = 0) 
print(X)
print(X_block)

tensor([-0.3000,  0.2000,  0.5000,  0.7000,  0.1000,  0.8000])
(tensor([-0.3000,  0.2000,  0.5000]), tensor([0.7000, 0.1000, 0.8000]))


In [9]:
# we parallel calculate  different block max & sum
X_block_0_max = X_block[0].max()
X_block_0_sum = torch.exp(X_block[0] - X_block_0_max).sum()

X_block_1_max = X_block[1].max()
X_block_1_sum = torch.exp(X_block[1] - X_block_1_max).sum()

In [10]:
# parallel online block update max & sum
X_max_global = torch.max(X_block_0_max, X_block_1_max) 
L_global = (X_block_0_sum * torch.exp(X_block_0_max - X_max_global) \
            + X_block_1_sum * torch.exp(X_block_1_max - X_max_global)) # block sum

X_block_online_softmax_parallel = torch.exp(X - X_max_global) / L_global
print(X_block_online_softmax_parallel)

tensor([0.0827, 0.1364, 0.1841, 0.2249, 0.1234, 0.2485])


In [11]:
# online block update max & sum
# updated version for multi-block, simpler version
X_block_1_max_update = torch.max(X_block_0_max, X_block_1_max) 
X_block_1_sum_update = X_block_0_sum * torch.exp(X_block_0_max - X_block_1_max_update) \
                     + torch.exp(X_block[1] - X_block_1_max_update).sum() # block sum

X_block_online_softmax = torch.exp(X - X_block_1_max_update) / X_block_1_sum_update
print(X_block_online_softmax)

tensor([0.0827, 0.1364, 0.1841, 0.2249, 0.1234, 0.2485])


## Multi Block Online Softmax

In [12]:
X_block = torch.split(X, split_size_or_sections = 2, dim = 0) 
print(X)
print(X_block)

tensor([-0.3000,  0.2000,  0.5000,  0.7000,  0.1000,  0.8000])
(tensor([-0.3000,  0.2000]), tensor([0.5000, 0.7000]), tensor([0.1000, 0.8000]))


In [30]:
# online multi-block update max & sum
M_old = torch.tensor([-100000.0])
L_old = torch.tensor([0.0])

# 在2.4我们实现了2个block的online softmax，我们可以拓展到多个块，并且使用for循环实现多block的更新
for i in range(len(X_block)):
    M = torch.max(X_block[i])
    M_new = torch.max(M, M_old) 
    
    L_new = L_old * torch.exp(M_old - M_new) \
            +  torch.exp(X_block[i] - M).sum() * torch.exp(M - M_new) 
    
    # use simplest format
    # L_new = L_old * torch.exp(M_old - M_new) \
    #         +  torch.exp(X_block[i] - M_new).sum() 
    
    M_old = M_new
    L_old = L_new

X_multi_block_online_softmax = torch.exp(X - M_old) / L_old
print(X_multi_block_online_softmax)
print(X_multi_block_online_softmax.sum())

tensor([0.0827, 0.1364, 0.1841, 0.2249, 0.1234, 0.2485])
tensor(1.0000)


## Batch Online Softmax

### Batch Online Softmax by Torch

In [31]:
X_batch = torch.randn(4, 6)
print(X_batch)
X_batch_softmax = F.softmax(X_batch, dim = 1) 
print(X_batch_softmax)
X_batch_softmax_evaluete = X_batch_softmax.sum(dim = 1)
print(X_batch_softmax_evaluete) # row prob sum is 1

tensor([[-0.0466,  0.9957, -0.7632,  1.5878, -0.5846,  2.0392],
        [-0.2612, -1.0416, -0.8095,  0.6656,  1.2044,  0.3641],
        [ 0.6272, -0.6707,  1.3326,  0.0887,  1.6424, -0.8036],
        [-0.5737,  1.2376,  0.1689, -0.0404, -0.2656, -0.0805]])
tensor([[0.0553, 0.1568, 0.0270, 0.2834, 0.0323, 0.4452],
        [0.0929, 0.0426, 0.0537, 0.2348, 0.4024, 0.1736],
        [0.1453, 0.0397, 0.2943, 0.0848, 0.4011, 0.0348],
        [0.0718, 0.4395, 0.1509, 0.1224, 0.0977, 0.1176]])
tensor([1.0000, 1.0000, 1.0000, 1.0000])


### Batch Online Softmax by Hand

In [None]:
b, d = X_batch.shape
print(b, d//2)

# 二维分块（只拆分为两块，可以暂时不用for循环）
X_batch_block_0 = X_batch[:, :d//2]
X_batch_block_1 = X_batch[:, d//2:]

print(X_batch)
print(X_batch_block_0)
print(X_batch_block_1)

4 3
tensor([[-0.0466,  0.9957, -0.7632,  1.5878, -0.5846,  2.0392],
        [-0.2612, -1.0416, -0.8095,  0.6656,  1.2044,  0.3641],
        [ 0.6272, -0.6707,  1.3326,  0.0887,  1.6424, -0.8036],
        [-0.5737,  1.2376,  0.1689, -0.0404, -0.2656, -0.0805]])
tensor([[-0.0466,  0.9957, -0.7632],
        [-0.2612, -1.0416, -0.8095],
        [ 0.6272, -0.6707,  1.3326],
        [-0.5737,  1.2376,  0.1689]])
tensor([[ 1.5878, -0.5846,  2.0392],
        [ 0.6656,  1.2044,  0.3641],
        [ 0.0887,  1.6424, -0.8036],
        [-0.0404, -0.2656, -0.0805]])


In [34]:
# we parallel calculate  different block max & sum
X_batch_0_max, _ = X_batch_block_0.max(dim = 1, keepdim = True)
X_batch_0_sum = torch.exp(X_batch_block_0 - X_batch_0_max).sum(dim = 1, keepdim = True)

X_batch_1_max, _ = X_batch_block_1.max(dim = 1, keepdim = True)
X_batch_1_sum = torch.exp(X_batch_block_1 - X_batch_1_max).sum(dim = 1, keepdim = True)

print(X_batch_0_max)
print(X_batch_0_sum)

tensor([[ 0.9957],
        [-0.2612],
        [ 1.3326],
        [ 1.2376]])
tensor([[1.5249],
        [2.0361],
        [1.6288],
        [1.5069]])


In [35]:
# online batch block update max & sum
X_batch_1_max_update = torch.maximum(X_batch_0_max, X_batch_1_max) # 逐个元素找最大值
X_batch_1_sum_update = X_batch_0_sum * torch.exp(X_batch_0_max - X_batch_1_max_update) \
                     + torch.exp(X_batch_block_1 - X_batch_1_max_update).sum(dim = 1, keepdim = True) # block sum

X_batch_online_softmax = torch.exp(X_batch - X_batch_1_max_update) / X_batch_1_sum_update
print(X_batch_online_softmax)

tensor([[0.0553, 0.1568, 0.0270, 0.2834, 0.0323, 0.4452],
        [0.0929, 0.0426, 0.0537, 0.2348, 0.4024, 0.1736],
        [0.1453, 0.0397, 0.2943, 0.0848, 0.4011, 0.0348],
        [0.0718, 0.4395, 0.1509, 0.1224, 0.0977, 0.1176]])


In [36]:
X_batch_softmax_torch = F.softmax(X_batch, dim = 1) 
print(X_batch_softmax_torch)

tensor([[0.0553, 0.1568, 0.0270, 0.2834, 0.0323, 0.4452],
        [0.0929, 0.0426, 0.0537, 0.2348, 0.4024, 0.1736],
        [0.1453, 0.0397, 0.2943, 0.0848, 0.4011, 0.0348],
        [0.0718, 0.4395, 0.1509, 0.1224, 0.0977, 0.1176]])


### Multi Block Batch Online Softmax

In [38]:
# X_batch = torch.randn(4, 6)
X_blocks = torch.split(X_batch, 2, dim=1)
print(X_blocks)

(tensor([[-0.0466,  0.9957],
        [-0.2612, -1.0416],
        [ 0.6272, -0.6707],
        [-0.5737,  1.2376]]), tensor([[-0.7632,  1.5878],
        [-0.8095,  0.6656],
        [ 1.3326,  0.0887],
        [ 0.1689, -0.0404]]), tensor([[-0.5846,  2.0392],
        [ 1.2044,  0.3641],
        [ 1.6424, -0.8036],
        [-0.2656, -0.0805]]))


In [39]:
b, d = X_batch.shape
M_old = torch.ones((b,1)) * -100000.0
L_old = torch.zeros((b,1))

for X_block in X_blocks:
    M,_ = torch.max(X_block, dim = 1, keepdim = True)
    M_new = torch.maximum(M, M_old) 
    
    L_new = L_old * torch.exp(M_old - M_new) \
            + torch.exp(X_block - M_new).sum(dim = 1, keepdim = True) 
    
    M_old = M_new
    L_old = L_new

X_blocks_batch = torch.exp(X_batch - M_old) / L_old
print(X_blocks_batch)
print(X_blocks_batch.sum(dim = 1, keepdim = True))

tensor([[0.0553, 0.1568, 0.0270, 0.2834, 0.0323, 0.4452],
        [0.0929, 0.0426, 0.0537, 0.2348, 0.4024, 0.1736],
        [0.1453, 0.0397, 0.2943, 0.0848, 0.4011, 0.0348],
        [0.0718, 0.4395, 0.1509, 0.1224, 0.0977, 0.1176]])
tensor([[1.0000],
        [1.0000],
        [1.0000],
        [1.0000]])
