In [47]:
import torch
import flashinfer
from rotary_embedding_torch import RotaryEmbedding
from model.flashinfer.modeling_qwen2 import Qwen2RotaryEmbedding, apply_rotary_pos_emb
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config

In [53]:
num_qo_heads = 32
num_kv_heads = 32
head_dim = 128
max_num_pages = 128
page_size = 16
# allocate 128MB workspace buffer
workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
    workspace_buffer, "NHD"
)
batch_size = 1
nnz_qo = 16
qo_indptr = torch.tensor(
    [0, nnz_qo], dtype=torch.int32, device="cuda:0"
)
paged_kv_indices = torch.arange(1).int().to("cuda:0")
paged_kv_indptr = torch.tensor(
    [0, 1], dtype=torch.int32, device="cuda:0"
)
# 1 <= paged_kv_last_page_len <= page_size
paged_kv_last_page_len = torch.tensor(
    [16], dtype=torch.int32, device="cuda:0"
)
q_at_layer = torch.randn(nnz_qo, num_qo_heads, head_dim, dtype=torch.bfloat16).to("cuda:0")
kv_cache_at_layer = torch.randn(
    max_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.bfloat16, device="cuda:0"
)

In [81]:
# create auxiliary data structures for batch prefill attention
prefill_wrapper.plan(
    qo_indptr,
    paged_kv_indptr,
    paged_kv_indices,
    paged_kv_last_page_len,
    num_qo_heads,
    num_kv_heads,
    head_dim,
    page_size,
    causal=False,
    pos_encoding_mode='ROPE_LLAMA',
    rope_scale=1.0,
    rope_theta=1000000.0,
    q_data_type=torch.bfloat16,
)
o = prefill_wrapper.run(q_at_layer, kv_cache_at_layer)

In [82]:
rotary_emb = Qwen2RotaryEmbedding(head_dim).to("cuda:0")

In [83]:
q = q_at_layer.unsqueeze(0).transpose(1, 2)
k = kv_cache_at_layer[0, 0].unsqueeze(0).transpose(1, 2)
v = kv_cache_at_layer[0, 1].unsqueeze(0).transpose(1, 2)

In [84]:
# q, k = rotary_emb.rotate_queries_with_cached_keys(q, k)
cos, sin = rotary_emb(v, torch.arange(16, device="cuda:0").view(1, -1))
q, k = apply_rotary_pos_emb(q, k, cos, sin)

In [85]:
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()

attn_output = torch.nn.functional.scaled_dot_product_attention(
    q, k, v,
    is_causal=False,
)
attn_output = attn_output.transpose(1, 2).contiguous()

In [86]:
attn_output[0] - o

tensor([[[ 0.0000, -0.0010,  0.0020,  ...,  0.0039,  0.0000, -0.0002],
         [ 0.0010, -0.0020,  0.0010,  ..., -0.0020, -0.0039,  0.0000],
         [-0.0039,  0.0000,  0.0000,  ...,  0.0000, -0.0010,  0.0010],
         ...,
         [ 0.0000,  0.0000, -0.0005,  ..., -0.0012,  0.0005, -0.0010],
         [ 0.0000,  0.0000,  0.0005,  ..., -0.0015, -0.0005, -0.0005],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[-0.0020,  0.0000,  0.0000,  ...,  0.0007,  0.0020,  0.0010],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0005,  0.0000],
         [ 0.0010, -0.0020,  0.0000,  ..., -0.0005,  0.0000, -0.0039],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0020,  0.0000,  0.0006,  ...,  0.0039,  0.0005, -0.0010],
         [ 0.0020,  0.0000,  0.0039,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0005,  0.0000,  ..., -0.0039,  0.0000,  0.0000],
         [ 0.0010,  0.0000,  0.0000,  ..., -0

# Qwen

In [3]:
from transformers import AutoConfig

config = AutoConfig.from_pretrained("Qwen/Qwen2.5-7B-Instruct")

config.json:   0%|          | 0.00/663 [00:00<?, ?B/s]

In [4]:
config

Qwen2Config {
  "_name_or_path": "Qwen/Qwen2.5-7B-Instruct",
  "architectures": [
    "Qwen2ForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 151643,
  "eos_token_id": 151645,
  "hidden_act": "silu",
  "hidden_size": 3584,
  "initializer_range": 0.02,
  "intermediate_size": 18944,
  "max_position_embeddings": 32768,
  "max_window_layers": 28,
  "model_type": "qwen2",
  "num_attention_heads": 28,
  "num_hidden_layers": 28,
  "num_key_value_heads": 4,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 1000000.0,
  "sliding_window": null,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.47.0",
  "use_cache": true,
  "use_sliding_window": false,
  "vocab_size": 152064
}