In [None]:
import random
import torch

In [None]:
from vllm import _custom_ops as ops
from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything, create_kv_caches_with_random

  from .autonotebook import tqdm as notebook_tqdm
2024-10-17 16:15:34,303	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [None]:
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
# MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
MAX_SEQ_LEN = 16
# There may not be enough gpu memory due to large NUM_BLOCKS.
# Reduce NUM_BLOCKS when it happens.
NUM_BLOCKS = 4321  # Arbitrary values for testing
PARTITION_SIZE = 512
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
DTYPES = [torch.half, torch.bfloat16, torch.float
          ] if not is_hip() else [torch.half, torch.bfloat16]
NUM_GEN_SEQS = [7]  # Arbitrary values for testing
NUM_PREFILL_SEQS = [3]  # Arbitrary values for testing
NUM_HEADS = [(40, 40), (64, 8)]  # Arbitrary values for testing

# FlashAttention forward only supports head dimension at most 128
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256]

BLOCK_SIZES = [16, 32]
USE_ALIBI = [False, True]
KV_CACHE_DTYPE = ["auto", "fp8"]
SEEDS = [0]
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]

In [None]:
num_heads = (16, 8)
num_seqs, num_query_heads, head_size = 4, 16, 128
block_size = 32
kv_cache_dtype = "auto"
seed = 42

In [None]:
# seed_everything(seed)
dtype = torch.half
device = "cuda:0"
torch.set_default_device(device)
scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
query.uniform_(-scale, scale)

assert num_query_heads % num_kv_heads == 0
num_queries_per_kv = num_query_heads // num_kv_heads
# alibi_slopes = None
# if use_alibi:
#     alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)

seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
seq_lens[-1] = MAX_SEQ_LEN
max_seq_len = max(seq_lens)
seq_lens = torch.tensor(seq_lens, dtype=torch.int)

# Create the block tables.
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables_lst = []
for _ in range(num_seqs):
    block_table = [
        random.randint(0, NUM_BLOCKS - 1)
        for _ in range(max_num_blocks_per_seq)
    ]
    block_tables_lst.append(block_table)

block_tables = torch.tensor(block_tables_lst, dtype=torch.int)

# Create the KV caches.
key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS, block_size, 1,
                                                        num_kv_heads, head_size,
                                                        kv_cache_dtype, dtype, seed,
                                                        device)
key_cache, value_cache = key_caches[0], value_caches[0]

# Using default kv_scale
k_scale = v_scale = 1.0

In [None]:
key_cache.shape, value_cache.shape

(torch.Size([4321, 8, 16, 32, 8]), torch.Size([4321, 8, 128, 32]))

In [None]:
# [num_kv_heads, palu_head_size, head_size]
palu_k_up_proj = torch.ones(num_kv_heads, head_size//8, head_size).to(dtype=dtype, device=device)

In [None]:
palu_k_up_proj

tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        ...,

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1., 

In [None]:
alibi_slopes = None

In [None]:
output = torch.empty_like(query)

In [None]:
output

tensor([[[ 1.5318e-05,  1.5318e-05,  1.5318e-05,  ...,  1.5318e-05,
           1.5318e-05,  1.5318e-05],
         [ 1.5318e-05,  1.5318e-05,  1.5318e-05,  ...,  1.5318e-05,
           1.5318e-05,  1.5318e-05],
         [ 1.5318e-05,  1.5318e-05,  1.5318e-05,  ...,  1.5318e-05,
           1.5318e-05,  1.5318e-05],
         ...,
         [ 1.5318e-05,  1.5318e-05,  1.5318e-05,  ...,  1.5318e-05,
           1.5318e-05,  1.5318e-05],
         [ 1.5318e-05,  1.5318e-05,  1.5318e-05,  ...,  1.5318e-05,
           1.5318e-05,  1.5318e-05],
         [ 1.5318e-05,  1.5318e-05,  1.5318e-05,  ...,  1.5318e-05,
           1.5318e-05,  1.5318e-05]],

        [[ 1.5318e-05,  1.5318e-05,  1.5318e-05,  ...,  1.5318e-05,
           1.5318e-05,  1.5318e-05],
         [ 1.5318e-05,  1.5318e-05,  1.5318e-05,  ...,  1.5318e-05,
           1.5318e-05,  1.5318e-05],
         [ 1.5318e-05,  1.5318e-05,  1.5318e-05,  ...,  1.5318e-05,
           1.5318e-05,  1.5318e-05],
         ...,
         [ 1.5318e-05,  1

In [None]:
# %%timeit -n 10
output = torch.empty_like(query)
ops.paged_attention_mlrd_palu_v1(
    output,
    query,
    key_cache,
    palu_k_up_proj,
    value_cache,
    num_kv_heads,
    scale,
    block_tables,
    seq_lens,
    block_size,
    max_seq_len,
    alibi_slopes,
    kv_cache_dtype,
    k_scale,
    v_scale,
)

In [None]:
output

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 

In [None]:
assert not torch.all(output == 0)

AssertionError: 

In [None]:
output.shape

torch.Size([4, 16, 128])

In [None]:
output.shape

torch.Size([4, 16, 128])

In [None]:
# %%timeit -n 10
output = torch.empty_like(query)
ops.paged_attention_v1(
    output,
    query,
    key_cache,
    value_cache,
    num_kv_heads,
    scale,
    block_tables,
    seq_lens,
    block_size,
    max_seq_len,
    alibi_slopes,
    kv_cache_dtype,
    k_scale,
    v_scale,
)

In [None]:
output

tensor([[[ 0.0054, -0.0048, -0.0015,  ...,  0.0238, -0.0044, -0.0098],
         [ 0.0054, -0.0048, -0.0015,  ...,  0.0238, -0.0044, -0.0098],
         [-0.0051,  0.0019, -0.0152,  ...,  0.0002, -0.0081,  0.0030],
         ...,
         [ 0.0134,  0.0103,  0.0134,  ..., -0.0143, -0.0222,  0.0170],
         [-0.0102, -0.0091,  0.0178,  ..., -0.0154,  0.0082, -0.0058],
         [-0.0103, -0.0092,  0.0179,  ..., -0.0153,  0.0083, -0.0057]],

        [[ 0.0057,  0.0132,  0.0216,  ...,  0.0173,  0.0064,  0.0083],
         [ 0.0057,  0.0133,  0.0216,  ...,  0.0172,  0.0063,  0.0082],
         [ 0.0148, -0.0280, -0.0303,  ...,  0.0019, -0.0021, -0.0257],
         ...,
         [ 0.0041, -0.0591, -0.0425,  ...,  0.0129,  0.0220, -0.0307],
         [ 0.0197,  0.0287, -0.0024,  ...,  0.0050,  0.0005, -0.0061],
         [ 0.0196,  0.0286, -0.0025,  ...,  0.0048,  0.0006, -0.0061]],

        [[ 0.0138,  0.0034,  0.0209,  ..., -0.0127, -0.0015,  0.0013],
         [ 0.0137,  0.0034,  0.0209,  ..., -0