diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index d88a34dfd9bf7e..54b0a0e2c67ea4 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -6,6 +6,10 @@ package( load("//tensorflow:tensorflow.bzl", "tf_copts") load("//tensorflow:tensorflow.bzl", "tf_kernel_library") +load( + "//third_party/mkl:build_defs.bzl", + "if_mkl", +) tf_kernel_library( name = "xla_ops", @@ -151,8 +155,14 @@ tf_kernel_library( "//tensorflow/core/kernels:sparse_to_dense_op", "//tensorflow/core/kernels:stack_ops", "//tensorflow/core/kernels:training_ops", - "//tensorflow/core/kernels:transpose_op", - ], + ] + if_mkl( + [ + "//tensorflow/core/kernels:mkl_transpose_op", + ], + [ + "//tensorflow/core/kernels:transpose_op", + ], + ), ) tf_kernel_library( diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 10cbcdecc85928..092d3b494b9f42 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -654,7 +654,14 @@ cc_library( ":split_v_op", ":strided_slice_op", ":tile_ops", - ":transpose_op", + ] + if_mkl( + [ + ":mkl_transpose_op", + ], + [ + ":transpose_op", + ], + ) + [ ":unique_op", ":unpack_op", ":unravel_index_op", @@ -891,18 +898,27 @@ tf_kernel_library( deps = ARRAY_DEPS, ) -tf_kernel_library( - name = "transpose_op", - srcs = [ - "transpose_op.cc", - ] + if_mkl([ - "mkl_transpose_op.cc", - ]), - hdrs = ["transpose_op.h"], - deps = ARRAY_DEPS + if_mkl([ - "//third_party/mkl:intel_binary_blob", - "@mkl_dnn", - ]), +if_mkl( + [tf_mkl_kernel_library( + name = "mkl_transpose_op", + srcs = [ + "transpose_op.cc", + "mkl_transpose_op.cc", + ], + hdrs = ["transpose_op.h"], + deps = ARRAY_DEPS + if_mkl([ + "//third_party/mkl:intel_binary_blob", + "@mkl_dnn", + ]), + )], + [tf_kernel_library( + name = "transpose_op", + srcs = [ + "transpose_op.cc", + ], + hdrs = ["transpose_op.h"], + deps = ARRAY_DEPS, + )], ) tf_kernel_library( diff --git a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h index dc028c2f1e9b5b..22203e242a1161 100644 --- a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h +++ b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h @@ -113,10 +113,25 @@ struct GatherNdSlice { #endif generator::GatherNdSliceGenerator gather_nd_generator( slice_size, Tindices, Tparams, Tout, &error_loc); + +#ifdef INTEL_MKL + // Eigen implementation below is not highly performant. gather_nd_generator + // does not seem to be called in parallel, leading to very poor performance. + // Additionally, since it uses scalar (Tscratch) to invoke 'generate', it + // needs to go through redundant operations like 'reshape', 'broadcast' and + // 'sum'. OpenMP loop below essentially does same thing as Eigen code, but + // is considerably more efficient. + #pragma omp parallel for + for (Eigen::DenseIndex i = 0; i < batch_size; i++) { + const Eigen::array loc = i; + gather_nd_generator(loc); + } +#else Tscratch.device(d) = Tscratch.reshape(reshape_dims) .broadcast(broadcast_dims) .generate(gather_nd_generator) .sum(); +#endif // error_loc() returns -1 if there's no out-of-bounds index, // otherwise it returns the location of an OOB index in Tindices. diff --git a/tensorflow/core/kernels/mkl_avgpooling_op.cc b/tensorflow/core/kernels/mkl_avgpooling_op.cc index d545d34fdfd868..47c93d9d0cc00e 100644 --- a/tensorflow/core/kernels/mkl_avgpooling_op.cc +++ b/tensorflow/core/kernels/mkl_avgpooling_op.cc @@ -442,22 +442,21 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase { void Compute(OpKernelContext* context) override { try { - auto cpu_engine = engine(engine::cpu, 0); - const Tensor& input_tensor = - MklGetInput(context, this->kInputTensorIndexInput); + const Tensor& input_tensor = MklGetInput(context, + this->kInputTensorIndexInput); MklDnnShape dnn_shape_input; GetMklShape(context, this->kInputTensorIndexInput, &dnn_shape_input); this->SanityCheckInput(context, input_tensor, dnn_shape_input); if (!context->status().ok()) return; - MklDnnData dnn_data_input(&cpu_engine); - MklDnnData dnn_data_output(&cpu_engine); + MklDnnData dnn_data_input(&cpu_engine_); // initialize variables for the pooling op MklPoolParameters pool_params; // Get the input tensor and initialize the pooling parameters - this->ConfigureInput(context, dnn_shape_input, input_tensor, &pool_params, - &dnn_data_input); + TensorShape input_tensor_shape = input_tensor.shape(); + this->InitMklPoolParameters(context, &pool_params, + dnn_shape_input, input_tensor_shape); OP_REQUIRES_OK(context, context->status()); // Declare output tensor @@ -467,65 +466,62 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase { // If input is an empty tensor, allocate an empty output tensor and return if (input_tensor.NumElements() == 0) { - MklDnnShape output_mkl_shape; - output_mkl_shape.SetMklTensor(false); - TensorShape output_tf_shape; - if (pool_params.data_format == TensorFormat::FORMAT_NCHW) { - output_tf_shape = MklDnnDimsToTFShape(output_dims_mkl_order); - } else { - memory::dims output_dims_NHWC_order; - output_dims_NHWC_order = {pool_params.tensor_in_batch, - static_cast(pool_params.out_height), - static_cast(pool_params.out_width), - pool_params.out_depth}; - output_tf_shape = MklDnnDimsToTFShape(output_dims_NHWC_order); - } const int kOutputIndex = 0; - AllocateOutputSetMklShape(context, kOutputIndex, &output_tensor, - output_tf_shape, output_mkl_shape); - CHECK_NOTNULL(output_tensor); + this->AllocateEmptyOutputTensor(context, kOutputIndex, &pool_params, + output_dims_mkl_order, &output_tensor); return; } - // If input is in Mkl layout, then just get the memory format from it - // directly, instead of using input data_format to AvgPool. - if (dnn_shape_input.IsMklTensor()) { - dnn_data_output.SetUsrMem( - output_dims_mkl_order, - static_cast( - dnn_data_input.GetUsrMemDesc().data.format)); + memory::dims filter_dims, strides, padding_left, padding_right; + this->PoolParamsToDims(&pool_params, &filter_dims, &strides, + &padding_left, &padding_right); + + // Get the input memory descriptor + memory::desc input_md = dnn_shape_input.IsMklTensor() + ? dnn_shape_input.GetMklLayout() + : memory::desc(TFShapeToMklDnnDimsInNCHW(input_tensor_shape, + this->data_format_tf_), + MklDnnType(), this->data_format_mkldnn_); + + // Get src/filter/stride/padding information + memory::dims src_dims = dnn_shape_input.IsMklTensor() + ? dnn_shape_input.GetSizesAsMklDnnDims() + : TFShapeToMklDnnDimsInNCHW(input_tensor.shape(), + this->data_format_tf_); + + // Get an average pooling primitive from the op pool + MklPoolingFwdPrimitive *pooling_fwd = nullptr; + MklPoolingParams fwdParams(src_dims, output_dims_mkl_order, filter_dims, + strides, padding_left, padding_right, + algorithm::pooling_avg_exclude_padding); + pooling_fwd = MklPoolingFwdPrimitiveFactory::Get(fwdParams); + + // allocate output tensor + this->AllocateOutputTensor(context, *(pooling_fwd->GetPoolingFwdPd()), + output_dims_mkl_order, this->data_format_mkldnn_, &output_tensor); + CHECK_NOTNULL(output_tensor); + + OP_REQUIRES_OK(context, context->status()); + // check whether we need to reorder src + T* src_data = nullptr; + if (input_md.data.format != pooling_fwd->GetSrcMemoryFormat()) { + dnn_data_input.SetUsrMem(input_md, &input_tensor); + auto src_target_primitive_desc = memory::primitive_desc({{src_dims}, + MklDnnType(), pooling_fwd->GetSrcMemoryFormat()}, cpu_engine_); + dnn_data_input.CheckReorderToOpMem(src_target_primitive_desc); + src_data = static_cast( + dnn_data_input.GetOpMem().get_data_handle()); } else { - dnn_data_output.SetUsrMem(output_dims_mkl_order, - this->data_format_mkldnn_); + src_data = static_cast(const_cast( + input_tensor.flat().data())); } - // describe the memory layout - dnn_data_output.SetOpMemDesc(output_dims_mkl_order, memory::format::any); - - // 3. create a pooling primitive descriptor - auto pool_desc = pooling_forward::desc( - prop_kind::forward, algorithm::pooling_avg_exclude_padding, - dnn_data_input.GetUsrMemDesc(), dnn_data_output.GetUsrMemDesc(), - memory::dims({pool_params.row_stride, pool_params.col_stride}), - memory::dims({pool_params.window_rows, pool_params.window_cols}), - memory::dims({static_cast(pool_params.pad_top), - static_cast(pool_params.pad_left)}), - memory::dims({static_cast(pool_params.pad_bottom), - static_cast(pool_params.pad_right)}), - TFPaddingToMklDnnPadding(this->padding_)); - auto pool_prim_desc = - pooling_forward::primitive_desc(pool_desc, cpu_engine); - - this->AllocateOutputTensor(context, pool_prim_desc, output_dims_mkl_order, - this->data_format_mkldnn_, &output_tensor); - CHECK_NOTNULL(output_tensor); - - OP_REQUIRES_OK(context, context->status()); - dnn_data_output.SetUsrMemDataHandle(output_tensor); + T* dst_data = static_cast( + const_cast(output_tensor->flat().data())); - this->PrepareAndExecuteNet(pool_prim_desc, &dnn_data_input, - &dnn_data_output); + // execute pooling + pooling_fwd->Execute(src_data, dst_data); } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + ", message: " + string(e.message) + ", in file " + @@ -535,9 +531,10 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase { errors::Aborted("Operation received an exception:", error_msg)); } } // Compute -}; // MklAvgPoolingOp -//----------------------------------------------------------------------------- + private: + engine cpu_engine_ = engine(engine::cpu, 0); +}; // MklAvgPoolingOp template class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase { @@ -547,91 +544,78 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase { void Compute(OpKernelContext* context) override { try { - auto cpu_engine = engine(engine::cpu, 0); - MklDnnShape original_input_mkl_shape, input_gradient_mkl_shape; - const Tensor& tensor_in_shape = + const Tensor& orig_input_tensor = MklGetInput(context, kInputTensorIndexInputShape); - const Tensor& input_gradient_tensor = + const Tensor& grad_tensor = MklGetInput(context, kInputTensorIndexInputGradient); + + MklDnnShape orig_input_mkl_shape, grad_mkl_shape; GetMklShape(context, kInputTensorIndexInputShape, - &original_input_mkl_shape); + &orig_input_mkl_shape); GetMklShape(context, kInputTensorIndexInputGradient, - &input_gradient_mkl_shape); - - SanityCheckInputs(context, tensor_in_shape, input_gradient_tensor, - original_input_mkl_shape, input_gradient_mkl_shape); + &grad_mkl_shape); if (!context->status().ok()) return; // Used to allocate output_diff_src/diff_src - // and create pool_fwd mdm desc - // 0. Input("orig_input_shape: int32") //NOT a T Tensor! - // 1. Input("grad: T") - - MklDnnData input_gradient_diff_dst(&cpu_engine); - MklDnnData output_diff_src(&cpu_engine); - Tensor* output_tensor_diff_src = nullptr; - TensorShape original_input_shape; + MklDnnData grad_dnn_data(&cpu_engine_); MklPoolParameters pool_params; - memory::dims output_dims_mkl_order, original_input_dims_nchw; - // Configure the original input memory descriptor - memory::desc original_input_md = ConfigureOriginalInput( - context, tensor_in_shape, original_input_mkl_shape, - &original_input_dims_nchw, &pool_params, &original_input_shape); - - // configure the original output memory descriptor - // by definition, the shape of the original output is the same - // as the shape of the gradient diff_dst - memory::desc original_output_md = this->ConfigureOriginalOutput( - pool_params, input_gradient_mkl_shape, output_dims_mkl_order); - - memory::desc target_diff_dst_md = this->ConfigureInputGradient( - input_gradient_mkl_shape, input_gradient_tensor, - &input_gradient_diff_dst, original_output_md); - // The shape of the output diff src needs to be the same shape as the - // original input. But we will set its format to be same as the format of - // input gradient. We won't use format of original input since it will - // always be in Tensorflow layout (given that AvgPoolGrad gets shape of - // the input rather than actual input). - output_diff_src.SetUsrMem( - original_input_dims_nchw, - static_cast(target_diff_dst_md.data.format)); - - // Create the forward pooling primitive descriptor so we can reference it - // in the backward pooling primitive descriptor - auto pool_fwd_desc = pooling_forward::desc( - prop_kind::forward, algorithm::pooling_avg_exclude_padding, - original_input_md, original_output_md, - memory::dims({pool_params.row_stride, pool_params.col_stride}), - memory::dims({pool_params.window_rows, pool_params.window_cols}), - memory::dims({static_cast(pool_params.pad_top), - static_cast(pool_params.pad_left)}), - memory::dims({static_cast(pool_params.pad_bottom), - static_cast(pool_params.pad_right)}), - TFPaddingToMklDnnPadding(this->padding_)); - auto pool_fwd_prim_desc = - pooling_forward::primitive_desc(pool_fwd_desc, cpu_engine); - - auto pool_bkwd_desc = pooling_backward::desc( - algorithm::pooling_avg_exclude_padding, - output_diff_src.GetUsrMemDesc(), target_diff_dst_md, - memory::dims({pool_params.row_stride, pool_params.col_stride}), - memory::dims({pool_params.window_rows, pool_params.window_cols}), - memory::dims({static_cast(pool_params.pad_top), - static_cast(pool_params.pad_left)}), - memory::dims({static_cast(pool_params.pad_bottom), - static_cast(pool_params.pad_right)}), - TFPaddingToMklDnnPadding(this->padding_)); - auto pool_bkwd_prim_desc = pooling_backward::primitive_desc( - pool_bkwd_desc, cpu_engine, pool_fwd_prim_desc); - this->AllocateOutputTensor( - context, pool_bkwd_prim_desc, original_input_dims_nchw, - this->data_format_mkldnn_, &output_tensor_diff_src); - - output_diff_src.SetUsrMemDataHandle(output_tensor_diff_src); - - this->PrepareAndExecuteNet( - pool_bkwd_prim_desc, &input_gradient_diff_dst, &output_diff_src, - memory::primitive_desc(target_diff_dst_md, cpu_engine)); + auto shape_vec = orig_input_tensor.vec(); + TensorShape orig_input_shape; + for (int i = 0; i < orig_input_tensor.NumElements(); i++) { + orig_input_shape.AddDim(shape_vec(i)); + } + this->InitMklPoolParameters(context, &pool_params, orig_input_mkl_shape, + orig_input_shape); + + memory::dims filter_dims, strides, padding_left, padding_right; + this->PoolParamsToDims(&pool_params, &filter_dims, &strides, + &padding_left, &padding_right); + + memory::dims orig_input_dims_mkl_order = + orig_input_mkl_shape.IsMklTensor() + ? orig_input_mkl_shape.GetSizesAsMklDnnDims() + : TFShapeToMklDnnDimsInNCHW(orig_input_shape, this->data_format_tf_); + + memory::dims diff_dst_dims = grad_mkl_shape.IsMklTensor() + ? grad_mkl_shape.GetSizesAsMklDnnDims() + : TFShapeToMklDnnDimsInNCHW(grad_tensor.shape(), + this->data_format_tf_); + memory::dims output_dims_mkl_order; + this->GetOutputDims(pool_params, &output_dims_mkl_order); + + MklPoolingParams bwdParams(orig_input_dims_mkl_order, + output_dims_mkl_order, filter_dims, strides, + padding_left, padding_right, algorithm::pooling_avg_exclude_padding); + MklPoolingBwdPrimitive *pooling_bwd = + MklPoolingBwdPrimitiveFactory::Get(bwdParams); + + Tensor* output_tensor = nullptr; + this->AllocateOutputTensor(context, *(pooling_bwd->GetPoolingBwdPd()), + orig_input_dims_mkl_order, + this->data_format_mkldnn_, &output_tensor); + // get diff_dst memory::desc + memory::desc diff_dst_md = grad_mkl_shape.IsMklTensor() + ? grad_mkl_shape.GetMklLayout() + : memory::desc(diff_dst_dims, MklDnnType(), + this->data_format_mkldnn_); + // Check whether we need to reorder diff_dst + T* diff_dst_data = nullptr; + if (diff_dst_md.data.format != pooling_bwd->GetDiffDstFormat()) { + auto target_diff_dst = memory::primitive_desc({{diff_dst_dims}, + MklDnnType(), pooling_bwd->GetDiffDstFormat()}, cpu_engine_); + grad_dnn_data.SetUsrMem(diff_dst_md, &grad_tensor); + grad_dnn_data.CheckReorderToOpMem(target_diff_dst); + diff_dst_data = static_cast( + grad_dnn_data.GetOpMem().get_data_handle()); + } else { + diff_dst_data = static_cast(const_cast( + grad_tensor.flat().data())); + } + T* diff_src_data = static_cast( + const_cast(output_tensor->flat().data())); + + // execute pooling op + pooling_bwd->Execute(diff_dst_data, diff_src_data); } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + ", message: " + string(e.message) + ", in file " + @@ -639,33 +623,14 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase { OP_REQUIRES_OK(context, errors::Aborted("Compute received an exception:", error_msg)); } - } // Compute + } private: // 0. Input("orig_input_shape: int32") // 1. Input("grad: T") const int kInputTensorIndexInputShape = 0; const int kInputTensorIndexInputGradient = 1; - - memory::desc ConfigureOriginalInput( - OpKernelContext* context, const Tensor& tensor_original_input_shape, - const MklDnnShape& original_input_mkl_shape, - memory::dims* original_input_dims_mkl_order, - MklPoolParameters* pool_params, TensorShape* input_tensor_shape) { - CHECK_NOTNULL(original_input_dims_mkl_order); - CHECK_NOTNULL(pool_params); - CHECK_NOTNULL(input_tensor_shape); - // For AvgPoolGrad, we only get the size of the original input because - // The original data is irrelvant. - auto shape_vec = tensor_original_input_shape.vec(); - for (int64 i = 0; i < tensor_original_input_shape.NumElements(); ++i) { - input_tensor_shape->AddDim(shape_vec(i)); - } - - return MklPoolingBackwardOpBase::ConfigureOriginalInput( - context, tensor_original_input_shape, original_input_mkl_shape, - original_input_dims_mkl_order, pool_params, *input_tensor_shape); - } + engine cpu_engine_ = engine(engine::cpu, 0); void SanityCheckInputs(OpKernelContext* context, const Tensor& tensor_in_shape, diff --git a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc index 3fe660cf968b4e..44268fc305c31d 100644 --- a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc @@ -262,6 +262,7 @@ class MklFusedBatchNormOp : public OpKernel { } void MklCreateInputLayout(OpKernelContext* context) { + const Tensor& input = MklGetInput(context, 0); bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor(); if (input_in_mkl_format) { mkl_lt_input = @@ -544,6 +545,7 @@ class MklFusedBatchNormGradOp : public OpKernel { } void MklCreateInputLayout(OpKernelContext* context) { + const Tensor& input = MklGetInput(context, 0); bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor(); if (input_in_mkl_format) { mkl_lt_input = @@ -684,6 +686,466 @@ class MklFusedBatchNormGradOp : public OpKernel { #ifndef INTEL_MKL_ML +struct MklBatchNormFwdParams { + memory::dims src_dims; + int depth; + float eps; + bool training; + + MklBatchNormFwdParams(const memory::dims &src_dims, + int depth, float eps, bool training) : src_dims(src_dims), + depth(depth), eps(eps), training(training) { + } +}; + +template +class MklFusedBatchNormFwdPrimitive : public MklPrimitive { + public: + explicit MklFusedBatchNormFwdPrimitive( + const MklBatchNormFwdParams& fwdParams) : + cpu_engine_(engine::cpu, 0) { + context_.fwd_stream.reset( + new mkldnn::stream(mkldnn::stream::kind::eager)); + if (context_.bn_fwd == nullptr) + Setup(fwdParams); + } + + ~MklFusedBatchNormFwdPrimitive() {} + + // BatchNormalization forward execute + // src_data: input data buffer of src + // weights_data: input data buffer of weights + // dst_data: output data buffer of dst + // mean_data: input data buffer of means + // variance_data: input data buffer of variances + void Execute(const T* src_data, const T* weights_data, const T* dst_data, + const T* mean_data, const T* variance_data) { + context_.src_mem->set_data_handle( + static_cast(const_cast(src_data))); + context_.dst_mem->set_data_handle( + static_cast(const_cast(dst_data))); + + if (context_.flags & use_scale_shift) + context_.weights_mem->set_data_handle( + static_cast(const_cast(weights_data))); + + if ((context_.pkind == prop_kind::forward_training) || + (context_.flags & use_global_stats)) { + context_.mean_mem->set_data_handle( + static_cast(const_cast(mean_data))); + context_.variance_mem->set_data_handle( + static_cast(const_cast(variance_data))); + } + + // execution + context_.fwd_stream->submit(context_.fwd_primitives); + + context_.src_mem->set_data_handle(DummyData); + context_.dst_mem->set_data_handle(DummyData); + + if (context_.flags & use_scale_shift) + context_.weights_mem->set_data_handle(DummyData); + + if ((context_.pkind == prop_kind::forward_training) || + (context_.flags & use_global_stats)) { + context_.mean_mem->set_data_handle(DummyData); + context_.variance_mem->set_data_handle(DummyData); + } + } + + memory::primitive_desc GetDstPd() const { + return (*context_.dst_mem).get_primitive_desc(); + } + + mkldnn_memory_format_t GetSrcFmt() const { + return (*context_.src_mem).get_primitive_desc().desc().data.format; + } + + mkldnn_memory_format_t GetDstFmt() const { + return (*context_.dst_mem).get_primitive_desc().desc().data.format; + } + + private: + // Primitive reuse context for BatchNorm fwd op + struct BatchNormFwdContext { + // flags indict if it is training or inference mode + int64 flags; + + // algorithm + mkldnn::prop_kind pkind; + + // Mkldnn Memory + std::shared_ptr src_mem; + std::shared_ptr weights_mem; + std::shared_ptr dst_mem; + std::shared_ptr mean_mem; + std::shared_ptr variance_mem; + + // BatchNorm forward primitive + std::shared_ptr bn_fwd; + std::shared_ptr fwd_stream; + std::vector fwd_primitives; + + BatchNormFwdContext() : + flags(0), pkind(mkldnn::forward_training), src_mem(nullptr), + weights_mem(nullptr), dst_mem(nullptr), mean_mem(nullptr), + variance_mem(nullptr), bn_fwd(nullptr), fwd_stream(nullptr) { + } + }; + + void Setup(const MklBatchNormFwdParams& fwdParams) { + context_.flags = fwdParams.training ? use_scale_shift + : (use_scale_shift | use_global_stats); + context_.pkind = fwdParams.training ? prop_kind::forward_training + : prop_kind::forward_scoring; + + // memory desc + auto src_md = memory::desc({fwdParams.src_dims}, + MklDnnType(), get_desired_format(fwdParams.src_dims[1])); + + // fwd desc & primitive desc + auto fwd_desc = batch_normalization_forward::desc( + context_.pkind, src_md, fwdParams.eps, context_.flags); + auto fwd_pd = batch_normalization_forward::primitive_desc( + fwd_desc, cpu_engine_); + + // memory primitive + context_.src_mem.reset(new memory({src_md, cpu_engine_}, DummyData)); + context_.dst_mem.reset(new memory(fwd_pd.dst_primitive_desc(), DummyData)); + + if (context_.flags & use_scale_shift) { + auto weights_desc = memory::desc({2, fwdParams.depth}, + MklDnnType(), memory::format::nc); + context_.weights_mem.reset(new memory({weights_desc, cpu_engine_}, + DummyData)); + } + + if (fwdParams.training || (context_.flags & use_global_stats)) { + auto mean_desc = memory::desc({1, fwdParams.depth}, + MklDnnType(), memory::format::nc); + context_.mean_mem.reset(new memory({mean_desc, cpu_engine_}, DummyData)); + + auto variance_desc = memory::desc({1, fwdParams.depth}, + MklDnnType(), memory::nc); + context_.variance_mem.reset(new memory({variance_desc, cpu_engine_}, + DummyData)); + } + + // BatchNorm forward primitive + if (!fwdParams.training && !(context_.flags & use_global_stats)) { + if ((context_.flags & use_scale_shift) && mkldnn_use_scaleshift) { + context_.bn_fwd.reset(new batch_normalization_forward(fwd_pd, + *context_.src_mem, *context_.weights_mem, *context_.dst_mem)); + } else { + context_.bn_fwd.reset(new batch_normalization_forward( + fwd_pd, *context_.src_mem, *context_.dst_mem)); + } + } else if (context_.flags & use_global_stats) { + if ((context_.flags & use_scale_shift) && mkldnn_use_scaleshift) { + context_.bn_fwd.reset(new batch_normalization_forward( + fwd_pd, *context_.src_mem, (const primitive::at)*context_.mean_mem, + (const primitive::at)*context_.variance_mem, *context_.weights_mem, + *context_.dst_mem)); + } else { + context_.bn_fwd.reset(new batch_normalization_forward( + fwd_pd, *context_.src_mem, (const primitive::at)*context_.mean_mem, + (const primitive::at)*context_.variance_mem, *context_.dst_mem)); + } + } else { + if ((context_.flags & use_scale_shift) && mkldnn_use_scaleshift) { + context_.bn_fwd.reset(new batch_normalization_forward( + fwd_pd, *context_.src_mem, *context_.weights_mem, *context_.dst_mem, + *context_.mean_mem, *context_.variance_mem)); + } else { + context_.bn_fwd.reset(new batch_normalization_forward( + fwd_pd, *context_.src_mem, *context_.dst_mem, + *context_.mean_mem, *context_.variance_mem)); + } + } + + context_.fwd_primitives.push_back(*context_.bn_fwd); + } + + mkldnn::memory::desc get_desc_data(const mkldnn::memory &m) const { + return m.get_primitive_desc().desc().data; + } + + struct BatchNormFwdContext context_; + engine cpu_engine_; +}; + +template +class MklFusedBatchNormFwdPrimitiveFactory : public MklPrimitiveFactory { + public: + static MklFusedBatchNormFwdPrimitive* Get( + const MklBatchNormFwdParams& fwdParams) { + auto bn_fwd = static_cast*>( + MklFusedBatchNormFwdPrimitiveFactory + ::GetInstance().GetBatchNormFwd(fwdParams)); + + if (bn_fwd == nullptr) { + bn_fwd = new MklFusedBatchNormFwdPrimitive(fwdParams); + MklFusedBatchNormFwdPrimitiveFactory::GetInstance().SetBatchNormFwd( + fwdParams, bn_fwd); + } + return bn_fwd; + } + + static MklFusedBatchNormFwdPrimitiveFactory & GetInstance() { + static MklFusedBatchNormFwdPrimitiveFactory instance_; + return instance_; + } + + private: + MklFusedBatchNormFwdPrimitiveFactory() {} + ~MklFusedBatchNormFwdPrimitiveFactory() {} + + static std::string CreateKey(const MklBatchNormFwdParams& fwdParams) { + std::string prefix = "bn_fwd"; + FactoryKeyCreator key_creator; + key_creator.AddAsKey(prefix); + key_creator.AddAsKey(fwdParams.src_dims); + key_creator.AddAsKey(fwdParams.depth); + key_creator.AddAsKey(fwdParams.eps); + key_creator.AddAsKey(fwdParams.training); + return key_creator.GetKey(); + } + + MklPrimitive* GetBatchNormFwd(const MklBatchNormFwdParams& fwdParams) { + std::string key = CreateKey(fwdParams); + return this->GetOp(key); + } + + void SetBatchNormFwd(const MklBatchNormFwdParams& fwdParams, + MklPrimitive *op) { + std::string key = CreateKey(fwdParams); + this->SetOp(key, op); + } +}; + +struct MklBatchNormBwdParams { + memory::dims src_dims; + memory::dims diff_dst_dims; + int depth; + float eps; + bool training; + + MklBatchNormBwdParams(memory::dims src_dims, memory::dims diff_dst_dims, + int depth, float eps, bool training) : src_dims(src_dims), + diff_dst_dims(diff_dst_dims), depth(depth), eps(eps), + training(training) { + } +}; + + +template +class MklFusedBatchNormBwdPrimitive : public MklPrimitive { + public: + explicit MklFusedBatchNormBwdPrimitive( + const MklBatchNormBwdParams& bwdParams) : + cpu_engine_(engine::cpu, 0) { + context_.bwd_stream.reset( + new mkldnn::stream(mkldnn::stream::kind::eager)); + if (context_.bn_bwd == nullptr) + Setup(bwdParams); + } + + ~MklFusedBatchNormBwdPrimitive() {} + + // BatchNormalization backward execute + // src_data: input data buffer of src + // mean_data: input data buffer of mean + // variance_data: input data buffer of variance + // diff_dst_data: input data buffer of diff_dst + // weights_data: input data buffer of weights + // diff_src_data: output data buffer of diff_src + // diff_weights_data: output data buffer of diff_weights + void Execute(const T* src_data, const T* mean_data, const T* variance_data, + const T* diff_dst_data, const T* weights_data, + const T* diff_src_data, const T* diff_weights_data) { + context_.src_mem->set_data_handle( + static_cast(const_cast(src_data))); + context_.mean_mem->set_data_handle( + static_cast(const_cast(mean_data))); + context_.variance_mem->set_data_handle( + static_cast(const_cast(variance_data))); + context_.diff_dst_mem->set_data_handle( + static_cast(const_cast(diff_dst_data))); + + if (context_.flags & use_scale_shift) { + context_.weights_mem->set_data_handle( + static_cast(const_cast(weights_data))); + context_.diff_weights_mem->set_data_handle( + static_cast(const_cast(diff_weights_data))); + } + + context_.diff_src_mem->set_data_handle( + static_cast(const_cast(diff_src_data))); + + // execution + context_.bwd_stream->submit(context_.bwd_primitives); + + context_.src_mem->set_data_handle(DummyData); + context_.mean_mem->set_data_handle(DummyData); + context_.variance_mem->set_data_handle(DummyData); + context_.diff_dst_mem->set_data_handle(DummyData); + if (context_.flags & use_scale_shift) { + context_.weights_mem->set_data_handle(DummyData); + context_.diff_weights_mem->set_data_handle(DummyData); + } + context_.diff_src_mem->set_data_handle(DummyData); + } + + mkldnn_memory_format_t GetSrcFmt() { + return(*context_.src_mem).get_primitive_desc().desc().data.format; + } + + mkldnn_memory_format_t GetDiffDstFmt() { + return(*context_.diff_dst_mem).get_primitive_desc().desc().data.format; + } + + memory::primitive_desc GetDiffSrcPd() { + return(*context_.diff_src_mem).get_primitive_desc(); + } + + private: + struct BatchNormBwdContext { + // Flags to indicate whether it is training or inference + int64 flags; + + // MKLDNN memory + std::shared_ptr src_mem; + std::shared_ptr mean_mem; + std::shared_ptr variance_mem; + std::shared_ptr diff_dst_mem; + std::shared_ptr weights_mem; + std::shared_ptr diff_weights_mem; + std::shared_ptr diff_src_mem; + + // Batch Norm primitive + std::shared_ptr bn_bwd; + std::vector bwd_primitives; + std::shared_ptr bwd_stream; + + BatchNormBwdContext() : + src_mem(nullptr), mean_mem(nullptr), variance_mem(nullptr), + diff_dst_mem(nullptr), weights_mem(nullptr), diff_weights_mem(nullptr), + diff_src_mem(nullptr), bwd_stream(nullptr) { + } + }; + + void Setup(const MklBatchNormBwdParams& bwdParams) { + context_.flags = bwdParams.training ? use_scale_shift + : (use_scale_shift | use_global_stats); + + // memory desc + auto src_md = memory::desc({bwdParams.src_dims}, + MklDnnType(), get_desired_format(bwdParams.src_dims[1])); + auto diff_dst_md = memory::desc({bwdParams.diff_dst_dims}, + MklDnnType(), get_desired_format(bwdParams.diff_dst_dims[1])); + auto variance_desc = memory::desc({1, bwdParams.depth}, MklDnnType(), + memory::nc); + auto mean_desc = memory::desc({1, bwdParams.depth}, + MklDnnType(), memory::format::nc); + auto weights_desc = memory::desc({2, bwdParams.depth}, + MklDnnType(), memory::format::nc); + auto diff_weights_desc = weights_desc; + + // fwd desc & primitive desc + auto fwd_desc = batch_normalization_forward::desc( + prop_kind::forward_training, src_md, bwdParams.eps, + bwdParams.training + ? use_scale_shift + : (use_scale_shift | use_global_stats)); + auto fwd_pd = batch_normalization_forward::primitive_desc( + fwd_desc, cpu_engine_); + + // BatchNorm backward primtive + // + // For inference, specify use_global_stats + // 1. on fwd propagation, use mean and variance provided as inputs. + // 2. on bwd propagation, mean and variance are considered as constants. + // Thus, reduce the amount of MKL computation. + auto bwd_desc = batch_normalization_backward::desc( + prop_kind::backward, diff_dst_md, src_md, bwdParams.eps, + bwdParams.training ? use_scale_shift + : (use_scale_shift | use_global_stats)); + auto bn_bwd_pd = batch_normalization_backward::primitive_desc( + bwd_desc, cpu_engine_, fwd_pd); + + // memory primitive + context_.src_mem.reset(new memory({src_md, cpu_engine_}, DummyData)); + context_.diff_dst_mem.reset(new memory({diff_dst_md, cpu_engine_}, + DummyData)); + context_.variance_mem.reset(new memory({variance_desc, cpu_engine_}, + DummyData)); + context_.mean_mem.reset(new memory({mean_desc, cpu_engine_}, DummyData)); + context_.weights_mem.reset(new memory({weights_desc, cpu_engine_}, + DummyData)); + context_.diff_weights_mem.reset(new memory({diff_weights_desc, cpu_engine_}, + DummyData)); + context_.diff_src_mem.reset(new memory({src_md, cpu_engine_}, DummyData)); + + context_.bn_bwd.reset(new batch_normalization_backward( + bn_bwd_pd, *context_.src_mem, *context_.mean_mem, + *context_.variance_mem, *context_.diff_dst_mem, *context_.weights_mem, + *context_.diff_src_mem, *context_.diff_weights_mem)); + context_.bwd_primitives.push_back(*context_.bn_bwd); + } + + struct BatchNormBwdContext context_; + engine cpu_engine_; +}; + +template +class MklFusedBatchNormBwdPrimitiveFactory : public MklPrimitiveFactory { + public: + static MklFusedBatchNormBwdPrimitive* Get( + const MklBatchNormBwdParams& bwdParams) { + auto bn_bwd = static_cast*>( + MklFusedBatchNormBwdPrimitiveFactory + ::GetInstance().GetBatchNormBwd(bwdParams)); + if (bn_bwd == nullptr) { + bn_bwd = new MklFusedBatchNormBwdPrimitive(bwdParams); + MklFusedBatchNormBwdPrimitiveFactory::GetInstance().SetBatchNormBwd( + bwdParams, bn_bwd); + } + return bn_bwd; + } + + static MklFusedBatchNormBwdPrimitiveFactory& GetInstance() { + static MklFusedBatchNormBwdPrimitiveFactory instance_; + return instance_; + } + + private: + MklFusedBatchNormBwdPrimitiveFactory() {} + ~MklFusedBatchNormBwdPrimitiveFactory() {} + + static std::string CreateKey(const MklBatchNormBwdParams& bwdParams) { + std::string prefix = "bn_bwd"; + FactoryKeyCreator key_creator; + key_creator.AddAsKey(prefix); + key_creator.AddAsKey(bwdParams.src_dims); + key_creator.AddAsKey(bwdParams.diff_dst_dims); + key_creator.AddAsKey(bwdParams.depth); + key_creator.AddAsKey(bwdParams.eps); + key_creator.AddAsKey(bwdParams.training); + return key_creator.GetKey(); + } + + MklPrimitive* GetBatchNormBwd(const MklBatchNormBwdParams& bwdParams) { + std::string key = CreateKey(bwdParams); + return this->GetOp(key); + } + + void SetBatchNormBwd(const MklBatchNormBwdParams& bwdParams, + MklPrimitive* op) { + std::string key = CreateKey(bwdParams); + this->SetOp(key, op); + } +}; + template class MklFusedBatchNormOp : public OpKernel { public: @@ -701,7 +1163,6 @@ class MklFusedBatchNormOp : public OpKernel { void Compute(OpKernelContext* context) override { try { - auto cpu_engine = engine(engine::cpu, 0); const size_t kSrcIndex = 0; // index of src input tensor const size_t kScaleIndex = 1; // index of scale tensor const size_t kShiftIndex = 2; // index of shift tensor @@ -786,7 +1247,7 @@ class MklFusedBatchNormOp : public OpKernel { SetMeanVariance(est_mean_tensor, est_variance_tensor); MklDnnData src(&cpu_engine); - MklDnnData dst(&cpu_engine); + MklDnnData weights(&cpu_engine); memory::format format_m; if (dnn_shape_src.IsMklTensor()) { @@ -800,123 +1261,106 @@ class MklFusedBatchNormOp : public OpKernel { } // set src primitive - memory::dims src_dims; - if (dnn_shape_src.IsMklTensor()) { - src_dims = TFShapeToMklDnnDimsInNCHW(dnn_shape_src.GetTfShape(), - tensor_format_); - } else { - src_dims = - TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_); - } + memory::dims src_dims = dnn_shape_src.IsMklTensor() + ? dnn_shape_src.GetSizesAsMklDnnDims() + : TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_); auto src_md = dnn_shape_src.IsMklTensor() ? dnn_shape_src.GetMklLayout() : memory::desc(src_dims, MklDnnType(), format_m); - src.SetUsrMem(src_md, &src_tensor); - // set weights primitive // MKL-DNN packs scale & shift as "weights": // ...... - auto weights_desc = memory::desc({2, static_cast(depth_)}, - MklDnnType(), memory::format::nc); - auto weights_pd = memory::primitive_desc(weights_desc, cpu_engine); - auto weights_m = memory(weights_pd); - T* weights_data = reinterpret_cast(weights_m.get_data_handle()); + weights.AllocateBuffer(2 * depth_ * sizeof (T)); + T* weights_data = reinterpret_cast(weights.GetAllocatedBuffer()); T* scale_tf = reinterpret_cast(const_cast(scale_tensor.flat().data())); T* shift_tf = reinterpret_cast(const_cast(shift_tensor.flat().data())); - for (int k = 0; k < depth_; k++) { - weights_data[k] = scale_tf[k]; - weights_data[k + depth_] = shift_tf[k]; - } - - // set mean primitive - auto mean_desc = memory::desc({1, static_cast(depth_)}, - MklDnnType(), memory::format::nc); - auto mean_pd = memory::primitive_desc(mean_desc, cpu_engine); + std::memcpy(weights_data, scale_tf, depth_ * sizeof(T)); + std::memcpy(weights_data + depth_, shift_tf, depth_ * sizeof(T)); char* saved_mean_data_tf = reinterpret_cast(saved_mean_tensor->flat().data()); std::memcpy(saved_mean_data_tf, reinterpret_cast(mean_values_), depth_ * sizeof(T)); - auto mean_m = - memory(mean_pd, reinterpret_cast(saved_mean_data_tf)); - // set variance primitive - auto variance_desc = memory::desc({1, static_cast(depth_)}, - MklDnnType(), memory::format::nc); - auto variance_pd = memory::primitive_desc(variance_desc, cpu_engine); char* saved_variance_data_tf = reinterpret_cast(saved_variance_tensor->flat().data()); std::memcpy(saved_variance_data_tf, reinterpret_cast(variance_values_), depth_ * sizeof(T)); - auto variance_m = memory(variance_pd, saved_variance_data_tf); - - prop_kind pk = (is_training_) ? prop_kind::forward_training - : prop_kind::forward_scoring; - auto bnrm_fwd_desc = batch_normalization_forward::desc( - pk, src.GetUsrMemDesc(), epsilon_, - is_training_ ? use_scale_shift - : (use_scale_shift | use_global_stats)); - auto bnrm_fwd_pd = batch_normalization_forward::primitive_desc( - bnrm_fwd_desc, cpu_engine); - - // allocate dst tensor - MklDnnShape dnn_shape_dst; - TensorShape tf_shape_dst; - if (dnn_shape_src.IsMklTensor()) { - dnn_shape_dst.SetMklTensor(true); - auto dst_pd = bnrm_fwd_pd.dst_primitive_desc(); - dnn_shape_dst.SetMklLayout(&dst_pd); - dnn_shape_dst.SetElemType(MklDnnType()); - dnn_shape_dst.SetTfLayout(dnn_shape_src.GetDimension(), src_dims, - format_m); - tf_shape_dst.AddDim(dst_pd.get_size() / sizeof(T)); - } else { - dnn_shape_dst.SetMklTensor(false); - tf_shape_dst = src_tensor.shape(); - } - AllocateOutputSetMklShape(context, kDstIndex, &dst_tensor, tf_shape_dst, - dnn_shape_dst); - - // Output of batchnorm has same shape as input. - dst.SetUsrMem(src_md, dst_tensor); - primitive bnrm_fwd_op; - if (is_training_) { - bnrm_fwd_op = - batch_normalization_forward(bnrm_fwd_pd, src.GetOpMem(), weights_m, - dst.GetOpMem(), mean_m, variance_m); + // get batchnorm op from the pool + MklBatchNormFwdParams fwdParams(src_dims, depth_, epsilon_, is_training_); + MklFusedBatchNormFwdPrimitive *bn_fwd = + MklFusedBatchNormFwdPrimitiveFactory::Get(fwdParams); + + // check if reorder is needed for src, weights, mean, variance + T* src_data = nullptr; + if (src_md.data.format != bn_fwd->GetSrcFmt()) { + src.SetUsrMem(src_md, &src_tensor); + auto src_target = memory::primitive_desc({{src_dims}, MklDnnType(), + static_cast(bn_fwd->GetSrcFmt())}, cpu_engine); + src.CheckReorderToOpMem(src_target); + src_data = static_cast(src.GetOpMem().get_data_handle()); } else { - bnrm_fwd_op = batch_normalization_forward( - bnrm_fwd_pd, src.GetOpMem(), mean_m, variance_m, - (const primitive::at)weights_m, dst.GetOpMem()); + src_data = static_cast( + const_cast(src_tensor.flat().data())); } - std::vector net; - net.push_back(bnrm_fwd_op); - stream(stream::kind::eager).submit(net).wait(); + + // allocate output (dst) tensor; always set it as MKL-DNN layout + MklDnnShape dnn_shape_dst; + TensorShape tf_shape_dst; + dnn_shape_dst.SetMklTensor(true); + auto dst_pd = bn_fwd->GetDstPd(); + dnn_shape_dst.SetMklLayout(&dst_pd); + dnn_shape_dst.SetElemType(MklDnnType()); + auto ndims = dnn_shape_src.IsMklTensor() + ? dnn_shape_src.GetDimension() + : src_tensor.shape().dims(); + dnn_shape_dst.SetTfLayout(ndims, src_dims, format_m); + tf_shape_dst.AddDim(dst_pd.get_size() / sizeof(T)); + AllocateOutputSetMklShape(context, kDstIndex, &dst_tensor, + tf_shape_dst, dnn_shape_dst); + + T* weights_op_data = weights_data; + T* mean_op_data = reinterpret_cast( + saved_mean_tensor->flat().data()); + T* variance_op_data = reinterpret_cast( + saved_variance_tensor->flat().data()); + T* dst_data = static_cast(dst_tensor->flat().data()); + + // execution + bn_fwd->Execute(src_data, weights_op_data, dst_data, + mean_op_data, variance_op_data); // copy batch_mean data T* batch_mean_data_tf = reinterpret_cast(batch_mean_tensor->flat().data()); std::memcpy(reinterpret_cast(batch_mean_data_tf), - reinterpret_cast(mean_m.get_data_handle()), + reinterpret_cast(saved_mean_data_tf), depth_ * sizeof(T)); + // TODO(yli135): OpMem is same as usr mem since + // since its format is hard-coded as nc when primitive is created. // copy batch_variance data with Bessel's correction - // if training mode is on float adjust_factor = 1.0; if (is_training_) { size_t orig_size = src_dims[0] * src_dims[2] * src_dims[3]; size_t adjust_size = orig_size - 1; adjust_factor = (static_cast(orig_size)) / adjust_size; } - for (int k = 0; k < depth_; k++) - batch_variance_tensor->flat().data()[k] = - (reinterpret_cast(variance_m.get_data_handle()))[k] * - adjust_factor; + + auto variance_data = reinterpret_cast(saved_variance_data_tf); + auto batch_variance_data = batch_variance_tensor->flat().data(); + if (is_training_) { + for (int k = 0; k < depth_; k++) { + batch_variance_data[k] = variance_data[k] * adjust_factor; + } + } else { + std::memcpy(batch_variance_data, variance_data, depth_ * sizeof(T)); + } } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + ", message: " + string(e.message) + ", in file " + @@ -933,7 +1377,8 @@ class MklFusedBatchNormOp : public OpKernel { bool is_training_; T* mean_values_; T* variance_values_; - int depth_; // batch normalization is done for per channel. + size_t depth_; // batch normalization is done for per channel. + engine cpu_engine = engine(engine::cpu, 0); void ExtractParams(OpKernelContext* context) { const Tensor& input = MklGetInput(context, 0); @@ -990,8 +1435,10 @@ class MklFusedBatchNormOp : public OpKernel { tf_shape_scale, mkl_shape_batch_mean); CHECK_NOTNULL(*batch_mean_tensor); // set NAN mean value in case of empty input tensor - for (int k = 0; k < tf_shape_scale.num_elements(); k++) - (*batch_mean_tensor)->flat().data()[k] = NAN; + int num_elements = tf_shape_scale.num_elements(); + auto batch_mean_data = (*batch_mean_tensor)->flat().data(); + for (int k = 0; k < num_elements; k++) + batch_mean_data[k] = NAN; // allocate batch variance output tensor MklDnnShape mkl_shape_batch_variance; @@ -1001,8 +1448,9 @@ class MklFusedBatchNormOp : public OpKernel { mkl_shape_batch_variance); CHECK_NOTNULL(*batch_variance_tensor); // set NAN variance value in case of empty input tensor - for (int k = 0; k < tf_shape_scale.num_elements(); k++) - (*batch_variance_tensor)->flat().data()[k] = NAN; + auto batch_variance_data = (*batch_variance_tensor)->flat().data(); + for (int k = 0; k < num_elements; k++) + batch_variance_data[k] = NAN; // Mean and variance (without Bessel's correction) saved for backward // computation to serve as pre-computed mean and variance. @@ -1012,8 +1460,9 @@ class MklFusedBatchNormOp : public OpKernel { tf_shape_scale, mkl_shape_saved_mean); CHECK_NOTNULL(*saved_mean_tensor); // set NAN mean value in case of empty input tensor - for (int k = 0; k < tf_shape_scale.num_elements(); k++) - (*saved_mean_tensor)->flat().data()[k] = NAN; + auto saved_mean_data = (*saved_mean_tensor)->flat().data(); + for (int k = 0; k < num_elements; k++) + saved_mean_data[k] = NAN; MklDnnShape mkl_shape_saved_variance; mkl_shape_saved_variance.SetMklTensor(false); @@ -1022,8 +1471,9 @@ class MklFusedBatchNormOp : public OpKernel { mkl_shape_saved_variance); CHECK_NOTNULL(*saved_variance_tensor); // set NAN variance value in case of empty input tensor - for (int k = 0; k < tf_shape_scale.num_elements(); k++) - (*saved_variance_tensor)->flat().data()[k] = NAN; + auto saved_variance_data = (*saved_variance_tensor)->flat().data(); + for (int k = 0; k < num_elements; k++) + saved_variance_data[k] = NAN; } }; @@ -1044,12 +1494,12 @@ class MklFusedBatchNormGradOp : public OpKernel { void Compute(OpKernelContext* context) override { try { - auto cpu_engine = engine(engine::cpu, 0); const size_t kDiffDstIndex = 0; // index of diff_dst tensor const size_t kSrcIndex = 1; // index of src input tensor const size_t kScaleIndex = 2; // index of scale tensor const size_t kMeanIndex = 3; // index of saved_mean tensor const size_t kVarianceIndex = 4; // index of saved_variance tensor + const Tensor& diff_dst_tensor = MklGetInput(context, kDiffDstIndex); const Tensor& src_tensor = MklGetInput(context, kSrcIndex); const Tensor& scale_tensor = MklGetInput(context, kScaleIndex); @@ -1060,8 +1510,8 @@ class MklFusedBatchNormGradOp : public OpKernel { MklDnnShape dnn_shape_src, dnn_shape_diff_dst; GetMklShape(context, kSrcIndex, &dnn_shape_src); GetMklShape(context, kDiffDstIndex, &dnn_shape_diff_dst); - TensorShape tf_shape_src, tf_shape_diff_dst; + TensorShape tf_shape_src, tf_shape_diff_dst; if (dnn_shape_diff_dst.IsMklTensor()) { tf_shape_diff_dst = dnn_shape_diff_dst.GetTfShape(); OP_REQUIRES( @@ -1102,6 +1552,7 @@ class MklFusedBatchNormGradOp : public OpKernel { saved_variance_tensor.shape().DebugString())); Tensor* diff_src_tensor = nullptr; + // special case: input with 0 element and 0 batch size if (tf_shape_src.num_elements() == 0 || tf_shape_diff_dst.num_elements() == 0) { HandleEmptyInput(context, tf_shape_src, scale_tensor.shape(), @@ -1117,174 +1568,111 @@ class MklFusedBatchNormGradOp : public OpKernel { ExtractParams(context); } - MklDnnData src(&cpu_engine); - MklDnnData mean(&cpu_engine); - MklDnnData variance(&cpu_engine); - MklDnnData diff_dst(&cpu_engine); - MklDnnData diff_src(&cpu_engine); - - memory::dims src_dims, diff_dst_dims; - if (dnn_shape_src.IsMklTensor()) - src_dims = TFShapeToMklDnnDimsInNCHW(dnn_shape_src.GetTfShape(), - tensor_format_); - else - src_dims = - TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_); - - if (dnn_shape_diff_dst.IsMklTensor()) - diff_dst_dims = TFShapeToMklDnnDimsInNCHW( - dnn_shape_diff_dst.GetTfShape(), tensor_format_); - else - diff_dst_dims = - TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(), tensor_format_); - - // set src and diff_dst primitives according to input layout - memory::desc src_md({}, memory::data_undef, memory::format_undef); - memory::desc diff_dst_md({}, memory::data_undef, memory::format_undef); + memory::format format_m; if (dnn_shape_src.IsMklTensor()) { - src_md = dnn_shape_src.GetMklLayout(); - } else { - src_md = memory::desc(src_dims, MklDnnType(), - TFDataFormatToMklDnnDataFormat(tensor_format_)); - } - if (dnn_shape_diff_dst.IsMklTensor()) { - diff_dst_md = dnn_shape_diff_dst.GetMklLayout(); + if (dnn_shape_src.IsTensorInNCHWFormat()) + format_m = memory::format::nchw; + else + format_m = memory::format::nhwc; } else { - diff_dst_md = memory::desc(diff_dst_dims, MklDnnType(), - TFDataFormatToMklDnnDataFormat(tensor_format_)); + format_m = TFDataFormatToMklDnnDataFormat(tensor_format_); } - src.SetUsrMem(src_md, &src_tensor); - diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); - - // weights -- DNN packs scales/shifts as weights in order of - // scale, ..., scale, shift, ..., shift - auto weights_desc = - memory::desc({2, depth_}, MklDnnType(), memory::format::nc); - auto weights_pd = memory::primitive_desc(weights_desc, cpu_engine); - auto weights_m = memory(weights_pd); - T* weights_data = reinterpret_cast(weights_m.get_data_handle()); + + MklDnnData src(&cpu_engine); + MklDnnData diff_dst(&cpu_engine); + MklDnnData weights(&cpu_engine); + MklDnnData diff_weights(&cpu_engine); + + memory::dims src_dims = dnn_shape_src.IsMklTensor() + ? dnn_shape_src.GetSizesAsMklDnnDims() + : TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_); + memory::dims diff_dst_dims = dnn_shape_diff_dst.IsMklTensor() + ? dnn_shape_diff_dst.GetSizesAsMklDnnDims() + : TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(), tensor_format_); + + // set src and diff_dst primitive descriptors + memory::desc src_md = dnn_shape_src.IsMklTensor() + ? dnn_shape_src.GetMklLayout() + : memory::desc(src_dims, MklDnnType(), format_m); + memory::desc diff_dst_md = dnn_shape_diff_dst.IsMklTensor() + ? dnn_shape_diff_dst.GetMklLayout() + : memory::desc(diff_dst_dims, MklDnnType(), format_m); + + // weights -- MKL DNN packs scales/ shifts as weights in order + // of scale, ..., scale, shift, ...., shift + weights.AllocateBuffer(2 * depth_ * sizeof(T)); + T* weights_data_tf = reinterpret_cast(weights.GetAllocatedBuffer()); T* scale_tf = reinterpret_cast(const_cast(scale_tensor.flat().data())); for (int k = 0; k < depth_; k++) { - weights_data[k] = scale_tf[k]; - weights_data[k + depth_] = 0; + weights_data_tf[k] = scale_tf[k]; + weights_data_tf[k + depth_] = 0; + } + + diff_weights.AllocateBuffer(2 * depth_ * sizeof(T)); + + MklBatchNormBwdParams bwdParams(src_dims, diff_dst_dims, + depth_, epsilon_, is_training_); + MklFusedBatchNormBwdPrimitive *bn_bwd = + MklFusedBatchNormBwdPrimitiveFactory::Get(bwdParams); + + // check if src/diff_dst need to be reordered + T* src_data = nullptr; + if (src_md.data.format != bn_bwd->GetSrcFmt()) { + src.SetUsrMem(src_md, &src_tensor); + auto src_target = memory::primitive_desc({{src_dims}, MklDnnType(), + static_cast(bn_bwd->GetSrcFmt())}, cpu_engine); + src.CheckReorderToOpMem(src_target); + src_data = static_cast(src.GetOpMem().get_data_handle()); + } else { + src_data = static_cast(const_cast( + src_tensor.flat().data())); } - // set mean primitive - memory::dims mv_dims = GetMeanVarianceDims(); - mean.SetUsrMem(mv_dims, memory::format::nc, - const_cast(static_cast( - saved_mean_tensor.flat().data()))); - mean.SetOpMemDesc(mv_dims, memory::format::nc); - - // set variance primitive - variance.SetUsrMem(mv_dims, memory::format::nc, - const_cast(static_cast( - saved_variance_tensor.flat().data()))); - variance.SetOpMemDesc(mv_dims, memory::format::nc); - - // set diff_weight primitive - auto diff_weights_desc = - memory::desc({2, depth_}, MklDnnType(), memory::format::nc); - auto diff_weights_pd = - memory::primitive_desc(diff_weights_desc, cpu_engine); - auto diff_weights_m = memory(diff_weights_pd); - - auto bnrm_fwd_desc = batch_normalization_forward::desc( - prop_kind::forward_training, src.GetUsrMemDesc(), epsilon_, - is_training_ ? use_scale_shift - : (use_scale_shift | use_global_stats)); - auto bnrm_fwd_pd = batch_normalization_forward::primitive_desc( - bnrm_fwd_desc, cpu_engine); + T* diff_dst_data = nullptr; + if (diff_dst_md.data.format != bn_bwd->GetDiffDstFmt()) { + diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); + auto diff_dst_target = memory::primitive_desc({{diff_dst_dims}, + MklDnnType(), static_cast( + bn_bwd->GetDiffDstFmt())}, cpu_engine); + diff_dst.CheckReorderToOpMem(diff_dst_target); + diff_dst_data = static_cast( + diff_dst.GetOpMem().get_data_handle()); + } else { + diff_dst_data = static_cast(const_cast( + diff_dst_tensor.flat().data())); + } // Indices of output tensors const size_t kDiffSrcIndex = 0; // index of diff_src tensor - // allocate diff_src tensor + // allocate diff_src tensor, always set as MKL-DNN layout MklDnnShape dnn_shape_diff_src; TensorShape tf_shape_diff_src; - - // MKL-DNN's BN primitive not provide API to fetch internal format - // set common_md as OpMem - // src and diff_dst will reorder to common_md - // diff_src will set as common_md - memory::desc common_md({}, memory::data_undef, memory::format_undef); - if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) { - if (dnn_shape_src.IsMklTensor()) { - common_md = dnn_shape_src.GetMklLayout(); - } else { - common_md = dnn_shape_diff_dst.GetMklLayout(); - } - } else { - common_md = memory::desc(src_dims, MklDnnType(), - TFDataFormatToMklDnnDataFormat(tensor_format_)); - } - // if any of src and diff_dst as mkl layout, - // then we set diff_src as mkl layout - if (dnn_shape_src.IsMklTensor() || - dnn_shape_diff_dst.IsMklTensor()) { - dnn_shape_diff_src.SetMklTensor(true); - // set diff_src's mkl layout as common_md - auto diff_src_pd = memory::primitive_desc(common_md, cpu_engine); - dnn_shape_diff_src.SetMklLayout(&diff_src_pd); - dnn_shape_diff_src.SetElemType(MklDnnType()); - if (dnn_shape_src.IsMklTensor()) { - dnn_shape_diff_src.SetTfLayout( - dnn_shape_src.GetDimension(), - src_dims, - dnn_shape_src.GetTfDataFormat()); - dnn_shape_diff_src.SetTfDimOrder( - dnn_shape_src.GetDimension(), - tensor_format_); - } else { - dnn_shape_diff_src.SetTfLayout( - dnn_shape_diff_dst.GetDimension(), - src_dims, - dnn_shape_diff_dst.GetTfDataFormat()); - dnn_shape_diff_src.SetTfDimOrder( - dnn_shape_diff_dst.GetDimension(), - tensor_format_); - } - tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T)); - } else { - dnn_shape_diff_src.SetMklTensor(false); - // both src and diff_dst are TensorFlow layout, - // so it is OK to get TensorFlow shape. - tf_shape_diff_src = src_tensor.shape(); - } + dnn_shape_diff_src.SetMklTensor(true); + auto diff_src_pd = bn_bwd->GetDiffSrcPd(); + dnn_shape_diff_src.SetMklLayout(&diff_src_pd); + dnn_shape_diff_src.SetElemType(MklDnnType()); + dnn_shape_diff_src.SetTfLayout(src_dims.size(), src_dims, + format_m); + dnn_shape_diff_src.SetTfDimOrder(src_dims.size(), tensor_format_); + tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T)); AllocateOutputSetMklShape(context, kDiffSrcIndex, &diff_src_tensor, tf_shape_diff_src, dnn_shape_diff_src); - // set diff_src - diff_src.SetUsrMem(common_md, diff_src_tensor); - - prop_kind pk = prop_kind::backward; - auto bnrm_bwd_desc = batch_normalization_backward::desc( - pk, common_md, common_md, epsilon_, - /* for inference, specify use_global_stats - 1. on fwd prop, use mean and variance - provided as inputs - 2. on bwd prop, mean and variance are - considered as constants. Thus, - reduce the amout of MKL computations - */ - is_training_ ? use_scale_shift - : (use_scale_shift | use_global_stats)); - auto bnrm_bwd_pd = batch_normalization_backward::primitive_desc( - bnrm_bwd_desc, cpu_engine, bnrm_fwd_pd); - - std::vector net; - src.CheckReorderToOpMem(memory::primitive_desc(common_md, - cpu_engine), &net); - diff_dst.CheckReorderToOpMem(memory::primitive_desc(common_md, - cpu_engine), &net); - - auto bnrm_bwd_op = batch_normalization_backward( - bnrm_bwd_pd, src.GetOpMem(), mean.GetOpMem(), variance.GetOpMem(), - diff_dst.GetOpMem(), weights_m, diff_src.GetOpMem(), diff_weights_m); - - net.push_back(bnrm_bwd_op); - stream(stream::kind::eager).submit(net).wait(); + + T* mean_data = static_cast(const_cast( + saved_mean_tensor.flat().data())); + T* variance_data = static_cast(const_cast( + saved_variance_tensor.flat().data())); + T* weights_data = weights_data_tf; + T* diff_src_data = static_cast( + diff_src_tensor->flat().data()); + T* diff_weights_data = static_cast( + diff_weights.GetAllocatedBuffer()); + // Execute + bn_bwd->Execute(src_data, mean_data, variance_data, diff_dst_data, + weights_data, diff_src_data, diff_weights_data); // allocate 4 output TF tensors Tensor* diff_scale_tensor = nullptr; @@ -1293,13 +1681,14 @@ class MklFusedBatchNormGradOp : public OpKernel { &diff_shift_tensor); // copy data: diff_scale and diff_shift - T* diff_weights_data_dnn = - reinterpret_cast(diff_weights_m.get_data_handle()); - for (int i = 0; i < depth_; i++) { - diff_scale_tensor->flat().data()[i] = diff_weights_data_dnn[i]; - diff_shift_tensor->flat().data()[i] = - diff_weights_data_dnn[i + depth_]; - } + auto diff_scale_data = diff_scale_tensor->flat().data(); + auto diff_shift_data = diff_shift_tensor->flat().data(); + std::memcpy(reinterpret_cast(diff_scale_data), + reinterpret_cast(diff_weights_data), + depth_ * sizeof(T)); + std::memcpy(reinterpret_cast(diff_shift_data), + reinterpret_cast(diff_weights_data + depth_), + depth_ * sizeof(T)); } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + ", message: " + string(e.message) + ", in file " + @@ -1315,6 +1704,7 @@ class MklFusedBatchNormGradOp : public OpKernel { TensorFormat tensor_format_; int depth_; // batch normalization is done for per channel. bool is_training_; + engine cpu_engine = engine(engine::cpu, 0); void ExtractParams(OpKernelContext* context) { const Tensor& input = MklGetInput(context, 0); @@ -1330,8 +1720,10 @@ class MklFusedBatchNormGradOp : public OpKernel { dnn_shape_diff_src.SetMklTensor(false); AllocateOutputSetMklShape(context, kDiffSrcIndex, diff_src_tensor, tf_shape_src, dnn_shape_diff_src); - for (size_t i = 0; i < (*diff_src_tensor)->shape().num_elements(); i++) - (*diff_src_tensor)->flat().data()[i] = 0; + int num_elements = (*diff_src_tensor)->shape().num_elements(); + auto diff_src_data = (*diff_src_tensor)->flat().data(); + for (size_t i = 0; i < num_elements; i++) + diff_src_data[i] = 0; Tensor* diff_scale_tensor = nullptr; Tensor* diff_shift_tensor = nullptr; @@ -1357,16 +1749,20 @@ class MklFusedBatchNormGradOp : public OpKernel { AllocateOutputSetMklShape(context, kDiffScaleIndex, diff_scale_tensor, tf_shape_scale_shift, mkl_shape_diff_scale); CHECK_NOTNULL(*diff_scale_tensor); - for (size_t i = 0; i < (*diff_scale_tensor)->shape().num_elements(); i++) - (*diff_scale_tensor)->flat().data()[i] = 0; + int diff_scale_num_elements = (*diff_scale_tensor)->shape().num_elements(); + auto diff_scale_data = (*diff_scale_tensor)->flat().data(); + for (size_t i = 0; i < diff_scale_num_elements; i++) + diff_scale_data[i] = 0; MklDnnShape mkl_shape_diff_shift; mkl_shape_diff_shift.SetMklTensor(false); AllocateOutputSetMklShape(context, kDiffShiftIndex, diff_shift_tensor, tf_shape_scale_shift, mkl_shape_diff_shift); CHECK_NOTNULL(*diff_shift_tensor); - for (size_t i = 0; i < (*diff_shift_tensor)->shape().num_elements(); i++) - (*diff_shift_tensor)->flat().data()[i] = 0; + int diff_shift_num_elements = (*diff_shift_tensor)->shape().num_elements(); + auto diff_shift_data = (*diff_shift_tensor)->flat().data(); + for (size_t i = 0; i < diff_shift_num_elements; i++) + diff_shift_data[i] = 0; // Placeholders for estimated_mean and estimated_variance, which are // used for inference and thus not needed here for gradient computation. diff --git a/tensorflow/core/kernels/mkl_maxpooling_op.cc b/tensorflow/core/kernels/mkl_maxpooling_op.cc index ea537524b11ef1..657f007a2e97ae 100644 --- a/tensorflow/core/kernels/mkl_maxpooling_op.cc +++ b/tensorflow/core/kernels/mkl_maxpooling_op.cc @@ -119,6 +119,7 @@ class MklMaxPoolingOp : public OpKernel { mkl_out_shape); Tensor* workspace_tensor; + void* workspace_buf = nullptr; TensorShape workspace_shape; mkl_workspace_shape.SetMklTensor(false); @@ -510,7 +511,6 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase { void Compute(OpKernelContext* context) override { try { - auto cpu_engine = engine(engine::cpu, 0); const Tensor& input_tensor = MklGetInput(context, this->kInputTensorIndexInput); MklDnnShape dnn_shape_input; @@ -525,8 +525,9 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase { // initialize variables for the pooling op MklPoolParameters pool_params; // Get the input tensor and initialize the pooling parameters - this->ConfigureInput(context, dnn_shape_input, input_tensor, &pool_params, - &dnn_data_input); + TensorShape input_tensor_shape = input_tensor.shape(); + this->InitMklPoolParameters(context, &pool_params, + dnn_shape_input, input_tensor_shape); OP_REQUIRES_OK(context, context->status()); // Declare output tensor @@ -534,44 +535,72 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase { memory::dims output_dims_mkl_order; this->GetOutputDims(pool_params, &output_dims_mkl_order); - // If input is in Mkl layout, then just get the memory format from it - // directly, instead of using input data_format to MaxPool. - if (dnn_shape_input.IsMklTensor()) { - dnn_data_output.SetUsrMem( - output_dims_mkl_order, - static_cast( - dnn_data_input.GetUsrMemDesc().data.format)); - } else { - dnn_data_output.SetUsrMem(output_dims_mkl_order, - this->data_format_mkldnn_); + // If input is an empty tensor, allocate an empty output tensor and return + if (input_tensor.NumElements() == 0) { + const int kOutputIndex = 0; + this->AllocateEmptyOutputTensor(context, kOutputIndex, &pool_params, + output_dims_mkl_order, &output_tensor); + return; } - // describe the memory layout; let mkl-dnn choose the best for the op - dnn_data_output.SetOpMemDesc(output_dims_mkl_order, memory::format::any); - - auto pool_desc = pooling_forward::desc( - prop_kind::forward, algorithm::pooling_max, - dnn_data_input.GetUsrMemDesc(), dnn_data_output.GetUsrMemDesc(), - memory::dims({pool_params.row_stride, pool_params.col_stride}), - memory::dims({pool_params.window_rows, pool_params.window_cols}), - memory::dims({static_cast(pool_params.pad_top), - static_cast(pool_params.pad_left)}), - memory::dims({static_cast(pool_params.pad_bottom), - static_cast(pool_params.pad_right)}), - TFPaddingToMklDnnPadding(this->padding_)); - auto pool_fwd_desc = - pooling_forward::primitive_desc(pool_desc, cpu_engine); - - this->AllocateOutputTensor(context, pool_fwd_desc, output_dims_mkl_order, - this->data_format_mkldnn_, &output_tensor); + // Get the input memory descriptor + memory::desc input_md = dnn_shape_input.IsMklTensor() + ? dnn_shape_input.GetMklLayout() + : memory::desc(TFShapeToMklDnnDimsInNCHW( + input_tensor_shape, + this->data_format_tf_), + MklDnnType(), + this->data_format_mkldnn_); + + // Get src/filter/stride/padding information + memory::dims src_dims = dnn_shape_input.IsMklTensor() + ? dnn_shape_input.GetSizesAsMklDnnDims() + : TFShapeToMklDnnDimsInNCHW(input_tensor.shape(), + this->data_format_tf_); + + memory::dims filter_dims, strides, padding_left, padding_right; + this->PoolParamsToDims(&pool_params, &filter_dims, &strides, + &padding_left, &padding_right); + + // Get a pooling op from the cached pool + MklPoolingFwdPrimitive *pooling_fwd = nullptr; + MklPoolingParams fwdParams(src_dims, output_dims_mkl_order, filter_dims, + strides, padding_left, padding_right, algorithm::pooling_max); + pooling_fwd = MklPoolingFwdPrimitiveFactory::Get(fwdParams); + + // allocate output tensor + this->AllocateOutputTensor(context, *(pooling_fwd->GetPoolingFwdPd()), + output_dims_mkl_order, this->data_format_mkldnn_, &output_tensor); OP_REQUIRES_OK(context, context->status()); - dnn_data_output.SetUsrMemDataHandle(output_tensor); + dnn_data_output.SetUsrMem(output_dims_mkl_order, + pooling_fwd->GetDstMemoryFormat(), output_tensor); - AllocateWorkspaceTensor(context, pool_fwd_desc, &dnn_data_wksp); + AllocateWorkspaceTensor(context, *(pooling_fwd->GetPoolingFwdPd()), + &dnn_data_wksp); OP_REQUIRES_OK(context, context->status()); - this->PrepareAndExecuteNet(pool_fwd_desc, &dnn_data_input, - &dnn_data_output, &dnn_data_wksp); + // check wehther we need to reorder src + T* src_data = nullptr; + if (input_md.data.format != pooling_fwd->GetSrcMemoryFormat()) { + dnn_data_input.SetUsrMem(input_md, &input_tensor); + auto src_target_primitive_desc = memory::primitive_desc( + {{src_dims}, MklDnnType(), pooling_fwd->GetSrcMemoryFormat()}, + cpu_engine); + dnn_data_input.CheckReorderToOpMem(src_target_primitive_desc); + src_data = static_cast( + dnn_data_input.GetOpMem().get_data_handle()); + } else { + src_data = static_cast(const_cast( + input_tensor.flat().data())); + } + + T* dst_data = static_cast( + const_cast(output_tensor->flat().data())); + T* ws_data = static_cast( + dnn_data_wksp.GetOpMem().get_data_handle()); + + // execute pooling op + pooling_fwd->Execute(src_data, dst_data, ws_data); } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + ", message: " + string(e.message) + ", in file " + @@ -579,10 +608,11 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase { OP_REQUIRES_OK(context, errors::Aborted("Compute received an exception:", error_msg)); } - } // Compute + } private: const int kOutputTensorIndexWorkspace = 1; + engine cpu_engine = engine(engine::cpu, 0); void AllocateWorkspaceTensor( OpKernelContext* context, @@ -616,98 +646,104 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase { public: explicit MklMaxPoolingGradOp(OpKernelConstruction* context) : MklPoolingBackwardOpBase(context) {} - void Compute(OpKernelContext* context) override { try { auto cpu_engine = engine(engine::cpu, 0); const Tensor& orig_input_tensor = MklGetInput(context, kInputTensorIndexOrigInput); - const Tensor& orig_output_tensor = - MklGetInput(context, kInputTensorIndexOrigOutput); const Tensor& grad_tensor = MklGetInput(context, kInputTensorIndexGradient); const Tensor& workspace_tensor = MklGetInput(context, kInputTensorIndexWorkspace); - MklDnnShape orig_input_mkl_shape, orig_output_mkl_shape, grad_mkl_shape, - workspace_mkl_shape; + MklDnnShape orig_input_mkl_shape, grad_mkl_shape; GetMklShape(context, kInputTensorIndexOrigInput, &orig_input_mkl_shape); - GetMklShape(context, kInputTensorIndexOrigOutput, &orig_output_mkl_shape); GetMklShape(context, kInputTensorIndexGradient, &grad_mkl_shape); - GetMklShape(context, kInputTensorIndexWorkspace, &workspace_mkl_shape); - - SanityCheckInputs(context, orig_input_tensor, orig_output_tensor, - grad_tensor, workspace_tensor, orig_input_mkl_shape, - orig_output_mkl_shape, grad_mkl_shape, - workspace_mkl_shape); if (!context->status().ok()) return; MklDnnData grad_dnn_data(&cpu_engine); MklDnnData workspace_dnn_data(&cpu_engine); - MklDnnData output_dnn_data(&cpu_engine); - Tensor* output_tensor = nullptr; + MklPoolParameters pool_params; - TensorShape orig_input_shape; - memory::dims output_dims_mkl_order, orig_input_dims_mkl_order; - memory::desc original_input_md = ConfigureOriginalInput( - context, orig_input_tensor, orig_input_mkl_shape, - &orig_input_dims_mkl_order, &pool_params, &orig_input_shape); - - memory::desc original_output_md = this->ConfigureOriginalOutput( - pool_params, orig_output_mkl_shape, output_dims_mkl_order); - - memory::desc target_diff_dst_md = this->ConfigureInputGradient( - grad_mkl_shape, grad_tensor, &grad_dnn_data, original_output_md); - - output_dnn_data.SetUsrMem(original_input_md); - - // Create the forward pooling primitive descriptor so we can - // pass it as a hint to the backward pooling primitive descriptor - auto pool_fwd_desc = pooling_forward::desc( - prop_kind::forward, algorithm::pooling_max, original_input_md, - original_output_md, - memory::dims({pool_params.row_stride, pool_params.col_stride}), - memory::dims({pool_params.window_rows, pool_params.window_cols}), - memory::dims({static_cast(pool_params.pad_top), - static_cast(pool_params.pad_left)}), - memory::dims({static_cast(pool_params.pad_bottom), - static_cast(pool_params.pad_right)}), - TFPaddingToMklDnnPadding(this->padding_)); - auto pool_fwd_prim_desc = - pooling_forward::primitive_desc(pool_fwd_desc, cpu_engine); - - auto pool_bkwd_desc = pooling_backward::desc( - algorithm::pooling_max, output_dnn_data.GetUsrMemDesc(), - target_diff_dst_md, - memory::dims({pool_params.row_stride, pool_params.col_stride}), - memory::dims({pool_params.window_rows, pool_params.window_cols}), - memory::dims({static_cast(pool_params.pad_top), - static_cast(pool_params.pad_left)}), - memory::dims({static_cast(pool_params.pad_bottom), - static_cast(pool_params.pad_right)}), - TFPaddingToMklDnnPadding(this->padding_)); - auto pool_bkwd_prim_desc = pooling_backward::primitive_desc( - pool_bkwd_desc, cpu_engine, pool_fwd_prim_desc); - - this->AllocateOutputTensor(context, pool_bkwd_prim_desc, + TensorShape orig_input_shape = orig_input_tensor.shape(); + this->InitMklPoolParameters(context, &pool_params, + orig_input_mkl_shape, orig_input_shape); + + memory::dims filter_dims, strides, padding_left, padding_right; + this->PoolParamsToDims(&pool_params, &filter_dims, &strides, + &padding_left, &padding_right); + + memory::dims diff_dst_dims = grad_mkl_shape.IsMklTensor() + ? grad_mkl_shape.GetSizesAsMklDnnDims() + : TFShapeToMklDnnDimsInNCHW(grad_tensor.shape(), + this->data_format_tf_); + memory::dims orig_input_dims_mkl_order = + orig_input_mkl_shape.IsMklTensor() + ? orig_input_mkl_shape.GetSizesAsMklDnnDims() + : TFShapeToMklDnnDimsInNCHW(orig_input_shape, + this->data_format_tf_); + + memory::dims output_dims_mkl_order; + this->GetOutputDims(pool_params, &output_dims_mkl_order); + + MklPoolingParams bwdParams(orig_input_dims_mkl_order, + output_dims_mkl_order, filter_dims, strides, + padding_left, padding_right, algorithm::pooling_max); + MklPoolingBwdPrimitive *pooling_bwd = + MklPoolingBwdPrimitiveFactory::Get(bwdParams); + + // allocate output tensor and memory primitive + Tensor* output_tensor = nullptr; + this->AllocateOutputTensor(context, *(pooling_bwd->GetPoolingBwdPd()), orig_input_dims_mkl_order, this->data_format_mkldnn_, &output_tensor); - output_dnn_data.SetUsrMemDataHandle(output_tensor); - - ConfigureWorkspace(workspace_tensor, - pool_fwd_prim_desc.workspace_primitive_desc(), - &workspace_dnn_data); - this->PrepareAndExecuteNet( - pool_bkwd_prim_desc, &grad_dnn_data, &output_dnn_data, - memory::primitive_desc(target_diff_dst_md, cpu_engine), - &workspace_dnn_data); - } catch (mkldnn::error& e) { - string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + ", in file " + - string(__FILE__) + ":" + std::to_string(__LINE__); - OP_REQUIRES_OK(context, errors::Aborted("Compute received an exception:", - error_msg)); + // get diff_dst mem desc + memory::desc diff_dst_md = grad_mkl_shape.IsMklTensor() + ? grad_mkl_shape.GetMklLayout() + : memory::desc(diff_dst_dims, MklDnnType(), + this->data_format_mkldnn_); + // check if diff_dst needs to be reordered + T* diff_dst_data = nullptr; + if (diff_dst_md.data.format != pooling_bwd->GetDiffDstFormat()) { + auto target_diff_dst = memory::primitive_desc({{diff_dst_dims}, + MklDnnType(), pooling_bwd->GetDiffDstFormat()}, cpu_engine); + grad_dnn_data.SetUsrMem(diff_dst_md, &grad_tensor); + grad_dnn_data.CheckReorderToOpMem(target_diff_dst); + diff_dst_data = static_cast( + grad_dnn_data.GetOpMem().get_data_handle()); + } else { + diff_dst_data = static_cast( + const_cast(grad_tensor.flat().data())); + } + void* ws_data = nullptr; + auto ws_md = + pooling_bwd->GetPoolingFwdPd()->workspace_primitive_desc().desc(); + if (ws_md.data.format != pooling_bwd->GetWorkspaceFormat()) { + memory::dims ws_dims; + ws_dims.assign(ws_md.data.dims, ws_md.data.dims + ws_md.data.ndims); + auto target_ws = memory::primitive_desc({{ws_dims}, + pooling_bwd->GetWorkspaceDataType(), + pooling_bwd->GetWorkspaceFormat()}, cpu_engine); + workspace_dnn_data.SetUsrMem(ws_md, &workspace_tensor); + workspace_dnn_data.CheckReorderToOpMem(target_ws); + ws_data = workspace_dnn_data.GetOpMem().get_data_handle(); + } else { + ws_data = static_cast(const_cast( + workspace_tensor.flat().data())); + } + + T* diff_src_data = static_cast( + const_cast(output_tensor->flat().data())); + + // execute pooling + pooling_bwd->Execute(diff_dst_data, diff_src_data, ws_data); + } catch (mkldnn::error &e) { + string error_msg = "Status:" + std::to_string(e.status) + + ", message: " + string(e.message) + ". in file " + + string(__FILE__) + ":" + std::to_string(__LINE__); + OP_REQUIRES_OK(context, errors::Aborted( + "Compute received an exception:", error_msg)); } - } // Compute + } private: // .Input("orig_input: T") @@ -718,18 +754,6 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase { const int kInputTensorIndexOrigOutput = 1; const int kInputTensorIndexGradient = 2; const int kInputTensorIndexWorkspace = 3; - // Output("output: T") in Base Class - - memory::desc ConfigureOriginalInput( - OpKernelContext* context, const Tensor& tensor_original_input, - const MklDnnShape& original_input_mkl_shape, - memory::dims* original_input_dims_mkl_order, - MklPoolParameters* pool_params, TensorShape* input_tensor_shape) { - *input_tensor_shape = tensor_original_input.shape(); - return MklPoolingBackwardOpBase::ConfigureOriginalInput( - context, tensor_original_input, original_input_mkl_shape, - original_input_dims_mkl_order, pool_params, *input_tensor_shape); - } void ConfigureWorkspace(const Tensor& workspace_tensor, memory::primitive_desc workspace_pd, diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.cc b/tensorflow/core/kernels/mkl_pooling_ops_common.cc index 5ef6ce2a578903..42652be44c64e1 100644 --- a/tensorflow/core/kernels/mkl_pooling_ops_common.cc +++ b/tensorflow/core/kernels/mkl_pooling_ops_common.cc @@ -24,12 +24,192 @@ limitations under the License. namespace tensorflow { +#ifndef INTEL_MKL_ML + +using mkldnn::pooling_max; +using mkldnn::pooling_avg; +using mkldnn::pooling_avg_include_padding; +using mkldnn::pooling_avg_exclude_padding; +using mkldnn::prop_kind; + +template +void MklPoolingFwdPrimitive::Setup(const MklPoolingParams& fwdParams) { + if (fwdParams.alg_kind != pooling_max && + fwdParams.alg_kind != pooling_avg && + fwdParams.alg_kind != pooling_avg_include_padding && + fwdParams.alg_kind != pooling_avg_exclude_padding) { + assert("Pooling algorithm kind is not supported\n"); + } + + context_.alg_kind = fwdParams.alg_kind; + // create memory desc + // FIXME: Pooling doesn't expose to get the src_primitive_desc, + // so src format is currently hard-coded. + // A utility function is used to do this, + // which may be broken with future CPU architectures + context_.src_md.reset(new memory::desc({fwdParams.src_dims}, + MklDnnType(), get_desired_format(fwdParams.src_dims[1]))); + context_.dst_md.reset(new memory::desc({fwdParams.dst_dims}, + MklDnnType(), memory::format::any)); + + // create a pooling descriptor + context_.fwd_desc.reset(new pooling_forward::desc(prop_kind::forward_training, + fwdParams.alg_kind, *context_.src_md, *context_.dst_md, fwdParams.strides, + fwdParams.filter_dims, fwdParams.padding_left, + fwdParams.padding_right, padding_kind::zero)); + context_.fwd_pd.reset( + new pooling_forward::primitive_desc(*context_.fwd_desc, cpu_engine_)); + + // store expected primitive format + context_.src_fmt = get_desired_format(fwdParams.src_dims[1]); + context_.dst_fmt = static_cast( + context_.fwd_pd.get()->dst_primitive_desc().desc().data.format); + + // create MKL-DNN internal memory object with dummy data + context_.src_mem.reset( + new memory({{{fwdParams.src_dims}, MklDnnType(), context_.src_fmt}, + cpu_engine_}, DummyData)); + context_.dst_mem.reset( + new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData)); + + // for max pooling, need to return workspace(ws) for backward computing + if (fwdParams.alg_kind == pooling_max) { + auto ws_pd = context_.fwd_pd.get()->workspace_primitive_desc().desc().data; + // store workspace's dims and format to create workspace tensor + context_.ws_fmt = static_cast(ws_pd.format); + context_.ws_dims.assign(ws_pd.dims, ws_pd.dims + ws_pd.ndims); + context_.ws_dt = static_cast(ws_pd.data_type); + context_.ws_size = + context_.fwd_pd.get()->workspace_primitive_desc().get_size(); + context_.ws_mem.reset( + new memory(context_.fwd_pd.get()->workspace_primitive_desc(), + DummyData)); + context_.fwd.reset(new pooling_forward(*context_.fwd_pd, + *context_.src_mem, *context_.dst_mem, *context_.ws_mem)); + } else { + context_.fwd.reset(new pooling_forward(*context_.fwd_pd, + *context_.src_mem, *context_.dst_mem)); + } + + context_.fwd_primitives.push_back(*context_.fwd); +} + +template +void MklPoolingFwdPrimitive::Execute(const T* src_data, const T* dst_data, + const void* ws_data) { + context_.src_mem->set_data_handle( + static_cast(const_cast(src_data))); + context_.dst_mem->set_data_handle( + static_cast(const_cast(dst_data))); + if (context_.alg_kind == pooling_max) { // max pooling must have ws + assert(ws_data != nullptr); + context_.ws_mem->set_data_handle(const_cast(ws_data)); + } + context_.fwd_stream->submit(context_.fwd_primitives); + + // set back data handle + context_.src_mem->set_data_handle(DummyData); + context_.dst_mem->set_data_handle(DummyData); + if (context_.alg_kind == pooling_max) { // max pooling must have ws + assert(ws_data != nullptr); + context_.ws_mem->set_data_handle(DummyData); + } +} + +template class MklPoolingFwdPrimitive; + +template +void MklPoolingBwdPrimitive::Setup(const MklPoolingParams& bwdParams) { + if (bwdParams.alg_kind != pooling_max && bwdParams.alg_kind != pooling_avg + && bwdParams.alg_kind != pooling_avg_include_padding + && bwdParams.alg_kind != pooling_avg_exclude_padding) { + assert("Pooling algorithm kind is not supported\n"); + } + context_.alg_kind = bwdParams.alg_kind; + + // Create memory desc + context_.diff_src_md.reset(new memory::desc({bwdParams.src_dims}, + MklDnnType(), memory::format::any)); + context_.diff_dst_md.reset(new memory::desc({bwdParams.dst_dims}, + MklDnnType(), get_desired_format(bwdParams.dst_dims[1]))); + context_.bwd_desc.reset(new pooling_backward::desc(bwdParams.alg_kind, + *context_.diff_src_md, *context_.diff_dst_md, bwdParams.strides, + bwdParams.filter_dims, bwdParams.padding_left, bwdParams.padding_right, + padding_kind::zero)); + + // create a forward primitive, + // which will be used as a hint for creating backward primitive + context_.fwd_desc.reset(new pooling_forward::desc(prop_kind::forward_training, + bwdParams.alg_kind, *context_.diff_src_md, *context_.diff_dst_md, + bwdParams.strides, bwdParams.filter_dims, bwdParams.padding_left, + bwdParams.padding_right, padding_kind::zero)); + context_.fwd_pd.reset( + new pooling_forward::primitive_desc(*context_.fwd_desc, cpu_engine)); + context_.bwd_pd.reset(new pooling_backward::primitive_desc( + *context_.bwd_desc, cpu_engine, *context_.fwd_pd)); + + // store expected primitive format + context_.diff_src_fmt = static_cast( + context_.bwd_pd.get()->diff_src_primitive_desc().desc().data.format); + context_.diff_dst_fmt = get_desired_format(bwdParams.dst_dims[1]); + + // create MKL-DNN internal memory object with dummy data + context_.diff_src_mem.reset( + new memory(context_.bwd_pd.get()->diff_src_primitive_desc(), DummyData)); + context_.diff_dst_mem.reset(new memory({{{bwdParams.dst_dims}, + MklDnnType(), context_.diff_dst_fmt}, cpu_engine}, DummyData)); + + // for max pooling, need to return workspace for backward + if (bwdParams.alg_kind == pooling_max) { + auto ws_pd = context_.fwd_pd.get()->workspace_primitive_desc().desc().data; + context_.ws_dims.assign(ws_pd.dims, ws_pd.dims + ws_pd.ndims); + context_.ws_fmt = get_desired_format(context_.ws_dims[1]); + context_.ws_dt = static_cast(ws_pd.data_type); + context_.ws_mem.reset(new memory({{{context_.ws_dims}, context_.ws_dt, + context_.ws_fmt}, cpu_engine}, DummyData)); + context_.bwd.reset(new pooling_backward( + *context_.bwd_pd, *context_.diff_dst_mem, *context_.ws_mem, + *context_.diff_src_mem)); + } else { + context_.bwd.reset(new pooling_backward(*context_.bwd_pd, + *context_.diff_dst_mem, *context_.diff_src_mem)); + } + context_.bwd_primitives.push_back(*context_.bwd); +} + +template +void MklPoolingBwdPrimitive::Execute(const T* diff_dst_data, + const T* diff_src_data, const void* ws_data) { + context_.diff_dst_mem->set_data_handle( + static_cast(const_cast(diff_dst_data))); + context_.diff_src_mem->set_data_handle( + static_cast(const_cast(diff_src_data))); + if (context_.alg_kind == pooling_max) { + assert(ws_data != nullptr); + context_.ws_mem->set_data_handle(const_cast(ws_data)); + } + + context_.bwd_stream->submit(context_.bwd_primitives); + // set back data handle + context_.diff_dst_mem->set_data_handle(DummyData); + context_.diff_src_mem->set_data_handle(DummyData); + if (context_.alg_kind == pooling_max) { + assert(ws_data != nullptr); + context_.ws_mem->set_data_handle(DummyData); + } +} + +template class MklPoolingBwdPrimitive; + +#endif + // Initialization for TensorFlow format -void MklPoolParameters::Init(OpKernelContext* context, - const std::vector& ksize, - const std::vector& stride, Padding padding, - TensorFormat data_format, - const TensorShape& tensor_in_shape) { +void MklPoolParameters::Init( + OpKernelContext* context, + const std::vector& ksize, + const std::vector& stride, Padding padding, + TensorFormat data_format, + const TensorShape& tensor_in_shape) { // For maxpooling, tensor_in should have 4 dimensions. OP_REQUIRES(context, tensor_in_shape.dims() == 4, errors::InvalidArgument("tensor_in must be 4-dimensional")); diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.h b/tensorflow/core/kernels/mkl_pooling_ops_common.h index c0dfed7d7d079c..84428ade56271f 100644 --- a/tensorflow/core/kernels/mkl_pooling_ops_common.h +++ b/tensorflow/core/kernels/mkl_pooling_ops_common.h @@ -19,6 +19,7 @@ limitations under the License. #ifdef INTEL_MKL #include #include +#include #include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/padding.h" @@ -32,6 +33,319 @@ using mkldnn::stream; namespace tensorflow { +#ifndef INTEL_MKL_ML + +using mkldnn::memory; +using mkldnn::pooling_max; +using mkldnn::pooling_avg; +using mkldnn::pooling_avg_include_padding; +using mkldnn::pooling_avg_exclude_padding; +using mkldnn::prop_kind; + +struct MklPoolingParams { + memory::dims src_dims; + memory::dims dst_dims; + memory::dims filter_dims; + memory::dims strides; + memory::dims padding_left; + memory::dims padding_right; + mkldnn::algorithm alg_kind; + + MklPoolingParams(memory::dims src_dims, + memory::dims dst_dims, memory::dims filter_dims, + memory::dims strides, memory::dims padding_left, + memory::dims padding_right, mkldnn::algorithm alg_kind) : + src_dims(src_dims), dst_dims(dst_dims), + filter_dims(filter_dims), strides(strides), + padding_left(padding_left), padding_right(padding_right), + alg_kind(alg_kind) { + } +}; + +template +class MklPoolingFwdPrimitive : public MklPrimitive { + public: + explicit MklPoolingFwdPrimitive(const MklPoolingParams& fwdParams) : + cpu_engine_(engine::cpu, 0) { + context_.fwd_stream.reset(new stream(stream::kind::eager)); + if (context_.fwd == nullptr) + Setup(fwdParams); + } + + ~MklPoolingFwdPrimitive() {} + + // Pooling forward execute + // src_data: input data buffer of src + // ws_data: input data buffer of workspace + // dst_data: output data buffer of dst + void Execute(const T* src_data, const T* dst_data, + const void* ws_data = nullptr); + + std::shared_ptr + GetPoolingFwdPd() const { + return context_.fwd_pd; + } + + memory::format GetSrcMemoryFormat() const { + return context_.src_fmt; + } + + memory::format GetDstMemoryFormat() const { + return context_.dst_fmt; + } + + private: + void Setup(const MklPoolingParams& fwdParams); + + struct PoolingFwdContext { + // algorithm + mkldnn::algorithm alg_kind; + + // expected memory format + memory::format src_fmt; + memory::format dst_fmt; + memory::format ws_fmt; + + // workspace shape + memory::dims ws_dims; + memory::data_type ws_dt; + size_t ws_size; + + // MKL-DNN memory, just dummy data + std::shared_ptr ws_mem; + std::shared_ptr src_mem; + std::shared_ptr dst_mem; + + // desc & primitive desc + std::shared_ptr fwd_desc; + std::shared_ptr fwd_pd; + + // memory desc + std::shared_ptr src_md; + std::shared_ptr dst_md; + + // Pooling primitive + std::shared_ptr fwd; + std::shared_ptr fwd_stream; + std::vector fwd_primitives; + + PoolingFwdContext() : + src_fmt(memory::format::any), dst_fmt(memory::format::any), + ws_fmt(memory::format::any), ws_mem(nullptr), src_mem(nullptr), + dst_mem(nullptr), fwd_desc(nullptr), fwd_pd(nullptr), src_md(nullptr), + dst_md(nullptr), fwd(nullptr), fwd_stream(nullptr) { + } + }; + + struct PoolingFwdContext context_; + engine cpu_engine_; +}; + +template +class MklPoolingFwdPrimitiveFactory : public MklPrimitiveFactory { + public: + static MklPoolingFwdPrimitive* Get(const MklPoolingParams& fwdParams) { + MklPoolingFwdPrimitive* pooling_forward = nullptr; + + // Get pooling primitive from the pool + pooling_forward = static_cast*>( + MklPoolingFwdPrimitiveFactory::GetInstance().GetPoolingFwd(fwdParams)); + + if (pooling_forward == nullptr) { + pooling_forward = new MklPoolingFwdPrimitive(fwdParams); + MklPoolingFwdPrimitiveFactory::GetInstance().SetPoolingFwd( + fwdParams, pooling_forward); + } + return pooling_forward; + } + + static MklPoolingFwdPrimitiveFactory& GetInstance() { + static MklPoolingFwdPrimitiveFactory instance_; + return instance_; + } + + private: + MklPoolingFwdPrimitiveFactory() {} + ~MklPoolingFwdPrimitiveFactory() {} + + // The key to be created will be used to get/set pooling + // primitive op from reuse perspective. + // A pooling key is a string which concates key parameters + // as well as algorithm kind (max versus avg). + static std::string CreateKey(const MklPoolingParams& fwdParams) { + std::string prefix = "pooling_fwd"; + FactoryKeyCreator key_creator; + key_creator.AddAsKey(prefix); + key_creator.AddAsKey(fwdParams.src_dims); + key_creator.AddAsKey(fwdParams.dst_dims); + key_creator.AddAsKey(fwdParams.filter_dims); + key_creator.AddAsKey(fwdParams.strides); + key_creator.AddAsKey(fwdParams.padding_left); + key_creator.AddAsKey(fwdParams.padding_right); + key_creator.AddAsKey(static_cast(fwdParams.alg_kind)); + return key_creator.GetKey(); + } + + MklPrimitive* GetPoolingFwd(const MklPoolingParams& fwdParams) { + std::string key = CreateKey(fwdParams); + return this->GetOp(key); + } + + void SetPoolingFwd(const MklPoolingParams& fwdParams, MklPrimitive *op) { + std::string key = CreateKey(fwdParams); + this->SetOp(key, op); + } +}; + + +template +class MklPoolingBwdPrimitive : public MklPrimitive { + public: + explicit MklPoolingBwdPrimitive(const MklPoolingParams& bwdParams) : + cpu_engine(engine::cpu, 0) { + context_.bwd_stream.reset(new stream(stream::kind::eager)); + if (context_.bwd == nullptr) + Setup(bwdParams); + } + + ~MklPoolingBwdPrimitive() {} + + // Pooling backward execute + // diff_dst_data: input data buffer of diff_dst + // diff_src_data: output data buffer of diff_src + // ws_data: input data buffer of workspace + void Execute(const T* diff_dst_data, const T* diff_src_data, + const void* ws_data = nullptr); + + public: + std::shared_ptr + GetPoolingFwdPd() const { + return context_.fwd_pd; + } + std::shared_ptr + GetPoolingBwdPd() const { + return context_.bwd_pd; + } + + memory::format GetDiffDstFormat() const { + return context_.diff_dst_fmt; + } + + mkldnn::memory::data_type GetWorkspaceDataType() const { + return context_.ws_dt; + } + memory::format GetWorkspaceFormat() const { + return context_.ws_fmt; + } + + private: + void Setup(const MklPoolingParams& bwdParams); + + // Primitive reuse context for pooling bwd ops + struct PoolingBwdContext { + // algorithm + mkldnn::algorithm alg_kind; + + // expected memory format + mkldnn::memory::format diff_src_fmt; + mkldnn::memory::format diff_dst_fmt; + mkldnn::memory::format ws_fmt; + + // workspace attribute + mkldnn::memory::dims ws_dims; + mkldnn::memory::data_type ws_dt; + + // MKL-DNN memory + std::shared_ptr ws_mem; + std::shared_ptr diff_src_mem; + std::shared_ptr diff_dst_mem; + + // memory desc + std::shared_ptr diff_src_md; + std::shared_ptr diff_dst_md; + + // desc & primitive desc + std::shared_ptr fwd_desc; + std::shared_ptr bwd_desc; + std::shared_ptr fwd_pd; + std::shared_ptr bwd_pd; + + // pooling primitive + std::shared_ptr bwd; + std::shared_ptr bwd_stream; + + std::vector bwd_primitives; + + PoolingBwdContext() : + diff_src_fmt(memory::format::any), diff_dst_fmt(memory::format::any), + ws_fmt(memory::format::any), ws_mem(nullptr), diff_src_mem(nullptr), + diff_dst_mem(nullptr), diff_src_md(nullptr), diff_dst_md(nullptr), + fwd_desc(nullptr), bwd_desc(nullptr), fwd_pd(nullptr), bwd_pd(nullptr), + bwd(nullptr), bwd_stream(nullptr) { + } + }; + + struct PoolingBwdContext context_; + engine cpu_engine; +}; + +template +class MklPoolingBwdPrimitiveFactory : public MklPrimitiveFactory { + public: + static MklPoolingBwdPrimitive *Get(const MklPoolingParams& bwdParams) { + MklPoolingBwdPrimitive* pooling_backward = nullptr; + + // Find a pooling backward primitive from the pool + // If it does not exist, create a new one + pooling_backward = static_cast*>( + MklPoolingBwdPrimitiveFactory::GetInstance().GetPoolingBwd(bwdParams)); + if (pooling_backward == nullptr) { + pooling_backward = new MklPoolingBwdPrimitive(bwdParams); + MklPoolingBwdPrimitiveFactory::GetInstance().SetPoolingBwd( + bwdParams, pooling_backward); + } + return pooling_backward; + } + + static MklPoolingBwdPrimitiveFactory& GetInstance() { + static MklPoolingBwdPrimitiveFactory instance_; + return instance_; + } + + private: + MklPoolingBwdPrimitiveFactory() {} + ~MklPoolingBwdPrimitiveFactory() {} + + // The key to be created will be used to get/set pooling + // primitive op from reuse perspective. + // A pooling key is a string which concates key parameters + // as well as algorithm kind (max versus avg). + static std::string CreateKey(const MklPoolingParams& bwdParams) { + std::string prefix = "pooling_bwd"; + FactoryKeyCreator key_creator; + key_creator.AddAsKey(prefix); + key_creator.AddAsKey(bwdParams.src_dims); + key_creator.AddAsKey(bwdParams.dst_dims); + key_creator.AddAsKey(bwdParams.filter_dims); + key_creator.AddAsKey(bwdParams.strides); + key_creator.AddAsKey(bwdParams.padding_left); + key_creator.AddAsKey(bwdParams.padding_right); + key_creator.AddAsKey(static_cast(bwdParams.alg_kind)); + return key_creator.GetKey(); + } + + MklPrimitive* GetPoolingBwd(const MklPoolingParams& bwdParams) { + std::string key = CreateKey(bwdParams); + return this->GetOp(key); + } + + void SetPoolingBwd(const MklPoolingParams& bwdParams, MklPrimitive *op) { + std::string key = CreateKey(bwdParams); + this->SetOp(key, op); + } +}; +#endif + typedef Eigen::ThreadPoolDevice CPUDevice; struct MklPoolParameters { @@ -163,6 +477,43 @@ class MklPoolingOpBase : public OpKernel { } } + void PoolParamsToDims(MklPoolParameters* pool_params, + memory::dims* filter_dims, + memory::dims* strides, + memory::dims* padding_left, + memory::dims* padding_right) { + *filter_dims = {pool_params->window_rows, pool_params->window_cols}; + *strides = {pool_params->row_stride, pool_params->col_stride}; + *padding_left = {static_cast(pool_params->pad_top), + static_cast(pool_params->pad_left)}; + *padding_right = {static_cast(pool_params->pad_bottom), + static_cast(pool_params->pad_right)}; + } + + void AllocateEmptyOutputTensor( + OpKernelContext* context, + const int kOutputIndex, + MklPoolParameters* pool_params, + const memory::dims output_dims_mkl_order, + Tensor** output_tensor) { + MklDnnShape output_mkl_shape; + output_mkl_shape.SetMklTensor(false); + TensorShape output_tf_shape; + if (pool_params->data_format == TensorFormat::FORMAT_NCHW) { + output_tf_shape = MklDnnDimsToTFShape(output_dims_mkl_order); + } else { + memory::dims output_dims_NHWC_order; + output_dims_NHWC_order = {pool_params->tensor_in_batch, + static_cast(pool_params->out_height), + static_cast(pool_params->out_width), + pool_params->out_depth}; + output_tf_shape = MklDnnDimsToTFShape(output_dims_NHWC_order); + } + AllocateOutputSetMklShape(context, kOutputIndex, output_tensor, + output_tf_shape, output_mkl_shape); + CHECK_NOTNULL(output_tensor); + } + // Checks to make sure that the memory we need to allocate // is a multiple of sizeof(T) // returns the number of elements @@ -235,23 +586,6 @@ class MklPoolingForwardOpBase : public MklPoolingOpBase { CHECK_NOTNULL(*output_tensor); } - void PrepareAndExecuteNet( - const pooling_forward::primitive_desc& pool_fwd_desc, - const MklDnnData* src, MklDnnData* dst, - MklDnnData* wksp = nullptr) { - std::vector net; - - // Create pooling primitive and add it to net - if (wksp != nullptr) { - net.push_back(pooling_forward(pool_fwd_desc, src->GetOpMem(), - dst->GetOpMem(), wksp->GetOpMem())); - } else { - net.push_back( - pooling_forward(pool_fwd_desc, src->GetOpMem(), dst->GetOpMem())); - } - stream(stream::kind::eager).submit(net).wait(); - } - void SanityCheckInput(OpKernelContext* context, const Tensor& input_tensor, const MklDnnShape& input_mkl_shape) { if (!input_mkl_shape.IsMklTensor()) { @@ -301,67 +635,6 @@ class MklPoolingBackwardOpBase : public MklPoolingOpBase { CHECK_NOTNULL(*output_tensor); } - void PrepareAndExecuteNet( - const pooling_backward::primitive_desc& pool_bkwd_desc, - MklDnnData* input_gradient_diff_dst, MklDnnData* output_diff_src, - const memory::primitive_desc& target_diff_dst_pd, - const MklDnnData* workspace = nullptr) { - std::vector net; - - // If the input gradient isn't in the same format as the output - // reorder it to the same format as the output - input_gradient_diff_dst->CheckReorderToOpMem(target_diff_dst_pd, &net); - - // Create pooling primitive and add it to net - if (nullptr == workspace) { - net.push_back(pooling_backward(pool_bkwd_desc, - input_gradient_diff_dst->GetOpMem(), - output_diff_src->GetOpMem())); - } else { - net.push_back( - pooling_backward(pool_bkwd_desc, input_gradient_diff_dst->GetOpMem(), - workspace->GetOpMem(), output_diff_src->GetOpMem())); - } - stream(stream::kind::eager).submit(net).wait(); - } - - // Max Pooling and Avg Pooling have slightly different implementations - // Takes the Tensor containing original input data and the original - // mkl Dnn Shape and populates other data - memory::desc ConfigureOriginalInput( - OpKernelContext* context, const Tensor& tensor_original_input_shape, - const MklDnnShape& original_input_mkl_shape, - memory::dims* original_input_dims_nchw, MklPoolParameters* pool_params, - const TensorShape& input_tensor_shape) { - CHECK_NOTNULL(original_input_dims_nchw); - CHECK_NOTNULL(pool_params); - this->InitMklPoolParameters(context, pool_params, original_input_mkl_shape, - input_tensor_shape); - - *original_input_dims_nchw = - original_input_mkl_shape.IsMklTensor() - ? original_input_mkl_shape.GetSizesAsMklDnnDims() - : TFShapeToMklDnnDimsInNCHW(input_tensor_shape, - this->data_format_tf_); - - return original_input_mkl_shape.IsMklTensor() - ? original_input_mkl_shape.GetMklLayout() - : memory::desc(*original_input_dims_nchw, MklDnnType(), - this->data_format_mkldnn_); - } - - memory::desc ConfigureOriginalOutput( - const MklPoolParameters& pool_params, - const MklDnnShape& original_output_mkl_shape, - memory::dims output_dims_mkl_order) { - this->GetOutputDims(pool_params, &output_dims_mkl_order); - - return original_output_mkl_shape.IsMklTensor() - ? original_output_mkl_shape.GetMklLayout() - : memory::desc(output_dims_mkl_order, MklDnnType(), - this->data_format_mkldnn_); - } - memory::desc ConfigureInputGradient( const MklDnnShape& input_gradient_mkl_shape, const Tensor& input_gradient_tensor, diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc index 78abbdb730bd3a..3d5a05be73b149 100644 --- a/tensorflow/core/kernels/mkl_relu_op.cc +++ b/tensorflow/core/kernels/mkl_relu_op.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" - #ifndef INTEL_MKL_ML #include "mkldnn.hpp" @@ -35,14 +34,411 @@ using mkldnn::prop_kind; using mkldnn::relu_backward; using mkldnn::relu_forward; using mkldnn::stream; +using mkldnn::memory; #else #include "mkl_dnn.h" #include "mkl_dnn_types.h" #endif +#include "tensorflow/core/platform/default/logging.h" #include "tensorflow/core/util/mkl_util.h" namespace tensorflow { +#ifndef INTEL_MKL_ML + +template +class MklEltwiseFwdParams { + public: + memory::dims src_dims; // check if this is needed + memory::desc src_md; + algorithm alg_kind; + T alpha; + T beta; + + MklEltwiseFwdParams(memory::dims src_dims, memory::desc src_md, + algorithm alg_kind, T alpha, T beta) : + src_dims(src_dims), src_md(src_md), + alg_kind(alg_kind), alpha(alpha), beta(beta) { + } +}; + +template +class MklEltwiseFwdPrimitive : public MklPrimitive { + public: + explicit MklEltwiseFwdPrimitive(const MklEltwiseFwdParams& fwdParams) : + cpu_engine_(engine::cpu, 0) { + // store expected format + context_.src_fmt = static_cast( + fwdParams.src_md.data.format); + context_.fwd_stream.reset(new stream(stream::kind::eager)); + + // create eltwise primitive + if (context_.eltwise_fwd == nullptr) { + Setup(fwdParams); + } + } + + ~MklEltwiseFwdPrimitive() {} + + // Eltwise forward execute + // src_data: input data buffer of src + // dst_data: output data buffer of dst + void Execute(T* src_data, T* dst_data) { + context_.src_mem->set_data_handle(static_cast(src_data)); + context_.dst_mem->set_data_handle(static_cast(dst_data)); + context_.fwd_stream->submit(context_.fwd_primitives); + + // after execution, set data handle back + context_.src_mem->set_data_handle(DummyData); + context_.dst_mem->set_data_handle(DummyData); + } + + std::shared_ptr GetEltwiseFwdPd() { + return context_.fwd_pd; + } + + memory::format GetSrcMemoryFormat() { + return context_.src_fmt; + } + + private: + // Primitive reuse context for eltwise Fwd ops: Relu, Elu, Tanh + struct EltwiseFwdContext { + // expected memory format for this primitive instance + mkldnn::memory::format src_fmt; + + // MKLDNN memory + std::shared_ptr src_mem; + std::shared_ptr dst_mem; + + // desc & prmitive desc + std::shared_ptr fwd_desc; + std::shared_ptr fwd_pd; + + // memory desc + std::shared_ptr src_md; + std::shared_ptr dst_md; + + // memory primitive desc + std::shared_ptr src_mpd; + + // Eltwise primitive + std::shared_ptr eltwise_fwd; + + std::shared_ptr fwd_stream; + std::vector fwd_primitives; + + EltwiseFwdContext() : + src_fmt(memory::format::any), src_mem(nullptr), dst_mem(nullptr), + fwd_desc(nullptr), fwd_pd(nullptr), src_md(nullptr), dst_md(nullptr), + src_mpd(nullptr), eltwise_fwd(nullptr), fwd_stream(nullptr) { + } + }; + + // Eltwise forward primitive setup + void Setup(const MklEltwiseFwdParams& fwdParams) { + // create memory descriptors for eltwise data with specified format + context_.src_md.reset(new memory::desc(fwdParams.src_md.data)); + context_.src_mpd.reset(new memory::primitive_desc( + *context_.src_md, cpu_engine_)); + + // create a eltwise + context_.fwd_desc.reset(new mkldnn::eltwise_forward::desc( + prop_kind::forward, fwdParams.alg_kind, *context_.src_md, + fwdParams.alpha, fwdParams.beta)); + context_.fwd_pd.reset(new mkldnn::eltwise_forward::primitive_desc( + *context_.fwd_desc, cpu_engine_)); + + // create memory primitive based on dummy data + context_.src_mem.reset(new memory(*context_.src_mpd, DummyData)); + context_.dst_mem.reset(new memory( + context_.fwd_pd.get()->dst_primitive_desc(), DummyData)); + + // create eltwise primitive and add it to net + context_.eltwise_fwd.reset(new mkldnn::eltwise_forward(*context_.fwd_pd, + *context_.src_mem, *context_.dst_mem)); + + context_.fwd_primitives.push_back(*context_.eltwise_fwd); + } + + struct EltwiseFwdContext context_; + engine cpu_engine_; +}; + +template +class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory { + public: + static MklEltwiseFwdPrimitive* Get( + const MklEltwiseFwdParams& fwdParams) { + MklEltwiseFwdPrimitive* eltwise_forward = nullptr; + + auto src_fmt = static_cast( + fwdParams.src_md.data.format); + + // Get a eltwise fwd primitive from the cached pool + eltwise_forward = static_cast*>( + MklEltwiseFwdPrimitiveFactory::GetInstance().GetEltwiseFwd( + fwdParams, src_fmt)); + if (eltwise_forward == nullptr) { + eltwise_forward = new MklEltwiseFwdPrimitive(fwdParams); + MklEltwiseFwdPrimitiveFactory::GetInstance().SetEltwiseFwd( + fwdParams, src_fmt, eltwise_forward); + } + return eltwise_forward; + } + + static MklEltwiseFwdPrimitiveFactory& GetInstance() { + static MklEltwiseFwdPrimitiveFactory instance_; + return instance_; + } + + private: + MklEltwiseFwdPrimitiveFactory() {} + ~MklEltwiseFwdPrimitiveFactory() {} + + static std::string CreateKey( + const MklEltwiseFwdParams& fwdParams, memory::format src_fmt) { + std::string prefix = "eltwise_fwd"; + FactoryKeyCreator key_creator; + key_creator.AddAsKey(prefix); + key_creator.AddAsKey(fwdParams.src_dims); + key_creator.AddAsKey(static_cast(fwdParams.alg_kind)); + key_creator.AddAsKey(static_cast(fwdParams.alpha)); + key_creator.AddAsKey(static_cast(fwdParams.beta)); + key_creator.AddAsKey(static_cast(src_fmt)); + return key_creator.GetKey(); + } + + MklPrimitive* GetEltwiseFwd(const MklEltwiseFwdParams& fwdParams, + memory::format src_fmt) { + std::string key = CreateKey(fwdParams, src_fmt); + return this->GetOp(key); + } + + void SetEltwiseFwd(const MklEltwiseFwdParams& fwdParams, + memory::format src_fmt, MklPrimitive* op) { + std::string key = CreateKey(fwdParams, src_fmt); + this->SetOp(key, op); + } +}; + +template +class MklEltwiseBwdParams { + public: + memory::dims src_dims; + memory::desc common_md; + algorithm alg_kind; + T alpha; + T beta; + + MklEltwiseBwdParams(const memory::dims &src_dims, + const memory::desc &common_md, + algorithm alg_kind, T alpha, T beta) : + src_dims(src_dims), common_md(common_md), + alg_kind(alg_kind), alpha(alpha), beta(beta) { + } +}; + +template +class MklEltwiseBwdPrimitive : public MklPrimitive { + public: + explicit MklEltwiseBwdPrimitive(const MklEltwiseBwdParams& bwdParams) : + cpu_engine_(engine::cpu, 0) { + context_.src_fmt = static_cast( + bwdParams.common_md.data.format); + context_.diff_dst_fmt = static_cast( + bwdParams.common_md.data.format); + context_.bwd_stream.reset(new stream(stream::kind::eager)); + // create eltwise primitive + if (context_.eltwise_bwd == nullptr) { + Setup(bwdParams); + } + } + + ~MklEltwiseBwdPrimitive() {} + + // Eltwise backward execute + // src_data: input data buffer of src + // diff_dst_data: input data buffer of diff_dst + // diff_src_data: output data buffer of diff_src + + void Execute(T* src_data, T* diff_dst_data, T* diff_src_data) { + context_.src_mem->set_data_handle(static_cast(src_data)); + context_.diff_dst_mem->set_data_handle(static_cast(diff_dst_data)); + context_.diff_src_mem->set_data_handle(static_cast(diff_src_data)); + context_.bwd_stream->submit(context_.bwd_primitives); + + // after execution, set data handle back + context_.src_mem->set_data_handle(DummyData); + context_.diff_dst_mem->set_data_handle(DummyData); + context_.diff_src_mem->set_data_handle(DummyData); + } + + std::shared_ptr GetEltwiseBwdPd() { + return context_.bwd_pd; + } + + memory::format GetSrcMemoryFormat() { + return context_.src_fmt; + } + + memory::format GetDiffDstMemoryFormat() { + return context_.diff_dst_fmt; + } + + private: + // Primitive reuse context for eltwise Bwd ops: Relu, Elu, Tanh + struct EltwiseBwdContext { + // expected memory format for this primitive instance + memory::format src_fmt; + memory::format diff_dst_fmt; + + // MKLDNN memory + std::shared_ptr src_mem; + std::shared_ptr diff_dst_mem; + std::shared_ptr diff_src_mem; + + // desc & prmitive desc + std::shared_ptr bwd_desc; + + // memory desc + std::shared_ptr src_md; + std::shared_ptr diff_dst_md; + std::shared_ptr common_md; + + // memory primitive desc + std::shared_ptr src_mpd; + std::shared_ptr diff_dst_mpd; + + // fwd primitive desc + std::shared_ptr fwd_desc; + std::shared_ptr fwd_pd; + std::shared_ptr bwd_pd; + + // Eltwise primitive + std::shared_ptr eltwise_bwd; + + std::shared_ptr bwd_stream; + std::vector bwd_primitives; + + EltwiseBwdContext() : + src_fmt(memory::format::any), diff_dst_fmt(memory::format::any), + src_mem(nullptr), diff_dst_mem(nullptr), diff_src_mem(nullptr), + src_md(nullptr), diff_dst_md(nullptr), common_md(nullptr), + src_mpd(nullptr), diff_dst_mpd(nullptr), + fwd_desc(nullptr), fwd_pd(nullptr), bwd_pd(nullptr), + eltwise_bwd(nullptr), bwd_stream(nullptr) { + } + }; + + // Eltwise backward primitive setup + void Setup(const MklEltwiseBwdParams& bwdParams) { + // create memory descriptors for eltwise data w/ no specified format + context_.src_md.reset(new memory::desc(bwdParams.common_md.data)); + context_.diff_dst_md.reset(new memory::desc(bwdParams.common_md.data)); + + context_.src_mpd.reset(new memory::primitive_desc( + *context_.src_md, cpu_engine_)); + context_.diff_dst_mpd.reset(new memory::primitive_desc( + *context_.diff_dst_md, cpu_engine_)); + + // create forward eltwise primitive + context_.fwd_desc.reset(new mkldnn::eltwise_forward::desc( + prop_kind::forward_training, bwdParams.alg_kind, + *context_.src_md, bwdParams.alpha, bwdParams.beta)); + context_.fwd_pd.reset(new mkldnn::eltwise_forward::primitive_desc( + *context_.fwd_desc, cpu_engine_)); + context_.bwd_desc.reset(new mkldnn::eltwise_backward::desc( + bwdParams.alg_kind, *context_.diff_dst_md, + *context_.src_md, bwdParams.alpha, bwdParams.beta)); + context_.bwd_pd.reset(new mkldnn::eltwise_backward::primitive_desc( + *context_.bwd_desc, cpu_engine_, *context_.fwd_pd)); + + // create memory primitive based on dummy data + context_.src_mem.reset(new memory(*context_.src_mpd, DummyData)); + context_.diff_dst_mem.reset(new memory(*context_.diff_dst_mpd, DummyData)); + context_.diff_src_mem.reset(new memory( + context_.bwd_pd.get()->diff_src_primitive_desc(), DummyData)); + + // create eltwise primitive and add it to net + context_.eltwise_bwd.reset(new mkldnn::eltwise_backward(*context_.bwd_pd, + *context_.src_mem, *context_.diff_dst_mem, *context_.diff_src_mem)); + + context_.bwd_primitives.push_back(*context_.eltwise_bwd); + } + + struct EltwiseBwdContext context_; + engine cpu_engine_; +}; + + +template +class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory { + private: + MklEltwiseBwdPrimitiveFactory() {} + ~MklEltwiseBwdPrimitiveFactory() {} + + public: + static MklEltwiseBwdPrimitive* Get( + const MklEltwiseBwdParams& bwdParams) { + MklEltwiseBwdPrimitive* eltwise_backward = nullptr; + + auto src_fmt = static_cast( + bwdParams.common_md.data.format); + auto diff_dst_fmt = static_cast( + bwdParams.common_md.data.format); + + // try to find a suitable one in pool + eltwise_backward = static_cast*> ( + MklEltwiseBwdPrimitiveFactory::GetInstance().GetEltwiseBwd( + bwdParams, src_fmt, diff_dst_fmt)); + + if (eltwise_backward == nullptr) { + eltwise_backward = new MklEltwiseBwdPrimitive(bwdParams); + MklEltwiseBwdPrimitiveFactory::GetInstance().SetEltwiseBwd( + bwdParams, src_fmt, diff_dst_fmt, eltwise_backward); + } + return eltwise_backward; + } + + static MklEltwiseBwdPrimitiveFactory& GetInstance() { + static MklEltwiseBwdPrimitiveFactory instance_; + return instance_; + } + + private: + static std::string CreateKey( + const MklEltwiseBwdParams& bwdParams, + const memory::format &src_fmt, + const memory::format &diff_dst_fmt) { + std::string prefix = "eltwise_bwd"; + FactoryKeyCreator key_creator; + key_creator.AddAsKey(prefix); + key_creator.AddAsKey(bwdParams.src_dims); + key_creator.AddAsKey(static_cast(bwdParams.alg_kind)); + key_creator.AddAsKey(static_cast(bwdParams.alpha)); + key_creator.AddAsKey(static_cast(bwdParams.beta)); + key_creator.AddAsKey(static_cast(src_fmt)); + key_creator.AddAsKey(static_cast(diff_dst_fmt)); + return key_creator.GetKey(); + } + + MklPrimitive* GetEltwiseBwd(const MklEltwiseBwdParams& bwdParams, + const memory::format &src_fmt, const memory::format &diff_dst_fmt) { + std::string key = CreateKey(bwdParams, src_fmt, diff_dst_fmt); + return this->GetOp(key); + } + + void SetEltwiseBwd(const MklEltwiseBwdParams& bwdParams, + const memory::format &src_fmt, + const memory::format &diff_dst_fmt, MklPrimitive *op) { + std::string key = CreateKey(bwdParams, src_fmt, diff_dst_fmt); + this->SetOp(key, op); + } +}; + +#endif + typedef Eigen::ThreadPoolDevice CPUDevice; struct MklReluHelpers { @@ -368,104 +764,109 @@ void MklReluGradOp::Compute(OpKernelContext* context) { mkl_context.MklCleanup(); } - - #else // INTEL_MKL_ML - template class MklReluOpBase : public OpKernel { public: ~MklReluOpBase() {} explicit MklReluOpBase(OpKernelConstruction* context) : OpKernel(context) {} - virtual void Compute_Scalar(OpKernelContext* context) = 0; void Compute(OpKernelContext* context) override { try { - auto cpu_engine = engine(engine::cpu, 0); const size_t src_index = 0; // index of src input tensor const size_t dst_index = 0; // index of dst output tensor const Tensor& src_tensor = MklGetInput(context, src_index); MklDnnShape dnn_shape_src; GetMklShape(context, src_index, &dnn_shape_src); - Tensor* dst_tensor = nullptr; if (src_tensor.dims() == 0) { - Compute_Scalar(context); // scalar case doesn't use in-place operation + Compute_Scalar(context); return; } - // Create relu primitive. - MklDnnData src(&cpu_engine); - MklDnnData dst(&cpu_engine); - // Set DNN primitive - src + MklDnnData src(&cpu_engine); + memory::dims src_dims; memory::desc src_md({}, memory::data_undef, memory::format_undef); if (dnn_shape_src.IsMklTensor()) { src_md = dnn_shape_src.GetMklLayout(); + src_dims = dnn_shape_src.GetSizesAsMklDnnDims(); } else { - auto src_dims = TFShapeToMklDnnDims(src_tensor.shape()); + src_dims = TFShapeToMklDnnDims(src_tensor.shape()); auto src_strides = CalculateTFStrides(src_dims); // Create blocked memory descriptor src_md = MklDnnData::CreateBlockedMemDesc(src_dims, src_strides); } - src.SetUsrMem(src_md, &src_tensor); T alpha = 0, beta = 0; - std::shared_ptr relu_fwd_pd; - auto relu_fwd_desc = relu_forward::desc( - prop_kind::forward_training, - // Operator memory descriptor is same as user memory descriptor. - alg_kind, src.GetUsrMemDesc(), alpha, beta); - relu_fwd_pd.reset( - new relu_forward::primitive_desc(relu_fwd_desc, cpu_engine)); - - // allocate dst tensor + + // get a eltwise fwd from primitive pool + MklEltwiseFwdParams fwdParams(src_dims, src_md, + alg_kind, alpha, beta); + MklEltwiseFwdPrimitive *eltwise_fwd = + MklEltwiseFwdPrimitiveFactory::Get(fwdParams); + + // prepare for execuation + T* src_data = nullptr; + // check wehther src need to reorder + if (src_md.data.format != eltwise_fwd->GetSrcMemoryFormat()) { + src.SetUsrMem(src_md, &src_tensor); + auto src_target_pd = memory::primitive_desc({{src_dims}, + MklDnnType(), eltwise_fwd->GetSrcMemoryFormat()}, cpu_engine); + src.CheckReorderToOpMem(src_target_pd); + src_data = static_cast(src.GetOpMem().get_data_handle()); + } else { + src_data = static_cast( + const_cast(src_tensor.flat().data())); + } + + // allocate dst tensor, always set it as MKL-DNN layout + std::shared_ptr + eltwise_fwd_pd = eltwise_fwd->GetEltwiseFwdPd(); MklDnnShape dnn_shape_dst; TensorShape tf_shape_dst; if (dnn_shape_src.IsMklTensor()) { dnn_shape_dst.SetMklTensor(true); - auto dst_pd = relu_fwd_pd->dst_primitive_desc(); + auto dst_pd = eltwise_fwd_pd->dst_primitive_desc(); dnn_shape_dst.SetMklLayout(&dst_pd); dnn_shape_dst.SetElemType(MklDnnType()); dnn_shape_dst.SetTfLayout(dnn_shape_src.GetDimension(), dnn_shape_src.GetSizesAsMklDnnDims(), dnn_shape_src.GetTfDataFormat()); - tf_shape_dst.AddDim(dst_pd.get_size() / sizeof(T)); + tf_shape_dst.AddDim(dst_pd.get_size()/sizeof(T)); } else { + // TODO(yli135): why relu's input is TF tensor in VGG16?? dnn_shape_dst.SetMklTensor(false); tf_shape_dst = src_tensor.shape(); } - - // Allocate output and MklDnnShape tensors separately for possible - // in-place operation + + Tensor* dst_tensor = nullptr; OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( - {static_cast(src_index)}, - static_cast(dst_index), - tf_shape_dst, &dst_tensor)); + {src_index}, dst_index, tf_shape_dst, &dst_tensor)); AllocateOutputSetMklShape(context, dst_index, dnn_shape_dst); - // Destination memory descriptor is same as source memory descriptor. - auto &dst_md = src_md; - dst.SetUsrMem(dst_md, dst_tensor); - - // execute net - std::vector net; - auto relu_fwd = - relu_forward(*relu_fwd_pd, src.GetOpMem(), dst.GetOpMem()); - net.push_back(relu_fwd); - stream(stream::kind::eager).submit(net).wait(); - } catch (mkldnn::error& e) { + T* dst_data = static_cast(const_cast( + dst_tensor->flat().data())); + + // execute eltwise + eltwise_fwd->Execute(src_data, dst_data); + } catch (mkldnn::error &e) { string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + ", in file " + - string(__FILE__) + ":" + std::to_string(__LINE__); - OP_REQUIRES_OK( - context, - errors::Aborted("Operation received an exception:", error_msg)); + ", message: " + string(e.message) + + ", in file " + string(__FILE__) + ":" + + std::to_string(__LINE__); + OP_REQUIRES_OK(context, + errors::Aborted("Operation received an exception:", + error_msg)); } } + + private: + engine cpu_engine = engine(engine::cpu, 0); + std::shared_ptr relu_fwd_pd; }; template @@ -473,17 +874,16 @@ class MklReluGradOpBase : public OpKernel { public: ~MklReluGradOpBase() {} - explicit MklReluGradOpBase(OpKernelConstruction* context) - : OpKernel(context) {} + explicit MklReluGradOpBase(OpKernelConstruction* context) : + OpKernel(context) { + } virtual void Compute_Scalar(OpKernelContext* context) = 0; void Compute(OpKernelContext* context) { try { - auto cpu_engine = engine(engine::cpu, 0); MklDnnData src(&cpu_engine); MklDnnData diff_dst(&cpu_engine); - MklDnnData diff_src(&cpu_engine); const size_t diff_dst_index = 0; // index of diff_dst input tensor const size_t src_index = 1; // index of src input tensor @@ -499,37 +899,23 @@ class MklReluGradOpBase : public OpKernel { int src_dims_size = src_tensor.dims(); if (src_dims_size == 0) { - Compute_Scalar(context); // scalar case doesn't use in-place operation + Compute_Scalar(context); return; } - // Set DNN primitives for src & diff_dst + // get a eltwise bwd from primitive pool + memory::dims src_dims = {}; memory::desc src_md({}, memory::data_undef, memory::format_undef); memory::desc diff_dst_md({}, memory::data_undef, memory::format_undef); - - // For creating Sum primitive, we need to ensure that all inputs are in - // same format. What that means is if we have a mixed input case - where - // one input is in Tensorflow format and one input is in MKL format -, - // then we need to ensure that all inputs are in same format for - // primitive construction. For performance reason, we say that all inputs - // are in MKL format in such case, and insert reorder for input that is - // in Tensorflow format into MKL format. On the other hand, if both the - // inputs are in MKL format or both are in Tensorflow format, then we - // dont need reorder. if (!dnn_shape_src.IsMklTensor() && !dnn_shape_diff_dst.IsMklTensor()) { - // If both the inputs are in Tensorflow format, we create blocked memory - // descriptor. - auto src_dims = TFShapeToMklDnnDims(src_tensor.shape()); + src_dims = TFShapeToMklDnnDims(src_tensor.shape()); auto src_strides = CalculateTFStrides(src_dims); src_md = MklDnnData::CreateBlockedMemDesc(src_dims, src_strides); diff_dst_md = src_md; } else if (dnn_shape_src.IsMklTensor() && !dnn_shape_diff_dst.IsMklTensor()) { - // If one input is in MKL format and other is in Tensorflow, then - // create respective descriptors describing the actual case. For input - // in Mkl format, we just get Mkl layout from MklDnnShape. For input in - // Tensorflow format, we create memory descriptor using data format. src_md = dnn_shape_src.GetMklLayout(); + src_dims = dnn_shape_src.GetSizesAsMklDnnDims(); memory::format src_mkl_data_format = dnn_shape_src.GetTfDataFormat(); auto src_tf_data_format = @@ -540,26 +926,23 @@ class MklReluGradOpBase : public OpKernel { memory::desc(diff_dst_dims, MklDnnType(), src_mkl_data_format); } else if (!dnn_shape_src.IsMklTensor() && dnn_shape_diff_dst.IsMklTensor()) { - // Same comment as above. diff_dst_md = dnn_shape_diff_dst.GetMklLayout(); memory::format diff_dst_mkl_data_format = dnn_shape_diff_dst.GetTfDataFormat(); auto diff_dst_tf_data_format = MklDnnDataFormatToTFDataFormat(diff_dst_mkl_data_format); - auto src_dims = TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), + src_dims = TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), diff_dst_tf_data_format); src_md = memory::desc(src_dims, MklDnnType(), diff_dst_mkl_data_format); } else { - // If both the inputs are in MKL format, we use Mkl layout of the input - // tensors. src_md = dnn_shape_src.GetMklLayout(); diff_dst_md = dnn_shape_diff_dst.GetMklLayout(); + src_dims = dnn_shape_src.GetSizesAsMklDnnDims(); } - src.SetUsrMem(src_md, &src_tensor); - diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); + T alpha = 0, beta = 0; // As per comment above, we tell MKLDNN that both the inputs are in same // format. So we set common memory descriptor in MKL format, if any of the @@ -574,83 +957,77 @@ class MklReluGradOpBase : public OpKernel { common_md = src_md; } - T alpha = 0, beta = 0; - std::shared_ptr relu_fwd_pd; - auto relu_fwd_desc = relu_forward::desc(prop_kind::forward_training, - alg_kind, src_md, alpha, beta); - relu_fwd_pd.reset( - new relu_forward::primitive_desc(relu_fwd_desc, cpu_engine)); - auto relu_bwd_desc = - relu_backward::desc(alg_kind, common_md, common_md, alpha, beta); - auto relu_bwd_pd = relu_backward::primitive_desc( - relu_bwd_desc, cpu_engine, *relu_fwd_pd); + MklEltwiseBwdParams bwdParams(src_dims, common_md, + alg_kind, alpha, beta); + MklEltwiseBwdPrimitive *eltwise_bwd = + MklEltwiseBwdPrimitiveFactory::Get(bwdParams); + auto eltwise_bwd_pd = eltwise_bwd->GetEltwiseBwdPd(); + + // check whether need reorder for src / diff_dst + T* src_data; + T* diff_dst_data; + if (src_md.data.format != eltwise_bwd->GetSrcMemoryFormat()) { + src.SetUsrMem(src_md, &src_tensor); + src.CheckReorderToOpMem( + eltwise_bwd_pd.get()->diff_src_primitive_desc()); + src_data = static_cast(src.GetOpMem().get_data_handle()); + } else { + src_data = static_cast( + const_cast(src_tensor.flat().data())); + } + + if (diff_dst_md.data.format != eltwise_bwd->GetDiffDstMemoryFormat()) { + diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); + diff_dst.CheckReorderToOpMem( + eltwise_bwd_pd.get()->diff_src_primitive_desc()); + diff_dst_data = static_cast( + diff_dst.GetOpMem().get_data_handle()); + } else { + diff_dst_data = static_cast(const_cast( + diff_dst_tensor.flat().data())); + } // allocate diff_src tensor MklDnnShape dnn_shape_diff_src; TensorShape tf_shape_diff_src; - if (dnn_shape_src.IsMklTensor() || - dnn_shape_diff_dst.IsMklTensor()) { + if (dnn_shape_src.IsMklTensor()) { + auto diff_src_pd = eltwise_bwd_pd->diff_src_primitive_desc(); dnn_shape_diff_src.SetMklTensor(true); - auto diff_src_pd = relu_bwd_pd.diff_src_primitive_desc(); dnn_shape_diff_src.SetMklLayout(&diff_src_pd); dnn_shape_diff_src.SetElemType(MklDnnType()); - if (dnn_shape_src.IsMklTensor()) { - dnn_shape_diff_src.SetTfLayout(dnn_shape_src.GetDimension(), - dnn_shape_src.GetSizesAsMklDnnDims(), - dnn_shape_src.GetTfDataFormat()); - } else { - dnn_shape_diff_src.SetTfLayout(dnn_shape_diff_dst.GetDimension(), - dnn_shape_diff_dst.GetSizesAsMklDnnDims(), - dnn_shape_diff_dst.GetTfDataFormat()); - } - tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T)); + dnn_shape_diff_src.SetTfLayout(dnn_shape_src.GetDimension(), + dnn_shape_src.GetSizesAsMklDnnDims(), + dnn_shape_src.GetTfDataFormat()); + tf_shape_diff_src.AddDim(diff_src_pd.get_size()/sizeof(T)); } else { dnn_shape_diff_src.SetMklTensor(false); - // both src and diff_dst are TensorFlow layout, - // so it is ok to get TensorFlow shape. tf_shape_diff_src = src_tensor.shape(); } - // Allocate diff_src and MklDnnShape tensors separately for possible - // in-place operation - OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( - {static_cast(diff_dst_index)}, - static_cast(diff_src_index), - tf_shape_diff_src, - &diff_src_tensor)); - AllocateOutputSetMklShape(context, diff_src_index, dnn_shape_diff_src); - - // diff_src memory descriptor is same as memory descriptor for both - // inputs. - diff_src.SetUsrMem(common_md, diff_src_tensor); - - PrepareAndExecuteNet(relu_bwd_pd, &src, &diff_src, &diff_dst); - } catch (mkldnn::error& e) { - string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + ", in file " + - string(__FILE__) + ":" + std::to_string(__LINE__); - OP_REQUIRES_OK( - context, - errors::Aborted("Operation received an exception:", error_msg)); + OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( + {diff_dst_index}, diff_src_index, tf_shape_diff_src, + &diff_src_tensor)); + AllocateOutputSetMklShape(context, diff_src_index, dnn_shape_diff_src); + + T* diff_src_data = static_cast(const_cast( + diff_src_tensor->flat().data())); + + // execute eltwise bwd + eltwise_bwd->Execute(src_data, diff_dst_data, diff_src_data); + } catch (mkldnn::error &e) { + string error_msg = "Status: " + std::to_string(e.status) + + ", message: " + string(e.message) + + ", in file " + string(__FILE__) + ":" + + std::to_string(__LINE__); + OP_REQUIRES_OK(context, + errors::Aborted("Operation received an exception:", + error_msg)); } } - void PrepareAndExecuteNet(const relu_backward::primitive_desc& relu_prim_desc, - MklDnnData* src, MklDnnData* diff_src, - MklDnnData* diff_dst) { - std::vector net; - - // Check if we need to reorder original input tensors into common_md layout - // that we set for primitive creation. diff_src_primitive_desc is same as - // common_md. - src->CheckReorderToOpMem(relu_prim_desc.diff_src_primitive_desc(), &net); - diff_dst->CheckReorderToOpMem(relu_prim_desc.diff_src_primitive_desc(), - &net); - - net.push_back(relu_backward(relu_prim_desc, src->GetOpMem(), - diff_dst->GetOpMem(), diff_src->GetOpMem())); - stream(stream::kind::eager).submit(net).wait(); - } + private: + engine cpu_engine = engine(engine::cpu, 0); + std::shared_ptr relu_fwd_pd; }; template diff --git a/tensorflow/core/kernels/mkl_transpose_op.cc b/tensorflow/core/kernels/mkl_transpose_op.cc index b180c2ff2006e1..a0a34fc7231af5 100644 --- a/tensorflow/core/kernels/mkl_transpose_op.cc +++ b/tensorflow/core/kernels/mkl_transpose_op.cc @@ -15,13 +15,23 @@ limitations under the License. // See docs in ../ops/array_ops.cc. -#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML) +#if defined(INTEL_MKL) #define EIGEN_USE_THREADS +#if !defined(DO_NOT_USE_ML) #include "mkl_trans.h" +#endif + #include "tensorflow/core/kernels/transpose_functor.h" #include "tensorflow/core/kernels/transpose_op.h" +#ifndef INTEL_MKL_ML +#include "mkldnn.hpp" +#include "tensorflow/core/util/mkl_util.h" + +using mkldnn::stream; +#endif + namespace tensorflow { // output = TransposeOp(T input, T perm) takes a tensor @@ -40,6 +50,7 @@ namespace tensorflow { // REQUIRES: perm is a permutation. namespace { +#if !defined(DO_NOT_USE_ML) template Status MKLTranspose2D(const char trans, const Tensor& in, Tensor* out); @@ -93,11 +104,67 @@ Status MKLTranspose2D(const char trans, const Tensor& in, static const char kMKLTranspose = 'T'; static const char kMKLConjugateTranspose = 'C'; +#endif // if !defined(DO_NOT_USE_ML) + +#ifndef INTEL_MKL_ML +// MKL-DNN based Transpose implementation +template +Status MKLTransposeND(OpKernelContext* ctx, const Tensor& in, Tensor* out, + const gtl::ArraySlice& perm); + + +static inline memory::dims ReorderStrides(const memory::dims& strides, + const gtl::ArraySlice& perm) { + memory::dims reordered_strides; + reordered_strides.resize(strides.size()); + for (size_t i = 0; i < strides.size(); ++i) { + reordered_strides[perm[i]] = strides[i]; + } + return reordered_strides; +} + +// Transpose of N-dimensional tensor using MKL-DNN +template +Status MKLTransposeND(OpKernelContext* context, + const Tensor& in_tensor, Tensor* out_tensor, + const gtl::ArraySlice& perm) { + try { + engine cpu_engine = engine(engine::cpu, 0); + MklDnnData in(&cpu_engine); + MklDnnData out(&cpu_engine); + + memory::dims in_dims = TFShapeToMklDnnDims(in_tensor.shape()); + memory::dims out_dims = TFShapeToMklDnnDims(out_tensor->shape()); + memory::dims in_strides = CalculateTFStrides(in_dims); + // Reorder output strides based on permutation requested. + memory::dims out_strides = ReorderStrides(CalculateTFStrides(out_dims), + perm); + + in.SetUsrMem(in_dims, in_strides, &in_tensor); + // Output dimensions are same as input dimensions. We adjust the layout + // using strides. + out.SetUsrMem(in_dims, out_strides, out_tensor); + + std::vector net; + net.push_back(in.CreateReorder(in.GetUsrMem(), out.GetUsrMem())); + stream(stream::kind::eager).submit(net).wait(); + return Status::OK(); + } catch (mkldnn::error &e) { + string error_msg = "Status: " + std::to_string(e.status) + + ", message: " + std::string(e.message) + + ", in file " + std::string(__FILE__) + ":" + + std::to_string(__LINE__); + return errors::Aborted("Operation received an exception:", error_msg); + } +} +#endif // #ifndef INTEL_MKL_ML + } // namespace Status MklTransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in, gtl::ArraySlice perm, Tensor* out) { +#if !defined(DO_NOT_USE_ML) if (in.dims() == 2) { if (perm[0] == 0 && perm[1] == 1) { return Status::OK(); @@ -115,7 +182,21 @@ Status MklTransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in, break; } } - // Fallback to eigen if transpose parameters not supported by MKL +#endif + +#ifndef INTEL_MKL_ML + // MKL-DNN has limit on the maximum number of dimensions in a tensor. + // Fallback to Eigen for not supported cases. + if (in.dims() <= TENSOR_MAX_DIMS) { + switch (in.dtype()) { + case DT_FLOAT: return MKLTransposeND(ctx, in, out, perm); break; + // TODO(nhasabni): support other types such as INT8. + default: break; + } + } +#endif + + // Fallback to eigen if transpose parameters not supported by MKL or MKL-DNN typedef Eigen::ThreadPoolDevice CPUDevice; return ::tensorflow::DoTranspose(ctx->eigen_device(), in, perm, out); @@ -125,6 +206,7 @@ Status MklConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in, gtl::ArraySlice perm, Tensor* out) { +#if !defined(DO_NOT_USE_ML) if (in.dims() == 2 && perm[0] == 1 && perm[1] == 0) { // TODO(rmlarsen): By setting lda and ldb, we could use the MKL kernels // for any transpose that can be reduced to swapping the last two @@ -143,7 +225,21 @@ Status MklConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx, break; } } - // Fallback to eigen if transpose parameters not supported by MKL +#endif + +#ifndef INTEL_MKL_ML + // MKL-DNN has limit on the maximum number of dimensions in a tensor. + // Fallback to Eigen for not supported cases. + if (in.dims() <= TENSOR_MAX_DIMS) { + switch (in.dtype()) { + case DT_FLOAT: return MKLTransposeND(ctx, in, out, perm); break; + // TODO(nhasabni): support other types such as INT8. + default: break; + } + } +#endif + + // Fallback to eigen if transpose parameters not supported by MKL or MKL-DNN typedef Eigen::ThreadPoolDevice CPUDevice; return ::tensorflow::DoConjugateTranspose(ctx->eigen_device(), in, perm, out); diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc index 886b3e7492efa8..0f0f65c5a37054 100644 --- a/tensorflow/core/kernels/transpose_op.cc +++ b/tensorflow/core/kernels/transpose_op.cc @@ -218,7 +218,7 @@ Status ConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx, perm, out); } -#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML) +#if defined(INTEL_MKL) #define REGISTER(T) \ REGISTER_KERNEL_BUILDER(Name("Transpose") \ .Device(DEVICE_CPU) \ diff --git a/tensorflow/core/kernels/transpose_op.h b/tensorflow/core/kernels/transpose_op.h index 709b0a92e90b5f..9e8c57376189d7 100644 --- a/tensorflow/core/kernels/transpose_op.h +++ b/tensorflow/core/kernels/transpose_op.h @@ -42,7 +42,7 @@ class TransposeCpuOp : public TransposeOp { gtl::ArraySlice perm, Tensor* out) override; }; -#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML) +#if defined(INTEL_MKL) class MklTransposeCpuOp : public TransposeOp { public: explicit MklTransposeCpuOp(OpKernelConstruction* ctx) : TransposeOp(ctx) {} @@ -85,7 +85,7 @@ class ConjugateTransposeCpuOp : public TransposeOp { bool IsConjugate() const override { return true; } }; -#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML) +#if defined(INTEL_MKL) class MklConjugateTransposeCpuOp : public TransposeOp { public: explicit MklConjugateTransposeCpuOp(OpKernelConstruction* ctx) diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index bb447e03938024..8b117e343391d0 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -39,7 +39,7 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" - +#include "tensorflow/core/platform/cpu_info.h" #ifndef INTEL_MKL_ML #include "mkldnn.hpp" #include "tensorflow/core/lib/core/stringpiece.h" @@ -1504,7 +1504,8 @@ class MklDnnData { /// Operations memory descriptor memory::desc* op_md_; - + /// Operations temp buffer + void* allocated_buffer_; /// CPU engine on which operation will be executed const engine* cpu_engine_; @@ -1513,6 +1514,7 @@ class MklDnnData { : user_memory_(nullptr), reorder_memory_(nullptr), op_md_(nullptr), + allocated_buffer_(nullptr), cpu_engine_(e) {} ~MklDnnData() { @@ -1653,6 +1655,14 @@ class MklDnnData { user_memory_->set_data_handle(GetTensorBuffer(tensor)); } + /// allocate function for data buffer + inline void AllocateBuffer(size_t size) { + allocated_buffer_ = cpu_allocator()->AllocateRaw(64, size); + } + inline void* GetAllocatedBuffer() { + return allocated_buffer_; + } + /// Get the memory primitive for input and output of an op. If inputs /// to an op require reorders, then this function returns memory primitive /// for reorder. Otherwise, it will return memory primitive for user memory. @@ -1883,9 +1893,9 @@ class MklPrimitive { public: virtual ~MklPrimitive() {} - // Dummy data. Its size, hard-coded as 256 here, does - // not matter since MKL should never operate on this buffer. - unsigned char DummyData[256]; + // Dummy data. + // does not matter since MKL should never operate on this buffer. + unsigned char *DummyData = nullptr; }; const mkldnn::memory::dims NONE_DIMS = {}; @@ -1897,8 +1907,9 @@ class MklPrimitiveFactory { ~MklPrimitiveFactory() {} MklPrimitive* GetOp(const std::string& key) { - auto stream_iter = MklPrimitiveFactory::GetHashMap().find(key); - if (stream_iter == MklPrimitiveFactory::GetHashMap().end()) { + auto &map = MklPrimitiveFactory::GetHashMap(); + auto stream_iter = map.find(key); + if (stream_iter == map.end()) { return nullptr; } else { return stream_iter->second; @@ -1906,11 +1917,12 @@ class MklPrimitiveFactory { } void SetOp(const std::string& key, MklPrimitive* op) { - auto stream_iter = MklPrimitiveFactory::GetHashMap().find(key); + auto &map = MklPrimitiveFactory::GetHashMap(); + auto stream_iter = map.find(key); - CHECK(stream_iter == MklPrimitiveFactory::GetHashMap().end()); + CHECK(stream_iter == map.end()); - MklPrimitiveFactory::GetHashMap()[key] = op; + map[key] = op; } private: @@ -1957,6 +1969,21 @@ class FactoryKeyCreator { } }; +static inline memory::format get_desired_format(int channel) { + memory::format fmt_desired = memory::format::any; + + if (port::TestCPUFeature(port::CPUFeature::AVX512F) + && (channel % 16) == 0) { + fmt_desired = memory::format::nChw16c; + } else if (port::TestCPUFeature(port::CPUFeature::AVX2) + && (channel % 8) == 0) { + fmt_desired = memory::format::nChw8c; + } else { + fmt_desired = memory::format::nchw; + } + return fmt_desired; +} + class MklReorderPrimitive : public MklPrimitive { public: explicit MklReorderPrimitive(const memory* from, const memory* to) { @@ -2059,7 +2086,7 @@ class MklReorderPrimitiveFactory : public MklPrimitiveFactory { MklReorderPrimitiveFactory::Get(from, to); return *reorder_prim->GetPrimitive(); } - + #endif // INTEL_MKL_DNN } // namespace tensorflow diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 3b7674e39703ba..1e89f1d921666d 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -52,8 +52,8 @@ def tf_workspace(path_prefix="", tf_repo_name=""): mkl_repository( name = "mkl_linux", urls = [ - "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.14/mklml_lnx_2018.0.3.20180406.tgz", - "https://github.com/intel/mkl-dnn/releases/download/v0.14/mklml_lnx_2018.0.3.20180406.tgz" + "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.15/mklml_lnx_2018.0.3.20180406.tgz", + "https://github.com/intel/mkl-dnn/releases/download/v0.15/mklml_lnx_2018.0.3.20180406.tgz" ], sha256 = "d2305244fdc9b87db7426ed4496e87a4b3977ad3374d73b8000e8b7a5b7aa725", strip_prefix = "mklml_lnx_2018.0.3.20180406", @@ -62,8 +62,8 @@ def tf_workspace(path_prefix="", tf_repo_name=""): mkl_repository( name = "mkl_windows", urls = [ - "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.14/mklml_win_2018.0.3.20180406.zip", - "https://github.com/intel/mkl-dnn/releases/download/v0.14/mklml_win_2018.0.3.20180406.zip" + "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.15/mklml_win_2018.0.3.20180406.zip", + "https://github.com/intel/mkl-dnn/releases/download/v0.15/mklml_win_2018.0.3.20180406.zip" ], sha256 = "a584a5bf1c8d2ad70b90d12b52652030e9a338217719064fdb84b7ad0d693694", strip_prefix = "mklml_win_2018.0.3.20180406", @@ -72,8 +72,8 @@ def tf_workspace(path_prefix="", tf_repo_name=""): mkl_repository( name = "mkl_darwin", urls = [ - "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.14/mklml_mac_2018.0.3.20180406.tgz", - "https://github.com/intel/mkl-dnn/releases/download/v0.14/mklml_mac_2018.0.3.20180406.tgz" + "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.15/mklml_mac_2018.0.3.20180406.tgz", + "https://github.com/intel/mkl-dnn/releases/download/v0.15/mklml_mac_2018.0.3.20180406.tgz" ], sha256 = "094e3dfd61c816136dc8d12a45cc611ce26c5f4828176a3644cd0b0efa15a25b", strip_prefix = "mklml_mac_2018.0.3.20180406", @@ -87,11 +87,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "mkl_dnn", urls = [ - "https://mirror.bazel.build/github.com/intel/mkl-dnn/archive/v0.14.tar.gz", - "https://github.com/intel/mkl-dnn/archive/v0.14.tar.gz", + "https://mirror.bazel.build/github.com/intel/mkl-dnn/archive/0c1cf54b63732e5a723c5670f66f6dfb19b64d20.zip", + "https://github.com/intel/mkl-dnn/archive/0c1cf54b63732e5a723c5670f66f6dfb19b64d20.zip" ], - sha256 = "efebc53882856afec86457a2da644693f5d59c68772d41d640d6b60a8efc4eb0", - strip_prefix = "mkl-dnn-0.14", + sha256 = "bfea2893ec978577a0e6c7a8703746bd87f60a4ea7c9ee96e3bafdf7c55a9f12", + strip_prefix = "mkl-dnn-0c1cf54b63732e5a723c5670f66f6dfb19b64d20", build_file = clean_dep("//third_party/mkl_dnn:mkldnn.BUILD"), ) diff --git a/third_party/mkl_dnn/mkldnn.BUILD b/third_party/mkl_dnn/mkldnn.BUILD index 57d2e1292b012a..cd1f9a63277f42 100644 --- a/third_party/mkl_dnn/mkldnn.BUILD +++ b/third_party/mkl_dnn/mkldnn.BUILD @@ -18,6 +18,7 @@ cc_library( srcs = glob([ "src/common/*.cpp", "src/cpu/*.cpp", + "src/cpu/gemm/*.cpp" ]), hdrs = glob(["include/*"]), copts = [ @@ -42,6 +43,7 @@ cc_library( "src/common", "src/cpu", "src/cpu/xbyak", + "src/cpu/gemm" ], nocopts = "-fno-exceptions", visibility = ["//visibility:public"],