Based on the flash attention1, the developer has implemented flash attention2, whose main changes are to put the tiled  Q in the outer loop and the tiled K and V in the inner loop, and to change the way of calculating online softmax and O. 

In [None]:
import torch

torch.manual_seed(456)

N,d=16,8
Q_mat=torch.rand((N,d))
K_mat=torch.rand((N,d))
V_mat=torch.rand((N,d))

expected_softmax=torch.softmax(Q_mat@K_mat.T,dim=1)
expected_attention=expected_softmax@V_mat

# tile size for matmul, no op bigger than this size can be stored in SRAM
Br=4
Bc=d

# variables outside the for loop represent the global memory
# they are the only ones bigger than what the SRAM can store
O=torch.zeros((N,d))

# For the 2 variables below, they may be removed in a serially executed code (in particular the outter for loop)
# They are needed in parallelized execution where each thread block need to sync its findings with the others
for block_start_Br in range(0,N,Br):
    block_end_Br=block_start_Br+Br
    # line 4, load a block of Q_mat from HBM
    Qi=Q_mat[block_start_Br:block_end_Br,:]
    # line 5, initialize Oi, li and mi.
    Oi=torch.zeros((Br,d))  # shape Br x d
    li=torch.zeros((Br,1))  # shape Br x 1
    mi=torch.full((Br,1),-torch.inf)  # shape Br x 1

    for block_start_Bc in range(0,N,Bc):
        block_end_Bc=block_start_Bc+Bc

        # line 7, load a block from matmul input tensor
        Kj=K_mat[block_start_Bc:block_end_Bc,:]
        Vj=V_mat[block_start_Bc:block_end_Bc,:]

        # line 8,QKt at the tile level
        Sij=Qi@Kj.T

        # line 9, find max of each row regarding the current block and the previous ones we have already visited
        mi_new=torch.max(torch.column_stack([mi,torch.max(Sij,dim=1).values[:,None]]),dim=1).values[:,None]
        
        # line 9,compute the softmax numerator like if we only had the data from this block (and nothing before or after)
        Pij_hat=torch.exp(Sij-mi_new)

        # line 9,adjusting factor (see online softmax computation above) leveraging the rule of exponentiation
        li=torch.exp(mi-mi_new)*li+torch.sum(Pij_hat,dim=1)[:,None]
        
        # line 10
        Oi=Oi*torch.exp(mi-mi_new)+Pij_hat@Vj

        # update the mi
        mi=mi_new
    
    # line 12
    Oi=Oi/li

    # line 14
    O[block_start_Br:block_end_Br,:]=Oi
assert torch.allclose(O,expected_attention)