In [99]:
import torch
import math

In [100]:
def standard_attention(Q, K, V):
    '''
        Q: B x N x d
        K: B x N x d
        V: B x N x d
    '''
    d = Q.shape[-1]

    attn_logits = Q @ K.transpose(-2, -1) # B x N x N
    scaled_attn_logits = attn_logits / torch.sqrt(torch.tensor(d, dtype=torch.float32))
    attn_scores = torch.softmax(scaled_attn_logits, dim=-1) # B x N x N
    # take softmax-across the row => cuz row gets multiplied with column of V (so softmax-attn should be multiplied to each elemen in V (V is a column vector))

    # B x N x N dot B x N x d => B x N x d
    attn = attn_scores @ V

    return attn

In [101]:
def flash_attention(Q, K, V, M):
    # break down into blocks (tiling)
    # blocks of Q, K and V
    
    # Calculate Block size based on M (bytes)
    d = Q.shape[-1] 
    N = Q.shape[-2]
    d_size = 4*d #(4 bytes for each float)
    Br = int(min(math.ceil(M / d_size), d))
    Bc = int(math.ceil(M / d_size))
    
    # init out
    # init l
    out = torch.zeros_like(Q)
    l = torch.zeros(N)
    # divde Q into block sizes of Br x d
    Tr = int(math.ceil(N / Br))
    # divide K, V into block sizes of Bc x d
    Tc = int(math.ceil(N / Bc))

    for j in range(Tc):
        Kj, Vj = K[j*Bc:(j+1)*Bc], V[j*Bc:(j+1)*Bc]
        for i in range(Tr):
            Qi = Q[i*Br:(i+1)*Br]

            Sij = (Qi @ Kj.transpose(-2, -1)) /  torch.sqrt(torch.tensor(d, dtype=torch.float32))
            Pij = torch.exp(Sij) # block exp
            lij = torch.sum(Pij, dim=-1)
                        
            # new updated sum
            lnew = l[i*Br:(i+1)*Br] + lij
            
            # previous running sum
            li = l[i*Br:(i+1)*Br].unsqueeze(-1)

            # print('Pij:', Pij.shape, 'Vj:', Vj.shape, 'Oij:', out[i*Br:(i+1)*Br].shape) # DEBUG

            numerator = ((Pij @ Vj) + out[i*Br:(i+1)*Br]*li)

            # estimated softmax
            out[i*Br:(i+1)*Br] = numerator / lnew.unsqueeze(-1)
            
            # Update denominator
            l[i*Br:(i+1)*Br] = lnew

    return out


In [102]:
# # N = 128, d = 64
Q = torch.randn(128, 64, dtype=torch.float32)
K = torch.randn(128, 64, dtype=torch.float32)
V = torch.randn(128, 64, dtype=torch.float32)

In [103]:
std_attn_out = standard_attention(Q, K, V)
fl_out = flash_attention(Q, K, V, M=1024)

In [104]:
fl_out

tensor([[ 0.0201, -0.1788, -0.0363,  ..., -0.0315,  0.0576,  0.0255],
        [ 0.2961, -0.1752, -0.1401,  ..., -0.0553,  0.0707, -0.1922],
        [ 0.0422,  0.0714,  0.0121,  ...,  0.0510, -0.1946, -0.0109],
        ...,
        [ 0.0184, -0.0160, -0.1022,  ..., -0.1712, -0.1380, -0.0208],
        [-0.1606, -0.1271,  0.0710,  ..., -0.0724,  0.0618, -0.1609],
        [ 0.0693, -0.2189,  0.0981,  ...,  0.1996, -0.0690,  0.0424]])

In [105]:
std_attn_out

tensor([[ 0.0201, -0.1788, -0.0363,  ..., -0.0315,  0.0576,  0.0255],
        [ 0.2961, -0.1752, -0.1401,  ..., -0.0553,  0.0707, -0.1922],
        [ 0.0422,  0.0714,  0.0121,  ...,  0.0510, -0.1946, -0.0109],
        ...,
        [ 0.0184, -0.0160, -0.1022,  ..., -0.1712, -0.1380, -0.0208],
        [-0.1606, -0.1271,  0.0710,  ..., -0.0724,  0.0618, -0.1609],
        [ 0.0693, -0.2189,  0.0981,  ...,  0.1996, -0.0690,  0.0424]])

In [106]:
torch.allclose(fl_out, std_attn_out, rtol=1e-5, atol=1e-5)

True