In [None]:
import random
import torch

In [None]:
from vllm import _custom_ops as ops
from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything, create_kv_caches_with_random

  from .autonotebook import tqdm as notebook_tqdm
2024-10-18 14:33:59,953	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


### test random

Test palu kernel random inputs.

- Output should have some random values.

In [None]:
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
# MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
MAX_SEQ_LEN = 16
# There may not be enough gpu memory due to large NUM_BLOCKS.
# Reduce NUM_BLOCKS when it happens.
NUM_BLOCKS = 4321  # Arbitrary values for testing
PARTITION_SIZE = 512
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
DTYPES = [torch.half, torch.bfloat16, torch.float
          ] if not is_hip() else [torch.half, torch.bfloat16]
NUM_GEN_SEQS = [7]  # Arbitrary values for testing
NUM_PREFILL_SEQS = [3]  # Arbitrary values for testing
NUM_HEADS = [(40, 40), (64, 8)]  # Arbitrary values for testing

# FlashAttention forward only supports head dimension at most 128
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256]

BLOCK_SIZES = [16, 32]
USE_ALIBI = [False, True]
KV_CACHE_DTYPE = ["auto", "fp8"]
SEEDS = [0]
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]

In [None]:
num_heads = (16, 8)
num_seqs, num_query_heads, head_size = 4, 16, 128
block_size = 32
kv_cache_dtype = "auto"
seed = 42

In [None]:
palu_head_size = head_size // 8

In [None]:
# seed_everything(seed)
dtype = torch.half
device = "cuda:0"
torch.set_default_device(device)
scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
query.uniform_(-scale, scale)

assert num_query_heads % num_kv_heads == 0
num_queries_per_kv = num_query_heads // num_kv_heads
# alibi_slopes = None
# if use_alibi:
#     alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)

seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
seq_lens[-1] = MAX_SEQ_LEN
max_seq_len = max(seq_lens)
seq_lens = torch.tensor(seq_lens, dtype=torch.int)

# Create the block tables.
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables_lst = []
for _ in range(num_seqs):
    block_table = [
        random.randint(0, NUM_BLOCKS - 1)
        for _ in range(max_num_blocks_per_seq)
    ]
    block_tables_lst.append(block_table)

block_tables = torch.tensor(block_tables_lst, dtype=torch.int)

# Create the KV caches.
key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS, block_size, 1,
                                                        num_kv_heads, palu_head_size,
                                                        kv_cache_dtype, dtype, seed,
                                                        device)
key_cache, value_cache = key_caches[0], value_caches[0]

# Using default kv_scale
k_scale = v_scale = 1.0

In [None]:
key_cache.shape, value_cache.shape

(torch.Size([4321, 8, 2, 32, 8]), torch.Size([4321, 8, 16, 32]))

