From b62d0f81092c9d79cb0bc461e003c0e16173c5ce Mon Sep 17 00:00:00 2001 From: ming1753 Date: Fri, 7 Nov 2025 00:54:41 +0800 Subject: [PATCH] [Feature] Optim PaddleOCR-VL --- custom_ops/gpu_ops/cpp_extensions.cc | 15 ++ .../gpu_ops/fused_neox_rope_embedding.cu | 140 ++++++++++++++++++ custom_ops/gpu_ops/gelu_tanh.cu | 106 +++++++++++++ custom_ops/setup_ops.py | 2 + docs/best_practices/PaddleOCR-VL-0.9B.md | 34 ++--- docs/zh/best_practices/PaddleOCR-VL-0.9B.md | 32 ++-- docs/zh/usage/kunlunxin_xpu_deployment.md | 4 +- .../paddleocr_vl_processor.py | 6 +- .../models/paddleocr_vl/paddleocr_vl.py | 26 ++-- .../models/paddleocr_vl/siglip.py | 84 +++-------- .../models/paddleocr_vl/siglip_ops.py | 74 +++++++++ .../test_fused_neox_rope_embedding.py | 88 +++++++++++ tests/operators/test_gelu_tanh.py | 42 ++++++ 13 files changed, 540 insertions(+), 113 deletions(-) create mode 100644 custom_ops/gpu_ops/fused_neox_rope_embedding.cu create mode 100644 custom_ops/gpu_ops/gelu_tanh.cu create mode 100644 fastdeploy/model_executor/models/paddleocr_vl/siglip_ops.py create mode 100644 tests/operators/test_fused_neox_rope_embedding.py create mode 100644 tests/operators/test_gelu_tanh.py diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 212860d5446..08b9ab18a89 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -1046,6 +1046,15 @@ void SpeculateGetTargetLogits(const paddle::Tensor& target_logits, const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& accept_num); +std::vector FusedNeoxRopeEmbedding( + const paddle::Tensor& qkv, + const paddle::Tensor& cos_emb, + const paddle::Tensor& sin_emb, + const int num_heads, + const int head_dim); + +std::vector GeluTanh(paddle::Tensor& input); + PYBIND11_MODULE(fastdeploy_ops, m) { m.def("get_expert_token_num", &GetExpertTokenNum, @@ -1631,4 +1640,10 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("speculate_get_target_logits", &SpeculateGetTargetLogits, "speculate_get_target_logits function"); + + m.def("fused_neox_rope_embedding", + &FusedNeoxRopeEmbedding, + "fused_neox_rope_embedding function"); + + m.def("gelu_tanh", &GeluTanh, "gelu_tanh function"); } diff --git a/custom_ops/gpu_ops/fused_neox_rope_embedding.cu b/custom_ops/gpu_ops/fused_neox_rope_embedding.cu new file mode 100644 index 00000000000..e8020ba8d81 --- /dev/null +++ b/custom_ops/gpu_ops/fused_neox_rope_embedding.cu @@ -0,0 +1,140 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" +#include "paddle/extension.h" + +template +__global__ void FusedNeoxRopeEmbeddingKernel(const T *__restrict__ qkv, + const float *__restrict__ cos_emb, + const float *__restrict__ sin_emb, + T *__restrict__ q, + T *__restrict__ k, + T *__restrict__ v, + const int64_t elem_cnt, + const int num_head, + const int last_dim) { + using LoadT = AlignedVector; + using LoadEmbT = AlignedVector; + LoadT left_vec; + LoadT right_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + const int half_lastdim = last_dim / 2; + const int hidden_size = num_head * half_lastdim; + const int full_hidden_size = num_head * last_dim; + const int offset = 3 * hidden_size; + for (int64_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + const int token_idx = linear_index / offset; + const int bias = linear_index % offset; + const int qkv_id = bias / hidden_size; + const int qkv_bias = bias % hidden_size; + const int hi = qkv_bias / half_lastdim; + const int h_bias = qkv_bias % half_lastdim; + const int base_idx_left = token_idx * 3 * full_hidden_size + + qkv_id * full_hidden_size + hi * last_dim + + h_bias; + const int base_idx_right = base_idx_left + half_lastdim; + const int emb_idx = token_idx * last_dim + h_bias; + const int base_split_idx_left = + token_idx * full_hidden_size + hi * last_dim + h_bias; + const int base_split_idx_right = base_split_idx_left + half_lastdim; + + // q,k,v output + T *out_p = nullptr; + if (qkv_id == 0) { + out_p = q; + } else if (qkv_id == 1) { + out_p = k; + } else { + out_p = v; + } + + Load(&qkv[base_idx_left], &left_vec); + Load(&qkv[base_idx_right], &right_vec); + // do rope + if (qkv_id < 2) { + Load(&cos_emb[emb_idx], &cos_emb_vec); + Load(&sin_emb[emb_idx], &sin_emb_vec); +#pragma unroll + for (int i = 0; i < VecSize; i++) { + float input_left = static_cast(left_vec[i]); + float input_right = static_cast(right_vec[i]); + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + left_vec[i] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + right_vec[i] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); + + int cur_idx_1 = base_split_idx_left + i; + int cur_idx_2 = base_split_idx_right + i; + } + } + Store(left_vec, &out_p[base_split_idx_left]); + Store(right_vec, &out_p[base_split_idx_right]); + } +} + +std::vector FusedNeoxRopeEmbedding( + const paddle::Tensor &qkv, + const paddle::Tensor &cos_emb, + const paddle::Tensor &sin_emb, + const int num_heads, + const int head_dim) { + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t; + + const auto &qkv_dims = qkv.dims(); + const int token_num = qkv_dims.size() == 2 ? qkv_dims[0] : qkv_dims[1]; + + auto stream = qkv.stream(); + paddle::Tensor q = GetEmptyTensor( + {token_num, num_heads, head_dim}, qkv.dtype(), qkv.place()); + paddle::Tensor k = GetEmptyTensor( + {token_num, num_heads, head_dim}, qkv.dtype(), qkv.place()); + paddle::Tensor v = GetEmptyTensor( + {token_num, num_heads, head_dim}, qkv.dtype(), qkv.place()); + + int64_t elem_nums = token_num * num_heads * head_dim * 3 / 2; + constexpr int PackSize = 4; + const int pack_num = elem_nums / PackSize; + const int blocksize = 128; + int grid_size = 1; + GetNumBlocks<128>(pack_num, &grid_size); + + FusedNeoxRopeEmbeddingKernel + <<>>( + reinterpret_cast(qkv.data()), + cos_emb.data(), + sin_emb.data(), + reinterpret_cast(q.data()), + reinterpret_cast(k.data()), + reinterpret_cast(v.data()), + elem_nums, + num_heads, + head_dim); + return {q, k, v}; +} + +PD_BUILD_STATIC_OP(fused_neox_rope_embedding) + .Inputs({"qkv", "cos_emb", "sin_emb"}) + .Outputs({"q", "k", "v"}) + .Attrs({"num_heads: int", "head_dim: int"}) + .SetKernelFn(PD_KERNEL(FusedNeoxRopeEmbedding)); diff --git a/custom_ops/gpu_ops/gelu_tanh.cu b/custom_ops/gpu_ops/gelu_tanh.cu new file mode 100644 index 00000000000..6bf29a8434b --- /dev/null +++ b/custom_ops/gpu_ops/gelu_tanh.cu @@ -0,0 +1,106 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" +#include "paddle/extension.h" + +__forceinline__ __device__ float tanh_ptx(float x) { + float y; + asm volatile("tanh.approx.f32 %0, %1;" : "=f"(y) : "f"(x)); + return y; +} + +__device__ __forceinline__ float gelu_tanh_func(const float& val) { + const float cdf = + 0.5f * (1.0f + tanh_ptx((0.7978845608028654f * + (val + 0.044715f * val * val * val)))); + return val * cdf; +} + +template +__global__ void gelu_tanh_kernel(T* __restrict__ out, + const T* __restrict__ input, + const int d) { + constexpr uint32_t kVecSize = 16 / sizeof(T); + const int64_t token_idx = blockIdx.x; + const int64_t thread_idx = threadIdx.x; + const int64_t stride = blockDim.x; + const int64_t offset = token_idx * d; + using vec_t = AlignedVector; +#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && \ + (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif + +#pragma unroll 1 + for (uint32_t idx = thread_idx; idx < d / kVecSize; idx += stride) { + vec_t x_vec; + Load(input + offset + idx * kVecSize, &x_vec); +#pragma unroll + for (uint32_t i = 0; i < kVecSize; ++i) { + x_vec[i] = static_cast(gelu_tanh_func(static_cast(x_vec[i]))); + } + Store(x_vec, out + token_idx * d + idx * kVecSize); + } + + const int64_t remaining_offset = d - d % (stride * kVecSize); + // process the remaining elements +#pragma unroll 1 + for (int64_t idx = thread_idx; idx < d % (stride * kVecSize); idx += stride) { + float x = static_cast(input[offset + remaining_offset + idx]); + out[token_idx * d + remaining_offset + idx] = + static_cast(gelu_tanh_func(x)); + } + +#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && \ + (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} + +std::vector GeluTanh(paddle::Tensor& input) { + int d = input.dims()[1]; + int64_t num_tokens = input.dims()[0]; + cudaStream_t stream = input.stream(); + + paddle::Tensor output = + GetEmptyTensor(input.dims(), input.dtype(), input.place()); + + DISPATCH_FLOAT_FP6_DTYPE(input.dtype(), scalar_t, { + uint32_t vec_size = 16 / sizeof(scalar_t); + cudaLaunchConfig_t config; + config.gridDim = num_tokens; + config.blockDim = std::min(d / vec_size, 1024U); + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = false; + config.numAttrs = 1; + config.attrs = attrs; + + cudaLaunchKernelEx(&config, + gelu_tanh_kernel, + output.data(), + input.data(), + d); + }); + + return {output}; +} + +PD_BUILD_STATIC_OP(gelu_tanh) + .Inputs({"input"}) + .Outputs({"output"}) + .SetKernelFn(PD_KERNEL(GeluTanh)); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 4dbcb309953..0911afcb6b4 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -305,6 +305,8 @@ def find_end_files(directory, end_str): "gpu_ops/merge_prefill_decode_output.cu", "gpu_ops/limit_thinking_content_length_v1.cu", "gpu_ops/limit_thinking_content_length_v2.cu", + "gpu_ops/fused_neox_rope_embedding.cu", + "gpu_ops/gelu_tanh.cu", ] # pd_disaggregation diff --git a/docs/best_practices/PaddleOCR-VL-0.9B.md b/docs/best_practices/PaddleOCR-VL-0.9B.md index 1535484d137..d15d731b574 100644 --- a/docs/best_practices/PaddleOCR-VL-0.9B.md +++ b/docs/best_practices/PaddleOCR-VL-0.9B.md @@ -5,8 +5,8 @@ ## 1. Environment Preparation ### 1.1 Support Status Recommended Hardware Configuration: -- GPU Memory: 12GB or more -- Shared Memory: 2GB or more +- GPU Memory: 8GB or more +- Shared Memory: 4GB or more ### 1.2 Install Fastdeploy @@ -18,38 +18,38 @@ Installation process reference documentation [FastDeploy GPU Install](../get_sta ```shell python -m fastdeploy.entrypoints.openai.api_server \ --model PaddlePaddle/PaddleOCR-VL \ - --port 8180 \ - --metrics-port 8181 \ - --engine-worker-queue-port 8182 \ + --port 8185 \ + --metrics-port 8186 \ + --engine-worker-queue-port 8187 \ --max-model-len 16384 \ --max-num-batched-tokens 16384 \ - --gpu-memory-utilization 0.9 \ - --max-num-seqs 128 + --gpu-memory-utilization 0.8 \ + --max-num-seqs 256 ``` **Example 2:** Deploying a 16K Context Service on a Single RTX 4090 GPU ```shell python -m fastdeploy.entrypoints.openai.api_server \ --model PaddlePaddle/PaddleOCR-VL \ - --port 8180 \ - --metrics-port 8181 \ - --engine-worker-queue-port 8182 \ + --port 8185 \ + --metrics-port 8186 \ + --engine-worker-queue-port 8187 \ --max-model-len 16384 \ --max-num-batched-tokens 16384 \ - --gpu-memory-utilization 0.8 \ - --max-num-seqs 196 + --gpu-memory-utilization 0.7 \ + --max-num-seqs 256 ``` **Example 3:** Deploying a 16K Context Service on a Single A100 GPU ```shell python -m fastdeploy.entrypoints.openai.api_server \ --model PaddlePaddle/PaddleOCR-VL \ - --port 8180 \ - --metrics-port 8181 \ - --engine-worker-queue-port 8182 \ + --port 8185 \ + --metrics-port 8186 \ + --engine-worker-queue-port 8187 \ --max-model-len 16384 \ --max-num-batched-tokens 16384 \ - --gpu-memory-utilization 0.8 \ + --gpu-memory-utilization 0.7 \ --max-num-seqs 256 ``` @@ -71,7 +71,7 @@ An example is a set of configurations that can run stably while also delivering > **Available GPU memory ratio during initialization** - **Parameters:** `--gpu-memory-utilization` - **Description:** Controls the available GPU memory for FastDeploy service initialization. The default value is 0.9, meaning 10% of the memory is reserved for backup. -- **Recommendation:** It is recommended to use 0.8. If an "out of memory" error occurs during stress testing, you may attempt to reduce this value. +- **Recommendation:** It is recommended to use 0.7. If an "out of memory" error occurs during stress testing, you may attempt to reduce this value. #### 2.2.2 Chunked Prefill - **Parameters:** `--max-num-batched-tokens` diff --git a/docs/zh/best_practices/PaddleOCR-VL-0.9B.md b/docs/zh/best_practices/PaddleOCR-VL-0.9B.md index 1999d6f81ce..134c0aa4938 100644 --- a/docs/zh/best_practices/PaddleOCR-VL-0.9B.md +++ b/docs/zh/best_practices/PaddleOCR-VL-0.9B.md @@ -5,8 +5,8 @@ ## 一、环境准备 ### 1.1 支持情况 推荐硬件配置: -- 显存:12GB显存及以上 -- 共享内存:2G及以上 +- 显存:8GB显存及以上 +- 共享内存:4G及以上 ### 1.2 安装fastdeploy @@ -18,12 +18,12 @@ ```shell python -m fastdeploy.entrypoints.openai.api_server \ --model PaddlePaddle/PaddleOCR-VL \ - --port 8180 \ - --metrics-port 8181 \ - --engine-worker-queue-port 8182 \ + --port 8185 \ + --metrics-port 8186 \ + --engine-worker-queue-port 8187 \ --max-model-len 16384 \ --max-num-batched-tokens 16384 \ - --gpu-memory-utilization 0.9 \ + --gpu-memory-utilization 0.8 \ --max-num-seqs 128 ``` @@ -31,25 +31,25 @@ python -m fastdeploy.entrypoints.openai.api_server \ ```shell python -m fastdeploy.entrypoints.openai.api_server \ --model PaddlePaddle/PaddleOCR-VL \ - --port 8180 \ - --metrics-port 8181 \ - --engine-worker-queue-port 8182 \ + --port 8185 \ + --metrics-port 8186 \ + --engine-worker-queue-port 8187 \ --max-model-len 16384 \ --max-num-batched-tokens 16384 \ - --gpu-memory-utilization 0.8 \ - --max-num-seqs 196 + --gpu-memory-utilization 0.7 \ + --max-num-seqs 256 ``` **示例3:** A100上单卡部署16K上下文的服务 ```shell python -m fastdeploy.entrypoints.openai.api_server \ --model PaddlePaddle/PaddleOCR-VL \ - --port 8180 \ - --metrics-port 8181 \ - --engine-worker-queue-port 8182 \ + --port 8185 \ + --metrics-port 8186 \ + --engine-worker-queue-port 8187 \ --max-model-len 16384 \ --max-num-batched-tokens 16384 \ - --gpu-memory-utilization 0.8 \ + --gpu-memory-utilization 0.7 \ --max-num-seqs 256 ``` @@ -72,7 +72,7 @@ python -m fastdeploy.entrypoints.openai.api_server \ > **初始化时可用的显存比例** - **参数:** `--gpu-memory-utilization` - **用处:** 用于控制 FastDeploy 初始化服务的可用显存,默认0.9,即预留10%的显存备用。 -- **推荐:** 推荐使用0.8。如果服务压测时提示显存不足,可以尝试调低该值。 +- **推荐:** 推荐使用0.7。如果服务压测时提示显存不足,可以尝试调低该值。 #### 2.2.2 Chunked Prefill - **参数:** `--max-num-batched-tokens` diff --git a/docs/zh/usage/kunlunxin_xpu_deployment.md b/docs/zh/usage/kunlunxin_xpu_deployment.md index 7224a3b6f19..26e021b2981 100644 --- a/docs/zh/usage/kunlunxin_xpu_deployment.md +++ b/docs/zh/usage/kunlunxin_xpu_deployment.md @@ -197,7 +197,7 @@ curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \ -d '{ "messages": [ {"role": "user", "content": [ - {"type": "image_url", "image_url": {"url": "https://paddle-model-ecology.bj.bcebos.com/PPOCRVL/dataset/ocr_v5_eval/handwrite_ch_rec_val/中文手写古籍_000054_crop_32.jpg"}}, + {"type": "image_url", "image_url": {"url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo.jpg"}}, {"type": "text", "text": "OCR:"} ]} ], @@ -216,7 +216,7 @@ response = client.chat.completions.create( model="default", messages=[ {"role": "user", "content": [ - {"type": "image_url", "image_url": {"url": "https://paddle-model-ecology.bj.bcebos.com/PPOCRVL/dataset/ocr_v5_eval/handwrite_ch_rec_val/中文手写古籍_000054_crop_32.jpg"}}, + {"type": "image_url", "image_url": {"url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo.jpg"}}, {"type": "text", "text": "OCR:"} ] }, diff --git a/fastdeploy/input/paddleocr_vl_processor/paddleocr_vl_processor.py b/fastdeploy/input/paddleocr_vl_processor/paddleocr_vl_processor.py index 2e9e680c0b5..a5335fd0c39 100644 --- a/fastdeploy/input/paddleocr_vl_processor/paddleocr_vl_processor.py +++ b/fastdeploy/input/paddleocr_vl_processor/paddleocr_vl_processor.py @@ -22,6 +22,8 @@ from .process import DataProcessor +_SAMPLING_EPS = 1e-5 + class PaddleOCRVLProcessor(TextProcessor): """ @@ -61,7 +63,6 @@ def __init__( tool_parser_obj: Tool parser instance """ super().__init__(model_name_or_path, reasoning_parser_obj, tool_parser_obj) - data_processor_logger.info(f"model_name_or_path: {model_name_or_path}") processor_kwargs = self._parse_processor_kwargs(mm_processor_kwargs) self.processor = DataProcessor( @@ -252,6 +253,9 @@ def process_request_dict(self, request, max_model_len=None): if request.get("max_tokens") is None: request["max_tokens"] = max(1, max_model_len - len(request["prompt_token_ids"])) # Ensure at least 1 token + if request.get("top_p") is not None and request.get("top_p") < _SAMPLING_EPS: + request["top_p"] = _SAMPLING_EPS + return request def append_generated_tokens(self, multimodal_inputs, generated_token_ids): diff --git a/fastdeploy/model_executor/models/paddleocr_vl/paddleocr_vl.py b/fastdeploy/model_executor/models/paddleocr_vl/paddleocr_vl.py index 52359a8e31a..2517e277350 100644 --- a/fastdeploy/model_executor/models/paddleocr_vl/paddleocr_vl.py +++ b/fastdeploy/model_executor/models/paddleocr_vl/paddleocr_vl.py @@ -25,7 +25,6 @@ from paddleformers.transformers.configuration_utils import PretrainedConfig from paddleformers.utils.log import logger -from fastdeploy import envs from fastdeploy.config import FDConfig from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.graph_optimization.decorator import ( @@ -154,12 +153,8 @@ def __init__(self, fd_config): ) # Persistent buffers for CUDA graphs. - if envs.FD_ENABLE_MAX_PREFILL: - max_length = fd_config.scheduler_config.max_num_seqs * fd_config.model_config.max_model_len - else: - max_length = fd_config.model_config.max_model_len - self._input_embeddings = paddle.zeros( - [max_length, fd_config.model_config.hidden_size], + self._decoder_input_embeddings = paddle.zeros( + [fd_config.scheduler_config.max_num_seqs, fd_config.model_config.hidden_size], dtype=fd_config.model_config.dtype, ) @@ -265,12 +260,19 @@ def forward( input_embeddings = self.get_input_embeddings( ids_remove_padding=ids_remove_padding, image_features=image_features ) - self._input_embeddings.copy_(input_embeddings, False) - hidden_states = self.model( - input_embeddings=self._input_embeddings, - forward_meta=forward_meta, - ) + if forward_meta.step_use_cudagraph: + self._decoder_input_embeddings.copy_(input_embeddings, False) + + hidden_states = self.model( + input_embeddings=self._decoder_input_embeddings, + forward_meta=forward_meta, + ) + else: + hidden_states = self.model( + input_embeddings=input_embeddings, + forward_meta=forward_meta, + ) return hidden_states diff --git a/fastdeploy/model_executor/models/paddleocr_vl/siglip.py b/fastdeploy/model_executor/models/paddleocr_vl/siglip.py index 4d2b2c0bfeb..d712a8e3317 100644 --- a/fastdeploy/model_executor/models/paddleocr_vl/siglip.py +++ b/fastdeploy/model_executor/models/paddleocr_vl/siglip.py @@ -21,39 +21,13 @@ import paddle import paddle.nn as nn import paddle.nn.functional as F -from paddleformers.transformers.activations import ACT2FN from paddleformers.transformers.model_utils import PretrainedModel from fastdeploy.model_executor.layers.utils import get_tensor from fastdeploy.model_executor.utils import slice_fn from .config import PaddleOCRVisionConfig - - -def rotate_half(x): - Dh = x.shape[-1] - x1 = x[..., : Dh // 2] - x2 = x[..., Dh // 2 :] - return paddle.concat([-x2, x1], axis=-1) - - -def _ensure_cos_sin_dim(cos, sin, dim_needed): - last = cos.shape[-1] - if last == dim_needed: - return cos, sin - elif last * 2 == dim_needed: - cos = paddle.concat([cos, cos], axis=-1) - sin = paddle.concat([sin, sin], axis=-1) - return cos, sin - else: - raise ValueError(f"Unexpected cos/sin last-dim: {last}, expected {dim_needed} or {dim_needed//2}") - - -def apply_rotary_pos_emb_vision(x, cos, sin): - orig_dtype = x.dtype - x = x.astype("float32") - x_embed = (x * cos) + (rotate_half(x) * sin) - return x_embed.astype(orig_dtype) +from .siglip_ops import get_activation_fn, neox_rope_embedding class SiglipAttention(nn.Layer): @@ -147,29 +121,12 @@ def forward( output_attentions: Optional[bool] = False, cu_seqlens: Optional[List[paddle.Tensor]] = None, max_seqlen: Optional[paddle.Tensor] = None, - rope_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, # (cos, sin) + cos_emb: Optional[paddle.Tensor] = None, # (cos, sin) + sin_emb: Optional[paddle.Tensor] = None, # (cos, sin) ): B, seq_length, D = hidden_states.shape - - qkv = ( - self.qkv_proj(hidden_states) - .reshape( - [ - seq_length, - 3, - self.num_heads, - -1, - ] - ) - .transpose(perm=[1, 0, 2, 3]) - ) - q, k, v = qkv.unbind(axis=0) - cos, sin = rope_emb - - # -------- - q = apply_rotary_pos_emb_vision(q, cos, sin) - k = apply_rotary_pos_emb_vision(k, cos, sin) - + qkv = self.qkv_proj(hidden_states) + q, k, v = neox_rope_embedding(qkv, cos_emb, sin_emb, self.num_heads, self.head_dim) attn_output = self.flash_attn_func( q, k, @@ -181,11 +138,9 @@ def forward( causal=False, **self.flash_attn_kwargs, )[0] - # -------- attn_output = attn_output.reshape((seq_length, -1)) attn_output = self.out_proj(attn_output) - return attn_output @@ -327,11 +282,7 @@ class SiglipMLP(nn.Layer): def __init__(self, config): super().__init__() self.config = config - if config.hidden_act == "gelu_pytorch_tanh": - config.hidden_act = "gelu_new" - - self.activation_fn = ACT2FN[config.hidden_act] - + self.activation_fn = get_activation_fn(config.hidden_act) self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) self.fc1.weight.weight_loader = self.weight_loader self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) @@ -353,7 +304,7 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: hidden_states = self.fc1(hidden_states) - hidden_states = self.activation_fn(hidden_states) + hidden_states = self.activation_fn(hidden_states[0]) hidden_states = self.fc2(hidden_states) return hidden_states @@ -375,7 +326,8 @@ def forward( output_attentions=False, cu_seqlens=None, max_seqlen=None, - rope_emb=None, + cos_emb=None, + sin_emb=None, ): residual = hidden_states @@ -388,7 +340,8 @@ def forward( output_attentions=output_attentions, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, - rope_emb=rope_emb, + cos_emb=cos_emb, + sin_emb=sin_emb, ) hs_post_attn = residual + x @@ -545,13 +498,13 @@ def forward( rope_emb = rope_emb_max_grid[pids].flatten(1) rope_emb = rope_emb.tile((1, 2)) - cos = rope_emb.cos().astype("float32") - sin = rope_emb.sin().astype("float32") - cos = cos.unsqueeze(-2) - sin = sin.unsqueeze(-2) - rope_emb = (cos, sin) + cos_emb = rope_emb.cos().astype("float32") + sin_emb = rope_emb.sin().astype("float32") + cos_emb = cos_emb.unsqueeze(-2) + sin_emb = sin_emb.unsqueeze(-2) else: - rope_emb = None + cos_emb = None + sin_emb = None window_indices, cu_seqlens_within_windows = None, None @@ -588,7 +541,8 @@ def forward( output_attentions=output_attentions, cu_seqlens=attn_cu_seqlens, max_seqlen=max_seqlen, - rope_emb=rope_emb, + cos_emb=cos_emb, + sin_emb=sin_emb, ) hidden_states = layer_outputs[0] diff --git a/fastdeploy/model_executor/models/paddleocr_vl/siglip_ops.py b/fastdeploy/model_executor/models/paddleocr_vl/siglip_ops.py new file mode 100644 index 00000000000..d84898f656c --- /dev/null +++ b/fastdeploy/model_executor/models/paddleocr_vl/siglip_ops.py @@ -0,0 +1,74 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from typing import List + +import paddle +from paddleformers.transformers.activations import ACT2FN + +from fastdeploy.platforms import current_platform + +if current_platform.is_cuda(): + from fastdeploy.model_executor.ops.gpu import fused_neox_rope_embedding, gelu_tanh + + +def rotate_half(x): + Dh = x.shape[-1] + x1 = x[..., : Dh // 2] + x2 = x[..., Dh // 2 :] + return paddle.concat([-x2, x1], axis=-1) + + +def apply_rotary_pos_emb_vision(x, cos, sin): + orig_dtype = x.dtype + x = x.astype("float32") + x_embed = (x * cos) + (rotate_half(x) * sin) + return x_embed.astype(orig_dtype) + + +def native_neox_rope_embedding(qkv, cos, sin, num_heads): + B, seq_length, D = qkv.shape + qkv = qkv.reshape( + [ + seq_length, + 3, + num_heads, + -1, + ] + ).transpose(perm=[1, 0, 2, 3]) + q, k, v = qkv.unbind(axis=0) + q = apply_rotary_pos_emb_vision(q, cos, sin) + k = apply_rotary_pos_emb_vision(k, cos, sin) + return q, k, v + + +def neox_rope_embedding( + qkv: paddle.Tensor, cos_emb: paddle.Tensor, sin_emb: paddle.Tensor, num_heads: int, head_dim: int +) -> List[paddle.Tensor]: + if current_platform.is_cuda(): + return fused_neox_rope_embedding(qkv, cos_emb, sin_emb, num_heads, head_dim) + else: + return native_neox_rope_embedding(qkv, cos_emb, sin_emb, num_heads) + + +def get_activation_fn(hidden_act: str): + if hidden_act == "gelu_pytorch_tanh": + if current_platform.is_cuda(): + return gelu_tanh + else: + return ACT2FN["gelu_new"] + else: + return ACT2FN[hidden_act] diff --git a/tests/operators/test_fused_neox_rope_embedding.py b/tests/operators/test_fused_neox_rope_embedding.py new file mode 100644 index 00000000000..53a5bcf3cd9 --- /dev/null +++ b/tests/operators/test_fused_neox_rope_embedding.py @@ -0,0 +1,88 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.gpu import fused_neox_rope_embedding + + +def rotate_half(x): + Dh = x.shape[-1] + x1 = x[..., : Dh // 2] + x2 = x[..., Dh // 2 :] + return paddle.concat([-x2, x1], axis=-1) + + +def apply_rotary_pos_emb_vision(x, cos, sin): + orig_dtype = x.dtype + x = x.astype("float32") + x_embed = (x * cos) + (rotate_half(x) * sin) + return x_embed.astype(orig_dtype) + + +class TestFusedNeoxRopeEmbedding(unittest.TestCase): + def setUp(self): + paddle.set_device("gpu") + np.random.seed(42) + + def native_neox_rope_embedding(self, qkv, cos, sin, num_heads): + seq_length = qkv.shape[0] + qkv = qkv.reshape( + [ + seq_length, + 3, + num_heads, + -1, + ] + ).transpose(perm=[1, 0, 2, 3]) + q, k, v = qkv.unbind(axis=0) + q = apply_rotary_pos_emb_vision(q, cos, sin) + k = apply_rotary_pos_emb_vision(k, cos, sin) + return q, k, v + + def test_fused_neox_rope_embedding(self): + token_num = 1024 + hidden_size = 2048 + head_dim = 128 + num_heads = hidden_size // head_dim + qkv = paddle.randn([token_num, 3 * hidden_size]).astype("bfloat16") + cos_emb = paddle.rand([token_num, head_dim // 2]).tile((1, 2)).unsqueeze(1) + sin_emb = paddle.rand([token_num, head_dim // 2]).tile((1, 2)).unsqueeze(1) + q, k, v = fused_neox_rope_embedding(qkv, cos_emb, sin_emb, num_heads, head_dim) + q_base, k_base, v_base = self.native_neox_rope_embedding(qkv, cos_emb, sin_emb, num_heads) + np.testing.assert_allclose( + q.cast("float32").numpy(), + q_base.cast("float32").numpy(), + rtol=1e-02, + atol=1e-02, + ) + np.testing.assert_allclose( + k.cast("float32").numpy(), + k_base.cast("float32").numpy(), + rtol=1e-02, + atol=1e-02, + ) + np.testing.assert_allclose( + v.cast("float32").numpy(), + v_base.cast("float32").numpy(), + rtol=1e-02, + atol=1e-02, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/operators/test_gelu_tanh.py b/tests/operators/test_gelu_tanh.py new file mode 100644 index 00000000000..061ae6adb69 --- /dev/null +++ b/tests/operators/test_gelu_tanh.py @@ -0,0 +1,42 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import paddle +from paddleformers.transformers.activations import ACT2FN + +from fastdeploy.model_executor.ops.gpu import gelu_tanh + + +class TestGeluTanh(unittest.TestCase): + def setUp(self): + paddle.set_device("gpu") + np.random.seed(42) + + def test_gelu_tanh(self): + x = paddle.randn(2048, 4096) + y0 = ACT2FN["gelu_new"](x) + y1 = gelu_tanh(x) + np.testing.assert_allclose( + y0.cast("float32").numpy(), + y1.cast("float32").numpy(), + rtol=1e-04, + atol=1e-04, + ) + + +if __name__ == "__main__": + unittest.main()