In [14]:
import torch
from einops import rearrange, repeat
from flash_attn import flash_attn_with_kvcache, flash_attn_varlen_func
import math

In [15]:
def attention_ref(
    q,
    k,
    v,
):
    """
    Arguments:
        q: (batch_size, seqlen_q, nheads, head_dim)
        k: (batch_size, seqlen_k, nheads_k, head_dim)
        v: (batch_size, seqlen_k, nheads_k, head_dim)
    Output:
        output: (batch_size, seqlen_q, nheads, head_dim)
        attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
    """
    dtype_og = q.dtype

    d = q.shape[-1]

    scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
    
    attention = torch.softmax(scores, dim=-1).to(v.dtype)

    output = torch.einsum("bhts,bshd->bthd", attention, v)

    return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)

In [28]:
batch_size = 1
seqlen_q = 1
seqlen_kv = 1024
num_heads = 32
head_dim = 128
device = "cuda"
dtype = torch.float16
seqlen_new = seqlen_q

# Initialize tensors
q = torch.randn(batch_size, seqlen_q, num_heads, head_dim, device=device, dtype=dtype)

k_cache = torch.randn(batch_size, seqlen_kv, num_heads, head_dim, device=device, dtype=dtype)
v_cache = torch.randn(batch_size, seqlen_kv, num_heads, head_dim, device=device, dtype=dtype)

k = 10 * torch.rand(batch_size, seqlen_new, num_heads, head_dim, device=device, dtype=dtype)
v = 15 * torch.rand(batch_size, seqlen_new, num_heads, head_dim, device=device, dtype=dtype) - 7.5

In [29]:

# Compute flash attention with kvcache
out_flashattn = flash_attn_with_kvcache(q, k_cache, v_cache)

# Reference attention computation
out_ref, _ = attention_ref(q, k_cache, v_cache)

# Print differences
print(f"flashattn vs pytorch: {(out_flashattn - out_ref).abs().mean().item()}")

flashattn vs pytorch: 2.187490463256836e-05


In [30]:
k_cache_new = torch.cat([k_cache, k], dim=1)
v_cache_new = torch.cat([v_cache, v], dim=1)

cache_seqlens = torch.tensor([seqlen_kv], dtype=torch.int32, device=device)
out_flashattn_new = flash_attn_with_kvcache(q, k_cache, v_cache, k, v, cache_seqlens=cache_seqlens)

out_ref, _ = attention_ref(q, k_cache, v_cache)
print(f"flashattn vs pytorch without append KV: {(out_flashattn_new - out_ref).abs().mean().item()}")   

out_ref_new, _ = attention_ref(q, k_cache_new, v_cache_new)
print(f"flashattn vs pytorch with append KV: {(out_flashattn_new - out_ref_new).abs().mean().item()}")   


flashattn vs pytorch without append KV: 2.187490463256836e-05
flashattn vs pytorch with append KV: 0.474365234375
