Skip to content

Commit

Permalink
Merge pull request tensorflow#17004 from Intel-tensorflow/relu_bn_fix2
Browse files Browse the repository at this point in the history
MKL: cifar 10 divergance fix and batchnorm unit test fix
  • Loading branch information
tatianashp committed Mar 1, 2018
2 parents 0196b4d + 53f3b6b commit 0d662ba
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 35 deletions.
96 changes: 65 additions & 31 deletions tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1110,19 +1110,12 @@ class MklFusedBatchNormGradOp : public OpKernel {
return;
}

if (dnn_shape_src.IsMklTensor())
depth_ = dnn_shape_src.DimSize(MklDnnDims::Dim_C);
else
ExtractParams(context);

memory::format format_m;
if (dnn_shape_src.IsMklTensor()) {
if (dnn_shape_src.IsTensorInNCHWFormat())
format_m = memory::format::nchw;
else
format_m = memory::format::nhwc;
depth_ = dnn_shape_src.DimSize(MklDnnDims::Dim_C);
} else if (dnn_shape_diff_dst.IsMklTensor()) {
depth_ = dnn_shape_diff_dst.DimSize(MklDnnDims::Dim_C);
} else {
format_m = TFDataFormatToMklDnnDataFormat(tensor_format_);
ExtractParams(context);
}

MklDnnData<T> src(&cpu_engine);
Expand All @@ -1146,20 +1139,20 @@ class MklFusedBatchNormGradOp : public OpKernel {
diff_dst_dims =
TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(), tensor_format_);

// set src and diff_dst primitives
// set src and diff_dst primitives according to input layout
memory::desc src_md({}, memory::data_undef, memory::format_undef);
memory::desc diff_dst_md({}, memory::data_undef, memory::format_undef);
if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) {
if (dnn_shape_src.IsMklTensor()) {
src_md = dnn_shape_src.GetMklLayout();
diff_dst_md = src_md;
} else {
diff_dst_md = dnn_shape_diff_dst.GetMklLayout();
src_md = diff_dst_md;
}
if (dnn_shape_src.IsMklTensor()) {
src_md = dnn_shape_src.GetMklLayout();
} else {
src_md = memory::desc(src_dims, MklDnnType<T>(), format_m);
diff_dst_md = src_md;
src_md = memory::desc(src_dims, MklDnnType<T>(),
TFDataFormatToMklDnnDataFormat(tensor_format_));
}
if (dnn_shape_diff_dst.IsMklTensor()) {
diff_dst_md = dnn_shape_diff_dst.GetMklLayout();
} else {
diff_dst_md = memory::desc(diff_dst_dims, MklDnnType<T>(),
TFDataFormatToMklDnnDataFormat(tensor_format_));
}
src.SetUsrMem(src_md, &src_tensor);
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
Expand Down Expand Up @@ -1211,28 +1204,64 @@ class MklFusedBatchNormGradOp : public OpKernel {
// allocate diff_src tensor
MklDnnShape dnn_shape_diff_src;
TensorShape tf_shape_diff_src;
if (dnn_shape_src.IsMklTensor()) {

// MKL-DNN's BN primitive not provide API to fetch internal format
// set common_md as OpMem
// src and diff_dst will reorder to common_md
// diff_src will set as common_md
memory::desc common_md({}, memory::data_undef, memory::format_undef);
if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) {
if (dnn_shape_src.IsMklTensor()) {
common_md = dnn_shape_src.GetMklLayout();
} else {
common_md = dnn_shape_diff_dst.GetMklLayout();
}
} else {
common_md = memory::desc(src_dims, MklDnnType<T>(),
TFDataFormatToMklDnnDataFormat(tensor_format_));
}
// if any of src and diff_dst as mkl layout,
// then we set diff_src as mkl layout
if (dnn_shape_src.IsMklTensor() ||
dnn_shape_diff_dst.IsMklTensor()) {
dnn_shape_diff_src.SetMklTensor(true);
auto diff_src_pd = bnrm_fwd_pd.dst_primitive_desc();
// set diff_src's mkl layout as common_md
auto diff_src_pd = memory::primitive_desc(common_md, cpu_engine);
dnn_shape_diff_src.SetMklLayout(&diff_src_pd);
dnn_shape_diff_src.SetElemType(MklDnnType<T>());
dnn_shape_diff_src.SetTfLayout(dnn_shape_src.GetDimension(), src_dims,
format_m);
dnn_shape_diff_src.SetTfDimOrder(dnn_shape_src.GetDimension(),
tensor_format_);
if (dnn_shape_src.IsMklTensor()) {
dnn_shape_diff_src.SetTfLayout(
dnn_shape_src.GetDimension(),
src_dims,
dnn_shape_src.GetTfDataFormat());
dnn_shape_diff_src.SetTfDimOrder(
dnn_shape_src.GetDimension(),
tensor_format_);
} else {
dnn_shape_diff_src.SetTfLayout(
dnn_shape_diff_dst.GetDimension(),
src_dims,
dnn_shape_diff_dst.GetTfDataFormat());
dnn_shape_diff_src.SetTfDimOrder(
dnn_shape_diff_dst.GetDimension(),
tensor_format_);
}
tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T));
} else {
dnn_shape_diff_src.SetMklTensor(false);
// both src and diff_dst are TensorFlow layout,
// so it is OK to get TensorFlow shape.
tf_shape_diff_src = src_tensor.shape();
}
AllocateOutputSetMklShape(context, kDiffSrcIndex, &diff_src_tensor,
tf_shape_diff_src, dnn_shape_diff_src);

