From 69e6c0473e6ee40bed3ef21facc9b78b66b41e60 Mon Sep 17 00:00:00 2001 From: bgawrych Date: Wed, 19 Jan 2022 13:59:17 +0100 Subject: [PATCH] Optimize 'take' operator for CPU (#20745) * Improve performance of take operator * remove comment * Fix build * fix sanity * Add comment * review * Update src/operator/tensor/indexing_op.h Co-authored-by: bartekkuncer Co-authored-by: Sheng Zha Co-authored-by: bartekkuncer --- src/operator/tensor/indexing_op.cc | 99 +++++++++++++++++++++--------- src/operator/tensor/indexing_op.h | 5 +- 2 files changed, 72 insertions(+), 32 deletions(-) diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index 30825415e481..28eca41b2ae4 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -60,6 +60,51 @@ struct TakeZeroAxisCPU { } }; +template +struct TakeNonzeroAxisCPU { + /*! + * \brief Map function for take operator + * \param i global thread id + * \param out_data ptr to output buffer + * \param in_data ptr to input buffer + * \param indices ptr to indices buffer + * \param outer_dim_stride stride of dimension before axis + * \param axis_dim_stride stride of axis dimension + * \param idx_size size of the indices tensor + * \param axis_dim dim size of the axis dimension + * \param axis axis id + */ + template + MSHADOW_XINLINE static void Map(index_t i, + DType* out_data, + const DType* in_data, + const IType* indices, + const index_t outer_dim_stride, + const index_t axis_dim_stride, + const int idx_size, + const int axis_dim, + const int axis) { + for (index_t j = 0; j < static_cast(idx_size); ++j) { + int index = indices[j]; + if (clip) { + index = std::max(index, 0); + index = std::min(axis_dim - 1, index); + } else { + index %= axis_dim; + index += (index < 0) ? axis_dim : 0; + } + size_t in_offset = i * outer_dim_stride + index * axis_dim_stride; + size_t out_offset = (i * idx_size + j) * axis_dim_stride; +#pragma GCC diagnostic push +#if __GNUC__ >= 8 +#pragma GCC diagnostic ignored "-Wclass-memaccess" +#endif + std::memcpy(out_data + out_offset, in_data + in_offset, axis_dim_stride * sizeof(DType)); +#pragma GCC diagnostic pop + } + } +}; + /* * \brief returns true if all indices are between [min, max] * \param data_ptr the indices to check @@ -323,6 +368,7 @@ void TakeOpForward(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { using namespace mxnet_op; + if (req[take_::kOut] == kNullOp) return; const TakeParam& param = nnvm::get(attrs.parsed); @@ -375,39 +421,32 @@ void TakeOpForward(const nnvm::NodeAttrs& attrs, for (int i = arrshape.ndim() - 1; i >= 0; stride *= arrshape[i], --i) { in_strides[i] = stride; } - mshadow::Shape<10> out_strides; - stride = 1; - for (int i = oshape.ndim() - 1; i >= 0; stride *= oshape[i], --i) { - out_strides[i] = stride; + int outer_dimensions = 1; + for (int i = 0; i < actual_axis; i++) { + outer_dimensions *= oshape[i]; } if (param.mode == take_::kClip) { - Kernel, cpu>::Launch(s, - oshape.Size(), - outputs[take_::kOut].dptr(), - inputs[take_::kArr].dptr(), - inputs[take_::kIdx].dptr(), - out_strides[actual_axis - 1], - in_strides[actual_axis - 1], - in_strides[actual_axis], - arrshape.ndim(), - oshape.ndim(), - idxshape.ndim(), - arrshape[actual_axis], - actual_axis); + Kernel, cpu>::Launch(s, + outer_dimensions, + outputs[take_::kOut].dptr(), + inputs[take_::kArr].dptr(), + inputs[take_::kIdx].dptr(), + in_strides[actual_axis - 1], + in_strides[actual_axis], + idxshape.Size(), + arrshape[actual_axis], + actual_axis); } else { - Kernel, cpu>::Launch(s, - oshape.Size(), - outputs[take_::kOut].dptr(), - inputs[take_::kArr].dptr(), - inputs[take_::kIdx].dptr(), - out_strides[actual_axis - 1], - in_strides[actual_axis - 1], - in_strides[actual_axis], - arrshape.ndim(), - oshape.ndim(), - idxshape.ndim(), - arrshape[actual_axis], - actual_axis); + Kernel, cpu>::Launch(s, + outer_dimensions, + outputs[take_::kOut].dptr(), + inputs[take_::kArr].dptr(), + inputs[take_::kIdx].dptr(), + in_strides[actual_axis - 1], + in_strides[actual_axis], + idxshape.Size(), + arrshape[actual_axis], + actual_axis); } } }); diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 81a04aa24027..cd97be8dcfc0 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -215,8 +215,9 @@ inline bool EmbeddingOpBackwardStorageType(const nnvm::NodeAttrs& attrs, return dispatched; } -/*! \brief name the struct TakeNonzeroAxis for general take when - * axis is not zero, use TakeZeroAxisGPU or TakeZeroAxisCPU for axis zero +/*! \brief TakeNonzeroAxis is designated for general take when + * axis is not zero (for CPU optimized version use TakeNonZeroAxisCPU and + for axis zero use TakeZeroAxisGPU or TakeZeroAxisCPU) */ template struct TakeNonzeroAxis {