# Online Softmax

github: xiaodongguaAIGC

- softmax
- Safe Softmax
- online softmax
- block online softmax
- batch online softmax

In [3]:
import torch
import torch.nn.functional as F

In [4]:
X = torch.tensor([-0.3, 0.2, 0.5, 0.7, 0.1, 0.8])

## Softmax By Torch

In [5]:
X_softmax = F.softmax(X, dim = 0)
print(X_softmax)

## Softmax By Handwrite

In [6]:
X_exp_sum = X.exp().sum()
X_softmax_hand = torch.exp(X) / X_exp_sum
print(X_softmax_hand)

## Safe Softmax By Handwrite

In [7]:
X_max = X.max()
X_exp_sum_sub_max = torch.exp(X-X_max).sum()
X_safe_softmax_hand = torch.exp(X - X_max) / X_exp_sum_sub_max
print(X_safe_softmax_hand)

## Online Softmax

In [8]:
X_pre = X[:-1]
print('input x')
print(X)
print(X_pre)
print(X[-1])

# we calculative t-1 time Online Softmax
X_max_pre = X_pre.max()
X_sum_pre = torch.exp(X_pre - X_max_pre).sum()

# we calculative t time Online Softmax
X_max_cur = torch.max(X_max_pre, X[-1]) # X[-1] is new data
X_sum_cur = X_sum_pre * torch.exp(X_max_pre - X_max_cur) + torch.exp(X[-1] - X_max_cur)

# final we calculative online softmax
X_online_softmax = torch.exp(X - X_max_cur) / X_sum_cur
print('online softmax result: ', X_online_softmax)

## Block Online Softmax

In [9]:
X_block = torch.split(X, split_size_or_sections = 3 , dim = 0) 
print(X)
print(X_block)

In [10]:
# we parallel calculate  different block max & sum
X_block_0_max = X_block[0].max()
X_block_0_sum = torch.exp(X_block[0] - X_block_0_max).sum()

X_block_1_max = X_block[1].max()
X_block_1_sum = torch.exp(X_block[1] - X_block_1_max).sum()

In [11]:
# online block update max & sum
X_block_1_max_update = torch.max(X_block_0_max, X_block_1_max) # X[-1] is new data
X_block_1_sum_update = X_block_0_sum * torch.exp(X_block_0_max - X_block_1_max_update) \
                     + torch.exp(X_block[1] - X_block_1_max_update).sum() # block sum

X_block_online_softmax = torch.exp(X - X_block_1_max_update) / X_block_1_sum_update
print(X_block_online_softmax)

## Multi Block Online Softmax

In [25]:
X_block = torch.split(X, split_size_or_sections = 2, dim = 0) 
print(X)
print(X_block)

In [26]:
# we parallel calculate  different block max & sum
X_block_0_max = X_block[0].max()
X_block_0_sum = torch.exp(X_block[0] - X_block_0_max).sum()

X_block_1_max = X_block[1].max()
X_block_1_sum = torch.exp(X_block[1] - X_block_1_max).sum()

X_block_2_max = X_block[2].max()
X_block_2_sum = torch.exp(X_block[2] - X_block_2_max).sum()

M = [X_block_0_max, X_block_1_max, X_block_2_max]
L = [X_block_0_sum, X_block_1_sum, X_block_2_sum]
print(M)
print(L)

In [27]:
# online multi-block update max & sum

M_old = torch.tensor([0.0])
L_old = torch.tensor([0.0])

for i in range(len(M)):
    M_new = torch.max(M[i], M_old) 
    L_new = L_old * torch.exp(M_old - M_new) \
            + torch.exp(X_block[i] - M_new).sum() # block sum
    M_old = M_new
    L_old = L_new

X_multi_block_online_softmax = torch.exp(X - M_old) / L_old
print(X_multi_block_online_softmax)
print(X_multi_block_online_softmax.sum())

## Batch Online Softmax

### Batch Online Softmax by Torch

In [10]:
X_batch = torch.randn(4, 6)
print(X_batch)
X_batch_softmax = F.softmax(X_batch, dim = 1) 
print(X_batch_softmax)
X_batch_softmax_evaluete = X_batch_softmax.sum(dim = 1)
print(X_batch_softmax_evaluete) # row prob sum is 1

### Batch Online Softmax by Hand

In [11]:
b, d = X_batch.shape
print(b, d//2)

X_batch_block_0 = X_batch[:, :d//2]
X_batch_block_1 = X_batch[:, d//2:]

print(X_batch)
print(X_batch_block_0)
print(X_batch_block_1)

In [12]:
# we parallel calculate  different block max & sum
X_batch_0_max, _ = X_batch_block_0.max(dim = 1, keepdim = True)
X_batch_0_sum = torch.exp(X_batch_block_0 - X_batch_0_max).sum(dim = 1, keepdim = True)

X_batch_1_max, _ = X_batch_block_1.max(dim = 1, keepdim = True)
X_batch_1_sum = torch.exp(X_batch_block_1 - X_batch_1_max).sum(dim = 1, keepdim = True)

print(X_batch_0_max)
print(X_batch_0_sum)

In [13]:
# online batch block update max & sum

X_batch_1_max_update = torch.maximum(X_batch_0_max, X_batch_1_max) # 逐个元素找最大值
X_batch_1_sum_update = X_batch_0_sum * torch.exp(X_batch_0_max - X_batch_1_max_update) \
                     + torch.exp(X_batch_block_1 - X_batch_1_max_update).sum(dim = 1, keepdim = True) # block sum

X_batch_online_softmax = torch.exp(X_batch - X_batch_1_max_update) / X_batch_1_sum_update
print(X_batch_online_softmax)

In [14]:
X_batch_softmax_torch = F.softmax(X_batch, dim = 1) 
print(X_batch_softmax_torch)