diff_src.SetUsrMem(src_md, diff_src_tensor);
// set diff_src
diff_src.SetUsrMem(common_md, diff_src_tensor);

prop_kind pk = prop_kind::backward;
auto bnrm_bwd_desc = batch_normalization_backward::desc(
pk, diff_src.GetUsrMemDesc(), src.GetUsrMemDesc(), epsilon_,
pk, common_md, common_md, epsilon_,
/* for inference, specify use_global_stats
1. on fwd prop, use mean and variance
provided as inputs
Expand All @@ -1245,11 +1274,16 @@ class MklFusedBatchNormGradOp : public OpKernel {
auto bnrm_bwd_pd = batch_normalization_backward::primitive_desc(
bnrm_bwd_desc, cpu_engine, bnrm_fwd_pd);

std::vector<primitive> net;
src.CheckReorderToOpMem(memory::primitive_desc(common_md,
cpu_engine), &net);
diff_dst.CheckReorderToOpMem(memory::primitive_desc(common_md,
cpu_engine), &net);

auto bnrm_bwd_op = batch_normalization_backward(
bnrm_bwd_pd, src.GetOpMem(), mean.GetOpMem(), variance.GetOpMem(),
diff_dst.GetOpMem(), weights_m, diff_src.GetOpMem(), diff_weights_m);

std::vector<primitive> net;
net.push_back(bnrm_bwd_op);
stream(stream::kind::eager).submit(net).wait();

Expand Down
20 changes: 16 additions & 4 deletions tensorflow/core/kernels/mkl_relu_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,11 @@ void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) {
mkl_context.MklCleanup();
}



#else // INTEL_MKL_ML


template <typename Device, typename T, algorithm alg_kind>
class MklReluOpBase : public OpKernel {
public:
Expand Down Expand Up @@ -579,17 +582,26 @@ class MklReluGradOpBase : public OpKernel {
// allocate diff_src tensor
MklDnnShape dnn_shape_diff_src;
TensorShape tf_shape_diff_src;
if (dnn_shape_src.IsMklTensor()) {
if (dnn_shape_src.IsMklTensor() ||
dnn_shape_diff_dst.IsMklTensor()) {
dnn_shape_diff_src.SetMklTensor(true);
auto diff_src_pd = relu_bwd_pd.diff_src_primitive_desc();
dnn_shape_diff_src.SetMklLayout(&diff_src_pd);
dnn_shape_diff_src.SetElemType(MklDnnType<T>());
dnn_shape_diff_src.SetTfLayout(dnn_shape_src.GetDimension(),
dnn_shape_src.GetSizesAsMklDnnDims(),
dnn_shape_src.GetTfDataFormat());
if (dnn_shape_src.IsMklTensor()) {
dnn_shape_diff_src.SetTfLayout(dnn_shape_src.GetDimension(),
dnn_shape_src.GetSizesAsMklDnnDims(),
dnn_shape_src.GetTfDataFormat());
} else {
dnn_shape_diff_src.SetTfLayout(dnn_shape_diff_dst.GetDimension(),
dnn_shape_diff_dst.GetSizesAsMklDnnDims(),
dnn_shape_diff_dst.GetTfDataFormat());
}
tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T));
} else {
dnn_shape_diff_src.SetMklTensor(false);
// both src and diff_dst are TensorFlow layout,
// so it is ok to get TensorFlow shape.
tf_shape_diff_src = src_tensor.shape();
}
AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor,
Expand Down

0 comments on commit 0d662ba

Please sign in to comment.