diff --git a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc index 3fe660cf968b4e..4ccb10a7cf4305 100644 --- a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc @@ -262,6 +262,7 @@ class MklFusedBatchNormOp : public OpKernel { } void MklCreateInputLayout(OpKernelContext* context) { + const Tensor& input = MklGetInput(context, 0); bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor(); if (input_in_mkl_format) { mkl_lt_input = @@ -544,6 +545,7 @@ class MklFusedBatchNormGradOp : public OpKernel { } void MklCreateInputLayout(OpKernelContext* context) { + const Tensor& input = MklGetInput(context, 0); bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor(); if (input_in_mkl_format) { mkl_lt_input = @@ -684,6 +686,466 @@ class MklFusedBatchNormGradOp : public OpKernel { #ifndef INTEL_MKL_ML +struct MklBatchNormFwdParams { + memory::dims src_dims; + int depth; + float eps; + bool training; + + MklBatchNormFwdParams(const memory::dims &src_dims, + int depth, float eps, bool training) : src_dims(src_dims), + depth(depth), eps(eps), training(training) { + } +}; + +template +class MklFusedBatchNormFwdPrimitive : public MklPrimitive { + public: + explicit MklFusedBatchNormFwdPrimitive( + const MklBatchNormFwdParams& fwdParams) : + cpu_engine_(engine::cpu, 0) { + context_.fwd_stream.reset( + new mkldnn::stream(mkldnn::stream::kind::eager)); + if (context_.bn_fwd == nullptr) + Setup(fwdParams); + } + + ~MklFusedBatchNormFwdPrimitive() {} + + // BatchNormalization forward execute + // src_data: input data buffer of src + // weights_data: input data buffer of weights + // dst_data: output data buffer of dst + // mean_data: input data buffer of means + // variance_data: input data buffer of variances + void Execute(const T* src_data, const T* weights_data, const T* dst_data, + const T* mean_data, const T* variance_data) { + context_.src_mem->set_data_handle( + static_cast(const_cast(src_data))); + context_.dst_mem->set_data_handle( + static_cast(const_cast(dst_data))); + + if (context_.flags & use_scale_shift) + context_.weights_mem->set_data_handle( + static_cast(const_cast(weights_data))); + + if ((context_.pkind == prop_kind::forward_training) || + (context_.flags & use_global_stats)) { + context_.mean_mem->set_data_handle( + static_cast(const_cast(mean_data))); + context_.variance_mem->set_data_handle( + static_cast(const_cast(variance_data))); + } + + // execution + context_.fwd_stream->submit(context_.fwd_primitives); + + context_.src_mem->set_data_handle(DummyData); + context_.dst_mem->set_data_handle(DummyData); + + if (context_.flags & use_scale_shift) + context_.weights_mem->set_data_handle(DummyData); + + if ((context_.pkind == prop_kind::forward_training) || + (context_.flags & use_global_stats)) { + context_.mean_mem->set_data_handle(DummyData); + context_.variance_mem->set_data_handle(DummyData); + } + } + + memory::primitive_desc GetDstPd() const { + return (*context_.dst_mem).get_primitive_desc(); + } + + mkldnn_memory_format_t GetSrcFmt() const { + return (*context_.src_mem).get_primitive_desc().desc().data.format; + } + + mkldnn_memory_format_t GetDstFmt() const { + return (*context_.dst_mem).get_primitive_desc().desc().data.format; + } + + private: + // Primitive reuse context for BatchNorm fwd op + struct BatchNormFwdContext { + // flags indict if it is training or inference mode + int64 flags; + + // algorithm + mkldnn::prop_kind pkind; + + // Mkldnn Memory + std::shared_ptr src_mem; + std::shared_ptr weights_mem; + std::shared_ptr dst_mem; + std::shared_ptr mean_mem; + std::shared_ptr variance_mem; + + // BatchNorm forward primitive + std::shared_ptr bn_fwd; + std::shared_ptr fwd_stream; + std::vector fwd_primitives; + + BatchNormFwdContext() : + flags(0), pkind(mkldnn::forward_training), src_mem(nullptr), + weights_mem(nullptr), dst_mem(nullptr), mean_mem(nullptr), + variance_mem(nullptr), bn_fwd(nullptr), fwd_stream(nullptr) { + } + }; + + void Setup(const MklBatchNormFwdParams& fwdParams) { + context_.flags = fwdParams.training ? use_scale_shift + : (use_scale_shift | use_global_stats); + context_.pkind = fwdParams.training ? prop_kind::forward_training + : prop_kind::forward_scoring; + + // memory desc + auto src_md = memory::desc({fwdParams.src_dims}, + MklDnnType(), get_desired_format(fwdParams.src_dims[1])); + + // fwd desc & primitive desc + auto fwd_desc = batch_normalization_forward::desc( + context_.pkind, src_md, fwdParams.eps, context_.flags); + auto fwd_pd = batch_normalization_forward::primitive_desc( + fwd_desc, cpu_engine_); + + // memory primitive + context_.src_mem.reset(new memory({src_md, cpu_engine_}, DummyData)); + context_.dst_mem.reset(new memory(fwd_pd.dst_primitive_desc(), DummyData)); + + if (context_.flags & use_scale_shift) { + auto weights_desc = memory::desc({2, fwdParams.depth}, + MklDnnType(), memory::format::nc); + context_.weights_mem.reset(new memory({weights_desc, cpu_engine_}, + DummyData)); + } + + if (fwdParams.training || (context_.flags & use_global_stats)) { + auto mean_desc = memory::desc({1, fwdParams.depth}, + MklDnnType(), memory::format::nc); + context_.mean_mem.reset(new memory({mean_desc, cpu_engine_}, DummyData)); + + auto variance_desc = memory::desc({1, fwdParams.depth}, + MklDnnType(), memory::nc); + context_.variance_mem.reset(new memory({variance_desc, cpu_engine_}, + DummyData)); + } + + // BatchNorm forward primitive + if (!fwdParams.training && !(context_.flags & use_global_stats)) { + if ((context_.flags & use_scale_shift) && mkldnn_use_scaleshift) { + context_.bn_fwd.reset(new batch_normalization_forward(fwd_pd, + *context_.src_mem, *context_.weights_mem, *context_.dst_mem)); + } else { + context_.bn_fwd.reset(new batch_normalization_forward( + fwd_pd, *context_.src_mem, *context_.dst_mem)); + } + } else if (context_.flags & use_global_stats) { + if ((context_.flags & use_scale_shift) && mkldnn_use_scaleshift) { + context_.bn_fwd.reset(new batch_normalization_forward( + fwd_pd, *context_.src_mem, (const primitive::at)*context_.mean_mem, + (const primitive::at)*context_.variance_mem, *context_.weights_mem, + *context_.dst_mem)); + } else { + context_.bn_fwd.reset(new batch_normalization_forward( + fwd_pd, *context_.src_mem, (const primitive::at)*context_.mean_mem, + (const primitive::at)*context_.variance_mem, *context_.dst_mem)); + } + } else { + if ((context_.flags & use_scale_shift) && mkldnn_use_scaleshift) { + context_.bn_fwd.reset(new batch_normalization_forward( + fwd_pd, *context_.src_mem, *context_.weights_mem, *context_.dst_mem, + *context_.mean_mem, *context_.variance_mem)); + } else { + context_.bn_fwd.reset(new batch_normalization_forward( + fwd_pd, *context_.src_mem, *context_.dst_mem, + *context_.mean_mem, *context_.variance_mem)); + } + } + + context_.fwd_primitives.push_back(*context_.bn_fwd); + } + + mkldnn::memory::desc get_desc_data(const mkldnn::memory &m) const { + return m.get_primitive_desc().desc().data; + } + + struct BatchNormFwdContext context_; + engine cpu_engine_; +}; + +template +class MklFusedBatchNormFwdPrimitiveFactory : public MklPrimitiveFactory { + public: + static MklFusedBatchNormFwdPrimitive* Get( + const MklBatchNormFwdParams& fwdParams) { + auto bn_fwd = static_cast*>( + MklFusedBatchNormFwdPrimitiveFactory + ::GetInstance().GetBatchNormFwd(fwdParams)); + + if (bn_fwd == nullptr) { + bn_fwd = new MklFusedBatchNormFwdPrimitive(fwdParams); + MklFusedBatchNormFwdPrimitiveFactory::GetInstance().SetBatchNormFwd( + fwdParams, bn_fwd); + } + return bn_fwd; + } + + static MklFusedBatchNormFwdPrimitiveFactory & GetInstance() { + static MklFusedBatchNormFwdPrimitiveFactory instance_; + return instance_; + } + + private: + MklFusedBatchNormFwdPrimitiveFactory() {} + ~MklFusedBatchNormFwdPrimitiveFactory() {} + + static std::string CreateKey(const MklBatchNormFwdParams& fwdParams) { + std::string prefix = "bn_fwd"; + FactoryKeyCreator key_creator; + key_creator.AddAsKey(prefix); + key_creator.AddAsKey(fwdParams.src_dims); + key_creator.AddAsKey(fwdParams.depth); + key_creator.AddAsKey(fwdParams.eps); + key_creator.AddAsKey(fwdParams.training); + return key_creator.GetKey(); + } + + MklPrimitive* GetBatchNormFwd(const MklBatchNormFwdParams& fwdParams) { + std::string key = CreateKey(fwdParams); + return this->GetOp(key); + } + + void SetBatchNormFwd(const MklBatchNormFwdParams& fwdParams, + MklPrimitive *op) { + std::string key = CreateKey(fwdParams); + this->SetOp(key, op); + } +}; + +struct MklBatchNormBwdParams { + memory::dims src_dims; + memory::dims diff_dst_dims; + int depth; + float eps; + bool training; + + MklBatchNormBwdParams(memory::dims src_dims, memory::dims diff_dst_dims, + int depth, float eps, bool training) : src_dims(src_dims), + diff_dst_dims(diff_dst_dims), depth(depth), eps(eps), + training(training) { + } +}; + + +template +class MklFusedBatchNormBwdPrimitive : public MklPrimitive { + public: + explicit MklFusedBatchNormBwdPrimitive( + const MklBatchNormBwdParams& bwdParams) : + cpu_engine_(engine::cpu, 0) { + context_.bwd_stream.reset( + new mkldnn::stream(mkldnn::stream::kind::eager)); + if (context_.bn_bwd == nullptr) + Setup(bwdParams); + } + + ~MklFusedBatchNormBwdPrimitive() {} + + // BatchNormalization backward execute + // src_data: input data buffer of src + // mean_data: input data buffer of mean + // variance_data: input data buffer of variance + // diff_dst_data: input data buffer of diff_dst + // weights_data: input data buffer of weights + // diff_src_data: output data buffer of diff_src + // diff_weights_data: output data buffer of diff_weights + void Execute(const T* src_data, const T* mean_data, const T* variance_data, + const T* diff_dst_data, const T* weights_data, + const T* diff_src_data, const T* diff_weights_data) { + context_.src_mem->set_data_handle( + static_cast(const_cast(src_data))); + context_.mean_mem->set_data_handle( + static_cast(const_cast(mean_data))); + context_.variance_mem->set_data_handle( + static_cast(const_cast(variance_data))); + context_.diff_dst_mem->set_data_handle( + static_cast(const_cast(diff_dst_data))); + + if (context_.flags & use_scale_shift) { + context_.weights_mem->set_data_handle( + static_cast(const_cast(weights_data))); + context_.diff_weights_mem->set_data_handle( + static_cast(const_cast(diff_weights_data))); + } + + context_.diff_src_mem->set_data_handle( + static_cast(const_cast(diff_src_data))); + + // execution + context_.bwd_stream->submit(context_.bwd_primitives); + + context_.src_mem->set_data_handle(DummyData); + context_.mean_mem->set_data_handle(DummyData); + context_.variance_mem->set_data_handle(DummyData); + context_.diff_dst_mem->set_data_handle(DummyData); + if (context_.flags & use_scale_shift) { + context_.weights_mem->set_data_handle(DummyData); + context_.diff_weights_mem->set_data_handle(DummyData); + } + context_.diff_src_mem->set_data_handle(DummyData); + } + + mkldnn_memory_format_t GetSrcFmt() { + return(*context_.src_mem).get_primitive_desc().desc().data.format; + } + + mkldnn_memory_format_t GetDiffDstFmt() { + return(*context_.diff_dst_mem).get_primitive_desc().desc().data.format; + } + + memory::primitive_desc GetDiffSrcPd() { + return(*context_.diff_src_mem).get_primitive_desc(); + } + + private: + struct BatchNormBwdContext { + // Flags to indicate whether it is training or inference + int64 flags; + + // MKLDNN memory + std::shared_ptr src_mem; + std::shared_ptr mean_mem; + std::shared_ptr variance_mem; + std::shared_ptr diff_dst_mem; + std::shared_ptr weights_mem; + std::shared_ptr diff_weights_mem; + std::shared_ptr diff_src_mem; + + // Batch Norm primitive + std::shared_ptr bn_bwd; + std::vector bwd_primitives; + std::shared_ptr bwd_stream; + + BatchNormBwdContext() : + src_mem(nullptr), mean_mem(nullptr), variance_mem(nullptr), + diff_dst_mem(nullptr), weights_mem(nullptr), diff_weights_mem(nullptr), + diff_src_mem(nullptr), bwd_stream(nullptr) { + } + }; + + void Setup(const MklBatchNormBwdParams& bwdParams) { + context_.flags = bwdParams.training ? use_scale_shift + : (use_scale_shift | use_global_stats); + + // memory desc + auto src_md = memory::desc({bwdParams.src_dims}, + MklDnnType(), get_desired_format(bwdParams.src_dims[1])); + auto diff_dst_md = memory::desc({bwdParams.diff_dst_dims}, + MklDnnType(), get_desired_format(bwdParams.diff_dst_dims[1])); + auto variance_desc = memory::desc({1, bwdParams.depth}, MklDnnType(), + memory::nc); + auto mean_desc = memory::desc({1, bwdParams.depth}, + MklDnnType(), memory::format::nc); + auto weights_desc = memory::desc({2, bwdParams.depth}, + MklDnnType(), memory::format::nc); + auto diff_weights_desc = weights_desc; + + // fwd desc & primitive desc + auto fwd_desc = batch_normalization_forward::desc( + prop_kind::forward_training, src_md, bwdParams.eps, + bwdParams.training + ? use_scale_shift + : (use_scale_shift | use_global_stats)); + auto fwd_pd = batch_normalization_forward::primitive_desc( + fwd_desc, cpu_engine_); + + // BatchNorm backward primtive + // + // For inference, specify use_global_stats + // 1. on fwd propagation, use mean and variance provided as inputs. + // 2. on bwd propagation, mean and variance are considered as constants. + // Thus, reduce the amount of MKL computation. + auto bwd_desc = batch_normalization_backward::desc( + prop_kind::backward, diff_dst_md, src_md, bwdParams.eps, + bwdParams.training ? use_scale_shift + : (use_scale_shift | use_global_stats)); + auto bn_bwd_pd = batch_normalization_backward::primitive_desc( + bwd_desc, cpu_engine_, fwd_pd); + + // memory primitive + context_.src_mem.reset(new memory({src_md, cpu_engine_}, DummyData)); + context_.diff_dst_mem.reset(new memory({diff_dst_md, cpu_engine_}, + DummyData)); + context_.variance_mem.reset(new memory({variance_desc, cpu_engine_}, + DummyData)); + context_.mean_mem.reset(new memory({mean_desc, cpu_engine_}, DummyData)); + context_.weights_mem.reset(new memory({weights_desc, cpu_engine_}, + DummyData)); + context_.diff_weights_mem.reset(new memory({diff_weights_desc, cpu_engine_}, + DummyData)); + context_.diff_src_mem.reset(new memory({src_md, cpu_engine_}, DummyData)); + + context_.bn_bwd.reset(new batch_normalization_backward( + bn_bwd_pd, *context_.src_mem, *context_.mean_mem, + *context_.variance_mem, *context_.diff_dst_mem, *context_.weights_mem, + *context_.diff_src_mem, *context_.diff_weights_mem)); + context_.bwd_primitives.push_back(*context_.bn_bwd); + } + + struct BatchNormBwdContext context_; + engine cpu_engine_; +}; + +template +class MklFusedBatchNormBwdPrimitiveFactory : public MklPrimitiveFactory { + public: + static MklFusedBatchNormBwdPrimitive* Get( + const MklBatchNormBwdParams& bwdParams) { + auto bn_bwd = static_cast*>( + MklFusedBatchNormBwdPrimitiveFactory + ::GetInstance().GetBatchNormBwd(bwdParams)); + if (bn_bwd == nullptr) { + bn_bwd = new MklFusedBatchNormBwdPrimitive(bwdParams); + MklFusedBatchNormBwdPrimitiveFactory::GetInstance().SetBatchNormBwd( + bwdParams, bn_bwd); + } + return bn_bwd; + } + + static MklFusedBatchNormBwdPrimitiveFactory& GetInstance() { + static MklFusedBatchNormBwdPrimitiveFactory instance_; + return instance_; + } + + private: + MklFusedBatchNormBwdPrimitiveFactory() {} + ~MklFusedBatchNormBwdPrimitiveFactory() {} + + static std::string CreateKey(const MklBatchNormBwdParams& bwdParams) { + std::string prefix = "bn_bwd"; + FactoryKeyCreator key_creator; + key_creator.AddAsKey(prefix); + key_creator.AddAsKey(bwdParams.src_dims); + key_creator.AddAsKey(bwdParams.diff_dst_dims); + key_creator.AddAsKey(bwdParams.depth); + key_creator.AddAsKey(bwdParams.eps); + key_creator.AddAsKey(bwdParams.training); + return key_creator.GetKey(); + } + + MklPrimitive* GetBatchNormBwd(const MklBatchNormBwdParams& bwdParams) { + std::string key = CreateKey(bwdParams); + return this->GetOp(key); + } + + void SetBatchNormBwd(const MklBatchNormBwdParams& bwdParams, + MklPrimitive* op) { + std::string key = CreateKey(bwdParams); + this->SetOp(key, op); + } +}; + template class MklFusedBatchNormOp : public OpKernel { public: @@ -701,7 +1163,6 @@ class MklFusedBatchNormOp : public OpKernel { void Compute(OpKernelContext* context) override { try { - auto cpu_engine = engine(engine::cpu, 0); const size_t kSrcIndex = 0; // index of src input tensor const size_t kScaleIndex = 1; // index of scale tensor const size_t kShiftIndex = 2; // index of shift tensor @@ -786,7 +1247,7 @@ class MklFusedBatchNormOp : public OpKernel { SetMeanVariance(est_mean_tensor, est_variance_tensor); MklDnnData src(&cpu_engine); - MklDnnData dst(&cpu_engine); + MklDnnData weights(&cpu_engine); memory::format format_m; if (dnn_shape_src.IsMklTensor()) { @@ -800,123 +1261,108 @@ class MklFusedBatchNormOp : public OpKernel { } // set src primitive - memory::dims src_dims; - if (dnn_shape_src.IsMklTensor()) { - src_dims = TFShapeToMklDnnDimsInNCHW(dnn_shape_src.GetTfShape(), - tensor_format_); - } else { - src_dims = - TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_); - } + memory::dims src_dims = dnn_shape_src.IsMklTensor() + ? dnn_shape_src.GetSizesAsMklDnnDims() + : TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_); auto src_md = dnn_shape_src.IsMklTensor() ? dnn_shape_src.GetMklLayout() : memory::desc(src_dims, MklDnnType(), format_m); - src.SetUsrMem(src_md, &src_tensor); - // set weights primitive // MKL-DNN packs scale & shift as "weights": // ...... - auto weights_desc = memory::desc({2, static_cast(depth_)}, - MklDnnType(), memory::format::nc); - auto weights_pd = memory::primitive_desc(weights_desc, cpu_engine); - auto weights_m = memory(weights_pd); - T* weights_data = reinterpret_cast(weights_m.get_data_handle()); + weights.AllocateBuffer(2 * depth_ * sizeof (T)); + T* weights_data = reinterpret_cast(weights.GetAllocatedBuffer()); T* scale_tf = reinterpret_cast(const_cast(scale_tensor.flat().data())); T* shift_tf = reinterpret_cast(const_cast(shift_tensor.flat().data())); - for (int k = 0; k < depth_; k++) { - weights_data[k] = scale_tf[k]; - weights_data[k + depth_] = shift_tf[k]; - } - - // set mean primitive - auto mean_desc = memory::desc({1, static_cast(depth_)}, - MklDnnType(), memory::format::nc); - auto mean_pd = memory::primitive_desc(mean_desc, cpu_engine); + std::memcpy(weights_data, scale_tf, depth_ * sizeof(T)); + std::memcpy(weights_data + depth_, shift_tf, depth_ * sizeof(T)); char* saved_mean_data_tf = reinterpret_cast(saved_mean_tensor->flat().data()); std::memcpy(saved_mean_data_tf, reinterpret_cast(mean_values_), depth_ * sizeof(T)); - auto mean_m = - memory(mean_pd, reinterpret_cast(saved_mean_data_tf)); - // set variance primitive - auto variance_desc = memory::desc({1, static_cast(depth_)}, - MklDnnType(), memory::format::nc); - auto variance_pd = memory::primitive_desc(variance_desc, cpu_engine); char* saved_variance_data_tf = reinterpret_cast(saved_variance_tensor->flat().data()); std::memcpy(saved_variance_data_tf, reinterpret_cast(variance_values_), depth_ * sizeof(T)); - auto variance_m = memory(variance_pd, saved_variance_data_tf); - - prop_kind pk = (is_training_) ? prop_kind::forward_training - : prop_kind::forward_scoring; - auto bnrm_fwd_desc = batch_normalization_forward::desc( - pk, src.GetUsrMemDesc(), epsilon_, - is_training_ ? use_scale_shift - : (use_scale_shift | use_global_stats)); - auto bnrm_fwd_pd = batch_normalization_forward::primitive_desc( - bnrm_fwd_desc, cpu_engine); - - // allocate dst tensor - MklDnnShape dnn_shape_dst; - TensorShape tf_shape_dst; - if (dnn_shape_src.IsMklTensor()) { - dnn_shape_dst.SetMklTensor(true); - auto dst_pd = bnrm_fwd_pd.dst_primitive_desc(); - dnn_shape_dst.SetMklLayout(&dst_pd); - dnn_shape_dst.SetElemType(MklDnnType()); - dnn_shape_dst.SetTfLayout(dnn_shape_src.GetDimension(), src_dims, - format_m); - tf_shape_dst.AddDim(dst_pd.get_size() / sizeof(T)); - } else { - dnn_shape_dst.SetMklTensor(false); - tf_shape_dst = src_tensor.shape(); - } - AllocateOutputSetMklShape(context, kDstIndex, &dst_tensor, tf_shape_dst, - dnn_shape_dst); - // Output of batchnorm has same shape as input. - dst.SetUsrMem(src_md, dst_tensor); + // get batchnorm op from the pool + MklBatchNormFwdParams fwdParams(src_dims, depth_, epsilon_, is_training_); + MklFusedBatchNormFwdPrimitive *bn_fwd = + MklFusedBatchNormFwdPrimitiveFactory::Get(fwdParams); - primitive bnrm_fwd_op; - if (is_training_) { - bnrm_fwd_op = - batch_normalization_forward(bnrm_fwd_pd, src.GetOpMem(), weights_m, - dst.GetOpMem(), mean_m, variance_m); + // check if reorder is needed for src, weights, mean, variance + std::vector net; + T* src_data = nullptr; + if (src_md.data.format != bn_fwd->GetSrcFmt()) { + src.SetUsrMem(src_md, &src_tensor); + auto src_target = memory::primitive_desc({{src_dims}, MklDnnType(), + static_cast(bn_fwd->GetSrcFmt())}, cpu_engine); + src.CheckReorderToOpMem(src_target, &net); + src_data = static_cast(src.GetOpMem().get_data_handle()); } else { - bnrm_fwd_op = batch_normalization_forward( - bnrm_fwd_pd, src.GetOpMem(), mean_m, variance_m, - (const primitive::at)weights_m, dst.GetOpMem()); + src_data = static_cast( + const_cast(src_tensor.flat().data())); } - std::vector net; - net.push_back(bnrm_fwd_op); stream(stream::kind::eager).submit(net).wait(); + // allocate output (dst) tensor; always set it as MKL-DNN layout + MklDnnShape dnn_shape_dst; + TensorShape tf_shape_dst; + dnn_shape_dst.SetMklTensor(true); + auto dst_pd = bn_fwd->GetDstPd(); + dnn_shape_dst.SetMklLayout(&dst_pd); + dnn_shape_dst.SetElemType(MklDnnType()); + auto ndims = dnn_shape_src.IsMklTensor() + ? dnn_shape_src.GetDimension() + : src_tensor.shape().dims(); + dnn_shape_dst.SetTfLayout(ndims, src_dims, format_m); + tf_shape_dst.AddDim(dst_pd.get_size() / sizeof(T)); + AllocateOutputSetMklShape(context, kDstIndex, &dst_tensor, + tf_shape_dst, dnn_shape_dst); + + T* weights_op_data = weights_data; + T* mean_op_data = reinterpret_cast( + saved_mean_tensor->flat().data()); + T* variance_op_data = reinterpret_cast( + saved_variance_tensor->flat().data()); + T* dst_data = static_cast(dst_tensor->flat().data()); + + // execution + bn_fwd->Execute(src_data, weights_op_data, dst_data, + mean_op_data, variance_op_data); + // copy batch_mean data T* batch_mean_data_tf = reinterpret_cast(batch_mean_tensor->flat().data()); std::memcpy(reinterpret_cast(batch_mean_data_tf), - reinterpret_cast(mean_m.get_data_handle()), + reinterpret_cast(saved_mean_data_tf), depth_ * sizeof(T)); + // TODO(yli135): OpMem is same as usr mem since + // since its format is hard-coded as nc when primitive is created. // copy batch_variance data with Bessel's correction - // if training mode is on float adjust_factor = 1.0; if (is_training_) { size_t orig_size = src_dims[0] * src_dims[2] * src_dims[3]; size_t adjust_size = orig_size - 1; adjust_factor = (static_cast(orig_size)) / adjust_size; } - for (int k = 0; k < depth_; k++) - batch_variance_tensor->flat().data()[k] = - (reinterpret_cast(variance_m.get_data_handle()))[k] * - adjust_factor; + + auto variance_data = reinterpret_cast(saved_variance_data_tf); + auto batch_variance_data = batch_variance_tensor->flat().data(); + if (is_training_) { + for (int k = 0; k < depth_; k++) { + batch_variance_data[k] = variance_data[k] * adjust_factor; + } + } else { + std::memcpy(batch_variance_data, variance_data, depth_ * sizeof(T)); + } } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + ", message: " + string(e.message) + ", in file " + @@ -933,7 +1379,8 @@ class MklFusedBatchNormOp : public OpKernel { bool is_training_; T* mean_values_; T* variance_values_; - int depth_; // batch normalization is done for per channel. + size_t depth_; // batch normalization is done for per channel. + engine cpu_engine = engine(engine::cpu, 0); void ExtractParams(OpKernelContext* context) { const Tensor& input = MklGetInput(context, 0); @@ -990,8 +1437,10 @@ class MklFusedBatchNormOp : public OpKernel { tf_shape_scale, mkl_shape_batch_mean); CHECK_NOTNULL(*batch_mean_tensor); // set NAN mean value in case of empty input tensor - for (int k = 0; k < tf_shape_scale.num_elements(); k++) - (*batch_mean_tensor)->flat().data()[k] = NAN; + int num_elements = tf_shape_scale.num_elements(); + auto batch_mean_data = (*batch_mean_tensor)->flat().data(); + for (int k = 0; k < num_elements; k++) + batch_mean_data[k] = NAN; // allocate batch variance output tensor MklDnnShape mkl_shape_batch_variance; @@ -1001,8 +1450,9 @@ class MklFusedBatchNormOp : public OpKernel { mkl_shape_batch_variance); CHECK_NOTNULL(*batch_variance_tensor); // set NAN variance value in case of empty input tensor - for (int k = 0; k < tf_shape_scale.num_elements(); k++) - (*batch_variance_tensor)->flat().data()[k] = NAN; + auto batch_variance_data = (*batch_variance_tensor)->flat().data(); + for (int k = 0; k < num_elements; k++) + batch_variance_data[k] = NAN; // Mean and variance (without Bessel's correction) saved for backward // computation to serve as pre-computed mean and variance. @@ -1012,8 +1462,9 @@ class MklFusedBatchNormOp : public OpKernel { tf_shape_scale, mkl_shape_saved_mean); CHECK_NOTNULL(*saved_mean_tensor); // set NAN mean value in case of empty input tensor - for (int k = 0; k < tf_shape_scale.num_elements(); k++) - (*saved_mean_tensor)->flat().data()[k] = NAN; + auto saved_mean_data = (*saved_mean_tensor)->flat().data(); + for (int k = 0; k < num_elements; k++) + saved_mean_data[k] = NAN; MklDnnShape mkl_shape_saved_variance; mkl_shape_saved_variance.SetMklTensor(false); @@ -1022,8 +1473,9 @@ class MklFusedBatchNormOp : public OpKernel { mkl_shape_saved_variance); CHECK_NOTNULL(*saved_variance_tensor); // set NAN variance value in case of empty input tensor - for (int k = 0; k < tf_shape_scale.num_elements(); k++) - (*saved_variance_tensor)->flat().data()[k] = NAN; + auto saved_variance_data = (*saved_variance_tensor)->flat().data(); + for (int k = 0; k < num_elements; k++) + saved_variance_data[k] = NAN; } }; @@ -1044,24 +1496,24 @@ class MklFusedBatchNormGradOp : public OpKernel { void Compute(OpKernelContext* context) override { try { - auto cpu_engine = engine(engine::cpu, 0); const size_t kDiffDstIndex = 0; // index of diff_dst tensor const size_t kSrcIndex = 1; // index of src input tensor const size_t kScaleIndex = 2; // index of scale tensor const size_t kMeanIndex = 3; // index of saved_mean tensor const size_t kVarianceIndex = 4; // index of saved_variance tensor + const Tensor& diff_dst_tensor = MklGetInput(context, kDiffDstIndex); const Tensor& src_tensor = MklGetInput(context, kSrcIndex); const Tensor& scale_tensor = MklGetInput(context, kScaleIndex); const Tensor& saved_mean_tensor = MklGetInput(context, kMeanIndex); const Tensor& saved_variance_tensor = - MklGetInput(context, kVarianceIndex); + MklGetInput(context, kVarianceIndex); MklDnnShape dnn_shape_src, dnn_shape_diff_dst; GetMklShape(context, kSrcIndex, &dnn_shape_src); GetMklShape(context, kDiffDstIndex, &dnn_shape_diff_dst); - TensorShape tf_shape_src, tf_shape_diff_dst; + TensorShape tf_shape_src, tf_shape_diff_dst; if (dnn_shape_diff_dst.IsMklTensor()) { tf_shape_diff_dst = dnn_shape_diff_dst.GetTfShape(); OP_REQUIRES( @@ -1102,6 +1554,7 @@ class MklFusedBatchNormGradOp : public OpKernel { saved_variance_tensor.shape().DebugString())); Tensor* diff_src_tensor = nullptr; + // special case: input with 0 element and 0 batch size if (tf_shape_src.num_elements() == 0 || tf_shape_diff_dst.num_elements() == 0) { HandleEmptyInput(context, tf_shape_src, scale_tensor.shape(), @@ -1117,174 +1570,114 @@ class MklFusedBatchNormGradOp : public OpKernel { ExtractParams(context); } + memory::format format_m; + if (dnn_shape_src.IsMklTensor()) { + if (dnn_shape_src.IsTensorInNCHWFormat()) + format_m = memory::format::nchw; + else + format_m = memory::format::nhwc; + } else { + format_m = TFDataFormatToMklDnnDataFormat(tensor_format_); + } + MklDnnData src(&cpu_engine); - MklDnnData mean(&cpu_engine); - MklDnnData variance(&cpu_engine); MklDnnData diff_dst(&cpu_engine); - MklDnnData diff_src(&cpu_engine); + MklDnnData weights(&cpu_engine); + MklDnnData diff_weights(&cpu_engine); + + memory::dims src_dims = dnn_shape_src.IsMklTensor() + ? dnn_shape_src.GetSizesAsMklDnnDims() + : TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_); + memory::dims diff_dst_dims = dnn_shape_diff_dst.IsMklTensor() + ? dnn_shape_diff_dst.GetSizesAsMklDnnDims() + : TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(), tensor_format_); + + // set src and diff_dst primitive descriptors + memory::desc src_md = dnn_shape_src.IsMklTensor() + ? dnn_shape_src.GetMklLayout() + : memory::desc(src_dims, MklDnnType(), format_m); + memory::desc diff_dst_md = dnn_shape_diff_dst.IsMklTensor() + ? dnn_shape_diff_dst.GetMklLayout() + : memory::desc(diff_dst_dims, MklDnnType(), format_m); + + // weights -- MKL DNN packs scales/ shifts as weights in order + // of scale, ..., scale, shift, ...., shift + weights.AllocateBuffer(2 * depth_ * sizeof(T)); + T* weights_data_tf = reinterpret_cast(weights.GetAllocatedBuffer()); + T* scale_tf = + reinterpret_cast(const_cast(scale_tensor.flat().data())); + for (int k = 0; k < depth_; k++) { + weights_data_tf[k] = scale_tf[k]; + weights_data_tf[k + depth_] = 0; + } - memory::dims src_dims, diff_dst_dims; - if (dnn_shape_src.IsMklTensor()) - src_dims = TFShapeToMklDnnDimsInNCHW(dnn_shape_src.GetTfShape(), - tensor_format_); - else - src_dims = - TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_); + diff_weights.AllocateBuffer(2 * depth_ * sizeof(T)); - if (dnn_shape_diff_dst.IsMklTensor()) - diff_dst_dims = TFShapeToMklDnnDimsInNCHW( - dnn_shape_diff_dst.GetTfShape(), tensor_format_); - else - diff_dst_dims = - TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(), tensor_format_); + MklBatchNormBwdParams bwdParams(src_dims, diff_dst_dims, + depth_, epsilon_, is_training_); + MklFusedBatchNormBwdPrimitive *bn_bwd = + MklFusedBatchNormBwdPrimitiveFactory::Get(bwdParams); - // set src and diff_dst primitives according to input layout - memory::desc src_md({}, memory::data_undef, memory::format_undef); - memory::desc diff_dst_md({}, memory::data_undef, memory::format_undef); - if (dnn_shape_src.IsMklTensor()) { - src_md = dnn_shape_src.GetMklLayout(); + // check if src/diff_dst need to be reordered + std::vector net; + T* src_data = nullptr; + if (src_md.data.format != bn_bwd->GetSrcFmt()) { + src.SetUsrMem(src_md, &src_tensor); + auto src_target = memory::primitive_desc({{src_dims}, MklDnnType(), + static_cast(bn_bwd->GetSrcFmt())}, cpu_engine); + src.CheckReorderToOpMem(src_target, &net); + src_data = static_cast(src.GetOpMem().get_data_handle()); } else { - src_md = memory::desc(src_dims, MklDnnType(), - TFDataFormatToMklDnnDataFormat(tensor_format_)); + src_data = static_cast(const_cast( + src_tensor.flat().data())); } - if (dnn_shape_diff_dst.IsMklTensor()) { - diff_dst_md = dnn_shape_diff_dst.GetMklLayout(); + + T* diff_dst_data = nullptr; + if (diff_dst_md.data.format != bn_bwd->GetDiffDstFmt()) { + diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); + auto diff_dst_target = memory::primitive_desc({{diff_dst_dims}, + MklDnnType(), static_cast( + bn_bwd->GetDiffDstFmt())}, cpu_engine); + diff_dst.CheckReorderToOpMem(diff_dst_target, &net); + diff_dst_data = static_cast( + diff_dst.GetOpMem().get_data_handle()); } else { - diff_dst_md = memory::desc(diff_dst_dims, MklDnnType(), - TFDataFormatToMklDnnDataFormat(tensor_format_)); + diff_dst_data = static_cast(const_cast( + diff_dst_tensor.flat().data())); } - src.SetUsrMem(src_md, &src_tensor); - diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); - - // weights -- DNN packs scales/shifts as weights in order of - // scale, ..., scale, shift, ..., shift - auto weights_desc = - memory::desc({2, depth_}, MklDnnType(), memory::format::nc); - auto weights_pd = memory::primitive_desc(weights_desc, cpu_engine); - auto weights_m = memory(weights_pd); - T* weights_data = reinterpret_cast(weights_m.get_data_handle()); - T* scale_tf = - reinterpret_cast(const_cast(scale_tensor.flat().data())); - for (int k = 0; k < depth_; k++) { - weights_data[k] = scale_tf[k]; - weights_data[k + depth_] = 0; - } - - // set mean primitive - memory::dims mv_dims = GetMeanVarianceDims(); - mean.SetUsrMem(mv_dims, memory::format::nc, - const_cast(static_cast( - saved_mean_tensor.flat().data()))); - mean.SetOpMemDesc(mv_dims, memory::format::nc); - - // set variance primitive - variance.SetUsrMem(mv_dims, memory::format::nc, - const_cast(static_cast( - saved_variance_tensor.flat().data()))); - variance.SetOpMemDesc(mv_dims, memory::format::nc); - - // set diff_weight primitive - auto diff_weights_desc = - memory::desc({2, depth_}, MklDnnType(), memory::format::nc); - auto diff_weights_pd = - memory::primitive_desc(diff_weights_desc, cpu_engine); - auto diff_weights_m = memory(diff_weights_pd); - - auto bnrm_fwd_desc = batch_normalization_forward::desc( - prop_kind::forward_training, src.GetUsrMemDesc(), epsilon_, - is_training_ ? use_scale_shift - : (use_scale_shift | use_global_stats)); - auto bnrm_fwd_pd = batch_normalization_forward::primitive_desc( - bnrm_fwd_desc, cpu_engine); + stream(stream::kind::eager).submit(net).wait(); // Indices of output tensors const size_t kDiffSrcIndex = 0; // index of diff_src tensor - // allocate diff_src tensor + // allocate diff_src tensor, always set as MKL-DNN layout MklDnnShape dnn_shape_diff_src; TensorShape tf_shape_diff_src; - - // MKL-DNN's BN primitive not provide API to fetch internal format - // set common_md as OpMem - // src and diff_dst will reorder to common_md - // diff_src will set as common_md - memory::desc common_md({}, memory::data_undef, memory::format_undef); - if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) { - if (dnn_shape_src.IsMklTensor()) { - common_md = dnn_shape_src.GetMklLayout(); - } else { - common_md = dnn_shape_diff_dst.GetMklLayout(); - } - } else { - common_md = memory::desc(src_dims, MklDnnType(), - TFDataFormatToMklDnnDataFormat(tensor_format_)); - } - // if any of src and diff_dst as mkl layout, - // then we set diff_src as mkl layout - if (dnn_shape_src.IsMklTensor() || - dnn_shape_diff_dst.IsMklTensor()) { - dnn_shape_diff_src.SetMklTensor(true); - // set diff_src's mkl layout as common_md - auto diff_src_pd = memory::primitive_desc(common_md, cpu_engine); - dnn_shape_diff_src.SetMklLayout(&diff_src_pd); - dnn_shape_diff_src.SetElemType(MklDnnType()); - if (dnn_shape_src.IsMklTensor()) { - dnn_shape_diff_src.SetTfLayout( - dnn_shape_src.GetDimension(), - src_dims, - dnn_shape_src.GetTfDataFormat()); - dnn_shape_diff_src.SetTfDimOrder( - dnn_shape_src.GetDimension(), - tensor_format_); - } else { - dnn_shape_diff_src.SetTfLayout( - dnn_shape_diff_dst.GetDimension(), - src_dims, - dnn_shape_diff_dst.GetTfDataFormat()); - dnn_shape_diff_src.SetTfDimOrder( - dnn_shape_diff_dst.GetDimension(), - tensor_format_); - } - tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T)); - } else { - dnn_shape_diff_src.SetMklTensor(false); - // both src and diff_dst are TensorFlow layout, - // so it is OK to get TensorFlow shape. - tf_shape_diff_src = src_tensor.shape(); - } + dnn_shape_diff_src.SetMklTensor(true); + auto diff_src_pd = bn_bwd->GetDiffSrcPd(); + dnn_shape_diff_src.SetMklLayout(&diff_src_pd); + dnn_shape_diff_src.SetElemType(MklDnnType()); + dnn_shape_diff_src.SetTfLayout(dnn_shape_src.GetDimension(), src_dims, + format_m); + dnn_shape_diff_src.SetTfDimOrder(dnn_shape_src.GetDimension(), + tensor_format_); + tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T)); AllocateOutputSetMklShape(context, kDiffSrcIndex, &diff_src_tensor, tf_shape_diff_src, dnn_shape_diff_src); - // set diff_src - diff_src.SetUsrMem(common_md, diff_src_tensor); - - prop_kind pk = prop_kind::backward; - auto bnrm_bwd_desc = batch_normalization_backward::desc( - pk, common_md, common_md, epsilon_, - /* for inference, specify use_global_stats - 1. on fwd prop, use mean and variance - provided as inputs - 2. on bwd prop, mean and variance are - considered as constants. Thus, - reduce the amout of MKL computations - */ - is_training_ ? use_scale_shift - : (use_scale_shift | use_global_stats)); - auto bnrm_bwd_pd = batch_normalization_backward::primitive_desc( - bnrm_bwd_desc, cpu_engine, bnrm_fwd_pd); - - std::vector net; - src.CheckReorderToOpMem(memory::primitive_desc(common_md, - cpu_engine), &net); - diff_dst.CheckReorderToOpMem(memory::primitive_desc(common_md, - cpu_engine), &net); - - auto bnrm_bwd_op = batch_normalization_backward( - bnrm_bwd_pd, src.GetOpMem(), mean.GetOpMem(), variance.GetOpMem(), - diff_dst.GetOpMem(), weights_m, diff_src.GetOpMem(), diff_weights_m); - net.push_back(bnrm_bwd_op); - stream(stream::kind::eager).submit(net).wait(); + T* mean_data = static_cast(const_cast( + saved_mean_tensor.flat().data())); + T* variance_data = static_cast(const_cast( + saved_variance_tensor.flat().data())); + T* weights_data = weights_data_tf; + T* diff_src_data = static_cast( + diff_src_tensor->flat().data()); + T* diff_weights_data = static_cast( + diff_weights.GetAllocatedBuffer()); + // Execute + bn_bwd->Execute(src_data, mean_data, variance_data, diff_dst_data, + weights_data, diff_src_data, diff_weights_data); // allocate 4 output TF tensors Tensor* diff_scale_tensor = nullptr; @@ -1293,13 +1686,14 @@ class MklFusedBatchNormGradOp : public OpKernel { &diff_shift_tensor); // copy data: diff_scale and diff_shift - T* diff_weights_data_dnn = - reinterpret_cast(diff_weights_m.get_data_handle()); - for (int i = 0; i < depth_; i++) { - diff_scale_tensor->flat().data()[i] = diff_weights_data_dnn[i]; - diff_shift_tensor->flat().data()[i] = - diff_weights_data_dnn[i + depth_]; - } + auto diff_scale_data = diff_scale_tensor->flat().data(); + auto diff_shift_data = diff_shift_tensor->flat().data(); + std::memcpy(reinterpret_cast(diff_scale_data), + reinterpret_cast(diff_weights_data), + depth_ * sizeof(T)); + std::memcpy(reinterpret_cast(diff_shift_data), + reinterpret_cast(diff_weights_data) + depth_, + depth_ * sizeof(T)); } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + ", message: " + string(e.message) + ", in file " + @@ -1315,6 +1709,7 @@ class MklFusedBatchNormGradOp : public OpKernel { TensorFormat tensor_format_; int depth_; // batch normalization is done for per channel. bool is_training_; + engine cpu_engine = engine(engine::cpu, 0); void ExtractParams(OpKernelContext* context) { const Tensor& input = MklGetInput(context, 0); @@ -1330,8 +1725,10 @@ class MklFusedBatchNormGradOp : public OpKernel { dnn_shape_diff_src.SetMklTensor(false); AllocateOutputSetMklShape(context, kDiffSrcIndex, diff_src_tensor, tf_shape_src, dnn_shape_diff_src); - for (size_t i = 0; i < (*diff_src_tensor)->shape().num_elements(); i++) - (*diff_src_tensor)->flat().data()[i] = 0; + int num_elements = (*diff_src_tensor)->shape().num_elements(); + auto diff_src_data = (*diff_src_tensor)->flat().data(); + for (size_t i = 0; i < num_elements; i++) + diff_src_data[i] = 0; Tensor* diff_scale_tensor = nullptr; Tensor* diff_shift_tensor = nullptr; @@ -1357,16 +1754,20 @@ class MklFusedBatchNormGradOp : public OpKernel { AllocateOutputSetMklShape(context, kDiffScaleIndex, diff_scale_tensor, tf_shape_scale_shift, mkl_shape_diff_scale); CHECK_NOTNULL(*diff_scale_tensor); - for (size_t i = 0; i < (*diff_scale_tensor)->shape().num_elements(); i++) - (*diff_scale_tensor)->flat().data()[i] = 0; + int diff_scale_num_elements = (*diff_scale_tensor)->shape().num_elements(); + auto diff_scale_data = (*diff_scale_tensor)->flat().data(); + for (size_t i = 0; i < diff_scale_num_elements; i++) + diff_scale_data[i] = 0; MklDnnShape mkl_shape_diff_shift; mkl_shape_diff_shift.SetMklTensor(false); AllocateOutputSetMklShape(context, kDiffShiftIndex, diff_shift_tensor, tf_shape_scale_shift, mkl_shape_diff_shift); CHECK_NOTNULL(*diff_shift_tensor); - for (size_t i = 0; i < (*diff_shift_tensor)->shape().num_elements(); i++) - (*diff_shift_tensor)->flat().data()[i] = 0; + int diff_shift_num_elements = (*diff_shift_tensor)->shape().num_elements(); + auto diff_shift_data = (*diff_shift_tensor)->flat().data(); + for (size_t i = 0; i < diff_shift_num_elements; i++) + diff_shift_data[i] = 0; // Placeholders for estimated_mean and estimated_variance, which are // used for inference and thus not needed here for gradient computation. diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index bb447e03938024..efba2d89419e25 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -39,7 +39,7 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" - +#include "tensorflow/core/platform/cpu_info.h" #ifndef INTEL_MKL_ML #include "mkldnn.hpp" #include "tensorflow/core/lib/core/stringpiece.h" @@ -1504,7 +1504,8 @@ class MklDnnData { /// Operations memory descriptor memory::desc* op_md_; - + /// Operations temp buffer + void* allocated_buffer_; /// CPU engine on which operation will be executed const engine* cpu_engine_; @@ -1513,6 +1514,7 @@ class MklDnnData { : user_memory_(nullptr), reorder_memory_(nullptr), op_md_(nullptr), + allocated_buffer_(nullptr), cpu_engine_(e) {} ~MklDnnData() { @@ -1653,6 +1655,14 @@ class MklDnnData { user_memory_->set_data_handle(GetTensorBuffer(tensor)); } + /// allocate function for data buffer + inline void AllocateBuffer(size_t size) { + allocated_buffer_ = cpu_allocator()->AllocateRaw(64, size); + } + inline void* GetAllocatedBuffer() { + return allocated_buffer_; + } + /// Get the memory primitive for input and output of an op. If inputs /// to an op require reorders, then this function returns memory primitive /// for reorder. Otherwise, it will return memory primitive for user memory. @@ -1957,6 +1967,21 @@ class FactoryKeyCreator { } }; +static inline memory::format get_desired_format(int channel) { + memory::format fmt_desired = memory::format::any; + + if (port::TestCPUFeature(port::CPUFeature::AVX512F) + && (channel % 16) == 0) { + fmt_desired = memory::format::nChw16c; + } else if (port::TestCPUFeature(port::CPUFeature::AVX2) + && (channel % 8) == 0) { + fmt_desired = memory::format::nChw8c; + } else { + fmt_desired = memory::format::nchw; + } + return fmt_desired; +} + class MklReorderPrimitive : public MklPrimitive { public: explicit MklReorderPrimitive(const memory* from, const memory* to) {