Skip to content

Commit

Permalink
[PHI] Migrate batch_norm (#47652)
Browse files Browse the repository at this point in the history
* init changes

* bnorm

* method signature

* change order

* bnorm

* removed unused args
  • Loading branch information
Silv3S committed Nov 7, 2022
1 parent 9db507f commit 2337e60
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 119 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(gelu, CPU, ALL_LAYOUT);

USE_OP_ITSELF(batch_norm);
USE_OP_DEVICE_KERNEL(batch_norm, MKLDNN);
PD_DECLARE_KERNEL(batch_norm, OneDNN, ONEDNN);
USE_OP_ITSELF(conv2d_transpose);
USE_OP_DEVICE_KERNEL(conv2d_transpose, MKLDNN);
USE_OP_ITSELF(elementwise_add);
Expand Down
118 changes: 0 additions & 118 deletions paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,38 +35,6 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT<
dnnl::batch_normalization_forward,
dnnl::batch_normalization_backward> {
public:
BatchNormMKLDNNHandler(const paddle::framework::ExecutionContext &ctx,
const dnnl::engine mkldnn_engine,
const Tensor *x,
const bool global_stats,
const bool test_mode)
: platform::MKLDNNHandlerNoCachingT<T,
dnnl::batch_normalization_forward,
dnnl::batch_normalization_backward>(
mkldnn_engine, ctx.GetPlace()) {
const float epsilon = ctx.Attr<float>("epsilon");
const bool fuse_with_relu = ctx.HasAttr("fuse_with_relu")
? ctx.Attr<bool>("fuse_with_relu")
: false;

std::vector<std::string> DataLayout_error_msg = {
"kNHWC", "kNCHW", "kAnyLayout", "kMKLDNN"};

// Flags are added by bitwise OR operation
auto flags = dnnl::normalization_flags::use_scale_shift; // 001
if (global_stats)
flags |= dnnl::normalization_flags::use_global_stats; // 010
if (fuse_with_relu && test_mode)
flags |= dnnl::normalization_flags::fuse_norm_relu; // 100

this->AcquireForwardPrimitiveDescriptor(
global_stats == true ? dnnl::prop_kind::forward_scoring
: dnnl::prop_kind::forward_training,
x->mem_desc(),
epsilon,
flags);
}

BatchNormMKLDNNHandler(const paddle::framework::ExecutionContext &ctx,
const dnnl::engine mkldnn_engine,
const Tensor *in_x,
Expand Down Expand Up @@ -157,88 +125,6 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT<
}
};

template <typename T>
class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &mkldnn_engine = dev_ctx.GetEngine();

const bool is_test = ctx.Attr<bool>("is_test");
const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const bool trainable_stats = ctx.Attr<bool>("trainable_statistics");
const bool test_mode = is_test && (!trainable_stats);
const bool global_stats = test_mode || use_global_stats;

const auto *x = ctx.Input<phi::DenseTensor>("X");
const auto *scale = ctx.Input<phi::DenseTensor>("Scale");
const auto *shift = ctx.Input<phi::DenseTensor>("Bias");

auto *y = ctx.Output<phi::DenseTensor>("Y");
auto *batch_mean = ctx.Output<phi::DenseTensor>("SavedMean");
auto *batch_variance = ctx.Output<phi::DenseTensor>("SavedVariance");
BatchNormMKLDNNHandler<T> handler(
ctx, mkldnn_engine, x, global_stats, test_mode);

auto src_memory = handler.AcquireSrcMemory(x);
auto scaleshift_memory = handler.AcquireScaleShiftMemory(scale, shift);
auto dst_memory = handler.AcquireDstMemory(y);
auto batch_norm_p = handler.AcquireForwardPrimitive();

std::shared_ptr<memory> mean_memory;
std::shared_ptr<memory> variance_memory;

if (global_stats) {
// mean and variance are taken from input Tensor
const auto *mean = ctx.Input<phi::DenseTensor>("Mean");
const auto *variance = ctx.Input<phi::DenseTensor>("Variance");

mean_memory = handler.AcquireMeanMemory(mean);
variance_memory = handler.AcquireVarianceMemory(variance);
} else {
// mean and variance are calculated and saved in output Tensor
mean_memory = handler.AcquireMeanMemory(batch_mean);
variance_memory = handler.AcquireVarianceMemory(batch_variance);
}

y->set_mem_desc(dst_memory->get_desc());

auto &astream = platform::MKLDNNDeviceContext::tls().get_stream();
batch_norm_p->execute(astream,
{{DNNL_ARG_SRC, *src_memory},
{DNNL_ARG_SCALE_SHIFT, *scaleshift_memory},
{DNNL_ARG_MEAN, *mean_memory},
{DNNL_ARG_VARIANCE, *variance_memory},
{DNNL_ARG_DST, *dst_memory}});
astream.wait();

if (!global_stats) {
auto *mean_out = ctx.Output<phi::DenseTensor>("MeanOut");
auto *variance_out = ctx.Output<phi::DenseTensor>("VarianceOut");
const float momentum = ctx.Attr<float>("momentum");

const unsigned int C = phi::vectorize(scale->dims())[0];

// mkldnn only compute stats for current batch
// so we need compute momentum stats via Eigen lib
EigenVectorArrayMap<T> batch_mean_e(
batch_mean->mutable_data<T>(ctx.GetPlace()), C);
EigenVectorArrayMap<T> batch_variance_e(
batch_variance->mutable_data<T>(ctx.GetPlace()), C);

EigenVectorArrayMap<T> running_mean_e(
mean_out->mutable_data<T>(ctx.GetPlace()), C);
EigenVectorArrayMap<T> running_variance_e(
variance_out->mutable_data<T>(ctx.GetPlace()), C);

running_mean_e =
running_mean_e * momentum + batch_mean_e * (1. - momentum);
running_variance_e =
running_variance_e * momentum + batch_variance_e * (1. - momentum);
}
}
};

