# compare original with flashinfer

In [1]:
import os
os.environ["TORCH_CUDA_ARCH_LIST"] = "8.6;8.9"

import torch
import flashinfer

from transformers import AutoConfig, AutoTokenizer

from model.qwen25 import SpeechQwenModel, SpeechQwenForCausalLM
from model.w2v2 import SpeechEncoderW2V2RoPE, W2V2RoPECache, LayerCache
from model.patches.patch_w2v2 import patch_w2v2
from model.patches.patch_qwen25 import patch_qwen25

from model.flashinfer.engine import init_paged_kv_cache
from model.flashinfer.sqwen import SpeechQwenFastModel, SpeechQwenFastForCausalLM

[2025-04-21 19:09:39,288] [INFO] [real_accelerator.py:222:get_accelerator] Setting ds_accelerator to cuda (auto detect)


# flashinfer causal

In [1]:
import os
os.environ["TORCH_CUDA_ARCH_LIST"] = "8.6;8.9"

import torch
import flashinfer

In [2]:
from transformers.models.qwen2.modeling_qwen2 import Qwen2RotaryEmbedding, apply_rotary_pos_emb, repeat_kv
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
qwen_cfg = Qwen2Config.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
rotary_emb = Qwen2RotaryEmbedding(config=qwen_cfg).to("cuda:0")

In [3]:
cqkv = torch.load('cor_qkv.pt')
nqkv = torch.load('new_qkv.pt')

  cqkv = torch.load('cor_qkv.pt')
  nqkv = torch.load('new_qkv.pt')


In [4]:
q, k, v = cqkv

In [5]:
q = q.view(q.size(0), q.size(1), qwen_cfg.num_attention_heads, -1)
k = k.view(k.size(0), k.size(1), qwen_cfg.num_key_value_heads, -1)
v = v.view(v.size(0), v.size(1), qwen_cfg.num_key_value_heads, -1)

In [234]:
q = torch.rand_like(q) * 10
k = torch.rand_like(k) * 10
v = torch.rand_like(v) * 10

In [12]:
qo_len = q.size(1)
kv_len = k.size(1)
num_qo_heads = qwen_cfg.num_attention_heads
num_kv_heads = qwen_cfg.num_key_value_heads
head_dim = qwen_cfg.hidden_size // qwen_cfg.num_attention_heads
o = flashinfer.single_prefill_with_kv_cache(
    q[0], 
    k[0], 
    v[0],
    pos_encoding_mode='ROPE_LLAMA',
    rope_scale=1.0,
    rope_theta=qwen_cfg.rope_theta,
    causal=True
)

2025-04-21 23:51:35,253 - INFO - flashinfer.jit: Loading JIT ops: single_prefill_with_kv_cache_dtype_q_f16_dtype_kv_f16_dtype_o_f16_head_dim_128_posenc_1_use_swa_False_use_logits_cap_False_f16qk_False
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
2025-04-21 23:51:48,857 - INFO - flashinfer.jit: Finished loading JIT ops: single_prefill_with_kv_cache_dtype_q_f16_dtype_kv_f16_dtype_o_f16_head_dim_128_posenc_1_use_swa_False_use_logits_cap_False_f16qk_False


In [13]:
q_ = q.transpose(1, 2)
k_ = k.transpose(1, 2)
v_ = v.transpose(1, 2)

