|
| 1 | +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#include <infer_ops.h> |
| 16 | +#include <functional> |
| 17 | +#include "paddle/extension.h" |
| 18 | +#include "paddle/phi/backends/xpu/enforce_xpu.h" |
| 19 | +#include "utility/debug.h" |
| 20 | +#include "utility/env.h" |
| 21 | + |
| 22 | +#ifndef PD_BUILD_STATIC_OP |
| 23 | +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) |
| 24 | +#endif |
| 25 | + |
| 26 | +XPU_DECLARE_BOOL(ENABLE_XVLLM_SDNN_INFER, false); |
| 27 | +namespace api = baidu::xpu::api; |
| 28 | + |
| 29 | +template <typename T> |
| 30 | +std::vector<paddle::Tensor> RmsNormKernel( |
| 31 | + const paddle::Tensor& x, |
| 32 | + const paddle::optional<paddle::Tensor>& bias, |
| 33 | + const paddle::optional<paddle::Tensor>& residual, |
| 34 | + const paddle::Tensor& norm_weight, |
| 35 | + const paddle::optional<paddle::Tensor>& norm_bias, |
| 36 | + const float epsilon, |
| 37 | + const int begin_norm_axis, |
| 38 | + const float quant_scale, |
| 39 | + const int quant_round_type, |
| 40 | + const float quant_max_bound, |
| 41 | + const float quant_min_bound) { |
| 42 | + using XPU_T = typename XPUTypeTrait<T>::Type; |
| 43 | + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); |
| 44 | + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); |
| 45 | + auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx); |
| 46 | + |
| 47 | + int ret = -1; |
| 48 | + auto x_shape = x.shape(); |
| 49 | + PD_CHECK(quant_scale <= 0, "Quantization is not supported"); |
| 50 | + PD_CHECK(begin_norm_axis > 0 && begin_norm_axis <= x_shape.size(), |
| 51 | + "begin_norm_axis check fail"); |
| 52 | + PD_CHECK(norm_bias.get_ptr() == nullptr, |
| 53 | + "rms norm kernel don't support norm_bias"); |
| 54 | + |
| 55 | + int64_t m = std::accumulate(x_shape.begin(), |
| 56 | + x_shape.begin() + begin_norm_axis, |
| 57 | + static_cast<int64_t>(1), |
| 58 | + std::multiplies<int64_t>()); |
| 59 | + int64_t n = std::accumulate(x_shape.begin() + begin_norm_axis, |
| 60 | + x_shape.end(), |
| 61 | + static_cast<int64_t>(1), |
| 62 | + std::multiplies<int64_t>()); |
| 63 | + |
| 64 | + PD_CHECK(n == norm_weight.shape()[0], |
| 65 | + "The product from begin_norm_axis to the last axis of x must be " |
| 66 | + "equal to the norm_weight's shape[0]"); |
| 67 | + if (bias.get_ptr()) { |
| 68 | + PD_CHECK(n == bias.get_ptr()->shape()[0], |
| 69 | + "The product from begin_norm_axis to the last axis of x must be " |
| 70 | + "equal to the bias's shape[0]"); |
| 71 | + } |
| 72 | + |
| 73 | + paddle::Tensor out = paddle::empty(x_shape, x.dtype(), x.place()); |
| 74 | + paddle::Tensor residual_out = paddle::empty(x_shape, x.dtype(), x.place()); |
| 75 | + const XPU_T* x_data = reinterpret_cast<const XPU_T*>(x.data<T>()); |
| 76 | + const XPU_T* norm_weight_data = |
| 77 | + reinterpret_cast<const XPU_T*>(norm_weight.data<T>()); |
| 78 | + const XPU_T* bias_data = |
| 79 | + bias.get_ptr() ? reinterpret_cast<const XPU_T*>(bias.get_ptr()->data<T>()) |
| 80 | + : nullptr; |
| 81 | + const XPU_T* residual_data = |
| 82 | + residual.get_ptr() |
| 83 | + ? reinterpret_cast<const XPU_T*>(residual.get_ptr()->data<T>()) |
| 84 | + : nullptr; |
| 85 | + XPU_T* out_data = reinterpret_cast<XPU_T*>(const_cast<T*>(out.data<T>())); |
| 86 | + XPU_T* residual_out_data = nullptr; |
| 87 | + if (residual_data) { |
| 88 | + residual_out_data = |
| 89 | + reinterpret_cast<XPU_T*>(const_cast<T*>(residual_out.data<T>())); |
| 90 | + } |
| 91 | + |
| 92 | + XPU_T* add_out_data = const_cast<XPU_T*>(x_data); |
| 93 | + if (bias_data) { |
| 94 | + ret = api::broadcast_add( |
| 95 | + xpu_ctx->x_context(), x_data, bias_data, out_data, {m, n}, {n}); |
| 96 | + PD_CHECK(ret == 0, "broadcast_add"); |
| 97 | + add_out_data = out_data; |
| 98 | + } |
| 99 | + |
| 100 | + bool use_sdnn = FLAGS_ENABLE_XVLLM_SDNN_INFER; |
| 101 | + if (residual_data) { |
| 102 | + ret = infer_ops::add_rms_layer_norm<XPU_T, XPU_T>(xpu_ctx->x_context(), |
| 103 | + add_out_data, |
| 104 | + residual_data, |
| 105 | + out_data, |
| 106 | + m, |
| 107 | + n, |
| 108 | + epsilon, |
| 109 | + norm_weight_data, |
| 110 | + nullptr, |
| 111 | + nullptr, |
| 112 | + residual_out_data, |
| 113 | + nullptr, |
| 114 | + use_sdnn); |
| 115 | + PD_CHECK(ret == 0, "add_rms_layer_norm"); |
| 116 | + } else { |
| 117 | + ret = api::rms_layer_norm<XPU_T, XPU_T>(xpu_ctx->x_context(), |
| 118 | + add_out_data, |
| 119 | + out_data, |
| 120 | + m, |
| 121 | + n, |
| 122 | + epsilon, |
| 123 | + norm_weight_data, |
| 124 | + nullptr, |
| 125 | + nullptr, |
| 126 | + false); |
| 127 | + PD_CHECK(ret == 0, "rms_layer_norm"); |
| 128 | + } |
| 129 | + |
| 130 | + return {out, residual_out}; |
| 131 | +} |
| 132 | + |
| 133 | +std::vector<paddle::Tensor> RmsNorm( |
| 134 | + const paddle::Tensor& x, |
| 135 | + const paddle::optional<paddle::Tensor>& bias, |
| 136 | + const paddle::optional<paddle::Tensor>& residual, |
| 137 | + const paddle::Tensor& norm_weight, |
| 138 | + const paddle::optional<paddle::Tensor>& norm_bias, |
| 139 | + const float epsilon, |
| 140 | + const int begin_norm_axis, |
| 141 | + const float quant_scale, |
| 142 | + const int quant_round_type, |
| 143 | + const float quant_max_bound, |
| 144 | + const float quant_min_bound) { |
| 145 | + const auto x_type = x.dtype(); |
| 146 | + |
| 147 | +#define APPLY_RMS_NORM_KERNEL(TX) \ |
| 148 | + return RmsNormKernel<TX>(x, \ |
| 149 | + bias, \ |
| 150 | + residual, \ |
| 151 | + norm_weight, \ |
| 152 | + norm_bias, \ |
| 153 | + epsilon, \ |
| 154 | + begin_norm_axis, \ |
| 155 | + quant_scale, \ |
| 156 | + quant_round_type, \ |
| 157 | + quant_max_bound, \ |
| 158 | + quant_min_bound); |
| 159 | + |
| 160 | + if (x_type == paddle::DataType::BFLOAT16) { |
| 161 | + APPLY_RMS_NORM_KERNEL(paddle::bfloat16); |
| 162 | + } else if (x_type == paddle::DataType::FLOAT16) { |
| 163 | + APPLY_RMS_NORM_KERNEL(paddle::float16); |
| 164 | + } else if (x_type == paddle::DataType::FLOAT32) { |
| 165 | + APPLY_RMS_NORM_KERNEL(float); |
| 166 | + } else { |
| 167 | + PD_THROW("RmsNorm not support x_type=", static_cast<int>(x_type)); |
| 168 | + return {}; |
| 169 | + } |
| 170 | +#undef APPLY_RMS_NORM_KERNEL |
| 171 | +} |
| 172 | + |
| 173 | +std::vector<std::vector<int64_t>> RmsNormInferShape( |
| 174 | + const std::vector<int64_t>& x_shape, |
| 175 | + const paddle::optional<std::vector<int64_t>>& bias_shape, |
| 176 | + const paddle::optional<std::vector<int64_t>>& residual_shape, |
| 177 | + const std::vector<int64_t>& norm_weight_shape, |
| 178 | + const paddle::optional<std::vector<int64_t>>& norm_bias_shape, |
| 179 | + const float epsilon, |
| 180 | + const int begin_norm_axis, |
| 181 | + const float quant_scale, |
| 182 | + const int quant_round_type, |
| 183 | + const float quant_max_bound, |
| 184 | + const float quant_min_bound) { |
| 185 | + PD_CHECK(begin_norm_axis > 0 && begin_norm_axis <= x_shape.size(), |
| 186 | + "begin_norm_axis check fail"); |
| 187 | + int64_t m = std::accumulate(x_shape.begin(), |
| 188 | + x_shape.begin() + begin_norm_axis, |
| 189 | + static_cast<int64_t>(1), |
| 190 | + std::multiplies<int64_t>()); |
| 191 | + return {x_shape, x_shape, {m}}; |
| 192 | +} |
| 193 | + |
| 194 | +std::vector<paddle::DataType> RmsNormInferDtype( |
| 195 | + const paddle::DataType& x_dtype, |
| 196 | + const paddle::optional<paddle::DataType>& bias_dtype, |
| 197 | + const paddle::optional<paddle::DataType>& residual_dtype, |
| 198 | + const paddle::DataType& norm_weight_dtype, |
| 199 | + const paddle::optional<paddle::DataType>& norm_bias_dtype, |
| 200 | + const float epsilon, |
| 201 | + const int begin_norm_axis, |
| 202 | + const float quant_scale, |
| 203 | + const int quant_round_type, |
| 204 | + const float quant_max_bound, |
| 205 | + const float quant_min_bound) { |
| 206 | + // out, residual_out |
| 207 | + return {x_dtype, x_dtype}; |
| 208 | +} |
| 209 | + |
| 210 | +PD_BUILD_STATIC_OP(fused_rms_norm_xpu) |
| 211 | + .Inputs({"x", |
| 212 | + paddle::Optional("bias"), |
| 213 | + paddle::Optional("residual"), |
| 214 | + "norm_weight", |
| 215 | + paddle::Optional("norm_bias")}) |
| 216 | + .Outputs({"out", "residul_out"}) |
| 217 | + .Attrs({"epsilon:float", |
| 218 | + "begin_norm_axis:int", |
| 219 | + "quant_scale:float", |
| 220 | + "quant_round_type:int", |
| 221 | + "quant_max_bound:float", |
| 222 | + "quant_min_bound:float"}) |
| 223 | + .SetKernelFn(PD_KERNEL(RmsNorm)) |
| 224 | + .SetInferShapeFn(PD_INFER_SHAPE(RmsNormInferShape)) |
| 225 | + .SetInferDtypeFn(PD_INFER_DTYPE(RmsNormInferDtype)); |
0 commit comments