From 78d30c24e947eaafeea7e9fc9c778fd407dcbe12 Mon Sep 17 00:00:00 2001 From: Rohit Kumar Srivastava Date: Fri, 5 Jun 2020 15:29:26 +0000 Subject: [PATCH 1/5] fixing broadcast_axis kernel to int32 --- src/operator/tensor/broadcast_reduce_op.h | 33 +++++++++++++++-------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index 12af331eefb0..60edf1147bee 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -1085,7 +1085,7 @@ inline void PrepareAUXData(ShapeAndStride *aux_data, } } // unnamed namespace -template +template struct broadcast_kernel { template MSHADOW_XINLINE static void Map(index_t i, @@ -1094,15 +1094,16 @@ struct broadcast_kernel { const ShapeAndStride& aux_data, const OpReqType req, const int ndim) { - index_t idx = i; - index_t in_idx = i; + printf("size of IDXType=%d",sizeof(IDXType)); + IDXType idx = i; + IDXType in_idx = i; #pragma unroll 4 - for (int iter = ndim - 1; iter >= 0; --iter) { - index_t out_dim_shape = aux_data.output_shape[iter]; - index_t out_dim_stride = aux_data.out_stride[iter]; + for (IDXType iter = ndim - 1; iter >= 0; --iter) { + IDXType out_dim_shape = aux_data.output_shape[iter]; + IDXType out_dim_stride = aux_data.out_stride[iter]; // x % y = x - (x / y) * y // speeds up modulo(%) operation in GPU - index_t dim_idx = idx - (idx / out_dim_shape) * out_dim_shape; + IDXType dim_idx = idx - (idx / out_dim_shape) * out_dim_shape; if (aux_data.input_shape[iter] != 1) { in_idx += dim_idx * (aux_data.in_stride[iter] - out_dim_stride); } else { @@ -1147,16 +1148,26 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs, outputs[0].get_with_shape(dst_shape.get<2>(), s); Tensor data = inputs[0].get_with_shape(src_shape.get<2>(), s); - Kernel, xpu>::Launch( - s, out.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], 2); + if (ctx.run_ctx.get_ctx().dev_type == Context::kGPU) { + Kernel, xpu>::Launch( + s, out.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], 2); + } else { + Kernel, xpu>::Launch( + s, out.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], 2); + } } else { const int ndim = MXNET_SPECIAL_MAX_NDIM; Tensor out = outputs[0].get_with_shape(dst_shape.get(), s); Tensor data = inputs[0].get_with_shape(src_shape.get(), s); - Kernel, xpu>::Launch( - s, out.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], ndim); + if (ctx.run_ctx.get_ctx().dev_type == Context::kGPU) { + Kernel, xpu>::Launch( + s, out.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], ndim); + } else { + Kernel, xpu>::Launch( + s, out.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], 2); + } } }); }); From ba9d26a0a8312212f8d9e1c73b68df069b4478f2 Mon Sep 17 00:00:00 2001 From: Rohit Kumar Srivastava Date: Fri, 5 Jun 2020 15:30:02 +0000 Subject: [PATCH 2/5] fixing slice_axis kernel to int32 --- 3rdparty/mshadow/mshadow/extension/slice.h | 36 +++++++++++----------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/3rdparty/mshadow/mshadow/extension/slice.h b/3rdparty/mshadow/mshadow/extension/slice.h index 038818f03c3e..9ac71a0d6ff9 100644 --- a/3rdparty/mshadow/mshadow/extension/slice.h +++ b/3rdparty/mshadow/mshadow/extension/slice.h @@ -45,10 +45,10 @@ struct SliceExp : public TRValue { static const int dimslice = srcdim - dimsrc_m_slice; const SrcExp &src_; - index_t ch_begin_; - index_t ch_old_; + int ch_begin_; + int ch_old_; Shape shape_; - SliceExp(const SrcExp &src, index_t begin, index_t end) + SliceExp(const SrcExp &src, int begin, int end) : src_(src), ch_begin_(begin) { shape_ = ShapeCheck::Check(src_); ch_old_ = shape_[dimslice]; @@ -81,7 +81,7 @@ struct SliceExp : public TRValue inline SliceExp -slice(const TRValue &src, index_t begin, index_t end) { +slice(const TRValue &src, int begin, int end) { TypeCheckPass::kDim == srcdim> ::Error_Expression_Does_Not_Meet_Dimension_Req(); return SliceExp(src.self(), begin, end); @@ -129,26 +129,26 @@ struct Plan, DType> { : src_(MakePlan(e.src_)), height_(e.shape_.ProdShape(dimslice + 1, srcdim - 1)), ch_begin_(e.ch_begin_), ch_old_(e.ch_old_), ch_(e.shape_[dimslice]) {} - MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { - const index_t y = i % height_; + MSHADOW_XINLINE DType Eval(int i, int j) const { + const int y = i % height_; i /= height_; - const index_t c = i % ch_ + ch_begin_; - const index_t b = i / ch_; - const index_t x = j; + const int c = i % ch_ + ch_begin_; + const int b = i / ch_; + const int x = j; return src_.Eval((b * ch_old_ + c) * height_ + y, x); } - MSHADOW_XINLINE DType &REval(index_t i, index_t j) { - const index_t y = i % height_; + MSHADOW_XINLINE DType &REval(int i, int j) { + const int y = i % height_; i /= height_; - const index_t c = i % ch_ + ch_begin_; - const index_t b = i / ch_; - const index_t x = j; + const int c = i % ch_ + ch_begin_; + const int b = i / ch_; + const int x = j; return src_.REval((b * ch_old_ + c) * height_ + y, x); } private: Plan src_; - const index_t height_, ch_begin_, ch_old_, ch_; + const int height_, ch_begin_, ch_old_, ch_; }; // struct Plan template, DType> { explicit Plan(const SliceExp &e) : src_(MakePlan(e.src_)), ch_begin_(e.ch_begin_) {} - MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + MSHADOW_XINLINE DType Eval(int y, int x) const { return src_.Eval(y, x + ch_begin_); } - MSHADOW_XINLINE DType &REval(index_t y, index_t x) { + MSHADOW_XINLINE DType &REval(int y, int x) { return src_.REval(y, x + ch_begin_); } private: Plan src_; - const index_t ch_begin_; + const int ch_begin_; }; } // namespace expr } // namespace mshadow From e6e3bc704599258bc5a6764ac4dc7f93c768890e Mon Sep 17 00:00:00 2001 From: Rohit Kumar Srivastava Date: Mon, 20 Apr 2020 16:38:09 +0000 Subject: [PATCH 3/5] adding comments explaining code optimizations --- src/operator/tensor/broadcast_reduce_op.h | 112 +++++++++++++++++++++- 1 file changed, 111 insertions(+), 1 deletion(-) diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index 60edf1147bee..bcbd868b1285 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -1049,6 +1049,7 @@ void ReduceAxesBackwardUseInOut(const nnvm::NodeAttrs& attrs, ReduceAxesBackwardUseInOutImpl(ctx, small, inputs, req, outputs); } +<<<<<<< HEAD namespace { // unnamed namespace to keep scope of the struct within the file struct ShapeAndStride { index_t in_stride[MXNET_SPECIAL_MAX_NDIM]; @@ -1087,12 +1088,24 @@ inline void PrepareAUXData(ShapeAndStride *aux_data, template struct broadcast_kernel { +======= +/** + * Changed the thread workload mapping from 1 + * thread/output element to 1 thread/input to be broadcasted + * This approach leverages vectorization when fastest varying + * index(stride=1) of the tensor is to be broadcasted. + * In other cases it simply performs better by better load balancing. + */ +template +struct broadcast_kernel_cpu { +>>>>>>> 9919eff50... adding comments explaining code optimizations template MSHADOW_XINLINE static void Map(index_t i, IType *input, OType *output, const ShapeAndStride& aux_data, const OpReqType req, +<<<<<<< HEAD const int ndim) { printf("size of IDXType=%d",sizeof(IDXType)); IDXType idx = i; @@ -1112,11 +1125,65 @@ struct broadcast_kernel { idx /= out_dim_shape; } KERNEL_ASSIGN(output[i], req, OP::Map(input[in_idx])); +======= + const uint32_t ndim, + const size_t *axes, + const size_t *out_stride, + const int num_broadcast_axes) { + index_t idx = i; + index_t init_off = 0; + for (int iter = ndim - 1; idx > 0 && iter >= 0; --iter) { + size_t dim_idx = idx % in_shape[iter]; + init_off += dim_idx * out_stride[iter]; + idx /= in_shape[iter]; + } + index_t stride_0, stride_1, stride_2; + // Each case is based on the number of axis to be broadcasted + // (1, 2 or 3) after merging axes. + switch (num_broadcast_axes) { + // when input shape is amogst one of the form + // [(x,1), (x,1,x), (1,x)] + // x can be any +ve number >=0 and they need not be equal to each other + case 1 : + stride_0 = out_stride[axes[0]]; + for (int l=0; l < out_shape[axes[0]]; l++) { + KERNEL_ASSIGN(output[init_off + l*stride_0], + req, OP::Map(input[i])); + } + break; + // when input shape is amogst one of the form + // [(x,1,x,1), (1,x,1,x), (x,1,x,1,x)] + // x can be any +ve number >1 or =0(the axis ) and they need not be equal to each other + case 2: + stride_1 = out_stride[axes[1]], stride_0 = out_stride[axes[0]]; + for (int k=0; k < out_shape[axes[1]]; k++) { + for (int l=0; l < out_shape[axes[0]]; l++) { + KERNEL_ASSIGN(output[init_off + k*stride_1 + l*stride_0], + req, OP::Map(input[i])); + } + } + break; + // when input shape is of the form [(1,x,1,x,1)] and + // x can be any +ve number >=0 and they need not be equal to each other + case 3: + stride_2 = out_stride[axes[2]], stride_1 = out_stride[axes[1]]; + stride_0 = out_stride[axes[0]]; + for (int j=0; j < out_shape[axes[2]]; j++) { + for (int k=0; k < out_shape[axes[1]]; k++) { + for (int l=0; l < out_shape[axes[0]]; l++) { + KERNEL_ASSIGN(output[init_off + j*stride_2 + k*stride_1 + l*stride_0], + req, OP::Map(input[i])); + } + } + } + break; + } +>>>>>>> 9919eff50... adding comments explaining code optimizations } }; template -inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs, +inline void BroadcastComputeImplCpu(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, const std::vector& req, @@ -1126,6 +1193,10 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs, using namespace mshadow::expr; using namespace mxnet_op; mxnet::TShape src_shape, dst_shape; + // combines 2 or more consecutive broadcast/non-broadcast axes together + // e.g. (3,4,1,1,5,1,6,7) (2,3,5) (5,10,9) -> (12,1,5,1,42) (1,3) (50, 9) + // and this is the new input for broadcast_kernel whose total + // num of dimensions cannot be greater than 5(throws an error otherwise). BroadcastReduceShapeCompact(outputs[0].shape_, small, &dst_shape, &src_shape); Stream *s = ctx.get_stream(); MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, IType, { @@ -1141,13 +1212,41 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs, out_shape[i] = 1; } } +<<<<<<< HEAD struct ShapeAndStride aux_data; PrepareAUXData(&aux_data, in_shape, out_shape, dst_shape.ndim()); if (dst_shape.ndim() == 2) { +======= + // axes: stores which axes in input is to broadcasted + // stride: stores offset corresponding to an index of output tensor. + // It is calculated using shape of the output tensor. + size_t axes[dst_shape.ndim()], out_stride[dst_shape.ndim()]; + int iter = dst_shape.ndim() - 1, i = 0; + bool shape_changed = false; + out_stride[iter] = 1; + if (in_shape[iter] != dst_shape[iter]) { + axes[i++] = iter; + shape_changed = true; + } + --iter; + for (; iter >= 0; --iter) { + if (in_shape[iter] != dst_shape[iter]) { + axes[i++] = iter; + shape_changed = true; + } + out_stride[iter] = out_stride[iter+1] * dst_shape[iter+1]; + } + if (!shape_changed) { + // If no broadcast is required (i.e. input_shape == output_shape) + // then simply copy input to outout. + mxnet_op::copy(ctx.get_stream(), outputs[0], inputs[0]); + } else if (dst_shape.ndim() == 2) { +>>>>>>> 9919eff50... adding comments explaining code optimizations Tensor out = outputs[0].get_with_shape(dst_shape.get<2>(), s); Tensor data = inputs[0].get_with_shape(src_shape.get<2>(), s); +<<<<<<< HEAD if (ctx.run_ctx.get_ctx().dev_type == Context::kGPU) { Kernel, xpu>::Launch( s, out.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], 2); @@ -1155,12 +1254,18 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs, Kernel, xpu>::Launch( s, out.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], 2); } +======= + Kernel, xpu>::Launch( + s, data.shape_.Size(), data.dptr_, out.dptr_, in_shape, + out_shape, req[0], 2, axes, out_stride, 1); +>>>>>>> 9919eff50... adding comments explaining code optimizations } else { const int ndim = MXNET_SPECIAL_MAX_NDIM; Tensor out = outputs[0].get_with_shape(dst_shape.get(), s); Tensor data = inputs[0].get_with_shape(src_shape.get(), s); +<<<<<<< HEAD if (ctx.run_ctx.get_ctx().dev_type == Context::kGPU) { Kernel, xpu>::Launch( s, out.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], ndim); @@ -1168,6 +1273,11 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs, Kernel, xpu>::Launch( s, out.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], 2); } +======= + Kernel, xpu>::Launch( + s, data.shape_.Size(), data.dptr_, out.dptr_, in_shape, + out_shape, req[0], ndim, axes, out_stride, i); +>>>>>>> 9919eff50... adding comments explaining code optimizations } }); }); From f68517bfd3d5a7db1c52d2134679e9a2b806700c Mon Sep 17 00:00:00 2001 From: Rohit Kumar Srivastava Date: Wed, 17 Jun 2020 01:49:10 +0000 Subject: [PATCH 4/5] combining CPU and GPU implementation method signatures and cleaned up code --- 3rdparty/mshadow/mshadow/extension/slice.h | 36 +-- src/operator/numpy/np_matmul_op-inl.h | 22 +- src/operator/tensor/broadcast_reduce_op.h | 260 +++++++++++---------- 3 files changed, 170 insertions(+), 148 deletions(-) diff --git a/3rdparty/mshadow/mshadow/extension/slice.h b/3rdparty/mshadow/mshadow/extension/slice.h index 9ac71a0d6ff9..038818f03c3e 100644 --- a/3rdparty/mshadow/mshadow/extension/slice.h +++ b/3rdparty/mshadow/mshadow/extension/slice.h @@ -45,10 +45,10 @@ struct SliceExp : public TRValue { static const int dimslice = srcdim - dimsrc_m_slice; const SrcExp &src_; - int ch_begin_; - int ch_old_; + index_t ch_begin_; + index_t ch_old_; Shape shape_; - SliceExp(const SrcExp &src, int begin, int end) + SliceExp(const SrcExp &src, index_t begin, index_t end) : src_(src), ch_begin_(begin) { shape_ = ShapeCheck::Check(src_); ch_old_ = shape_[dimslice]; @@ -81,7 +81,7 @@ struct SliceExp : public TRValue inline SliceExp -slice(const TRValue &src, int begin, int end) { +slice(const TRValue &src, index_t begin, index_t end) { TypeCheckPass::kDim == srcdim> ::Error_Expression_Does_Not_Meet_Dimension_Req(); return SliceExp(src.self(), begin, end); @@ -129,26 +129,26 @@ struct Plan, DType> { : src_(MakePlan(e.src_)), height_(e.shape_.ProdShape(dimslice + 1, srcdim - 1)), ch_begin_(e.ch_begin_), ch_old_(e.ch_old_), ch_(e.shape_[dimslice]) {} - MSHADOW_XINLINE DType Eval(int i, int j) const { - const int y = i % height_; + MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { + const index_t y = i % height_; i /= height_; - const int c = i % ch_ + ch_begin_; - const int b = i / ch_; - const int x = j; + const index_t c = i % ch_ + ch_begin_; + const index_t b = i / ch_; + const index_t x = j; return src_.Eval((b * ch_old_ + c) * height_ + y, x); } - MSHADOW_XINLINE DType &REval(int i, int j) { - const int y = i % height_; + MSHADOW_XINLINE DType &REval(index_t i, index_t j) { + const index_t y = i % height_; i /= height_; - const int c = i % ch_ + ch_begin_; - const int b = i / ch_; - const int x = j; + const index_t c = i % ch_ + ch_begin_; + const index_t b = i / ch_; + const index_t x = j; return src_.REval((b * ch_old_ + c) * height_ + y, x); } private: Plan src_; - const int height_, ch_begin_, ch_old_, ch_; + const index_t height_, ch_begin_, ch_old_, ch_; }; // struct Plan template, DType> { explicit Plan(const SliceExp &e) : src_(MakePlan(e.src_)), ch_begin_(e.ch_begin_) {} - MSHADOW_XINLINE DType Eval(int y, int x) const { + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { return src_.Eval(y, x + ch_begin_); } - MSHADOW_XINLINE DType &REval(int y, int x) { + MSHADOW_XINLINE DType &REval(index_t y, index_t x) { return src_.REval(y, x + ch_begin_); } private: Plan src_; - const int ch_begin_; + const index_t ch_begin_; }; } // namespace expr } // namespace mshadow diff --git a/src/operator/numpy/np_matmul_op-inl.h b/src/operator/numpy/np_matmul_op-inl.h index c1f0eed4414f..973fe409ae48 100644 --- a/src/operator/numpy/np_matmul_op-inl.h +++ b/src/operator/numpy/np_matmul_op-inl.h @@ -138,6 +138,7 @@ inline void MatmulImpl(const OpContext& ctx, mshadow::Tensor workspace; mshadow::Tensor ans, mlhs, mrhs; mshadow::Stream *s = ctx.get_stream(); + bool isCPU = std::is_same::value; if (MatmulNeedBroadcast(a_shape, b_shape)) { // e.g. a.shape = (2, 3, 1, 4, 2) // b.shape = (5, 2, 4) @@ -160,12 +161,21 @@ inline void MatmulImpl(const OpContext& ctx, struct ShapeAndStride aux_data_a, aux_data_b; PrepareAUXData(&aux_data_a, k_a_shape, k_a_shape_bc, ndim); PrepareAUXData(&aux_data_b, k_b_shape, k_b_shape_bc, ndim); - Kernel, xpu>::Launch( - s, bc_size_a, input_a.dptr(), bc_a_ptr, - aux_data_a, OpReqType::kWriteTo, ndim); - Kernel, xpu>::Launch( - s, bc_size_b, input_b.dptr(), bc_b_ptr, - aux_data_b, OpReqType::kWriteTo, ndim); + if (isCPU) { + Kernel, xpu>::Launch( + s, input_a.Size(), input_a.dptr(), bc_a_ptr, + aux_data_a, OpReqType::kWriteTo, ndim); + Kernel, xpu>::Launch( + s, input_b.Size(), input_b.dptr(), bc_b_ptr, + aux_data_b, OpReqType::kWriteTo, ndim); + } else { + Kernel, xpu>::Launch( + s, bc_size_a, input_a.dptr(), bc_a_ptr, + aux_data_a, OpReqType::kWriteTo, ndim); + Kernel, xpu>::Launch( + s, bc_size_b, input_b.dptr(), bc_b_ptr, + aux_data_b, OpReqType::kWriteTo, ndim); + } }); }); ans = mshadow::Tensor(output.dptr(), diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index bcbd868b1285..c5222e08a590 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -25,6 +25,7 @@ #ifndef MXNET_OPERATOR_TENSOR_BROADCAST_REDUCE_OP_H_ #define MXNET_OPERATOR_TENSOR_BROADCAST_REDUCE_OP_H_ +#include #include #include #include @@ -1049,14 +1050,18 @@ void ReduceAxesBackwardUseInOut(const nnvm::NodeAttrs& attrs, ReduceAxesBackwardUseInOutImpl(ctx, small, inputs, req, outputs); } -<<<<<<< HEAD namespace { // unnamed namespace to keep scope of the struct within the file struct ShapeAndStride { index_t in_stride[MXNET_SPECIAL_MAX_NDIM]; index_t out_stride[MXNET_SPECIAL_MAX_NDIM]; index_t input_shape[MXNET_SPECIAL_MAX_NDIM]; index_t output_shape[MXNET_SPECIAL_MAX_NDIM]; + // axes: stores which axes in input is to broadcasted + index_t axes[MXNET_SPECIAL_MAX_NDIM]; + int num_broadcast_axes = -1; + bool shape_changed = false; }; +} // unnamed namespace /*! * \brief Calculates Stride of input and output tensor dimesnions @@ -1071,52 +1076,48 @@ inline void PrepareAUXData(ShapeAndStride *aux_data, mshadow::Shape in_shape, mshadow::Shape out_shape, int ndim) { - int iter = ndim - 1; + int iter = ndim - 1, i = 0; aux_data->out_stride[iter] = 1; aux_data->in_stride[iter] = 1; aux_data->input_shape[iter] = in_shape[iter]; aux_data->output_shape[iter] = out_shape[iter]; + if (in_shape[iter] != out_shape[iter]) { + aux_data->axes[i++] = iter; + aux_data->shape_changed = true; + } iter--; for (; iter >= 0; --iter) { aux_data->out_stride[iter] = aux_data->out_stride[iter + 1] * out_shape[iter + 1]; aux_data->in_stride[iter] = aux_data->in_stride[iter + 1] * in_shape[iter + 1]; aux_data->input_shape[iter] = in_shape[iter]; aux_data->output_shape[iter] = out_shape[iter]; + if (in_shape[iter] != out_shape[iter]) { + aux_data->axes[i++] = iter; + aux_data->shape_changed = true; + } } + aux_data->num_broadcast_axes = i; + assert(aux_data->num_broadcast_axes > -1 && aux_data->num_broadcast_axes < 4); } -} // unnamed namespace -template -struct broadcast_kernel { -======= -/** - * Changed the thread workload mapping from 1 - * thread/output element to 1 thread/input to be broadcasted - * This approach leverages vectorization when fastest varying - * index(stride=1) of the tensor is to be broadcasted. - * In other cases it simply performs better by better load balancing. - */ template -struct broadcast_kernel_cpu { ->>>>>>> 9919eff50... adding comments explaining code optimizations +struct broadcast_kernel_gpu { template MSHADOW_XINLINE static void Map(index_t i, IType *input, OType *output, const ShapeAndStride& aux_data, const OpReqType req, -<<<<<<< HEAD const int ndim) { - printf("size of IDXType=%d",sizeof(IDXType)); - IDXType idx = i; - IDXType in_idx = i; + index_t idx = i; + index_t in_idx = i; #pragma unroll 4 - for (IDXType iter = ndim - 1; iter >= 0; --iter) { - IDXType out_dim_shape = aux_data.output_shape[iter]; - IDXType out_dim_stride = aux_data.out_stride[iter]; + for (int iter = ndim - 1; iter >= 0; --iter) { + index_t out_dim_shape = aux_data.output_shape[iter]; + index_t out_dim_stride = aux_data.out_stride[iter]; // x % y = x - (x / y) * y // speeds up modulo(%) operation in GPU - IDXType dim_idx = idx - (idx / out_dim_shape) * out_dim_shape; + index_t dim_idx = idx - (idx / out_dim_shape) * out_dim_shape; if (aux_data.input_shape[iter] != 1) { in_idx += dim_idx * (aux_data.in_stride[iter] - out_dim_stride); } else { @@ -1125,65 +1126,108 @@ struct broadcast_kernel_cpu { idx /= out_dim_shape; } KERNEL_ASSIGN(output[i], req, OP::Map(input[in_idx])); -======= - const uint32_t ndim, - const size_t *axes, - const size_t *out_stride, - const int num_broadcast_axes) { - index_t idx = i; - index_t init_off = 0; - for (int iter = ndim - 1; idx > 0 && iter >= 0; --iter) { - size_t dim_idx = idx % in_shape[iter]; - init_off += dim_idx * out_stride[iter]; - idx /= in_shape[iter]; + } +}; + +/** + * Changed the thread workload mapping from 1 + * thread/output element to 1 thread/input to be broadcasted + * This approach leverages vectorization when fastest varying + * index(stride=1) of the tensor is to be broadcasted. + * In other cases it simply performs better by better load balancing. + */ +template +struct broadcast_kernel_cpu { + template + MSHADOW_XINLINE static void Map(index_t i, + IType *input, + OType *output, + const ShapeAndStride& aux_data, + const OpReqType req, + const int ndim) { + index_t idx = i; + index_t init_off = 0; + for (int iter = ndim - 1; idx > 0 && iter >= 0; --iter) { + size_t dim_idx = idx % aux_data.input_shape[iter]; + init_off += dim_idx * aux_data.out_stride[iter]; + idx /= aux_data.input_shape[iter]; + } + index_t stride_0, stride_1, stride_2; + // Each case is based on the number of axis to be broadcasted + // (1, 2 or 3) after merging axes. + switch (aux_data.num_broadcast_axes) { + // when input shape is one of the follwing forms + // (x_1,1) or (x_1,1,x_2) or (1,x_1) + // x_1, x_2 are size of the dimensions that are not to be broadcasted + // in case of (x_1,1) the system leverages vectorization but in other 2 + // the performance is improved due avoidance of duplicate stride calculations + // for each output location input[i] needs to be written to. + case 1 : + stride_0 = aux_data.out_stride[aux_data.axes[0]]; + for (index_t l = 0; l < aux_data.output_shape[aux_data.axes[0]]; l++) { + KERNEL_ASSIGN(output[init_off + l * stride_0], + req, OP::Map(input[i])); + } + break; + // when input shape is one of the follwing forms + // (x_1,1,x_2,1) or (1,x_1,1,x_2) or (x_1,1,x_2,1,x_3) + // x_1, x_2, x_3 are size of the dimensions that are not to be broadcasted + // in the inner most loop can be vectorized by compiler in outer loops + // the performance is improved due avoidance of duplicate stride calculations + // for each output location input[i] needs to be written to. + case 2: + stride_1 = aux_data.out_stride[aux_data.axes[1]]; + stride_0 = aux_data.out_stride[aux_data.axes[0]]; + for (index_t k = 0; k < aux_data.output_shape[aux_data.axes[1]]; k++) { + for (index_t l = 0; l < aux_data.output_shape[aux_data.axes[0]]; l++) { + KERNEL_ASSIGN(output[init_off + k * stride_1 + l * stride_0], + req, OP::Map(input[i])); + } } - index_t stride_0, stride_1, stride_2; - // Each case is based on the number of axis to be broadcasted - // (1, 2 or 3) after merging axes. - switch (num_broadcast_axes) { - // when input shape is amogst one of the form - // [(x,1), (x,1,x), (1,x)] - // x can be any +ve number >=0 and they need not be equal to each other - case 1 : - stride_0 = out_stride[axes[0]]; - for (int l=0; l < out_shape[axes[0]]; l++) { - KERNEL_ASSIGN(output[init_off + l*stride_0], + break; + // when input shape is of the form (1,x_1,1,x_2,1) + // x_1, x_2 are size of the dimensions that are not to be broadcasted + // here the last axis which is [4] is the one where compiler can vectorize + // the code the outer 2 loops improve preformance by avoiding + // duplicate stride calculations + // for each output location input[i] needs to be written to. + case 3: + stride_2 = aux_data.out_stride[aux_data.axes[2]]; + stride_1 = aux_data.out_stride[aux_data.axes[1]]; + stride_0 = aux_data.out_stride[aux_data.axes[0]]; + for (index_t j = 0; j < aux_data.output_shape[aux_data.axes[2]]; j++) { + for (index_t k = 0; k < aux_data.output_shape[aux_data.axes[1]]; k++) { + for (index_t l = 0; l < aux_data.output_shape[aux_data.axes[0]]; l++) { + KERNEL_ASSIGN(output[init_off + j * stride_2 + k * stride_1 + l * stride_0], req, OP::Map(input[i])); } - break; - // when input shape is amogst one of the form - // [(x,1,x,1), (1,x,1,x), (x,1,x,1,x)] - // x can be any +ve number >1 or =0(the axis ) and they need not be equal to each other - case 2: - stride_1 = out_stride[axes[1]], stride_0 = out_stride[axes[0]]; - for (int k=0; k < out_shape[axes[1]]; k++) { - for (int l=0; l < out_shape[axes[0]]; l++) { - KERNEL_ASSIGN(output[init_off + k*stride_1 + l*stride_0], - req, OP::Map(input[i])); - } - } - break; - // when input shape is of the form [(1,x,1,x,1)] and - // x can be any +ve number >=0 and they need not be equal to each other - case 3: - stride_2 = out_stride[axes[2]], stride_1 = out_stride[axes[1]]; - stride_0 = out_stride[axes[0]]; - for (int j=0; j < out_shape[axes[2]]; j++) { - for (int k=0; k < out_shape[axes[1]]; k++) { - for (int l=0; l < out_shape[axes[0]]; l++) { - KERNEL_ASSIGN(output[init_off + j*stride_2 + k*stride_1 + l*stride_0], - req, OP::Map(input[i])); - } - } - } - break; + } } ->>>>>>> 9919eff50... adding comments explaining code optimizations + break; + } } }; +template +struct direct_copy { + template + MSHADOW_XINLINE static void Map(index_t i, + IType *input, + OType *output, + const OpReqType req) { + KERNEL_ASSIGN(output[i], req, OP::Map(input[i])); + } +}; + +/** + * When CPU context is used the no. of kernel launches are equal to + * the no. of input elements, this helps leverage vectorization when possible + * When GPU context is used no. of kernel launches are equal to + * the no. of output elements, this ensures coalesced memory writes to output + * and improves coalesced memory reads. + */ template -inline void BroadcastComputeImplCpu(const nnvm::NodeAttrs& attrs, +inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, const std::vector& req, @@ -1194,11 +1238,13 @@ inline void BroadcastComputeImplCpu(const nnvm::NodeAttrs& attrs, using namespace mxnet_op; mxnet::TShape src_shape, dst_shape; // combines 2 or more consecutive broadcast/non-broadcast axes together - // e.g. (3,4,1,1,5,1,6,7) (2,3,5) (5,10,9) -> (12,1,5,1,42) (1,3) (50, 9) + // e.g. (3,4,1,1,5,1,6,7) (2,3,5) (5,10,9) -> (3*4,1*1,5,1,6*7) (1,3) (5*10, 9) + // -> (12,1,5,1,42) (1,3) (50, 9) // and this is the new input for broadcast_kernel whose total // num of dimensions cannot be greater than 5(throws an error otherwise). BroadcastReduceShapeCompact(outputs[0].shape_, small, &dst_shape, &src_shape); Stream *s = ctx.get_stream(); + bool isCPU = std::is_same::value; MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, IType, { MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, { mshadow::Shape in_shape; @@ -1212,72 +1258,38 @@ inline void BroadcastComputeImplCpu(const nnvm::NodeAttrs& attrs, out_shape[i] = 1; } } -<<<<<<< HEAD struct ShapeAndStride aux_data; PrepareAUXData(&aux_data, in_shape, out_shape, dst_shape.ndim()); - if (dst_shape.ndim() == 2) { -======= - // axes: stores which axes in input is to broadcasted - // stride: stores offset corresponding to an index of output tensor. - // It is calculated using shape of the output tensor. - size_t axes[dst_shape.ndim()], out_stride[dst_shape.ndim()]; - int iter = dst_shape.ndim() - 1, i = 0; - bool shape_changed = false; - out_stride[iter] = 1; - if (in_shape[iter] != dst_shape[iter]) { - axes[i++] = iter; - shape_changed = true; - } - --iter; - for (; iter >= 0; --iter) { - if (in_shape[iter] != dst_shape[iter]) { - axes[i++] = iter; - shape_changed = true; - } - out_stride[iter] = out_stride[iter+1] * dst_shape[iter+1]; - } - if (!shape_changed) { + if (!aux_data.shape_changed) { // If no broadcast is required (i.e. input_shape == output_shape) // then simply copy input to outout. - mxnet_op::copy(ctx.get_stream(), outputs[0], inputs[0]); + Kernel, xpu>::Launch( + s, outputs[0].Size(), inputs[0].dptr(), outputs[0].dptr(), req[0]); } else if (dst_shape.ndim() == 2) { ->>>>>>> 9919eff50... adding comments explaining code optimizations Tensor out = outputs[0].get_with_shape(dst_shape.get<2>(), s); Tensor data = inputs[0].get_with_shape(src_shape.get<2>(), s); -<<<<<<< HEAD - if (ctx.run_ctx.get_ctx().dev_type == Context::kGPU) { - Kernel, xpu>::Launch( - s, out.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], 2); - } else { - Kernel, xpu>::Launch( + if (isCPU) { + Kernel, xpu>::Launch( + s, data.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], 2); + } else { + Kernel, xpu>::Launch( s, out.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], 2); - } -======= - Kernel, xpu>::Launch( - s, data.shape_.Size(), data.dptr_, out.dptr_, in_shape, - out_shape, req[0], 2, axes, out_stride, 1); ->>>>>>> 9919eff50... adding comments explaining code optimizations + } } else { const int ndim = MXNET_SPECIAL_MAX_NDIM; Tensor out = outputs[0].get_with_shape(dst_shape.get(), s); Tensor data = inputs[0].get_with_shape(src_shape.get(), s); -<<<<<<< HEAD - if (ctx.run_ctx.get_ctx().dev_type == Context::kGPU) { - Kernel, xpu>::Launch( + if (isCPU) { + Kernel, xpu>::Launch( + s, data.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], ndim); + } else { + Kernel, xpu>::Launch( s, out.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], ndim); - } else { - Kernel, xpu>::Launch( - s, out.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], 2); - } -======= - Kernel, xpu>::Launch( - s, data.shape_.Size(), data.dptr_, out.dptr_, in_shape, - out_shape, req[0], ndim, axes, out_stride, i); ->>>>>>> 9919eff50... adding comments explaining code optimizations + } } }); }); From df8c232309e1cf6171d6d1351b8169c713be3f80 Mon Sep 17 00:00:00 2001 From: Rohit Kumar Srivastava Date: Sat, 27 Jun 2020 01:04:39 +0000 Subject: [PATCH 5/5] adding new broadcast_axis to np_matmul --- src/operator/numpy/np_matmul_op-inl.h | 27 ++++++++++++++++++----- src/operator/tensor/broadcast_reduce_op.h | 2 +- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/src/operator/numpy/np_matmul_op-inl.h b/src/operator/numpy/np_matmul_op-inl.h index 973fe409ae48..8f1b4f9f3a30 100644 --- a/src/operator/numpy/np_matmul_op-inl.h +++ b/src/operator/numpy/np_matmul_op-inl.h @@ -139,6 +139,7 @@ inline void MatmulImpl(const OpContext& ctx, mshadow::Tensor ans, mlhs, mrhs; mshadow::Stream *s = ctx.get_stream(); bool isCPU = std::is_same::value; + // Is true if either a or b requires broadcast or not if (MatmulNeedBroadcast(a_shape, b_shape)) { // e.g. a.shape = (2, 3, 1, 4, 2) // b.shape = (5, 2, 4) @@ -162,12 +163,26 @@ inline void MatmulImpl(const OpContext& ctx, PrepareAUXData(&aux_data_a, k_a_shape, k_a_shape_bc, ndim); PrepareAUXData(&aux_data_b, k_b_shape, k_b_shape_bc, ndim); if (isCPU) { - Kernel, xpu>::Launch( - s, input_a.Size(), input_a.dptr(), bc_a_ptr, - aux_data_a, OpReqType::kWriteTo, ndim); - Kernel, xpu>::Launch( - s, input_b.Size(), input_b.dptr(), bc_b_ptr, - aux_data_b, OpReqType::kWriteTo, ndim); + if (!aux_data_a.shape_changed) { + Kernel, xpu>::Launch( + s, bc_size_a, input_a.dptr(), bc_a_ptr, OpReqType::kWriteTo); + Kernel, xpu>::Launch( + s, input_b.Size(), input_b.dptr(), bc_b_ptr, + aux_data_b, OpReqType::kWriteTo, ndim); + } else if (!aux_data_b.shape_changed) { + Kernel, xpu>::Launch( + s, bc_size_b, input_b.dptr(), bc_b_ptr, OpReqType::kWriteTo); + Kernel, xpu>::Launch( + s, input_a.Size(), input_a.dptr(), bc_a_ptr, + aux_data_a, OpReqType::kWriteTo, ndim); + } else { + Kernel, xpu>::Launch( + s, input_a.Size(), input_a.dptr(), bc_a_ptr, + aux_data_a, OpReqType::kWriteTo, ndim); + Kernel, xpu>::Launch( + s, input_b.Size(), input_b.dptr(), bc_b_ptr, + aux_data_b, OpReqType::kWriteTo, ndim); + } } else { Kernel, xpu>::Launch( s, bc_size_a, input_a.dptr(), bc_a_ptr, diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index c5222e08a590..bd2af77c5f9d 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -1156,7 +1156,7 @@ struct broadcast_kernel_cpu { // Each case is based on the number of axis to be broadcasted // (1, 2 or 3) after merging axes. switch (aux_data.num_broadcast_axes) { - // when input shape is one of the follwing forms + // when input shape is one of the following forms // (x_1,1) or (x_1,1,x_2) or (1,x_1) // x_1, x_2 are size of the dimensions that are not to be broadcasted // in case of (x_1,1) the system leverages vectorization but in other 2