In [194]:
import torch
from fla.ops.linear_attn.naive import naive_chunk_linear_attn as fla_chunkwise_attn

In [None]:
B = 8
L = 256
d = 128

def chunkwise_attn(Q, K, V):
    # Q, K, V : (B, L, d)

    chunk_size = 16
    assert L%chunk_size==0
    n = L//chunk_size

    S_prev = torch.zeros(B, d, d)
    Q = Q.view(B, n, chunk_size, d)
    K = K.view(B, n, chunk_size, d)
    V = V.view(B, n, chunk_size, d)
    O = torch.zeros(B, n, chunk_size, d)
    
    for c in range(n):
        O[:, c] = Q[:, c] @ S_prev + torch.tril(Q[:, c] @ K[:, c].transpose(1, 2)) @ V[:, c]
        S_prev = S_prev + K[:, c].transpose(1, 2) @ V[:, c]

    O = O.view(B, L, d)
    return O

def recurrent_attn(Q, K, V):
    # Q, K, V : (B, L, d)

    S = torch.zeros(B, d, d)
    O = torch.zeros(B, L, d)

    for t in range(L):
        S = S + K[:, [t]].transpose(1, 2) @ V[:, [t]]
        O[:, [t]] = Q[:, [t]] @ S
        
    return O

In [196]:
Q, K, V = torch.randn(B, L, d), torch.randn(B, L, d), torch.randn(B, L, d)

In [None]:
own_res = chunkwise_attn(Q, K, V)
own_recc_res = recurrent_attn(Q, K, V)
fla_res = fla_chunkwise_attn(Q.unsqueeze(1), K.unsqueeze(1), V.unsqueeze(1), scale=1).squeeze(1)

In [198]:
torch.allclose(own_res, fla_res, atol=0.001)

True

In [199]:
torch.allclose(own_recc_res, fla_res, atol=0.001)

True

In [201]:
a,b,c = Q.size()

In [206]:
(lambda x:2*x)(3)

6