From edbf6c578466f7023742fe4c31c1a66735aad390 Mon Sep 17 00:00:00 2001 From: Silv3S Date: Thu, 27 Oct 2022 12:26:37 +0200 Subject: [PATCH 1/6] init changes --- .../operators/mkldnn/batch_norm_mkldnn_op.cc | 4 -- .../phi/kernels/onednn/batch_norm_kernel.cc | 45 +++++++++++++++++++ 2 files changed, 45 insertions(+), 4 deletions(-) create mode 100644 paddle/phi/kernels/onednn/batch_norm_kernel.cc diff --git a/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc index d7575f0ebf885..550be6b0d1abe 100644 --- a/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc @@ -308,10 +308,6 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_KERNEL(batch_norm, - MKLDNN, - ::paddle::platform::CPUPlace, - ops::BatchNormMKLDNNOpKernel); REGISTER_OP_KERNEL(batch_norm_grad, MKLDNN, ::paddle::platform::CPUPlace, diff --git a/paddle/phi/kernels/onednn/batch_norm_kernel.cc b/paddle/phi/kernels/onednn/batch_norm_kernel.cc new file mode 100644 index 0000000000000..c01c5cf4cd2ac --- /dev/null +++ b/paddle/phi/kernels/onednn/batch_norm_kernel.cc @@ -0,0 +1,45 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/batch_norm_kernel.h" + +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void BatchNormKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& scale, + const DenseTensor& bias, + const DenseTensor& mean, + const DenseTensor& variance, + float momentum, + float epsilon, + const std::string& data_layout, + bool is_test, + bool use_global_stats, + bool trainable_statistics, + bool fuse_with_relu, + DenseTensor* y, + DenseTensor* mean_out, + DenseTensor* variance_out, + DenseTensor* saved_mean, + DenseTensor* saved_variance, + DenseTensor* reserve_space) {} + +} // namespace phi + +PD_REGISTER_KERNEL(batch_norm, OneDNN, ONEDNN, phi::BatchNormKernel, float) {} From e5c43573807c847987a658afa226bfa043d225a8 Mon Sep 17 00:00:00 2001 From: Silv3S Date: Wed, 2 Nov 2022 12:55:36 +0100 Subject: [PATCH 2/6] bnorm --- .../mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc | 2 +- .../operators/mkldnn/batch_norm_mkldnn_op.cc | 114 ------------------ paddle/phi/backends/onednn/onednn_reuse.h | 90 ++++++++++++++ .../phi/kernels/onednn/batch_norm_kernel.cc | 85 +++++++++++-- 4 files changed, 163 insertions(+), 128 deletions(-) diff --git a/paddle/fluid/framework/ir/mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc index e51073385bcbb..bdb2bef362be4 100644 --- a/paddle/fluid/framework/ir/mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc @@ -32,7 +32,7 @@ PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(gelu, CPU, ALL_LAYOUT); USE_OP_ITSELF(batch_norm); -USE_OP_DEVICE_KERNEL(batch_norm, MKLDNN); +PD_DECLARE_KERNEL(batch_norm, OneDNN, ONEDNN); USE_OP_ITSELF(conv2d_transpose); USE_OP_DEVICE_KERNEL(conv2d_transpose, MKLDNN); USE_OP_ITSELF(elementwise_add); diff --git a/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc index 550be6b0d1abe..4144608de4b6d 100644 --- a/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc @@ -35,38 +35,6 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT< dnnl::batch_normalization_forward, dnnl::batch_normalization_backward> { public: - BatchNormMKLDNNHandler(const paddle::framework::ExecutionContext &ctx, - const dnnl::engine mkldnn_engine, - const Tensor *x, - const bool global_stats, - const bool test_mode) - : platform::MKLDNNHandlerNoCachingT( - mkldnn_engine, ctx.GetPlace()) { - const float epsilon = ctx.Attr("epsilon"); - const bool fuse_with_relu = ctx.HasAttr("fuse_with_relu") - ? ctx.Attr("fuse_with_relu") - : false; - - std::vector DataLayout_error_msg = { - "kNHWC", "kNCHW", "kAnyLayout", "kMKLDNN"}; - - // Flags are added by bitwise OR operation - auto flags = dnnl::normalization_flags::use_scale_shift; // 001 - if (global_stats) - flags |= dnnl::normalization_flags::use_global_stats; // 010 - if (fuse_with_relu && test_mode) - flags |= dnnl::normalization_flags::fuse_norm_relu; // 100 - - this->AcquireForwardPrimitiveDescriptor( - global_stats == true ? dnnl::prop_kind::forward_scoring - : dnnl::prop_kind::forward_training, - x->mem_desc(), - epsilon, - flags); - } - BatchNormMKLDNNHandler(const paddle::framework::ExecutionContext &ctx, const dnnl::engine mkldnn_engine, const Tensor *in_x, @@ -157,88 +125,6 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT< } }; -template -class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - auto &dev_ctx = ctx.template device_context(); - const auto &mkldnn_engine = dev_ctx.GetEngine(); - - const bool is_test = ctx.Attr("is_test"); - const bool use_global_stats = ctx.Attr("use_global_stats"); - const bool trainable_stats = ctx.Attr("trainable_statistics"); - const bool test_mode = is_test && (!trainable_stats); - const bool global_stats = test_mode || use_global_stats; - - const auto *x = ctx.Input("X"); - const auto *scale = ctx.Input("Scale"); - const auto *shift = ctx.Input("Bias"); - - auto *y = ctx.Output("Y"); - auto *batch_mean = ctx.Output("SavedMean"); - auto *batch_variance = ctx.Output("SavedVariance"); - BatchNormMKLDNNHandler handler( - ctx, mkldnn_engine, x, global_stats, test_mode); - - auto src_memory = handler.AcquireSrcMemory(x); - auto scaleshift_memory = handler.AcquireScaleShiftMemory(scale, shift); - auto dst_memory = handler.AcquireDstMemory(y); - auto batch_norm_p = handler.AcquireForwardPrimitive(); - - std::shared_ptr mean_memory; - std::shared_ptr variance_memory; - - if (global_stats) { - // mean and variance are taken from input Tensor - const auto *mean = ctx.Input("Mean"); - const auto *variance = ctx.Input("Variance"); - - mean_memory = handler.AcquireMeanMemory(mean); - variance_memory = handler.AcquireVarianceMemory(variance); - } else { - // mean and variance are calculated and saved in output Tensor - mean_memory = handler.AcquireMeanMemory(batch_mean); - variance_memory = handler.AcquireVarianceMemory(batch_variance); - } - - y->set_mem_desc(dst_memory->get_desc()); - - auto &astream = platform::MKLDNNDeviceContext::tls().get_stream(); - batch_norm_p->execute(astream, - {{DNNL_ARG_SRC, *src_memory}, - {DNNL_ARG_SCALE_SHIFT, *scaleshift_memory}, - {DNNL_ARG_MEAN, *mean_memory}, - {DNNL_ARG_VARIANCE, *variance_memory}, - {DNNL_ARG_DST, *dst_memory}}); - astream.wait(); - - if (!global_stats) { - auto *mean_out = ctx.Output("MeanOut"); - auto *variance_out = ctx.Output("VarianceOut"); - const float momentum = ctx.Attr("momentum"); - - const unsigned int C = phi::vectorize(scale->dims())[0]; - - // mkldnn only compute stats for current batch - // so we need compute momentum stats via Eigen lib - EigenVectorArrayMap batch_mean_e( - batch_mean->mutable_data(ctx.GetPlace()), C); - EigenVectorArrayMap batch_variance_e( - batch_variance->mutable_data(ctx.GetPlace()), C); - - EigenVectorArrayMap running_mean_e( - mean_out->mutable_data(ctx.GetPlace()), C); - EigenVectorArrayMap running_variance_e( - variance_out->mutable_data(ctx.GetPlace()), C); - - running_mean_e = - running_mean_e * momentum + batch_mean_e * (1. - momentum); - running_variance_e = - running_variance_e * momentum + batch_variance_e * (1. - momentum); - } - } -}; - template class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel { public: diff --git a/paddle/phi/backends/onednn/onednn_reuse.h b/paddle/phi/backends/onednn/onednn_reuse.h index 4a28c4262f32d..8295b110754cc 100644 --- a/paddle/phi/backends/onednn/onednn_reuse.h +++ b/paddle/phi/backends/onednn/onednn_reuse.h @@ -1147,5 +1147,95 @@ class ClipOneDNNHandler } }; +template +class BatchNormOneDNNHandler + : public OneDNNHandlerNoCachingT { + public: + BatchNormOneDNNHandler(const dnnl::engine engine, + Place cpu_place, + const DenseTensor* x, + const float epsilon, + const float fuse_with_relu, + const bool global_stats, + const bool test_mode) + : OneDNNHandlerNoCachingT(engine, + cpu_place) { + // Flags are added by bitwise OR operation + auto flags = dnnl::normalization_flags::use_scale_shift; // 001 + if (global_stats) + flags |= dnnl::normalization_flags::use_global_stats; // 010 + if (fuse_with_relu && test_mode) + flags |= dnnl::normalization_flags::fuse_norm_relu; // 100 + + this->AcquireForwardPrimitiveDescriptor( + global_stats ? dnnl::prop_kind::forward_scoring + : dnnl::prop_kind::forward_training, + x->mem_desc(), + epsilon, + flags); + } + + std::shared_ptr AcquireScaleShiftMemory( + const DenseTensor* scale, const DenseTensor* shift) { + auto scale_tz = phi::vectorize(scale->dims()); + const unsigned int C = scale_tz[0]; + PADDLE_ENFORCE_EQ( + scale_tz.size(), + 1, + phi::errors::InvalidArgument( + "Dims of scale tensor must be 1, but received scale's size is %d", + scale_tz.size())); + + auto scaleshift_memory = + this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc()); + + // MKLDNN requires a single piece of memory for scale and shift/bias data + auto mem_data_handle = + reinterpret_cast(scaleshift_memory->get_data_handle()); + std::copy(scale->data(), scale->data() + C, mem_data_handle); + std::copy(shift->data(), shift->data() + C, mem_data_handle + C); + return scaleshift_memory; + } + + std::shared_ptr AcquireDiffScaleShiftMemory( + T* diff_scaleshift_data) { + return this->AcquireMemoryFromPrimitive(this->bwd_pd_->diff_weights_desc(), + diff_scaleshift_data); + } + + std::shared_ptr AcquireMeanMemory( + const phi::DenseTensor* mean) { + const T* mean_data = mean->data(); + return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(), + to_void_cast(mean_data)); + } + + std::shared_ptr AcquireMeanMemory(phi::DenseTensor* mean) { + T* mean_data = mean->mutable_data(this->place_, + this->fwd_pd_->mean_desc().get_size()); + return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(), + mean_data); + } + + std::shared_ptr AcquireVarianceMemory( + const phi::DenseTensor* variance) { + const T* variance_data = variance->data(); + return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(), + to_void_cast(variance_data)); + } + + std::shared_ptr AcquireVarianceMemory( + phi::DenseTensor* variance) { + T* variance_data = variance->mutable_data( + this->place_, this->fwd_pd_->variance_desc().get_size()); + return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(), + variance_data); + } +}; + } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/onednn/batch_norm_kernel.cc b/paddle/phi/kernels/onednn/batch_norm_kernel.cc index c01c5cf4cd2ac..ca68375132618 100644 --- a/paddle/phi/kernels/onednn/batch_norm_kernel.cc +++ b/paddle/phi/kernels/onednn/batch_norm_kernel.cc @@ -20,25 +20,84 @@ namespace phi { template -void BatchNormKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& scale, - const DenseTensor& bias, - const DenseTensor& mean, - const DenseTensor& variance, +void BatchNormKernel(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &scale, + const DenseTensor &bias, + const DenseTensor &mean, + const DenseTensor &variance, float momentum, float epsilon, - const std::string& data_layout, + const std::string &data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu, - DenseTensor* y, - DenseTensor* mean_out, - DenseTensor* variance_out, - DenseTensor* saved_mean, - DenseTensor* saved_variance, - DenseTensor* reserve_space) {} + DenseTensor *y, + DenseTensor *mean_out, + DenseTensor *variance_out, + DenseTensor *saved_mean, + DenseTensor *saved_variance, + DenseTensor *reserve_space) { + const bool test_mode = is_test && (!trainable_statistics); + const bool global_stats = test_mode || use_global_stats; + + funcs::BatchNormOneDNNHandler handler(dev_ctx.GetEngine(), + dev_ctx.GetPlace(), + &x, + epsilon, + global_stats, + fuse_with_relu, + test_mode); + + auto src_memory = handler.AcquireSrcMemory(&x); + auto scaleshift_memory = handler.AcquireScaleShiftMemory(&scale, &bias); + auto dst_memory = handler.AcquireDstMemory(y); + auto batch_norm_p = handler.AcquireForwardPrimitive(); + + std::shared_ptr mean_memory; + std::shared_ptr variance_memory; + + // mean and variance can be taken either from input or output Tensor + if (global_stats) { + mean_memory = handler.AcquireMeanMemory(&mean); + variance_memory = handler.AcquireVarianceMemory(&variance); + } else { + mean_memory = handler.AcquireMeanMemory(saved_mean); + variance_memory = handler.AcquireVarianceMemory(saved_variance); + } + + y->set_mem_desc(dst_memory->get_desc()); + + auto &astream = OneDNNContext::tls().get_stream(); + batch_norm_p->execute(astream, + {{DNNL_ARG_SRC, *src_memory}, + {DNNL_ARG_SCALE_SHIFT, *scaleshift_memory}, + {DNNL_ARG_MEAN, *mean_memory}, + {DNNL_ARG_VARIANCE, *variance_memory}, + {DNNL_ARG_DST, *dst_memory}}); + astream.wait(); + + if (!global_stats) { + const unsigned int C = phi::vectorize(scale.dims())[0]; + + // mkldnn only compute stats for current batch + // so we need compute momentum stats via Eigen lib + EigenVectorArrayMap batch_mean_e(dev_ctx.template Alloc(scale_mean), + C); + EigenVectorArrayMap batch_variance_e( + dev_ctx.template Alloc(saved_variance), C); + + EigenVectorArrayMap running_mean_e(dev_ctx.template Alloc(mean_out), + C); + EigenVectorArrayMap running_variance_e( + dev_ctx.template Alloc(variance_out), C); + + running_mean_e = running_mean_e * momentum + batch_mean_e * (1. - momentum); + running_variance_e = + running_variance_e * momentum + batch_variance_e * (1. - momentum); + } +} } // namespace phi From 0301a5364b9dc80436c3e693cb97c46a58a48212 Mon Sep 17 00:00:00 2001 From: Silv3S Date: Thu, 3 Nov 2022 12:14:11 +0100 Subject: [PATCH 3/6] method signature --- paddle/phi/kernels/onednn/batch_norm_kernel.cc | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/paddle/phi/kernels/onednn/batch_norm_kernel.cc b/paddle/phi/kernels/onednn/batch_norm_kernel.cc index ca68375132618..d0a0acc6d4531 100644 --- a/paddle/phi/kernels/onednn/batch_norm_kernel.cc +++ b/paddle/phi/kernels/onednn/batch_norm_kernel.cc @@ -16,23 +16,26 @@ #include "paddle/phi/backends/onednn/onednn_reuse.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" namespace phi { +template +using EigenVectorArrayMap = Eigen::Map>; + template void BatchNormKernel(const Context &dev_ctx, const DenseTensor &x, - const DenseTensor &scale, - const DenseTensor &bias, const DenseTensor &mean, const DenseTensor &variance, + const DenseTensor &scale, + const DenseTensor &bias, + bool is_test, float momentum, float epsilon, const std::string &data_layout, - bool is_test, bool use_global_stats, bool trainable_statistics, - bool fuse_with_relu, DenseTensor *y, DenseTensor *mean_out, DenseTensor *variance_out, @@ -41,6 +44,10 @@ void BatchNormKernel(const Context &dev_ctx, DenseTensor *reserve_space) { const bool test_mode = is_test && (!trainable_statistics); const bool global_stats = test_mode || use_global_stats; + const bool fuse_with_relu = + dev_ctx.HasDnnAttr("fuse_with_relu") + ? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("fuse_with_relu")) + : false; funcs::BatchNormOneDNNHandler handler(dev_ctx.GetEngine(), dev_ctx.GetPlace(), @@ -83,7 +90,7 @@ void BatchNormKernel(const Context &dev_ctx, // mkldnn only compute stats for current batch // so we need compute momentum stats via Eigen lib - EigenVectorArrayMap batch_mean_e(dev_ctx.template Alloc(scale_mean), + EigenVectorArrayMap batch_mean_e(dev_ctx.template Alloc(saved_mean), C); EigenVectorArrayMap batch_variance_e( dev_ctx.template Alloc(saved_variance), C); From 0552b5514c295c78737259824e7b1f45cbd7ac7a Mon Sep 17 00:00:00 2001 From: Silv3S Date: Thu, 3 Nov 2022 13:55:14 +0100 Subject: [PATCH 4/6] change order --- paddle/phi/backends/onednn/onednn_reuse.h | 2 +- paddle/phi/kernels/onednn/batch_norm_kernel.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/phi/backends/onednn/onednn_reuse.h b/paddle/phi/backends/onednn/onednn_reuse.h index 9647122ad2931..d766b6746f293 100644 --- a/paddle/phi/backends/onednn/onednn_reuse.h +++ b/paddle/phi/backends/onednn/onednn_reuse.h @@ -1164,7 +1164,7 @@ class BatchNormOneDNNHandler Place cpu_place, const DenseTensor* x, const float epsilon, - const float fuse_with_relu, + const bool fuse_with_relu, const bool global_stats, const bool test_mode) : OneDNNHandlerNoCachingT Date: Fri, 4 Nov 2022 15:55:57 +0100 Subject: [PATCH 5/6] bnorm --- .../phi/kernels/onednn/batch_norm_kernel.cc | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/paddle/phi/kernels/onednn/batch_norm_kernel.cc b/paddle/phi/kernels/onednn/batch_norm_kernel.cc index 557f12283cc21..665c89de57be0 100644 --- a/paddle/phi/kernels/onednn/batch_norm_kernel.cc +++ b/paddle/phi/kernels/onednn/batch_norm_kernel.cc @@ -16,6 +16,7 @@ #include "paddle/phi/backends/onednn/onednn_reuse.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/eigen/common.h" namespace phi { @@ -106,6 +107,47 @@ void BatchNormKernel(const Context &dev_ctx, } } +template +void BatchNormInferKernel(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &mean, + const DenseTensor &variance, + const DenseTensor &scale, + const DenseTensor &bias, + float momentum, + float epsilon, + const std::string &data_layout, + DenseTensor *y, + DenseTensor *mean_out, + DenseTensor *variance_out) { + // Since saved_mean and saved_variance are used regardless of whether + // they are in test mode, temporary variables need to be created here + // to be compatible + auto saved_mean = phi::EmptyLike(dev_ctx, *mean_out); + auto saved_variance = phi::EmptyLike(dev_ctx, *variance_out); + + BatchNormKernel(dev_ctx, + x, + mean, + variance, + scale, + bias, + /*is_test=*/true, + momentum, + epsilon, + data_layout, + /*use_global_stats=*/false, + /*trainable_statistics=*/false, + y, + mean_out, + variance_out, + &saved_mean, + &saved_variance, + /*reserve_space=*/nullptr); +} + } // namespace phi PD_REGISTER_KERNEL(batch_norm, OneDNN, ONEDNN, phi::BatchNormKernel, float) {} +PD_REGISTER_KERNEL( + batch_norm_infer, OneDNN, ONEDNN, phi::BatchNormInferKernel, float) {} From 61485a50a0b710635b9ca07f539e245c2b984980 Mon Sep 17 00:00:00 2001 From: Silv3S Date: Fri, 4 Nov 2022 19:52:29 +0100 Subject: [PATCH 6/6] removed unused args --- paddle/phi/kernels/onednn/batch_norm_kernel.cc | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/paddle/phi/kernels/onednn/batch_norm_kernel.cc b/paddle/phi/kernels/onednn/batch_norm_kernel.cc index 665c89de57be0..fd943f8455605 100644 --- a/paddle/phi/kernels/onednn/batch_norm_kernel.cc +++ b/paddle/phi/kernels/onednn/batch_norm_kernel.cc @@ -16,7 +16,6 @@ #include "paddle/phi/backends/onednn/onednn_reuse.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/eigen/common.h" namespace phi { @@ -120,12 +119,6 @@ void BatchNormInferKernel(const Context &dev_ctx, DenseTensor *y, DenseTensor *mean_out, DenseTensor *variance_out) { - // Since saved_mean and saved_variance are used regardless of whether - // they are in test mode, temporary variables need to be created here - // to be compatible - auto saved_mean = phi::EmptyLike(dev_ctx, *mean_out); - auto saved_variance = phi::EmptyLike(dev_ctx, *variance_out); - BatchNormKernel(dev_ctx, x, mean, @@ -141,8 +134,8 @@ void BatchNormInferKernel(const Context &dev_ctx, y, mean_out, variance_out, - &saved_mean, - &saved_variance, + /*saved_mean*/ nullptr, + /*saved_variance*/ nullptr, /*reserve_space=*/nullptr); }