In [1]:
import os

os.environ['TORCH_LOGS'] = 'recompiles'

In [2]:
import torch

torch.cuda.get_device_name(0)

'NVIDIA B200'

In [3]:
import random
import numpy as np
import flash_attn
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.nn.attention.flex_attention import create_block_mask, flex_attention

torch._dynamo.config.cache_size_limit = 1000

flex_attention = torch.compile(flex_attention, dynamic=False)

def generate_list_sum_n(n, length=5, min_val=5):

    numbers = [min_val] * length
    remaining = n - min_val * length

    for _ in range(remaining):
        numbers[random.randint(0, length - 1)] += 1

    random.shuffle(numbers)
    return numbers

def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

def _offsets_to_doc_ids_tensor(offsets):
    device = offsets.device
    offsets = offsets[offsets != -1]
    counts = offsets[1:] - offsets[:-1]
    return torch.repeat_interleave(
        torch.arange(len(counts), device=device, dtype=torch.int32), counts
    )

def length_to_offsets(lengths, device):
    offsets = [0]
    offsets.extend(lengths)
    offsets = torch.tensor(offsets, device=device, dtype=torch.int32)
    offsets = torch.cumsum(offsets, dim=-1)
    return offsets

def generate_doc_mask_mod(offsets):
    
    offsets = pad_sequence(offsets, batch_first = True, padding_value = -1)
    docs = [_offsets_to_doc_ids_tensor(offsets[i]) for i in range(offsets.shape[0])]
    docs = torch.stack(docs, 0)
    
    def document_causal_mask(b, h, q_idx, kv_idx):
        causal_mask = q_idx >= kv_idx
        document_mask = docs[b, q_idx] == docs[b, kv_idx]
        return causal_mask & document_mask
    
    return document_causal_mask

In [4]:
sequence_length = 4096
query_lens = np.array(generate_list_sum_n(sequence_length, length=20, min_val=10), dtype=np.int64)

q = torch.randn(1, sequence_length, 32, 128, dtype = torch.bfloat16).cuda()
k = torch.randn(1, sequence_length, 32, 128, dtype = torch.bfloat16).cuda()
v = torch.randn(1, sequence_length, 32, 128, dtype = torch.bfloat16).cuda()

cumsum = [0] + np.cumsum(query_lens).tolist()
max_cumsum = int(np.max(cumsum))
cu_seq_lens_q = torch.tensor(cumsum, dtype=torch.int32).cuda()
max_seqlen_q = np.max(query_lens)

out_flash2 = flash_attn.flash_attn_varlen_func(
    q = q[0],
    k = k[0],
    v = v[0],
    cu_seqlens_q = cu_seq_lens_q,
    cu_seqlens_k = cu_seq_lens_q,
    max_seqlen_q = max_seqlen_q,
    max_seqlen_k = max_seqlen_q,
    causal = True
)

In [5]:
attention_mask = cu_seq_lens_q
seq_len = q.shape[1]
device = q.device
document_causal_mask = generate_doc_mask_mod(attention_mask[None])
attention_mask = create_block_mask(
    document_causal_mask, None, None, seq_len, seq_len, device, _compile = True)

In [6]:
o = flex_attention(
    q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), score_mod=None, block_mask=attention_mask
)

In [7]:
torch.allclose(out_flash2, o[0].transpose(0, 1), atol=0.125, rtol=0)

True

In [8]:
import time

for i in range(10):

    query_lens = np.array(generate_list_sum_n(sequence_length, length=random.randint(12, 30), min_val=10), dtype=np.int64)
    
    q = torch.randn(1, sequence_length, 32, 128, dtype = torch.bfloat16).cuda()
    k = torch.randn(1, sequence_length, 32, 128, dtype = torch.bfloat16).cuda()
    v = torch.randn(1, sequence_length, 32, 128, dtype = torch.bfloat16).cuda()
    
    cumsum = [0] + np.cumsum(query_lens).tolist()
    print(i, 'doc cumsum', cumsum)
    max_cumsum = int(np.max(cumsum))
    cu_seq_lens_q = torch.tensor(cumsum, dtype=torch.int32).cuda()
    max_seqlen_q = np.max(query_lens)

    torch.cuda.synchronize()
    t0 = time.time()

    out_flash2 = flash_attn.flash_attn_varlen_func(
        q = q[0],
        k = k[0],
        v = v[0],
        cu_seqlens_q = cu_seq_lens_q,
        cu_seqlens_k = cu_seq_lens_q,
        max_seqlen_q = max_seqlen_q,
        max_seqlen_k = max_seqlen_q,
        causal = True
    )

    torch.cuda.synchronize()
    t1 = time.time()
    dt = t1 - t0

    print(i, 'fa2', dt)

    attention_mask = cu_seq_lens_q
    seq_len = q.shape[1]
    device = q.device

    torch.cuda.synchronize()
    t0 = time.time()
    
    document_causal_mask = generate_doc_mask_mod(attention_mask[None])
    attention_mask = create_block_mask(
        document_causal_mask, None, None, seq_len, seq_len, device, _compile = True)

    o = flex_attention(
        q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), score_mod=None, block_mask=attention_mask
    )

    torch.cuda.synchronize()
    t1 = time.time()
    dt = t1 - t0

    print(i, 'flex', dt)
    print()

0 doc cumsum [0, 304, 592, 921, 1235, 1553, 1854, 2187, 2477, 2807, 3119, 3470, 3767, 4096]
0 fa2 0.00041866302490234375
0 flex 0.0024042129516601562

1 doc cumsum [0, 313, 654, 991, 1323, 1639, 1968, 2277, 2568, 2862, 3155, 3480, 3800, 4096]
1 fa2 0.0004990100860595703
1 flex 0.0026535987854003906

2 doc cumsum [0, 240, 482, 742, 991, 1264, 1540, 1794, 2029, 2283, 2541, 2809, 3058, 3324, 3595, 3845, 4096]
2 fa2 0.0005774497985839844
2 flex 0.0028772354125976562

3 doc cumsum [0, 163, 291, 441, 610, 753, 904, 1048, 1196, 1336, 1484, 1626, 1782, 1929, 2082, 2225, 2365, 2515, 2654, 2815, 2960, 3087, 3227, 3377, 3510, 3655, 3813, 3956, 4096]
3 fa2 0.0005552768707275391
3 flex 0.0035028457641601562

4 doc cumsum [0, 359, 675, 1012, 1374, 1704, 2049, 2391, 2741, 3084, 3388, 3759, 4096]
4 fa2 0.00047397613525390625
4 flex 0.0030066967010498047

5 doc cumsum [0, 164, 328, 479, 643, 815, 980, 1131, 1266, 1421, 1564, 1709, 1860, 1993, 2122, 2269, 2416, 2563, 2729, 2880, 3018, 3181, 3333, 3485, 