Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1046,6 +1046,15 @@ void SpeculateGetTargetLogits(const paddle::Tensor& target_logits,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& accept_num);

std::vector<paddle::Tensor> FusedNeoxRopeEmbedding(
const paddle::Tensor& qkv,
const paddle::Tensor& cos_emb,
const paddle::Tensor& sin_emb,
const int num_heads,
const int head_dim);

std::vector<paddle::Tensor> GeluTanh(paddle::Tensor& input);

PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("get_expert_token_num",
&GetExpertTokenNum,
Expand Down Expand Up @@ -1631,4 +1640,10 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("speculate_get_target_logits",
&SpeculateGetTargetLogits,
"speculate_get_target_logits function");

m.def("fused_neox_rope_embedding",
&FusedNeoxRopeEmbedding,
"fused_neox_rope_embedding function");

m.def("gelu_tanh", &GeluTanh, "gelu_tanh function");
}
140 changes: 140 additions & 0 deletions custom_ops/gpu_ops/fused_neox_rope_embedding.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "helper.h"
#include "paddle/extension.h"

template <typename T, int VecSize = 1>
__global__ void FusedNeoxRopeEmbeddingKernel(const T *__restrict__ qkv,
const float *__restrict__ cos_emb,
const float *__restrict__ sin_emb,
T *__restrict__ q,
T *__restrict__ k,
T *__restrict__ v,
const int64_t elem_cnt,
const int num_head,
const int last_dim) {
using LoadT = AlignedVector<T, VecSize>;
using LoadEmbT = AlignedVector<float, VecSize>;
LoadT left_vec;
LoadT right_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
const int half_lastdim = last_dim / 2;
const int hidden_size = num_head * half_lastdim;
const int full_hidden_size = num_head * last_dim;
const int offset = 3 * hidden_size;
for (int64_t linear_index = global_thread_idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
linear_index += step) {
const int token_idx = linear_index / offset;
const int bias = linear_index % offset;
const int qkv_id = bias / hidden_size;
const int qkv_bias = bias % hidden_size;
const int hi = qkv_bias / half_lastdim;
const int h_bias = qkv_bias % half_lastdim;
const int base_idx_left = token_idx * 3 * full_hidden_size +
qkv_id * full_hidden_size + hi * last_dim +
h_bias;
const int base_idx_right = base_idx_left + half_lastdim;
const int emb_idx = token_idx * last_dim + h_bias;
const int base_split_idx_left =
token_idx * full_hidden_size + hi * last_dim + h_bias;
const int base_split_idx_right = base_split_idx_left + half_lastdim;

// q,k,v output
T *out_p = nullptr;
if (qkv_id == 0) {
out_p = q;
} else if (qkv_id == 1) {
out_p = k;
} else {
out_p = v;
}

Load<T, VecSize>(&qkv[base_idx_left], &left_vec);
Load<T, VecSize>(&qkv[base_idx_right], &right_vec);
// do rope
if (qkv_id < 2) {
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
float input_left = static_cast<float>(left_vec[i]);
float input_right = static_cast<float>(right_vec[i]);
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
left_vec[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
right_vec[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);

int cur_idx_1 = base_split_idx_left + i;
int cur_idx_2 = base_split_idx_right + i;
}
}
Store<T, VecSize>(left_vec, &out_p[base_split_idx_left]);
Store<T, VecSize>(right_vec, &out_p[base_split_idx_right]);
}
}

std::vector<paddle::Tensor> FusedNeoxRopeEmbedding(
const paddle::Tensor &qkv,
const paddle::Tensor &cos_emb,
const paddle::Tensor &sin_emb,
const int num_heads,
const int head_dim) {
typedef PDTraits<paddle::DataType::BFLOAT16> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;

const auto &qkv_dims = qkv.dims();
const int token_num = qkv_dims.size() == 2 ? qkv_dims[0] : qkv_dims[1];

auto stream = qkv.stream();
paddle::Tensor q = GetEmptyTensor(
{token_num, num_heads, head_dim}, qkv.dtype(), qkv.place());
paddle::Tensor k = GetEmptyTensor(
{token_num, num_heads, head_dim}, qkv.dtype(), qkv.place());
paddle::Tensor v = GetEmptyTensor(
{token_num, num_heads, head_dim}, qkv.dtype(), qkv.place());

int64_t elem_nums = token_num * num_heads * head_dim * 3 / 2;
constexpr int PackSize = 4;
const int pack_num = elem_nums / PackSize;
const int blocksize = 128;
int grid_size = 1;
GetNumBlocks<128>(pack_num, &grid_size);

FusedNeoxRopeEmbeddingKernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<const DataType_ *>(qkv.data<data_t>()),
cos_emb.data<float>(),
sin_emb.data<float>(),
reinterpret_cast<DataType_ *>(q.data<data_t>()),
reinterpret_cast<DataType_ *>(k.data<data_t>()),
reinterpret_cast<DataType_ *>(v.data<data_t>()),
elem_nums,
num_heads,
head_dim);
return {q, k, v};
}

PD_BUILD_STATIC_OP(fused_neox_rope_embedding)
.Inputs({"qkv", "cos_emb", "sin_emb"})
.Outputs({"q", "k", "v"})
.Attrs({"num_heads: int", "head_dim: int"})
.SetKernelFn(PD_KERNEL(FusedNeoxRopeEmbedding));
106 changes: 106 additions & 0 deletions custom_ops/gpu_ops/gelu_tanh.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "helper.h"
#include "paddle/extension.h"

__forceinline__ __device__ float tanh_ptx(float x) {
float y;
asm volatile("tanh.approx.f32 %0, %1;" : "=f"(y) : "f"(x));
return y;
}

