In [1]:
# online SoftMax 2-pass
import torch

N = 6
m = torch.tensor(-1000.0)
d = 0
x = torch.randn(N)
a = torch.zeros(N)
print('x:', x)

for i in range(N):
    m_pre = m
    m = torch.max(m, x[i])
    d = d * (m_pre - m).exp() + (x[i] - m).exp()
    
for i in range(N):
    a[i] = (x[i]-m).exp() / d
    
print('online softmax a:',a)
print(torch.sum(a))

x: tensor([ 0.1438,  0.8986, -1.4740,  1.1223, -1.6086,  1.5352])
online softmax a: tensor([0.0982, 0.2090, 0.0195, 0.2613, 0.0170, 0.3949])
tensor(1.)


In [None]:
# flash attention 

NEG_INF = -1e10     # 表示负无穷，用于初始化最大值
EPSILON = 1e-10     # 防止除零

Q_LEN = 6
K_LEN = 8
Q_BLOCK_SIZE = 3
KV_BLOCK_SIZE = 4

Tr = Q_LEN // Q_BLOCK_SIZE      # 内层循环
Tc = K_LEN // KV_BLOCK_SIZE     # 外层循环

Q = torch.randn(1, 1, Q_LEN, 4, requires_grad = True).to(device='cpu')
K = torch.randn(1, 1, K_LEN, 4, requires_grad = True).to(device='cpu')
V = torch.randn(1, 1, K_LEN, 4, requires_grad = True).to(device='cpu')
O = torch.zeros_like(Q, requires_grad = True)   # O[i]：最终输出向量（对应第 i 个 query token）
l = torch.zeros(Q.shape[:-1]).unsqueeze(-1)   # l[i]：当前累积的 softmax 分母
m = torch.ones(Q.shape[:-1]).unsqueeze(-1) * NEG_INF   # m[i]：当前累积的最大 attention score

Q_blocks = torch.split(Q, Q_BLOCK_SIZE, dim=2)
K_blocks = torch.split(K, KV_BLOCK_SIZE, dim=2)
V_blocks = torch.split(V, KV_BLOCK_SIZE, dim=2)
O_blocks = list(torch.split(O, Q_BLOCK_SIZE, dim=2))
l_blocks = list(torch.split(l, Q_BLOCK_SIZE, dim=2))
m_blocks = list(torch.split(m, Q_BLOCK_SIZE, dim=2))

for j in range(Tc):
    Kj = K_blocks[j]    # (1, 1, 4, 4)
    Vj = V_blocks[j]    # (1, 1, 4, 4)
    for i in range(Tr):
        Qi = Q_blocks[i]    # (1, 1, 3, 4)
        Oi = O_blocks[i]    # (1, 1, 3, 4)
        li = l_blocks[i]    # (1, 1, 3, 1)
        mi = m_blocks[i]    # (1, 1, 3, 1)

        S_ij = Qi @ Kj.transpose(-1, -2)    # (1, 1, 3, 4)
        # 这个是沿着最后一维求最大，即每一行遍历过去取每一行中列的最大值
        # shape: (1, 1, 3, 1)
        m_block_ij, _ = torch.max(S_ij, dim=-1, keepdim=True)   # torch.max 返回两个值，第一个是最大值，第二个是索引
        P_ij = torch.exp(S_ij - m_block_ij)  # (1, 1, 3, 4)
        l_block_ij = torch.sum(P_ij, dim=-1, keepdim=True)   # (1, 1, 3, 1)
        mi_new = torch.maximum(m_block_ij, mi)
        P_ij_Vj = P_ij @ Vj  # (1, 1, 3, 4)

        li_new = torch.exp(mi - mi_new) * li + torch.exp(m_block_ij - mi_new) * l_block_ij  # (1, 1, 3, 1)
        # shape: (1, 1, 3, 4)
        O_blocks[i] = (li / li_new) * torch.exp(mi - mi_new) * Oi + (torch.exp(m_block_ij - mi_new) / li_new) * P_ij_Vj
        
        print(f'-----------Attn : Q{i}xK{j}---------')
        print('O blocks 0:', O_blocks[0])
        print('O blocks 1:', O_blocks[1])
        print('\n')

        l_blocks[i] = li_new
        m_blocks[i] = mi_new

O = torch.cat(O_blocks, dim = 2)
l = torch.cat(l_blocks, dim = 2)
m = torch.cat(m_blocks, dim = 2)

print(O)




-----------Attn : Q0xK0---------
O blocks 0: tensor([[[[ 0.9302, -0.4377,  1.2176, -0.1538],
          [-1.5067,  0.5400, -1.0797,  0.4081],
          [ 1.8763, -0.8075,  2.3340, -0.0710]]]], grad_fn=<AddBackward0>)
O blocks 1: tensor([[[[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]]]], grad_fn=<SplitBackward0>)


-----------Attn : Q1xK0---------
O blocks 0: tensor([[[[ 0.9302, -0.4377,  1.2176, -0.1538],
          [-1.5067,  0.5400, -1.0797,  0.4081],
          [ 1.8763, -0.8075,  2.3340, -0.0710]]]], grad_fn=<AddBackward0>)
O blocks 1: tensor([[[[ 2.0871, -0.9031,  2.5503, -0.0987],
          [-1.1344,  0.4068, -0.8468,  0.1686],
          [-0.2861,  0.4992,  0.0705,  0.2249]]]], grad_fn=<AddBackward0>)


-----------Attn : Q0xK1---------
O blocks 0: tensor([[[[ 0.4608, -0.2396,  0.5418, -0.1609],
          [-1.4419,  0.5208, -1.0398,  0.3814],
          [ 0.7820,  0.7656,  0.6603, -0.7068]]]], grad_fn=<AddBackward0>)
O blocks 1: tensor([[[[ 2.0871, -0.9031,