diff --git a/src/operator/nn/mkldnn/mkldnn_concat.cc b/src/operator/nn/mkldnn/mkldnn_concat.cc index 96893959050a..1dd2dc31ee0c 100644 --- a/src/operator/nn/mkldnn/mkldnn_concat.cc +++ b/src/operator/nn/mkldnn/mkldnn_concat.cc @@ -29,8 +29,15 @@ namespace mxnet { namespace op { -MKLDNNConcatFwd::MKLDNNConcatFwd( - int concat_dim, const std::vector &data_md) +static inline bool IsUsingPadding(const mkldnn::memory::desc &dst_md) { + // make sure a blocked format is used (at least one dimension is blocked) + bool is_blocked_format = dst_md.data.format_kind == mkldnn_blocked && + dst_md.data.format_desc.blocking.inner_nblks > 0; + return is_blocked_format && !std::equal(dst_md.data.dims, dst_md.data.dims + dst_md.data.ndims, + dst_md.data.padded_dims); +} + +MKLDNNConcatFwd::MKLDNNConcatFwd(int concat_dim, const std::vector &data_md) : fwd_pd(concat_dim, data_md, CpuEngine::Get()->get_engine()) { // MKL-DNN introduced padded formats since 0.15 which require more memory // compared to the actual size of the tensor. Currently, MKL-DNN operators @@ -39,14 +46,10 @@ MKLDNNConcatFwd::MKLDNNConcatFwd( // When fwd_pd uses padding, impose a plain format const auto &dst_md = fwd_pd.dst_desc(); - if (dst_md.data.format_kind == mkldnn_blocked && - dst_md.data.format_desc.blocking.inner_nblks > 0 && - !std::equal(dst_md.data.dims, dst_md.data.dims + dst_md.data.ndims, - dst_md.data.padded_dims)) { + if (IsUsingPadding(dst_md)) { auto plain_dst_tag = static_cast( GetDefaultFormat(dst_md.data.ndims)); - auto plain_dst_md = - mkldnn::memory::desc(dst_md.dims(), dst_md.data_type(), plain_dst_tag); + auto plain_dst_md = mkldnn::memory::desc(dst_md.dims(), dst_md.data_type(), plain_dst_tag); fwd_pd = mkldnn::concat::primitive_desc(plain_dst_md, concat_dim, data_md, CpuEngine::Get()->get_engine()); }