__device__ __forceinline__ float gelu_tanh_func(const float& val) {
const float cdf =
0.5f * (1.0f + tanh_ptx((0.7978845608028654f *
(val + 0.044715f * val * val * val))));
return val * cdf;
}

template <typename T>
__global__ void gelu_tanh_kernel(T* __restrict__ out,
const T* __restrict__ input,
const int d) {
constexpr uint32_t kVecSize = 16 / sizeof(T);
const int64_t token_idx = blockIdx.x;
const int64_t thread_idx = threadIdx.x;
const int64_t stride = blockDim.x;
const int64_t offset = token_idx * d;
using vec_t = AlignedVector<T, kVecSize>;
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && \
(__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif

#pragma unroll 1
for (uint32_t idx = thread_idx; idx < d / kVecSize; idx += stride) {
vec_t x_vec;
Load(input + offset + idx * kVecSize, &x_vec);
#pragma unroll
for (uint32_t i = 0; i < kVecSize; ++i) {
x_vec[i] = static_cast<T>(gelu_tanh_func(static_cast<float>(x_vec[i])));
}
Store(x_vec, out + token_idx * d + idx * kVecSize);
}

const int64_t remaining_offset = d - d % (stride * kVecSize);
// process the remaining elements
#pragma unroll 1
for (int64_t idx = thread_idx; idx < d % (stride * kVecSize); idx += stride) {
float x = static_cast<float>(input[offset + remaining_offset + idx]);
out[token_idx * d + remaining_offset + idx] =
static_cast<T>(gelu_tanh_func(x));
}

#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && \
(__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
}

std::vector<paddle::Tensor> GeluTanh(paddle::Tensor& input) {
int d = input.dims()[1];
int64_t num_tokens = input.dims()[0];
cudaStream_t stream = input.stream();

paddle::Tensor output =
GetEmptyTensor(input.dims(), input.dtype(), input.place());

DISPATCH_FLOAT_FP6_DTYPE(input.dtype(), scalar_t, {
uint32_t vec_size = 16 / sizeof(scalar_t);
cudaLaunchConfig_t config;
config.gridDim = num_tokens;
config.blockDim = std::min(d / vec_size, 1024U);
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = false;
config.numAttrs = 1;
config.attrs = attrs;

cudaLaunchKernelEx(&config,
gelu_tanh_kernel<scalar_t>,
output.data<scalar_t>(),
input.data<scalar_t>(),
d);
});

return {output};
}

PD_BUILD_STATIC_OP(gelu_tanh)
.Inputs({"input"})
.Outputs({"output"})
.SetKernelFn(PD_KERNEL(GeluTanh));
2 changes: 2 additions & 0 deletions custom_ops/setup_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,8 @@ def find_end_files(directory, end_str):
"gpu_ops/merge_prefill_decode_output.cu",
"gpu_ops/limit_thinking_content_length_v1.cu",
"gpu_ops/limit_thinking_content_length_v2.cu",
"gpu_ops/fused_neox_rope_embedding.cu",
"gpu_ops/gelu_tanh.cu",
]

# pd_disaggregation
Expand Down
34 changes: 17 additions & 17 deletions docs/best_practices/PaddleOCR-VL-0.9B.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
## 1. Environment Preparation
### 1.1 Support Status
Recommended Hardware Configuration:
- GPU Memory: 12GB or more
- Shared Memory: 2GB or more
- GPU Memory: 8GB or more
- Shared Memory: 4GB or more

### 1.2 Install Fastdeploy

Expand All @@ -18,38 +18,38 @@ Installation process reference documentation [FastDeploy GPU Install](../get_sta
```shell
python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/PaddleOCR-VL \
--port 8180 \
--metrics-port 8181 \
--engine-worker-queue-port 8182 \
--port 8185 \
--metrics-port 8186 \
--engine-worker-queue-port 8187 \
--max-model-len 16384 \
--max-num-batched-tokens 16384 \
--gpu-memory-utilization 0.9 \
--max-num-seqs 128
--gpu-memory-utilization 0.8 \
--max-num-seqs 256
```

**Example 2:** Deploying a 16K Context Service on a Single RTX 4090 GPU
```shell
python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/PaddleOCR-VL \
--port 8180 \
--metrics-port 8181 \
--engine-worker-queue-port 8182 \
--port 8185 \
--metrics-port 8186 \
--engine-worker-queue-port 8187 \
--max-model-len 16384 \
--max-num-batched-tokens 16384 \
--gpu-memory-utilization 0.8 \
--max-num-seqs 196
--gpu-memory-utilization 0.7 \
--max-num-seqs 256
```

**Example 3:** Deploying a 16K Context Service on a Single A100 GPU
```shell
python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/PaddleOCR-VL \
--port 8180 \
--metrics-port 8181 \
--engine-worker-queue-port 8182 \
--port 8185 \
--metrics-port 8186 \
--engine-worker-queue-port 8187 \
--max-model-len 16384 \
--max-num-batched-tokens 16384 \
--gpu-memory-utilization 0.8 \
--gpu-memory-utilization 0.7 \
--max-num-seqs 256
```

Expand All @@ -71,7 +71,7 @@ An example is a set of configurations that can run stably while also delivering
> **Available GPU memory ratio during initialization**
- **Parameters:** `--gpu-memory-utilization`
- **Description:** Controls the available GPU memory for FastDeploy service initialization. The default value is 0.9, meaning 10% of the memory is reserved for backup.
- **Recommendation:** It is recommended to use 0.8. If an "out of memory" error occurs during stress testing, you may attempt to reduce this value.
- **Recommendation:** It is recommended to use 0.7. If an "out of memory" error occurs during stress testing, you may attempt to reduce this value.

#### 2.2.2 Chunked Prefill
- **Parameters:** `--max-num-batched-tokens`
Expand Down
Loading
Loading