Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
adding new broadcast_axis to np_matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
Rohit Kumar Srivastava committed Jun 27, 2020
1 parent f68517b commit df8c232
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
27 changes: 21 additions & 6 deletions src/operator/numpy/np_matmul_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ inline void MatmulImpl(const OpContext& ctx,
mshadow::Tensor<xpu, 3, DType> ans, mlhs, mrhs;
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
bool isCPU = std::is_same<xpu, cpu>::value;
// Is true if either a or b requires broadcast or not
if (MatmulNeedBroadcast(a_shape, b_shape)) {
// e.g. a.shape = (2, 3, 1, 4, 2)
// b.shape = (5, 2, 4)
Expand All @@ -162,12 +163,26 @@ inline void MatmulImpl(const OpContext& ctx,
PrepareAUXData(&aux_data_a, k_a_shape, k_a_shape_bc, ndim);
PrepareAUXData(&aux_data_b, k_b_shape, k_b_shape_bc, ndim);
if (isCPU) {
Kernel<broadcast_kernel_cpu<mshadow_op::identity>, xpu>::Launch(
s, input_a.Size(), input_a.dptr<IType>(), bc_a_ptr,
aux_data_a, OpReqType::kWriteTo, ndim);
Kernel<broadcast_kernel_cpu<mshadow_op::identity>, xpu>::Launch(
s, input_b.Size(), input_b.dptr<IType>(), bc_b_ptr,
aux_data_b, OpReqType::kWriteTo, ndim);
if (!aux_data_a.shape_changed) {
Kernel<direct_copy<mshadow_op::identity>, xpu>::Launch(
s, bc_size_a, input_a.dptr<IType>(), bc_a_ptr, OpReqType::kWriteTo);
Kernel<broadcast_kernel_cpu<mshadow_op::identity>, xpu>::Launch(
s, input_b.Size(), input_b.dptr<IType>(), bc_b_ptr,
aux_data_b, OpReqType::kWriteTo, ndim);
} else if (!aux_data_b.shape_changed) {
Kernel<direct_copy<mshadow_op::identity>, xpu>::Launch(
s, bc_size_b, input_b.dptr<IType>(), bc_b_ptr, OpReqType::kWriteTo);
Kernel<broadcast_kernel_cpu<mshadow_op::identity>, xpu>::Launch(
s, input_a.Size(), input_a.dptr<IType>(), bc_a_ptr,
aux_data_a, OpReqType::kWriteTo, ndim);
} else {
Kernel<broadcast_kernel_cpu<mshadow_op::identity>, xpu>::Launch(
s, input_a.Size(), input_a.dptr<IType>(), bc_a_ptr,
aux_data_a, OpReqType::kWriteTo, ndim);
Kernel<broadcast_kernel_cpu<mshadow_op::identity>, xpu>::Launch(
s, input_b.Size(), input_b.dptr<IType>(), bc_b_ptr,
aux_data_b, OpReqType::kWriteTo, ndim);
}
} else {
Kernel<broadcast_kernel_gpu<mshadow_op::identity>, xpu>::Launch(
s, bc_size_a, input_a.dptr<IType>(), bc_a_ptr,
Expand Down
2 changes: 1 addition & 1 deletion src/operator/tensor/broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -1156,7 +1156,7 @@ struct broadcast_kernel_cpu {
// Each case is based on the number of axis to be broadcasted
// (1, 2 or 3) after merging axes.
switch (aux_data.num_broadcast_axes) {
// when input shape is one of the follwing forms
// when input shape is one of the following forms
// (x_1,1) or (x_1,1,x_2) or (1,x_1)
// x_1, x_2 are size of the dimensions that are not to be broadcasted
// in case of (x_1,1) the system leverages vectorization but in other 2
Expand Down

0 comments on commit df8c232

Please sign in to comment.