In [5]:
import torch
from fla.ops.linear_attn.naive import naive_chunk_linear_attn as fla_chunk_attn

In [90]:
B = 8
L = 256
d = 128

def chunk_attn(S_prev, Q, K, V):
    # Q, K, V : (B, C, d)
    O_intra = torch.tril(Q @ K.transpose(1, 2)) @ V
    O_inter = Q @ S_prev
    O = O_inter + O_intra
    return O

def attn(Q, K, V):
    chunk_size = 16
    assert L%chunk_size==0
    n = L//chunk_size

    S_prev = torch.zeros(B, d, d)
    Q = Q.view(B, n, chunk_size, d)
    K = K.view(B, n, chunk_size, d)
    V = V.view(B, n, chunk_size, d)
    O = torch.zeros(B, n, chunk_size, d)
    
    for c in range(n):
        O[:, c] = chunk_attn(S_prev, Q[:, c], K[:, c], V[:, c])
        S_prev = S_prev + K[:, c].transpose(1, 2) @ V[:, c]

    O = O.view(B, L, d)
    return O

In [91]:
Q, K, V = torch.randn(B, L, d), torch.randn(B, L, d), torch.randn(B, L, d)

In [92]:
own_res = attn(Q, K, V)
fla_res = fla_chunk_attn(Q.unsqueeze(1), K.unsqueeze(1), V.unsqueeze(1), scale=1).squeeze(1)

In [94]:
torch.allclose(own_res, fla_res, atol=0.01)

True

In [99]:
torch.mean(own_res-fla_res)

tensor(4.2725e-08)