Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Move the padding usage check to a separate function
Browse files Browse the repository at this point in the history
  • Loading branch information
PawelGlomski-Intel committed Jan 11, 2021
1 parent 288dceb commit 76018df
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions src/operator/nn/mkldnn/mkldnn_concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,15 @@
namespace mxnet {
namespace op {

MKLDNNConcatFwd::MKLDNNConcatFwd(
int concat_dim, const std::vector<mkldnn::memory::desc> &data_md)
static inline bool IsUsingPadding(const mkldnn::memory::desc &dst_md) {
// make sure a blocked format is used (at least one dimension is blocked)
bool is_blocked_format = dst_md.data.format_kind == mkldnn_blocked &&
dst_md.data.format_desc.blocking.inner_nblks > 0;
return is_blocked_format && !std::equal(dst_md.data.dims, dst_md.data.dims + dst_md.data.ndims,
dst_md.data.padded_dims);
}

MKLDNNConcatFwd::MKLDNNConcatFwd(int concat_dim, const std::vector<mkldnn::memory::desc> &data_md)
: fwd_pd(concat_dim, data_md, CpuEngine::Get()->get_engine()) {
// MKL-DNN introduced padded formats since 0.15 which require more memory
// compared to the actual size of the tensor. Currently, MKL-DNN operators
Expand All @@ -39,14 +46,10 @@ MKLDNNConcatFwd::MKLDNNConcatFwd(

// When fwd_pd uses padding, impose a plain format
const auto &dst_md = fwd_pd.dst_desc();
if (dst_md.data.format_kind == mkldnn_blocked &&
dst_md.data.format_desc.blocking.inner_nblks > 0 &&
!std::equal(dst_md.data.dims, dst_md.data.dims + dst_md.data.ndims,
dst_md.data.padded_dims)) {
if (IsUsingPadding(dst_md)) {
auto plain_dst_tag = static_cast<mkldnn::memory::format_tag>(
GetDefaultFormat(dst_md.data.ndims));
auto plain_dst_md =
mkldnn::memory::desc(dst_md.dims(), dst_md.data_type(), plain_dst_tag);
auto plain_dst_md = mkldnn::memory::desc(dst_md.dims(), dst_md.data_type(), plain_dst_tag);
fwd_pd = mkldnn::concat::primitive_desc(plain_dst_md, concat_dim, data_md,
CpuEngine::Get()->get_engine());
}
Expand Down

0 comments on commit 76018df

Please sign in to comment.