Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
856 changes: 856 additions & 0 deletions csrc/gpu/cache_kernels.cu

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions csrc/gpu/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,14 @@ class PDTraits<paddle::DataType::BFLOAT16> {
typedef paddle::bfloat16 data_t;
};

#ifndef PADDLE_WITH_HIP
template <>
class PDTraits<paddle::DataType::FLOAT8_E4M3FN> {
public:
typedef __nv_fp8_e4m3 DataType;
typedef paddle::float8_e4m3fn data_t;
};
#endif

template <typename T, int Size>
struct alignas(sizeof(T) * Size) AlignedVector {
Expand Down
3 changes: 2 additions & 1 deletion csrc/setup_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def get_gencode_flags():
"./gpu/speculate_decoding_kernels/ngram_match.cc",
"./gpu/speculate_decoding_kernels/speculate_save_output.cc",
"./gpu/speculate_decoding_kernels/speculate_get_output.cc",
"./gpu/cache_kernels.cu",
]
sources += find_end_files("./gpu/speculate_decoding_kernels", ".cu")
sources += find_end_files("./gpu/moe/fused_moe/cutlass_kernels/moe_gemm/", ".cu")
Expand All @@ -126,7 +127,7 @@ def get_gencode_flags():
nvcc_compile_args = gencode_flags
update_git_submodule()
nvcc_compile_args += [
"-O3",
"-O1",
"-DNDEBUG",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
Expand Down
4 changes: 4 additions & 0 deletions csrc/setup_hip.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def update_git_submodule():
"./gpu/flash_attn_bwd.cc",
"./gpu/update_inputs_v2.cu",
"./gpu/set_preids_token_penalty_multi_scores.cu",
"./gpu/get_position_ids_and_mask_encoder_batch.cu",
"./gpu/fused_rotary_position_encoding.cu",
"./gpu/cache_kernels.cu",
],
extra_compile_args={
"cxx": ["-O3"],
Expand All @@ -67,6 +70,7 @@ def update_git_submodule():
"-U__HIP_NO_BFLOAT16_CONVERSIONS__",
"-U__HIP_NO_BFLOAT162_OPERATORS__",
"-U__HIP_NO_BFLOAT162_CONVERSIONS__",
"-Igpu",
"-Ithird_party/cutlass/include",
"-Ithird_party/nlohmann_json/single_include",
],
Expand Down
137 changes: 125 additions & 12 deletions paddlenlp/experimental/transformers/fused_transformer_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3551,18 +3551,131 @@ def compute_mla_absorb(
fmha_out_prefill = paddle.nn.functional.pad(fmha_out_prefill, (0, 192 - 128))
fmha_out_prefill = paddle.squeeze(fmha_out_prefill, axis=0)
else:
fmha_out_prefill = paddle.nn.functional.flash_attention.flash_attn_unpadded(
query,
key,
value,
kwargs.get("cu_seqlens_q", None),
kwargs.get("cu_seqlens_k", None),
kwargs.get("max_enc_len_this_time", -1),
kwargs.get("max_enc_len_this_time", -1),
self.softmax_scale,
causal=True,
training=False,
)[0]
# if paddle.is_compiled_with_rocm():
if True
from paddlenlp.ops.triton_ops.paged_attn import PagedAttention, compute_slot_mapping, generate_slot_mapping
"""
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
block_tables = [batch_size, max_blocks_per_seq]
"""
query_lens = kwargs.get("seq_lens_this_time", None).numpy().tolist()
query_lens = [item for sublist in query_lens for item in sublist]
# seq_lens = kwargs.get("seq_lens", None)
seq_lens = kwargs.get("seq_lens_this_time", None).numpy().tolist() # seq_lens_this_time
seq_lens = [item for sublist in seq_lens for item in sublist]
seq_lens_tensor = kwargs.get("seq_lens_this_time", None)

block_tables = kwargs.get("block_tables", None)
batch_size = block_tables.shape[0]
block_size = kwargs.get("block_size", 64)

# max_query_len = max(query_lens)
max_query_len = kwargs.get("max_input_length", None)
query_start_loc = kwargs.get("cu_seqlens_q", None)
# query_start_loc = paddle.to_tensor(list(accumulate(query_lens, initial=0)), dtype=paddle.int64)
# seq_start_loc = paddle.to_tensor(list(accumulate(seq_lens, initial=0)), dtype=paddle.int64)
# seq_lens_tensor = paddle.to_tensor(seq_lens, dtype=paddle.int64)
# context_lens_tensor = paddle.zeros([batch_size], dtype='int64')
context_lens_tensor = paddle.full([batch_size], seq_lens[0] - 1, dtype='int64')
alibi_slopes = None
sliding_window = None
kv_cache_dtype = "auto"
k_scale = paddle.to_tensor(1.0, dtype="bfloat16")
v_scale = paddle.to_tensor(1.0, dtype="bfloat16")

# slot_mapping = compute_slot_mapping(
# seq_lens=seq_lens,
# query_lens=query_lens,
# context_lens=context_lens_tensor.cpu().numpy().tolist(),
# block_tables=block_tables,
# block_size=block_size,
# sliding_window=sliding_window
# )
# slot_mapping_tensor = paddle.to_tensor(slot_mapping, dtype=paddle.int64)
# print(f"slot_mapping:{slot_mapping_tensor}")

slot_mapping_test = generate_slot_mapping(block_tables, seq_lens, block_size)

# query = query.reshape([-1, self.num_heads, self.head_dim])
# if key is not None:
# key = key.reshape([-1, self.num_heads, self.head_dim])
# value = value.reshape([-1, self.num_heads, self.head_dim])
# else:
# assert value is None

# if isinstance(caches, list):
# kv_cache_tensor = paddle.concat(caches)
# key_cache, value_cache = PagedAttention.split_kv_cache(kv_cache_tensor, self.kv_num_heads, self.head_dim)
# print(f"latent_cache:{latent_cache}")
num_blocks = kwargs.get("kv_num_blocks", None)
block_total_size = block_size * self.kv_num_heads * self.head_dim
kv_cache_tensor = paddle.zeros([2, num_blocks, block_total_size], dtype='bfloat16')
key_cache, value_cache = PagedAttention.split_kv_cache(kv_cache_tensor, self.kv_num_heads, self.head_dim)
# print(f"before key_cache:{key_cache}")
# print(f"before value_cache:{value_cache}")

PagedAttention.write_to_paged_cache(
key,
value,
key_cache,
value_cache,
slot_mapping_test,
kv_cache_dtype,
k_scale,
v_scale,
)
fmha_out_prefill = PagedAttention.forward_prefix(
query=query,
key=key,
value=value,
kv_cache_dtype=kv_cache_dtype,
key_cache=key_cache,
value_cache=value_cache,
block_tables=block_tables,
query_start_loc=query_start_loc,
seq_lens_tensor=seq_lens_tensor,
context_lens=context_lens_tensor,
max_query_len=max_query_len,
alibi_slopes=alibi_slopes,
sliding_window=sliding_window,
k_scale=k_scale,
v_scale=v_scale,
)
# else:
fmha_out_prefill_test = paddle.nn.functional.flash_attention.flash_attn_unpadded(
query,
key,
value,
kwargs.get("cu_seqlens_q", None),
kwargs.get("cu_seqlens_k", None),
kwargs.get("max_enc_len_this_time", -1),
kwargs.get("max_enc_len_this_time", -1),
self.softmax_scale,
causal=True,
training=False,
)[0]

print(f"fmha_out_prefill shape:{fmha_out_prefill.shape}")
print(f"fmha_out_prefill_test shape:{fmha_out_prefill_test.shape}")

tolerance = 1e-5
total_elements = paddle.numel(fmha_out_prefill)

diff_prefill_test = paddle.abs(fmha_out_prefill - fmha_out_prefill_test)

percent_diff_test = (paddle.sum(diff_prefill_test > tolerance) / total_elements) * 100

max_diff_test = paddle.max(diff_prefill_test)
min_diff_test = paddle.min(diff_prefill_test)

print(f"fmha_out_prefill fmha_out_prefill_test diff: {percent_diff_test.numpy()}%")

print(f"fmha_out_prefill fmha_out_prefill_test max diff: {max_diff_test.numpy()}")
print(f"fmha_out_prefill fmha_out_prefill_test min diff: {min_diff_test.numpy()}")
print("-" * 50)

fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_heads, self.config.mla_config.qk_head_dim])
fmha_out_prefill = fmha_out_prefill[:, :, : self.config.mla_config.v_head_dim]
Expand Down
200 changes: 200 additions & 0 deletions paddlenlp/ops/triton_ops/paged_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/attention/ops/paged_attn.py