template <typename T>
class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
public:
Expand Down Expand Up @@ -308,10 +194,6 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_KERNEL(batch_norm,
MKLDNN,
::paddle::platform::CPUPlace,
ops::BatchNormMKLDNNOpKernel<float>);
REGISTER_OP_KERNEL(batch_norm_grad,
MKLDNN,
::paddle::platform::CPUPlace,
Expand Down
90 changes: 90 additions & 0 deletions paddle/phi/backends/onednn/onednn_reuse.h
Original file line number Diff line number Diff line change
Expand Up @@ -1155,6 +1155,96 @@ class ClipOneDNNHandler
}
};

template <typename T>
class BatchNormOneDNNHandler
: public OneDNNHandlerNoCachingT<T,
dnnl::batch_normalization_forward,
dnnl::batch_normalization_backward> {
public:
BatchNormOneDNNHandler(const dnnl::engine engine,
Place cpu_place,
const DenseTensor* x,
const float epsilon,
const bool fuse_with_relu,
const bool global_stats,
const bool test_mode)
: OneDNNHandlerNoCachingT<T,
dnnl::batch_normalization_forward,
dnnl::batch_normalization_backward>(engine,
cpu_place) {
// Flags are added by bitwise OR operation
auto flags = dnnl::normalization_flags::use_scale_shift; // 001
if (global_stats)
flags |= dnnl::normalization_flags::use_global_stats; // 010
if (fuse_with_relu && test_mode)
flags |= dnnl::normalization_flags::fuse_norm_relu; // 100

this->AcquireForwardPrimitiveDescriptor(
global_stats ? dnnl::prop_kind::forward_scoring
: dnnl::prop_kind::forward_training,
x->mem_desc(),
epsilon,
flags);
}

std::shared_ptr<dnnl::memory> AcquireScaleShiftMemory(
const DenseTensor* scale, const DenseTensor* shift) {
auto scale_tz = phi::vectorize(scale->dims());
const unsigned int C = scale_tz[0];
PADDLE_ENFORCE_EQ(
scale_tz.size(),
1,
phi::errors::InvalidArgument(
"Dims of scale tensor must be 1, but received scale's size is %d",
scale_tz.size()));

auto scaleshift_memory =
this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc());

// MKLDNN requires a single piece of memory for scale and shift/bias data
auto mem_data_handle =
reinterpret_cast<T*>(scaleshift_memory->get_data_handle());
std::copy(scale->data<T>(), scale->data<T>() + C, mem_data_handle);
std::copy(shift->data<T>(), shift->data<T>() + C, mem_data_handle + C);
return scaleshift_memory;
}

std::shared_ptr<dnnl::memory> AcquireDiffScaleShiftMemory(
T* diff_scaleshift_data) {
return this->AcquireMemoryFromPrimitive(this->bwd_pd_->diff_weights_desc(),
diff_scaleshift_data);
}

std::shared_ptr<dnnl::memory> AcquireMeanMemory(
const phi::DenseTensor* mean) {
const T* mean_data = mean->data<T>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(),
to_void_cast<T>(mean_data));
}

std::shared_ptr<dnnl::memory> AcquireMeanMemory(phi::DenseTensor* mean) {
T* mean_data = mean->mutable_data<T>(this->place_,
this->fwd_pd_->mean_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(),
mean_data);
}

std::shared_ptr<dnnl::memory> AcquireVarianceMemory(
const phi::DenseTensor* variance) {
const T* variance_data = variance->data<T>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(),
to_void_cast<T>(variance_data));
}