In [None]:
# [num_kv_heads, palu_head_size, head_size]
palu_k_up_proj = torch.ones(num_kv_heads, head_size//8, head_size).to(dtype=dtype, device=device)

In [None]:
assert palu_k_up_proj.is_contiguous()

In [None]:
alibi_slopes = None

In [None]:
output = torch.empty(num_seqs, num_query_heads, palu_head_size); output

tensor([[[1.0000e+00, 1.0000e+00, 1.0000e+00,  ..., 1.0000e+00,
          1.0000e+00, 1.0000e+00],
         [1.0000e+00, 1.0000e+00, 1.0000e+00,  ..., 1.0000e+00,
          1.0000e+00, 1.0000e+00],
         [1.0000e+00, 1.0000e+00, 1.0000e+00,  ..., 1.0000e+00,
          1.0000e+00, 1.0000e+00],
         ...,
         [1.0000e+00, 1.0000e+00, 1.0000e+00,  ..., 1.0000e+00,
          1.0000e+00, 1.0000e+00],
         [1.0000e+00, 1.0000e+00, 1.0000e+00,  ..., 1.0000e+00,
          1.0000e+00, 1.0000e+00],
         [1.0000e+00, 1.0000e+00, 1.0000e+00,  ..., 1.0000e+00,
          1.0000e+00, 1.0000e+00]],

        [[1.0000e+00, 1.0000e+00, 1.0000e+00,  ..., 1.0000e+00,
          1.0000e+00, 1.0000e+00],
         [1.0000e+00, 1.0000e+00, 1.0000e+00,  ..., 1.0000e+00,
          1.0000e+00, 1.0000e+00],
         [1.0000e+00, 1.0000e+00, 1.0000e+00,  ..., 1.0000e+00,
          1.0000e+00, 1.0000e+00],
         ...,
         [1.8750e+00, 1.8750e+00, 1.8750e+00,  ..., 1.0510e-43,
          1.219

In [None]:
# %%timeit -n 10
ops.paged_attention_mlrd_palu_v1(
    output,
    query,
    key_cache,
    palu_k_up_proj,
    value_cache,
    num_kv_heads,
    scale,
    block_tables,
    seq_lens,
    block_size,
    max_seq_len,
    alibi_slopes,
    kv_cache_dtype,
    k_scale,
    v_scale,
)

In [None]:
output

tensor([[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         ...,
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00]],

        [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         ...,
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.000

In [None]:
output.shape

torch.Size([4, 16, 16])

In [None]:
block_tables, seq_lens, query.shape

(tensor([[3668],
         [ 185],
         [2209],
         [2856]], device='cuda:0', dtype=torch.int32),
 tensor([ 9,  3, 10, 16], device='cuda:0', dtype=torch.int32),
 torch.Size([4, 16, 128]))

### test palu paged attn against paged attn

In [None]:
num_seqs = 1
num_blocks = num_seqs
num_heads = 64

num_kv_heads = 8
head_size = 128
palu_head_size = head_size // 4
x = 8
block_size = 32

key_cache = torch.randn(num_blocks, num_kv_heads, palu_head_size//x, block_size, x,
                        device=device, dtype=torch.half)
value_cache = torch.randn(num_blocks, num_kv_heads, palu_head_size, block_size,
                          device=device, dtype=torch.half)

In [None]:
key_cache.shape, value_cache.shape

(torch.Size([1, 8, 4, 32, 8]), torch.Size([1, 8, 32, 32]))

In [None]:
query = torch.randn(num_seqs, num_heads, head_size, device=device, dtype=torch.half)
block_tables = torch.tensor([[0]], device=device, dtype=torch.int32)
seq_lens = torch.tensor([4], device=device, dtype=torch.int32)

In [None]:
palu_k_up_proj = torch.randn(num_kv_heads, palu_head_size, head_size, device=device, dtype=torch.half)

In [None]:
max_seq_len = 4
alibi_slopes = None
kv_cache_dtype = "auto"
k_scale = v_scale = 1.0

In [None]:
# %%timeit -n 10
test_output = torch.empty(num_seqs, num_heads, palu_head_size)
ops.paged_attention_mlrd_palu_v1(
    test_output,
    query,
    key_cache,
    palu_k_up_proj,
    value_cache,
    num_kv_heads,
    scale,
    block_tables,
    seq_lens,
    block_size,
    max_seq_len,
    alibi_slopes,
    kv_cache_dtype,
    k_scale,
    v_scale,
)

In [None]:
test_output.shape

torch.Size([1, 64, 32])

In [None]:
# FIXME: all output is zeros.
torch.all(test_output==0)

tensor(False, device='cuda:0')

In [None]:
test_output

tensor([[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         ...,
         [ 1.2741e-02,  1.4058e-09,  7.8694e-07,  ...,  1.0390e-02,
          -1.0547e-03,  1.4940e-13],
         [ 4.3596e-08,  7.5044e-13,  2.0900e-04,  ..., -3.6811e-03,
          -1.1726e-05, -1.2618e-02],
         [-1.2024e-05, -3.2733e-07,  1.4920e-06,  ...,  6.0168e-08,
          -1.8650e-12, -7.8139e-07]]], device='cuda:0')

In [None]:
# Here we manually up project the 
key_cache_tmp = key_cache.permute(0,1,3,4,2).reshape(num_blocks, num_kv_heads, block_size, palu_head_size)
key_cache_tmp = key_cache_tmp.permute(0,2,1,3).reshape(num_blocks*block_size, num_kv_heads, palu_head_size)

# bmm: num_kv_heads, num_blocks*block_size, palu_head_size @ num_kv_heads, palu_head_size, head_size
# -> num_kv_heads, num_blocks*block_size, head_size
# permute: -> num_blocks*block_size, num_kv_heads, head_size
key_cache_up = torch.bmm(key_cache_tmp.permute(1,0,2), palu_k_up_proj).permute(1,0,2)
key_cache_up = key_cache_up.reshape(num_blocks, block_size, num_kv_heads, head_size//x, x)

# to original shape: num_blocks, num_kv_heads, head_size//x, block_size, x
key_cache_up = key_cache_up.permute(0,2,3,1,4)

In [None]:
# %%timeit -n 10
base_output = torch.empty(num_seqs, num_heads, palu_head_size)
ops.paged_attention_v1(
    base_output,
    query,
    key_cache_up,
    value_cache,
    num_kv_heads,
    scale,
    block_tables,
    seq_lens,
    block_size,
    max_seq_len,
    alibi_slopes,
    kv_cache_dtype,
    k_scale,
    v_scale,
)

In [None]:
base_output

tensor([[[-1.4245e-03, -3.1543e-03, -1.5170e+00,  ...,  3.1979e-01,
          -4.6233e-10,  9.9317e-04],
         [-4.6809e-02, -3.9276e-05,  1.2437e-01,  ..., -3.6342e-03,
           1.2656e-04, -2.8240e+00],
         [ 9.7393e-05, -3.8031e-06, -4.9628e-07,  ..., -2.2337e+01,
           5.0419e-03, -5.3989e-02],
         ...,
         [-1.1876e-05, -7.1635e-03, -4.7026e-05,  ...,  3.5705e-04,
          -3.4625e-05, -7.1183e-01],
         [-5.6910e-05, -1.0234e+01, -4.1433e-04,  ..., -7.5913e-03,
          -5.5052e-06,  1.3659e-09],
         [-1.8383e-06, -5.0255e-03,  9.9629e-03,  ...,  3.5924e-13,
          -1.9391e-07,  5.5802e-06]]], device='cuda:0')

In [None]:
assert torch.equal(test_output, base_output)

AssertionError: 