from typing import List, Optional, Tuple

Check warning on line 17 in paddlenlp/ops/triton_ops/paged_attn.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/paged_attn.py#L17

Added line #L17 was not covered by tests

import paddle
import numpy as np
from paddlenlp.ops.triton_ops.prefix_prefill import context_attention_fwd

Check warning on line 21 in paddlenlp/ops/triton_ops/paged_attn.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/paged_attn.py#L19-L21

Added lines #L19 - L21 were not covered by tests


class PagedAttention:
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [32, 64, 80, 96, 112, 120, 128, 192, 256]

Check warning on line 27 in paddlenlp/ops/triton_ops/paged_attn.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/paged_attn.py#L24-L27

Added lines #L24 - L27 were not covered by tests

@staticmethod
def get_kv_cache_shape(

Check warning on line 30 in paddlenlp/ops/triton_ops/paged_attn.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/paged_attn.py#L29-L30

Added lines #L29 - L30 were not covered by tests
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (2, num_blocks, block_size * num_kv_heads * head_size)

Check warning on line 36 in paddlenlp/ops/triton_ops/paged_attn.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/paged_attn.py#L36

Added line #L36 was not covered by tests

@staticmethod
def split_kv_cache(

Check warning on line 39 in paddlenlp/ops/triton_ops/paged_attn.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/paged_attn.py#L38-L39

Added lines #L38 - L39 were not covered by tests
kv_cache: paddle.Tensor,
num_kv_heads: int,
head_size: int,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
x = 16 // kv_cache.element_size()
num_blocks = kv_cache.shape[1]

Check warning on line 45 in paddlenlp/ops/triton_ops/paged_attn.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/paged_attn.py#L44-L45

Added lines #L44 - L45 were not covered by tests

key_cache = kv_cache[0]
key_cache = key_cache.reshape([num_blocks, num_kv_heads, head_size // x, -1, x])

Check warning on line 48 in paddlenlp/ops/triton_ops/paged_attn.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/paged_attn.py#L47-L48

Added lines #L47 - L48 were not covered by tests

value_cache = kv_cache[1]
value_cache = value_cache.reshape([num_blocks, num_kv_heads, head_size, -1])
return key_cache, value_cache

Check warning on line 52 in paddlenlp/ops/triton_ops/paged_attn.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/paged_attn.py#L50-L52

Added lines #L50 - L52 were not covered by tests

@staticmethod
def write_to_paged_cache(

Check warning on line 55 in paddlenlp/ops/triton_ops/paged_attn.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/paged_attn.py#L54-L55

Added lines #L54 - L55 were not covered by tests
key: paddle.Tensor,
value: paddle.Tensor,
key_cache: paddle.Tensor,
value_cache: paddle.Tensor,
slot_mapping: paddle.Tensor,
kv_cache_dtype: str,
k_scale: paddle.Tensor,
v_scale: paddle.Tensor,
) -> None:
from paddlenlp_ops import reshape_and_cache
reshape_and_cache(

Check warning on line 66 in paddlenlp/ops/triton_ops/paged_attn.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/paged_attn.py#L65-L66

Added lines #L65 - L66 were not covered by tests
key,
value,
key_cache,
value_cache,
slot_mapping.flatten(),
k_scale,
v_scale,
kv_cache_dtype
)
@staticmethod
def forward_prefix(

Check warning on line 77 in paddlenlp/ops/triton_ops/paged_attn.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/paged_attn.py#L76-L77

Added lines #L76 - L77 were not covered by tests
query,
key,
value,
kv_cache_dtype: str,
key_cache,
value_cache,
block_tables,
query_start_loc,
seq_lens_tensor,
context_lens,
max_query_len: int,
alibi_slopes,
sliding_window: Optional[int],
k_scale,
v_scale,
):
output = paddle.empty_like(query)
context_attention_fwd(

Check warning on line 95 in paddlenlp/ops/triton_ops/paged_attn.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/paged_attn.py#L94-L95

Added lines #L94 - L95 were not covered by tests
query,
key,
value,
output,
kv_cache_dtype,
key_cache,
value_cache,
block_tables,
# query_start_loc is (batch_size + 1,)
query_start_loc,
seq_lens_tensor,
context_lens,
max_query_len,
k_scale,
v_scale,
alibi_slopes,
sliding_window,
)
return output

Check warning on line 114 in paddlenlp/ops/triton_ops/paged_attn.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/paged_attn.py#L114

Added line #L114 was not covered by tests

def compute_slot_mapping(

Check warning on line 116 in paddlenlp/ops/triton_ops/paged_attn.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/paged_attn.py#L116

Added line #L116 was not covered by tests
seq_lens: List[int],
query_lens: List[int],
context_lens: List[int],
block_tables: List[List[int]],
block_size: int,
sliding_window: int
) -> List[int]:
PAD_SLOT_ID = -1
_COMPUTE_SLOT_MAPPING_NUMPY_NUMEL=256
slot_mapping = []

Check warning on line 126 in paddlenlp/ops/triton_ops/paged_attn.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/paged_attn.py#L124-L126

Added lines #L124 - L126 were not covered by tests

for i in range(len(seq_lens)):
seq_len = seq_lens[i]
query_len = query_lens[i]
context_len = context_lens[i]
if i < len(block_tables):
block_table = block_tables[i]

Check warning on line 133 in paddlenlp/ops/triton_ops/paged_attn.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/paged_attn.py#L128-L133

Added lines #L128 - L133 were not covered by tests
else:
slot_mapping.append(PAD_SLOT_ID)
continue

Check warning on line 136 in paddlenlp/ops/triton_ops/paged_attn.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/paged_attn.py#L135-L136

Added lines #L135 - L136 were not covered by tests

is_profile_run = block_table is None
if is_profile_run:
slot_mapping.extend([PAD_SLOT_ID] * seq_len)
continue

Check warning on line 141 in paddlenlp/ops/triton_ops/paged_attn.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/paged_attn.py#L138-L141

Added lines #L138 - L141 were not covered by tests

is_prompt = query_len > 1

Check warning on line 143 in paddlenlp/ops/triton_ops/paged_attn.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/paged_attn.py#L143

Added line #L143 was not covered by tests

start_idx = 0
if is_prompt and sliding_window is not None:
start_idx = max(0, query_len - sliding_window)

Check warning on line 147 in paddlenlp/ops/triton_ops/paged_attn.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/paged_attn.py#L145-L147

Added lines #L145 - L147 were not covered by tests

padding_mask_len = max(0, start_idx - context_len)
slot_mapping.extend([PAD_SLOT_ID] * padding_mask_len)

Check warning on line 150 in paddlenlp/ops/triton_ops/paged_attn.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/paged_attn.py#L149-L150

Added lines #L149 - L150 were not covered by tests

range_start = max(start_idx, context_len)
range_end = seq_len
numel = range_end - range_start

Check warning on line 154 in paddlenlp/ops/triton_ops/paged_attn.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/paged_attn.py#L152-L154

Added lines #L152 - L154 were not covered by tests

if numel <= 0:
continue

Check warning on line 157 in paddlenlp/ops/triton_ops/paged_attn.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/paged_attn.py#L156-L157

Added lines #L156 - L157 were not covered by tests

if numel < _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL:
for j in range(range_start, range_end):
block_idx = j // block_size
if block_idx >= len(block_table):
slot_mapping.append(PAD_SLOT_ID)

Check warning on line 163 in paddlenlp/ops/triton_ops/paged_attn.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/paged_attn.py#L159-L163

Added lines #L159 - L163 were not covered by tests
else:
block_number = block_table[block_idx]
slot = block_number * block_size + (j % block_size)
slot_mapping.append(slot)

Check warning on line 167 in paddlenlp/ops/triton_ops/paged_attn.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/paged_attn.py#L165-L167

Added lines #L165 - L167 were not covered by tests
else:
j_indices = np.arange(range_start, range_end)
block_indices = j_indices // block_size
valid_mask = block_indices < len(block_table)

Check warning on line 171 in paddlenlp/ops/triton_ops/paged_attn.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/paged_attn.py#L169-L171

Added lines #L169 - L171 were not covered by tests

block_numbers = np.where(

Check warning on line 173 in paddlenlp/ops/triton_ops/paged_attn.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/paged_attn.py#L173

Added line #L173 was not covered by tests
valid_mask,
np.array(block_table)[block_indices],
-1
)
offsets = j_indices % block_size
slots = block_numbers * block_size + offsets
slot_mapping.extend(slots.astype(np.int64).tolist())

Check warning on line 180 in paddlenlp/ops/triton_ops/paged_attn.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/paged_attn.py#L178-L180

Added lines #L178 - L180 were not covered by tests

return slot_mapping

Check warning on line 182 in paddlenlp/ops/triton_ops/paged_attn.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/paged_attn.py#L182

Added line #L182 was not covered by tests

def generate_slot_mapping(

Check warning on line 184 in paddlenlp/ops/triton_ops/paged_attn.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/paged_attn.py#L184

Added line #L184 was not covered by tests
block_tables: paddle.Tensor, # [bsz, max_blocks_per_seq]
seq_lens: List[int],
block_size: int
) -> paddle.Tensor:
bsz, max_blocks_per_seq = block_tables.shape
slot_mapping = []
for bi in range(bsz):
seq_len = seq_lens[bi]
blocks = block_tables[bi]
for pos in range(seq_len):
block_idx = pos // block_size
physical_block = blocks[block_idx] if block_idx < len(blocks) else -1
offset = pos % block_size
slot = physical_block * block_size + offset if physical_block >=0 else -1
slot_mapping.append(slot)
return paddle.to_tensor(slot_mapping, dtype=paddle.int64)

Check warning on line 200 in paddlenlp/ops/triton_ops/paged_attn.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/paged_attn.py#L189-L200

Added lines #L189 - L200 were not covered by tests
Loading
Loading