k_ = repeat_kv(k_, num_qo_heads // num_kv_heads)
v_ = repeat_kv(v_, num_qo_heads // num_kv_heads)

position_ids = torch.arange(kv_len, device="cuda:0").view(1, -1)
cos, sin = rotary_emb(k_, position_ids)
k_, _ = apply_rotary_pos_emb(k_, k_, cos, sin)
q_position_ids = torch.arange(kv_len - qo_len, kv_len, device="cuda:0").view(1, -1)
q_cos, q_sin = rotary_emb(q_, q_position_ids)
q_, _ = apply_rotary_pos_emb(q_, q_, q_cos, q_sin)

mask = torch.tril(
    torch.full((qo_len, kv_len), True, device="cuda:0"),
    diagonal=(kv_len - qo_len),
).unsqueeze(0)

attn_output = torch.nn.functional.scaled_dot_product_attention(
    q_, k_, v_,
    attn_mask=mask,
)
attn_output = attn_output.transpose(1, 2).contiguous()

In [16]:
torch.allclose(o, attn_output, atol=1e-2, rtol=1e-2)

True

In [15]:
(o - attn_output).abs().mean()

tensor(1.9372e-05, device='cuda:0', dtype=torch.float16)

In [6]:
flashinfer.__version__

'0.2.5+cu124torch2.5'

## LLM

In [2]:
patch_qwen25()

In [3]:
dtype = torch.bfloat16

In [4]:
model = SpeechQwenForCausalLM.from_pretrained(
    "/data/user_data/siqiouya/runs/pretrained/qwen2.5-7b-instruct",
    torch_dtype=dtype,
    attn_implementation="eager",
    device_map='cuda',
).eval()

You are using a model of type qwen2 to instantiate a model of type SpeechQwen. This is not supported for all configurations of models and can yield errors.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [5]:
model_flash = SpeechQwenFastForCausalLM.from_pretrained(
    "/data/user_data/siqiouya/runs/pretrained/qwen2.5-7b-instruct",
    torch_dtype=dtype,
    attn_implementation="eager",
    device_map='cuda',
).eval()

You are using a model of type qwen2 to instantiate a model of type SpeechQwenFast. This is not supported for all configurations of models and can yield errors.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [6]:
tokenizer = AutoTokenizer.from_pretrained(
    "/data/user_data/siqiouya/runs/pretrained/qwen2.5-7b-instruct",
    padding_side="right",
    use_fast=False,
)
tokenizer.pad_token = "<|finetune_right_pad_id|>"

In [7]:
inputs = tokenizer("Hello, world! Hello, world! Hello, world! Hello, world!", return_tensors="pt").to("cuda")

In [8]:
inputs_embeds = model.get_input_embeddings()(inputs.input_ids)

In [9]:
speech_pagetable, llm_prefill_pagetable, llm_decode_pagetable = \
    init_paged_kv_cache(
        1,
        576,
        12,
        16,
        128,
        1000,
        model.config.num_hidden_layers,
        model.config.num_key_value_heads,
        model.config.hidden_size // model.config.num_attention_heads,
        dtype=dtype,
        device_prefill='cuda:0',
        device_decode='cuda:0'
    )

In [10]:
requests = [
    {
        "input_ids": inputs.input_ids.view(-1),
        "cache": None
    }
]

In [11]:
output = super(SpeechQwenModel, model.model).forward(
    inputs_embeds=inputs_embeds,
    output_hidden_states=True,
)

In [12]:
output_flash = super(SpeechQwenFastModel, model_flash.model).forward(
    inputs_embeds=inputs_embeds.view(-1, inputs_embeds.size(-1)),
    requests=requests,
    pagetable=llm_prefill_pagetable,
    output_hidden_states=True,
)

2025-04-21 19:09:48,594 - INFO - flashinfer.jit: Loading JIT ops: batch_prefill_with_kv_cache_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_1_use_swa_False_use_logits_cap_False_f16qk_False
2025-04-21 19:09:48,791 - INFO - flashinfer.jit: Finished loading JIT ops: batch_prefill_with_kv_cache_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_1_use_swa_False_use_logits_cap_False_f16qk_False


In [24]:
layer_idx = -2
(output['hidden_states'][layer_idx] - output_flash[-1][layer_idx]).abs().mean()

tensor(0.2109, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)

In [16]:
layer_idx = -1
output['hidden_states'][layer_idx], output_flash[-1][layer_idx]

(tensor([[[ 0.6797, -1.2578, -2.1406,  ..., -2.7500,  0.3086, -2.0156],
          [-0.7500, -0.7891, -1.1250,  ..., -2.7031,  1.1250, -4.5312],
          [-2.5156,  1.4766, -1.0859,  ..., -0.7852,  0.5820, -0.2217],
          ...,
          [-1.1953,  0.6914, -0.1172,  ...,  1.9922,  0.2373, -1.2578],
          [-1.9219, -1.2656, -0.1934,  ..., -1.0078, -2.6250, -0.0757],
          [-1.7500,  0.3066, -0.7812,  ..., -0.5000, -0.5625, -0.5625]]],
        device='cuda:0', dtype=torch.bfloat16, grad_fn=<MulBackward0>),
 tensor([[ -0.1406,  -7.6250,  -3.8906,  ...,  -8.5625,  -0.8203, -12.4375],
         [ -1.1562,   0.1250,  -5.8125,  ...,  -7.5312,   5.0000, -13.9375],
         [ -6.7812,   5.3438,  -5.7188,  ...,  -1.2031,   2.8125,  -0.9375],
         ...,
         [ -5.5312,   5.0312,  -3.9688,  ...,   9.4375,  -0.1719,  -4.3125],
         [-10.5625,  -5.6250,  -7.1875,  ...,  -4.1250,  -8.4375,  -1.8906],
         [ -5.1875,   2.3750,  -4.5938,  ...,  -0.0312,  -0.3281,  -1.4688]],
  

In [None]:
output = model_flash.model.layers[0].self_attn.forward_vanilla(
    inputs_embeds, 
    position_ids=torch.arange(4).unsqueeze(0).to("cuda"),
)[0]



In [11]:
qo_indptr = torch.tensor([0, 4], dtype=torch.int32, device="cuda:0")
paged_kv_indptr = torch.tensor([0, 1], dtype=torch.int32, device="cuda:0")
paged_kv_indices = torch.arange(1, dtype=torch.int32, device="cuda:0")
paged_kv_last_page_len = torch.tensor([4], dtype=torch.int32, device="cuda:0")

In [12]:
output_flash = model_flash.model.layers[0].self_attn(
    inputs_embeds, 
    qo_indptr,
    paged_kv_indptr,
    paged_kv_indices,
    paged_kv_last_page_len,
    llm_prefill_pagetable,
)[0]

AttributeError: 'BatchPrefillWithPagedKVCacheWrapper' object has no attribute '_cached_q_data_type'

In [23]:
((output - output_flash).abs() / output.abs())

tensor([[[0.0045, 0.0000, 0.0247,  ..., 0.0889, 0.4297, 0.1045],
         [0.4512, 0.0135, 0.0408,  ..., 0.0381, 0.1436, 0.0732],
         [0.3965, 0.0240, 0.0203,  ..., 0.0055, 2.2500, 0.0454],
         [0.1289, 0.0166, 0.0047,  ..., 0.0000, 0.0000, 0.0208]]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<DivBackward0>)

In [13]:
llm_prefill_pagetable.wrapper._causal

AttributeError: 'BatchPrefillWithPagedKVCacheWrapper' object has no attribute '_causal'

## speech encoder

In [2]:
qwen_cfg = AutoConfig.from_pretrained("Qwen/Qwen2.5-7B-Instruct")

In [3]:
speech_batch = torch.rand(1, 15759).to(device='cuda', dtype=torch.bfloat16)

In [None]:
patch_w2v2(True)
speech_encoder_args = [
    "/compute/babel-4-1/siqiouya/wav2_vec_vox_960h_pl.pt",
    True,
    "[(1024,2,2)] * 2",
    
    48,
    572,
    4096,
    None,
    True,
    False,
]
speech_encoder = SpeechEncoderW2V2RoPE(*speech_encoder_args).to(device='cuda', dtype=torch.bfloat16).eval()

In [5]:
with torch.no_grad():
    speech_encoder.set_blocksize(48)
    cache = W2V2RoPECache(
        max_steps=speech_encoder.max_cache_size,
        layers=[LayerCache() for _ in range(speech_encoder.s_layer)]
    )
    output = speech_encoder.speech_encoder.extract_features(speech_batch, cache=cache)

In [None]:
output['x'].size()

In [None]:
speech_encoder_args = [
    "/compute/babel-4-1/siqiouya/wav2_vec_vox_960h_pl.pt",
    True,
    "[(1024,2,2)] * 2",
    
    48,
    572,
    4096,
    None,
    True,
    True,
]
speech_encoder_flash = SpeechEncoderW2V2RoPE(*speech_encoder_args).to(device='cuda', dtype=torch.bfloat16).eval()

In [8]:
speech_cfg = speech_encoder_flash.speech_encoder.cfg
speech_pagetable, llm_prefill_pagetable, llm_decode_pagetable = \
    init_paged_kv_cache(
        1,
        576,
        speech_cfg.encoder_layers,
        speech_cfg.encoder_attention_heads,
        speech_cfg.encoder_embed_dim // speech_cfg.encoder_attention_heads,
        1000,
        qwen_cfg.num_hidden_layers,
        qwen_cfg.num_key_value_heads,
        qwen_cfg.hidden_size // qwen_cfg.num_attention_heads,
        device_prefill='cuda:0',
        device_decode='cuda:0'
    )

In [9]:
requests = [
    {
        "speech": speech_batch.view(-1),
        "blocksize": 48,
        "cache": None
    }
]

In [None]:
with torch.no_grad():
    output_flash = speech_encoder_flash.speech_encoder(requests, speech_pagetable)

In [12]:
layer_results = output['layer_results']
layer_results_flash = output_flash[-1]

In [None]:
layer_idx = 6
(layer_results[layer_idx][0] - layer_results_flash[layer_idx][0]).abs() / layer_results[layer_idx][0].abs()

# flashinfer

In [1]:
import os
os.environ["TORCH_CUDA_ARCH_LIST"] = "8.6;8.9"

import torch
import flashinfer
from rotary_embedding_torch import RotaryEmbedding
from model.flashinfer.modeling_qwen2 import Qwen2RotaryEmbedding, apply_rotary_pos_emb
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config

[2025-04-21 21:37:27,281] [INFO] [real_accelerator.py:222:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [12]:
num_qo_heads = 32
num_kv_heads = 32
head_dim = 128
max_num_pages = 128
page_size = 16
# allocate 128MB workspace buffer
workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
    workspace_buffer, "NHD"
)
batch_size = 1
nnz_qo = 128
qo_indptr = torch.tensor(
    [0, nnz_qo], dtype=torch.int32, device="cuda:0"
)
paged_kv_indices = torch.arange(nnz_qo // page_size).int().to("cuda:0")
paged_kv_indptr = torch.tensor(
    [0, nnz_qo // page_size], dtype=torch.int32, device="cuda:0"
)
# 1 <= paged_kv_last_page_len <= page_size
paged_kv_last_page_len = torch.tensor(
    [(nnz_qo - 1) % page_size + 1], dtype=torch.int32, device="cuda:0"
)
q_at_layer = torch.randn(nnz_qo, num_qo_heads, head_dim, dtype=torch.bfloat16).to("cuda:0")
kv_cache_at_layer = torch.randn(
    max_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.bfloat16, device="cuda:0"
)

In [13]:
qwen_cfg = Qwen2Config.from_pretrained("Qwen/Qwen2.5-7B-Instruct")

In [14]:
# create auxiliary data structures for batch prefill attention
prefill_wrapper.plan(
    qo_indptr,
    paged_kv_indptr,
    paged_kv_indices,
    paged_kv_last_page_len,
    num_qo_heads,
    num_kv_heads,
    head_dim,
    page_size,
    causal=True,
    pos_encoding_mode='ROPE_LLAMA',
    rope_scale=1.0,
    rope_theta=qwen_cfg.rope_theta,
    q_data_type=torch.bfloat16,
)
o = prefill_wrapper.run(q_at_layer, kv_cache_at_layer)

In [15]:
rotary_emb = Qwen2RotaryEmbedding(config=qwen_cfg).to("cuda:0")

In [23]:
kv_cache_at_layer[:nnz_qo // page_size, 0].size()

torch.Size([8, 16, 32, 128])

In [28]:
q = q_at_layer.unsqueeze(0).transpose(1, 2)
k = kv_cache_at_layer[:nnz_qo // page_size, 0].reshape(-1, num_kv_heads, head_dim).unsqueeze(0).transpose(1, 2)
v = kv_cache_at_layer[:nnz_qo // page_size, 1].reshape(-1, num_kv_heads, head_dim).unsqueeze(0).transpose(1, 2)

In [30]:
# q, k = rotary_emb.rotate_queries_with_cached_keys(q, k)
cos, sin = rotary_emb(v, torch.arange(nnz_qo, device="cuda:0").view(1, -1))
q, k = apply_rotary_pos_emb(q, k, cos, sin)

In [38]:
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()

attn_output = torch.nn.functional.scaled_dot_product_attention(
    q, k, v,
    is_causal=True,
)
attn_output = attn_output.transpose(1, 2).contiguous()

In [39]:
(attn_output[0] - o).abs().mean()

tensor(0.0006, device='cuda:0', dtype=torch.bfloat16)

# Qwen

In [None]:
from transformers import AutoConfig

config = AutoConfig.from_pretrained("Qwen/Qwen2.5-7B-Instruct")

In [None]:
config