diff --git a/paddle/fluid/operators/layer_norm_op.h b/paddle/fluid/operators/layer_norm_op.h index 605b5c258ca57..7b84ba0a7daf1 100644 --- a/paddle/fluid/operators/layer_norm_op.h +++ b/paddle/fluid/operators/layer_norm_op.h @@ -22,6 +22,103 @@ limitations under the License. */ namespace paddle { namespace operators { +// Wrap RowwiseMean and ColwiseMean. +// Reuse the cpu codes and replace the gpu codes with cublas_gemv, which is +// significantly faster. Unlike the RowwiseMean and ColwiseMean, the +// implementation only considers 2D. +template +struct RowwiseMean2D { + RowwiseMean2D(int left, int right, const platform::DeviceContext& dev_ctx); + + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor* vec); +}; + +#ifdef PADDLE_WITH_CUDA +template +class RowwiseMean2D { + public: + RowwiseMean2D(int left, int right, const platform::DeviceContext& dev_ctx) + : left_(left), right_(right) { + framework::DDim ones_dim({right_}); + divisor_.mutable_data(ones_dim, dev_ctx.GetPlace()); + math::set_constant(dev_ctx, &divisor_, 1.0 / right); + } + void operator()(const platform::CUDADeviceContext& context, + const framework::Tensor& input, framework::Tensor* out) { + math::gemv( + context, false, left_, right_, 1., input.data(), divisor_.data(), + 0., out->data()); + } + + private: + int left_; + int right_; + framework::Tensor divisor_; +}; +#endif + +template +class RowwiseMean2D { + public: + RowwiseMean2D(int left, int right, const platform::DeviceContext& dev_ctx) {} + + void operator()(const platform::CPUDeviceContext& context, + const framework::Tensor& input, framework::Tensor* out) { + row_mean_(context, input, out); + } + + private: + math::RowwiseMean row_mean_; +}; + +template +struct ColwiseSum2D { + ColwiseSum2D(int left, int right, const platform::DeviceContext& dev_ctx); + + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor* vec); +}; + +#ifdef PADDLE_WITH_CUDA +template +class ColwiseSum2D { + public: + ColwiseSum2D(int left, int right, const platform::DeviceContext& dev_ctx) + : left_(left), right_(right) { + framework::DDim ones_dim({left_}); + divisor_.mutable_data(ones_dim, dev_ctx.GetPlace()); + math::set_constant(dev_ctx, &divisor_, 1.0); + } + + void operator()(const platform::CUDADeviceContext& context, + const framework::Tensor& input, framework::Tensor* out) { + math::gemv( + context, true, left_, right_, 1., input.data(), divisor_.data(), + 0., out->data()); + } + + private: + int left_; + int right_; + framework::Tensor divisor_; +}; +#endif + +template +class ColwiseSum2D { + public: + ColwiseSum2D(int left, int right, const platform::DeviceContext& dev_ctx) {} + + void operator()(const platform::CPUDeviceContext& context, + const framework::Tensor& input, framework::Tensor* out) { + col_wise_(context, input, out); + } + + private: + math::ColwiseSum col_wise_; +}; + template struct SubAndSquareFunctor { inline HOSTDEVICE T operator()(T a, T b) const { return (a - b) * (a - b); } @@ -67,15 +164,15 @@ using DataLayout = framework::DataLayout; template class LayerNormKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext &ctx) const override { + void Compute(const framework::ExecutionContext& ctx) const override { const float epsilon = ctx.Attr("epsilon"); - auto *scale = ctx.Input("Scale"); - auto *bias = ctx.Input("Bias"); + auto* scale = ctx.Input("Scale"); + auto* bias = ctx.Input("Bias"); auto x = *ctx.Input("X"); - auto *y = ctx.Output("Y"); - auto *mean = ctx.Output("Mean"); - auto *var = ctx.Output("Variance"); + auto* y = ctx.Output("Y"); + auto* mean = ctx.Output("Mean"); + auto* var = ctx.Output("Variance"); const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); const auto x_dims = x.dims(); @@ -94,8 +191,8 @@ class LayerNormKernel : public framework::OpKernel { out.ShareDataWith(*y); out.Resize(matrix_shape); - auto &dev_ctx = ctx.template device_context(); - math::RowwiseMean row_mean; + auto& dev_ctx = ctx.template device_context(); + RowwiseMean2D row_mean(left, right, ctx.device_context()); // get mean row_mean(dev_ctx, x, mean); @@ -126,31 +223,32 @@ class LayerNormKernel : public framework::OpKernel { template class LayerNormGradKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext &ctx) const override { + void Compute(const framework::ExecutionContext& ctx) const override { const float epsilon = ctx.Attr("epsilon"); auto x = *ctx.Input("X"); - auto *y = ctx.Input("Y"); - auto *mean = ctx.Input("Mean"); - auto *var = ctx.Input("Variance"); - auto *scale = ctx.Input("Scale"); - auto *bias = ctx.Input("Bias"); + auto* y = ctx.Input("Y"); + auto* mean = ctx.Input("Mean"); + auto* var = ctx.Input("Variance"); + auto* scale = ctx.Input("Scale"); + auto* bias = ctx.Input("Bias"); auto d_y = *ctx.Input(framework::GradVarName("Y")); const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); // init output - auto *d_x = ctx.Output(framework::GradVarName("X")); - auto *d_scale = ctx.Output(framework::GradVarName("Scale")); - auto *d_bias = ctx.Output(framework::GradVarName("Bias")); + auto* d_x = ctx.Output(framework::GradVarName("X")); + auto* d_scale = ctx.Output(framework::GradVarName("Scale")); + auto* d_bias = ctx.Output(framework::GradVarName("Bias")); - const auto &x_dims = x.dims(); + const auto& x_dims = x.dims(); auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); int left = static_cast(matrix_dim[0]); int right = static_cast(matrix_dim[1]); framework::DDim matrix_shape({left, right}); d_y.Resize(matrix_shape); - auto &dev_ctx = ctx.template device_context(); - math::ColwiseSum colwise_sum; + auto& dev_ctx = ctx.template device_context(); + ColwiseSum2D colwise_sum(left, right, + ctx.device_context()); Tensor temp; Tensor temp_norm; @@ -190,7 +288,8 @@ class LayerNormGradKernel : public framework::OpKernel { Tensor temp_vec; temp_vec.mutable_data(vec_shape, ctx.GetPlace()); - math::RowwiseMean row_mean; + RowwiseMean2D row_mean(left, right, + ctx.device_context()); if (d_scale) { // dy_dx