In [10]:
import random
import torch

In [15]:
torch.ops._C.paged_attention_mlrd_palu_v1??

In [12]:
torch.ops._C.paged_attention_v1??

In [16]:
from vllm import _custom_ops as ops
from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything, create_kv_caches_with_random

In [17]:
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
# MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
MAX_SEQ_LEN = 16
# There may not be enough gpu memory due to large NUM_BLOCKS.
# Reduce NUM_BLOCKS when it happens.
NUM_BLOCKS = 4321  # Arbitrary values for testing
PARTITION_SIZE = 512
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
DTYPES = [torch.half, torch.bfloat16, torch.float
          ] if not is_hip() else [torch.half, torch.bfloat16]
NUM_GEN_SEQS = [7]  # Arbitrary values for testing
NUM_PREFILL_SEQS = [3]  # Arbitrary values for testing
NUM_HEADS = [(40, 40), (64, 8)]  # Arbitrary values for testing

# FlashAttention forward only supports head dimension at most 128
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256]

BLOCK_SIZES = [16, 32]
USE_ALIBI = [False, True]
KV_CACHE_DTYPE = ["auto", "fp8"]
SEEDS = [0]
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]

In [59]:
num_heads = (16, 8)
num_seqs, num_query_heads, head_size = 4, 16, 128
block_size = 32
kv_cache_dtype = "auto"
seed = 42

In [60]:
# seed_everything(seed)
dtype = torch.half
device = "cuda:0"
torch.set_default_device(device)
scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
query.uniform_(-scale, scale)

assert num_query_heads % num_kv_heads == 0
num_queries_per_kv = num_query_heads // num_kv_heads
# alibi_slopes = None
# if use_alibi:
#     alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)

seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
seq_lens[-1] = MAX_SEQ_LEN
max_seq_len = max(seq_lens)
seq_lens = torch.tensor(seq_lens, dtype=torch.int)

# Create the block tables.
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables_lst = []
for _ in range(num_seqs):
    block_table = [
        random.randint(0, NUM_BLOCKS - 1)
        for _ in range(max_num_blocks_per_seq)
    ]
    block_tables_lst.append(block_table)

block_tables = torch.tensor(block_tables_lst, dtype=torch.int)

# Create the KV caches.
key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS, block_size, 1,
                                                        num_kv_heads, head_size,
                                                        kv_cache_dtype, dtype, seed,
                                                        device)
key_cache, value_cache = key_caches[0], value_caches[0]

# Using default kv_scale
k_scale = v_scale = 1.0

In [61]:
key_cache.shape, value_cache.shape

(torch.Size([4321, 8, 16, 32, 8]), torch.Size([4321, 8, 128, 32]))

In [62]:
# [num_kv_heads, palu_head_size, head_size]
palu_k_up_proj = torch.randn(5, 16, 128).to(dtype=dtype, device=device)

In [63]:
palu_k_up_proj.shape

torch.Size([5, 16, 128])

In [66]:
alibi_slopes = None

In [78]:
%%timeit -n 10
output = torch.empty_like(query)
ops.paged_attention_mlrd_palu_v1(
    output,
    query,
    key_cache,
    palu_k_up_proj,
    value_cache,
    num_kv_heads,
    scale,
    block_tables,
    seq_lens,
    block_size,
    max_seq_len,
    alibi_slopes,
    kv_cache_dtype,
    k_scale,
    v_scale,
)

The slowest run took 4.40 times longer than the fastest. This could mean that an intermediate result is being cached.
75.9 μs ± 59.9 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [80]:
output

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 

In [79]:
%%timeit -n 10
output = torch.empty_like(query)
ops.paged_attention_v1(
    output,
    query,
    key_cache,
    value_cache,
    num_kv_heads,
    scale,
    block_tables,
    seq_lens,
    block_size,
    max_seq_len,
    alibi_slopes,
    kv_cache_dtype,
    k_scale,
    v_scale,
)

The slowest run took 4.25 times longer than the fastest. This could mean that an intermediate result is being cached.
74.1 μs ± 56.5 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [73]:
output

tensor([[[-2.2141e-02,  3.2616e-03, -2.3148e-02,  ..., -3.5370e-02,
          -1.2772e-02, -7.0190e-03],
         [-2.2217e-02,  3.1929e-03, -2.3148e-02,  ..., -3.5461e-02,
          -1.2871e-02, -6.9122e-03],
         [ 3.7415e-02, -3.6392e-03, -4.5898e-02,  ...,  6.8665e-05,
           4.7943e-02, -1.3275e-02],
         ...,
         [ 1.8997e-02, -1.2413e-02, -3.1982e-02,  ..., -5.6793e-02,
           4.4800e-02,  2.9755e-02],
         [ 1.3069e-02, -1.8433e-02,  4.1168e-02,  ...,  1.9741e-04,
           6.8817e-03,  2.7420e-02],
         [ 1.3123e-02, -1.8463e-02,  4.1351e-02,  ...,  1.0681e-04,
           6.9847e-03,  2.7435e-02]],

        [[-8.0688e-02,  1.5297e-02,  6.9214e-02,  ..., -5.4474e-02,
          -5.7587e-02,  4.4769e-02],
         [-8.0688e-02,  1.5297e-02,  6.9214e-02,  ..., -5.4474e-02,
          -5.7587e-02,  4.4769e-02],
         [ 7.6172e-02,  8.5876e-02, -8.7891e-02,  ...,  2.2278e-02,
          -4.7755e-04,  7.1350e-02],
         ...,
         [ 3.8574e-02,  1