std::shared_ptr<dnnl::memory> AcquireVarianceMemory(
phi::DenseTensor* variance) {
T* variance_data = variance->mutable_data<T>(
this->place_, this->fwd_pd_->variance_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(),
variance_data);
}
};

template <typename T>
class PoolingOneDNNHandler
: public OneDNNHandlerNoCachingT<T,
Expand Down
146 changes: 146 additions & 0 deletions paddle/phi/kernels/onednn/batch_norm_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/batch_norm_kernel.h"

#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"

namespace phi {

template <typename T>
using EigenVectorArrayMap = Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>>;

template <typename T, typename Context>
void BatchNormKernel(const Context &dev_ctx,
const DenseTensor &x,
const DenseTensor &mean,
const DenseTensor &variance,
const DenseTensor &scale,
const DenseTensor &bias,
bool is_test,
float momentum,
float epsilon,
const std::string &data_layout,
bool use_global_stats,
bool trainable_statistics,
DenseTensor *y,
DenseTensor *mean_out,
DenseTensor *variance_out,
DenseTensor *saved_mean,
DenseTensor *saved_variance,
DenseTensor *reserve_space) {
const bool test_mode = is_test && (!trainable_statistics);
const bool global_stats = test_mode || use_global_stats;
const bool fuse_with_relu =
dev_ctx.HasDnnAttr("fuse_with_relu")
? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("fuse_with_relu"))
: false;

funcs::BatchNormOneDNNHandler<T> handler(dev_ctx.GetEngine(),
dev_ctx.GetPlace(),
&x,
epsilon,
fuse_with_relu,
global_stats,
test_mode);

auto src_memory = handler.AcquireSrcMemory(&x);
auto scaleshift_memory = handler.AcquireScaleShiftMemory(&scale, &bias);
auto dst_memory = handler.AcquireDstMemory(y);
auto batch_norm_p = handler.AcquireForwardPrimitive();

std::shared_ptr<dnnl::memory> mean_memory;
std::shared_ptr<dnnl::memory> variance_memory;

// mean and variance can be taken either from input or output Tensor
if (global_stats) {
mean_memory = handler.AcquireMeanMemory(&mean);
variance_memory = handler.AcquireVarianceMemory(&variance);
} else {
mean_memory = handler.AcquireMeanMemory(saved_mean);
variance_memory = handler.AcquireVarianceMemory(saved_variance);
}

y->set_mem_desc(dst_memory->get_desc());

auto &astream = OneDNNContext::tls().get_stream();
batch_norm_p->execute(astream,
{{DNNL_ARG_SRC, *src_memory},
{DNNL_ARG_SCALE_SHIFT, *scaleshift_memory},
{DNNL_ARG_MEAN, *mean_memory},
{DNNL_ARG_VARIANCE, *variance_memory},
{DNNL_ARG_DST, *dst_memory}});
astream.wait();

if (!global_stats) {
const unsigned int C = phi::vectorize(scale.dims())[0];

// mkldnn only compute stats for current batch
// so we need compute momentum stats via Eigen lib
EigenVectorArrayMap<T> batch_mean_e(dev_ctx.template Alloc<T>(saved_mean),
C);
EigenVectorArrayMap<T> batch_variance_e(
dev_ctx.template Alloc<T>(saved_variance), C);

EigenVectorArrayMap<T> running_mean_e(dev_ctx.template Alloc<T>(mean_out),
C);
EigenVectorArrayMap<T> running_variance_e(
dev_ctx.template Alloc<T>(variance_out), C);

running_mean_e = running_mean_e * momentum + batch_mean_e * (1. - momentum);
running_variance_e =
running_variance_e * momentum + batch_variance_e * (1. - momentum);
}
}

template <typename T, typename Context>
void BatchNormInferKernel(const Context &dev_ctx,
const DenseTensor &x,
const DenseTensor &mean,
const DenseTensor &variance,
const DenseTensor &scale,
const DenseTensor &bias,
float momentum,
float epsilon,
const std::string &data_layout,
DenseTensor *y,
DenseTensor *mean_out,
DenseTensor *variance_out) {
BatchNormKernel<T, Context>(dev_ctx,
x,
mean,
variance,
scale,
bias,
/*is_test=*/true,
momentum,
epsilon,
data_layout,
/*use_global_stats=*/false,
/*trainable_statistics=*/false,
y,
mean_out,
variance_out,
/*saved_mean*/ nullptr,
/*saved_variance*/ nullptr,
/*reserve_space=*/nullptr);
}

} // namespace phi

PD_REGISTER_KERNEL(batch_norm, OneDNN, ONEDNN, phi::BatchNormKernel, float) {}
PD_REGISTER_KERNEL(
batch_norm_infer, OneDNN, ONEDNN, phi::BatchNormInferKernel, float) {}

0 comments on commit 2337e60

Please sign in to comment.