In [1]:
import torch
from flash_attn import flash_attn_varlen_func

# small test
q = torch.randn(6, 8, 64, device='cuda', dtype=torch.float16).contiguous()  # sum_seq=6, nheads=8, head_dim=64
k = torch.randn(6, 4, 64, device='cuda', dtype=torch.float16).contiguous()
v = torch.randn(6, 4, 64, device='cuda', dtype=torch.float16).contiguous()
cu_seqlens = torch.tensor([0, 3, 6], dtype=torch.int32, device='cuda')  # two sequences: lengths 3 and 3
out = flash_attn_varlen_func(q, k, v, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=3, max_seqlen_k=3, dropout_p=0.0, causal=True)
print("flash test ok, out.shape=", out.shape)

flash test ok, out.shape= torch.Size([6, 8, 64])


In [4]:

import torch
from flash_attn import flash_attn_varlen_func

# -----------------------------
# Fake batch parameters
# -----------------------------
batch_size = 4
seq_len = 512         # sequence length
hidden_size = 2048
num_heads = 16
head_dim = hidden_size // num_heads  # 2048 / 16 = 128 ✅ (≤ 256)

# -----------------------------
# Dummy hidden states
# -----------------------------
hidden_states = torch.randn(
    batch_size, seq_len, hidden_size, device='cuda', dtype=torch.float16
)
position_ids = torch.arange(seq_len, device='cuda').unsqueeze(0).expand(batch_size, -1)

# Cumulative sequence lengths for varlen FlashAttention
cu_seqlens = torch.arange(0, batch_size * seq_len + 1, seq_len, device='cuda', dtype=torch.int32)
max_seqlen = seq_len

# -----------------------------
# Dummy Q/K/V projections
# -----------------------------
q = hidden_states.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)  # [B, H, L, D]
k = hidden_states.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
v = hidden_states.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)

# Flatten batch and sequence for flash_attn_varlen_func
q_flat = q.transpose(1, 2).reshape(-1, num_heads, head_dim)  # [B*L, H, D]
k_flat = k.transpose(1, 2).reshape(-1, num_heads, head_dim)
v_flat = v.transpose(1, 2).reshape(-1, num_heads, head_dim)

# -----------------------------
# FlashAttention forward
# -----------------------------
output = flash_attn_varlen_func(
    q_flat.contiguous(), 
    k_flat.contiguous(), 
    v_flat.contiguous(),
    cu_seqlens_q=cu_seqlens,
    cu_seqlens_k=cu_seqlens,
    max_seqlen_q=max_seqlen,
    max_seqlen_k=max_seqlen,
    dropout_p=0.0,
    causal=True
)

# -----------------------------
# Output shape
# -----------------------------
print("FlashAttention output shape:", output.shape)  # [B*L, H, D]



FlashAttention output shape: torch.Size([2048, 16, 128])
