In [1]:
import torch
import flashinfer

In [2]:
num_layers = 32
num_qo_heads = 64
num_kv_heads = 16
head_dim = 128
max_num_pages = 128
page_size = 16
# allocate 128MB workspace buffer
workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
    workspace_buffer, "NHD"
)
batch_size = 7
nnz_qo = 100
qo_indptr = torch.tensor(
    [0, 33, 44, 55, 66, 77, 88, nnz_qo], dtype=torch.int32, device="cuda:0"
)
paged_kv_indices = torch.arange(max_num_pages).int().to("cuda:0")
paged_kv_indptr = torch.tensor(
    [0, 17, 29, 44, 48, 66, 100, 128], dtype=torch.int32, device="cuda:0"
)
# 1 <= paged_kv_last_page_len <= page_size
paged_kv_last_page_len = torch.tensor(
    [1, 7, 14, 4, 3, 1, 16], dtype=torch.int32, device="cuda:0"
)
q_at_layer = torch.randn(num_layers, nnz_qo, num_qo_heads, head_dim, dtype=torch.bfloat16).to("cuda:0")
kv_cache_at_layer = torch.randn(
    num_layers, max_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.bfloat16, device="cuda:0"
)

In [3]:
# create auxiliary data structures for batch prefill attention
prefill_wrapper.plan(
    qo_indptr,
    paged_kv_indptr,
    paged_kv_indices,
    paged_kv_last_page_len,
    num_qo_heads,
    num_kv_heads,
    head_dim,
    page_size,
    causal=False,
    pos_encoding_mode='ROPE_LLAMA',
    rope_scale=1.0,
    rope_theta=1e4,
    q_data_type=torch.bfloat16,
)
outputs = []
for i in range(num_layers):
    q = q_at_layer[i]
    kv_cache = kv_cache_at_layer[i]
    # compute batch prefill attention, reuse auxiliary data structures
    o = prefill_wrapper.run(q, kv_cache)
    outputs.append(o)

outputs[0].shape

2025-04-16 14:32:59,062 - INFO - flashinfer.jit: Loading JIT ops: batch_prefill_with_kv_cache_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_1_use_swa_False_use_logits_cap_False_f16qk_False
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
2025-04-16 14:33:51,675 - INFO - flashinfer.jit: Finished loading JIT ops: batch_prefill_with_kv_cache_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_1_use_swa_False_use_logits_cap_False_f16qk_False


torch.Size([100, 64, 128])

In [4]:
outputs[0]

tensor([[[ 0.0903,  0.1738,  0.0884,  ..., -0.0850,  0.0481, -0.0009],
         [ 0.0295,  0.0403,  0.1914,  ...,  0.0554,  0.1206, -0.0459],
         [-0.2500,  0.0981,  0.2559,  ...,  0.1895,  0.1216,  0.1406],
         ...,
         [-0.0364, -0.0089, -0.0583,  ...,  0.1318, -0.0413, -0.1611],
         [-0.0083, -0.1572, -0.0869,  ...,  0.0510,  0.0256, -0.0977],
         [-0.0251,  0.0055,  0.0630,  ...,  0.1572,  0.0187, -0.1367]],

        [[ 0.0708, -0.0203,  0.0625,  ...,  0.1157,  0.1455, -0.0522],
         [ 0.0669,  0.2061,  0.0796,  ...,  0.0442, -0.1318,  0.0918],
         [-0.0156,  0.0757,  0.0251,  ...,  0.1108,  0.0103, -0.0420],
         ...,
         [-0.1934, -0.0188, -0.0586,  ...,  0.0991,  0.1035, -0.0718],
         [-0.1128,  0.0354, -0.0425,  ...,  0.0942,  0.1777, -0.1216],
         [-0.0325, -0.0654,  0.0015,  ...,  0.0293, -0.0459, -0.1436]],

        [[-0.1055,  0.1387,  0.0439,  ...,  0.1680,  0.0525,  0.1602],
         [-0.1260,  0.0864,  0.2109,  ..., -0

# Qwen

In [3]:
from transformers import AutoConfig

config = AutoConfig.from_pretrained("Qwen/Qwen2.5-7B-Instruct")

config.json:   0%|          | 0.00/663 [00:00<?, ?B/s]

In [4]:
config

Qwen2Config {
  "_name_or_path": "Qwen/Qwen2.5-7B-Instruct",
  "architectures": [
    "Qwen2ForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 151643,
  "eos_token_id": 151645,
  "hidden_act": "silu",
  "hidden_size": 3584,
  "initializer_range": 0.02,
  "intermediate_size": 18944,
  "max_position_embeddings": 32768,
  "max_window_layers": 28,
  "model_type": "qwen2",
  "num_attention_heads": 28,
  "num_hidden_layers": 28,
  "num_key_value_heads": 4,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 1000000.0,
  "sliding_window": null,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.47.0",
  "use_cache": true,
  "use_sliding_window": false,
  "vocab_size": 152064
}