Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge branch 'v1.x' into 1x_softmax_empty
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Jul 29, 2020
2 parents e9a342c + ca6bcf3 commit e766ecc
Show file tree
Hide file tree
Showing 6 changed files with 393 additions and 26 deletions.
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ List of Contributors
* [Oliver Kowalke](https://github.com/olk)
* [Connor Goggins](https://github.com/connorgoggins)
* [Joe Evans](https://github.com/josephevans)
* [Zhaoqi Zhu](https://github.com/zha0q1)

Label Bot
---------
Expand Down
13 changes: 13 additions & 0 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,19 @@ def create_vector(size, dtype=np.int64):
a = mx.nd.arange(0, size, dtype=dtype)
return a

# For testing Large Square Matrix with total size > 2^32 elements
def get_identity_mat(size):
A = mx.nd.zeros((size, size))
for i in range(size):
A[i, i] = 1
return A

# For testing Batch of Large Square Matrix with total size > 2^32 elements
def get_identity_mat_batch(size):
A = get_identity_mat(size)
A_np = A.asnumpy()
return mx.nd.array([A_np, A_np])

def rand_sparse_ndarray(shape, stype, density=None, dtype=None, distribution=None,
data_init=None, rsp_indices=None, modifier_func=None,
shuffle_csr_indices=False, ctx=None):
Expand Down
40 changes: 34 additions & 6 deletions src/operator/numpy/np_matmul_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ inline void MatmulImpl(const OpContext& ctx,
mshadow::Tensor<xpu, 1, DType*> workspace;
mshadow::Tensor<xpu, 3, DType> ans, mlhs, mrhs;
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
bool isCPU = std::is_same<xpu, cpu>::value;
// Is true if either a or b requires broadcast or not
if (MatmulNeedBroadcast(a_shape, b_shape)) {
// e.g. a.shape = (2, 3, 1, 4, 2)
// b.shape = (5, 2, 4)
Expand All @@ -157,12 +159,38 @@ inline void MatmulImpl(const OpContext& ctx,
DType* bc_b_ptr = bc_a_ptr + bc_size_a;
MSHADOW_TYPE_SWITCH_WITH_BOOL(input_a.type_flag_, IType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(input_b.type_flag_, OType, {
Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch(
s, bc_size_a, input_a.dptr<IType>(), bc_a_ptr,
k_a_shape, k_a_shape_bc, OpReqType::kWriteTo, ndim);
Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch(
s, bc_size_b, input_b.dptr<IType>(), bc_b_ptr,
k_b_shape, k_b_shape_bc, OpReqType::kWriteTo, ndim);
struct ShapeAndStride aux_data_a, aux_data_b;
PrepareAUXData(&aux_data_a, k_a_shape, k_a_shape_bc, ndim);
PrepareAUXData(&aux_data_b, k_b_shape, k_b_shape_bc, ndim);
if (isCPU) {
if (!aux_data_a.shape_changed) {
Kernel<direct_copy<mshadow_op::identity>, xpu>::Launch(
s, bc_size_a, input_a.dptr<IType>(), bc_a_ptr, OpReqType::kWriteTo);
Kernel<broadcast_kernel_cpu<mshadow_op::identity>, xpu>::Launch(
s, input_b.Size(), input_b.dptr<IType>(), bc_b_ptr,
aux_data_b, OpReqType::kWriteTo, ndim);
} else if (!aux_data_b.shape_changed) {
Kernel<direct_copy<mshadow_op::identity>, xpu>::Launch(
s, bc_size_b, input_b.dptr<IType>(), bc_b_ptr, OpReqType::kWriteTo);
Kernel<broadcast_kernel_cpu<mshadow_op::identity>, xpu>::Launch(
s, input_a.Size(), input_a.dptr<IType>(), bc_a_ptr,
aux_data_a, OpReqType::kWriteTo, ndim);
} else {
Kernel<broadcast_kernel_cpu<mshadow_op::identity>, xpu>::Launch(
s, input_a.Size(), input_a.dptr<IType>(), bc_a_ptr,
aux_data_a, OpReqType::kWriteTo, ndim);
Kernel<broadcast_kernel_cpu<mshadow_op::identity>, xpu>::Launch(
s, input_b.Size(), input_b.dptr<IType>(), bc_b_ptr,
aux_data_b, OpReqType::kWriteTo, ndim);
}
} else {
Kernel<broadcast_kernel_gpu<mshadow_op::identity>, xpu>::Launch(
s, bc_size_a, input_a.dptr<IType>(), bc_a_ptr,
aux_data_a, OpReqType::kWriteTo, ndim);
Kernel<broadcast_kernel_gpu<mshadow_op::identity>, xpu>::Launch(
s, bc_size_b, input_b.dptr<IType>(), bc_b_ptr,
aux_data_b, OpReqType::kWriteTo, ndim);
}
});
});
ans = mshadow::Tensor<xpu, 3, DType>(output.dptr<DType>(),
Expand Down
208 changes: 190 additions & 18 deletions src/operator/tensor/broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#ifndef MXNET_OPERATOR_TENSOR_BROADCAST_REDUCE_OP_H_
#define MXNET_OPERATOR_TENSOR_BROADCAST_REDUCE_OP_H_

#include <assert.h>
#include <mxnet/operator_util.h>
#include <string>
#include <vector>
Expand Down Expand Up @@ -1037,34 +1038,182 @@ void ReduceAxesBackwardUseInOut(const nnvm::NodeAttrs& attrs,
ReduceAxesBackwardUseInOutImpl<xpu, OP, normalize>(ctx, small, inputs, req, outputs);
}

namespace { // unnamed namespace to keep scope of the struct within the file
struct ShapeAndStride {
index_t in_stride[MXNET_SPECIAL_MAX_NDIM];
index_t out_stride[MXNET_SPECIAL_MAX_NDIM];
index_t input_shape[MXNET_SPECIAL_MAX_NDIM];
index_t output_shape[MXNET_SPECIAL_MAX_NDIM];
// axes: stores which axes in input is to broadcasted
index_t axes[MXNET_SPECIAL_MAX_NDIM];
int num_broadcast_axes = -1;
bool shape_changed = false;
};
} // unnamed namespace

/*!
* \brief Calculates Stride of input and output tensor dimesnions
And saves mshadow::Shape data in an integer array for
faster access.
* \param *aux_data to hold stride and shape data.
* \param in_shape input shape
* \param out_shape output shape
* \param ndim no of dimensions in output
*/
inline void PrepareAUXData(ShapeAndStride *aux_data,
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> in_shape,
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> out_shape,
int ndim) {
int iter = ndim - 1, i = 0;
aux_data->out_stride[iter] = 1;
aux_data->in_stride[iter] = 1;
aux_data->input_shape[iter] = in_shape[iter];
aux_data->output_shape[iter] = out_shape[iter];
if (in_shape[iter] != out_shape[iter]) {
aux_data->axes[i++] = iter;
aux_data->shape_changed = true;
}
iter--;
for (; iter >= 0; --iter) {
aux_data->out_stride[iter] = aux_data->out_stride[iter + 1] * out_shape[iter + 1];
aux_data->in_stride[iter] = aux_data->in_stride[iter + 1] * in_shape[iter + 1];
aux_data->input_shape[iter] = in_shape[iter];
aux_data->output_shape[iter] = out_shape[iter];
if (in_shape[iter] != out_shape[iter]) {
aux_data->axes[i++] = iter;
aux_data->shape_changed = true;
}
}
aux_data->num_broadcast_axes = i;
assert(aux_data->num_broadcast_axes > -1 && aux_data->num_broadcast_axes < 4);
}

template<typename OP>
struct broadcast_kernel {
struct broadcast_kernel_gpu {
template<typename IType, typename OType>
MSHADOW_XINLINE static void Map(index_t i,
IType *input,
OType *output,
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> in_shape,
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> out_shape,
const ShapeAndStride& aux_data,
const OpReqType req,
const uint32_t ndim) {
size_t in_stride = 1;
size_t out_stride = 1;
const int ndim) {
index_t idx = i;
index_t in_idx = i;
#pragma unroll 4
for (int iter = ndim - 1; iter >= 0; --iter) {
size_t dim_idx = idx % out_shape[iter];
in_idx -= dim_idx * out_stride;
if (in_shape[iter] != 1) {
in_idx += dim_idx * in_stride;
index_t out_dim_shape = aux_data.output_shape[iter];
index_t out_dim_stride = aux_data.out_stride[iter];
// x % y = x - (x / y) * y
// speeds up modulo(%) operation in GPU
index_t dim_idx = idx - (idx / out_dim_shape) * out_dim_shape;
if (aux_data.input_shape[iter] != 1) {
in_idx += dim_idx * (aux_data.in_stride[iter] - out_dim_stride);
} else {
in_idx -= dim_idx * out_dim_stride;
}
idx /= out_shape[iter];
in_stride *= in_shape[iter];
out_stride *= out_shape[iter];
idx /= out_dim_shape;
}
KERNEL_ASSIGN(output[i], req, OP::Map(input[in_idx]));
}
};

/**
* Changed the thread workload mapping from 1
* thread/output element to 1 thread/input to be broadcasted
* This approach leverages vectorization when fastest varying
* index(stride=1) of the tensor is to be broadcasted.
* In other cases it simply performs better by better load balancing.
*/
template<typename OP>
struct broadcast_kernel_cpu {
template<typename IType, typename OType>
MSHADOW_XINLINE static void Map(index_t i,
IType *input,
OType *output,
const ShapeAndStride& aux_data,
const OpReqType req,
const int ndim) {
index_t idx = i;
index_t init_off = 0;
for (int iter = ndim - 1; idx > 0 && iter >= 0; --iter) {
size_t dim_idx = idx % aux_data.input_shape[iter];
init_off += dim_idx * aux_data.out_stride[iter];
idx /= aux_data.input_shape[iter];
}
index_t stride_0, stride_1, stride_2;
// Each case is based on the number of axis to be broadcasted
// (1, 2 or 3) after merging axes.
switch (aux_data.num_broadcast_axes) {
// when input shape is one of the following forms
// (x_1,1) or (x_1,1,x_2) or (1,x_1)
// x_1, x_2 are size of the dimensions that are not to be broadcasted
// in case of (x_1,1) the system leverages vectorization but in other 2
// the performance is improved due avoidance of duplicate stride calculations
// for each output location input[i] needs to be written to.
case 1 :
stride_0 = aux_data.out_stride[aux_data.axes[0]];
for (index_t l = 0; l < aux_data.output_shape[aux_data.axes[0]]; l++) {
KERNEL_ASSIGN(output[init_off + l * stride_0],
req, OP::Map(input[i]));
}
break;
// when input shape is one of the follwing forms
// (x_1,1,x_2,1) or (1,x_1,1,x_2) or (x_1,1,x_2,1,x_3)
// x_1, x_2, x_3 are size of the dimensions that are not to be broadcasted
// in the inner most loop can be vectorized by compiler in outer loops
// the performance is improved due avoidance of duplicate stride calculations
// for each output location input[i] needs to be written to.
case 2:
stride_1 = aux_data.out_stride[aux_data.axes[1]];
stride_0 = aux_data.out_stride[aux_data.axes[0]];
for (index_t k = 0; k < aux_data.output_shape[aux_data.axes[1]]; k++) {
for (index_t l = 0; l < aux_data.output_shape[aux_data.axes[0]]; l++) {
KERNEL_ASSIGN(output[init_off + k * stride_1 + l * stride_0],
req, OP::Map(input[i]));
}
}
break;
// when input shape is of the form (1,x_1,1,x_2,1)
// x_1, x_2 are size of the dimensions that are not to be broadcasted
// here the last axis which is [4] is the one where compiler can vectorize
// the code the outer 2 loops improve preformance by avoiding
// duplicate stride calculations
// for each output location input[i] needs to be written to.
case 3:
stride_2 = aux_data.out_stride[aux_data.axes[2]];
stride_1 = aux_data.out_stride[aux_data.axes[1]];
stride_0 = aux_data.out_stride[aux_data.axes[0]];
for (index_t j = 0; j < aux_data.output_shape[aux_data.axes[2]]; j++) {
for (index_t k = 0; k < aux_data.output_shape[aux_data.axes[1]]; k++) {
for (index_t l = 0; l < aux_data.output_shape[aux_data.axes[0]]; l++) {
KERNEL_ASSIGN(output[init_off + j * stride_2 + k * stride_1 + l * stride_0],
req, OP::Map(input[i]));
}
}
}
break;
}
}
};

template<typename OP>
struct direct_copy {
template<typename IType, typename OType>
MSHADOW_XINLINE static void Map(index_t i,
IType *input,
OType *output,
const OpReqType req) {
KERNEL_ASSIGN(output[i], req, OP::Map(input[i]));
}
};

/**
* When CPU context is used the no. of kernel launches are equal to
* the no. of input elements, this helps leverage vectorization when possible
* When GPU context is used no. of kernel launches are equal to
* the no. of output elements, this ensures coalesced memory writes to output
* and improves coalesced memory reads.
*/
template<typename xpu>
inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand All @@ -1076,8 +1225,14 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs,
using namespace mshadow::expr;
using namespace mxnet_op;
mxnet::TShape src_shape, dst_shape;
// combines 2 or more consecutive broadcast/non-broadcast axes together
// e.g. (3,4,1,1,5,1,6,7) (2,3,5) (5,10,9) -> (3*4,1*1,5,1,6*7) (1,3) (5*10, 9)
// -> (12,1,5,1,42) (1,3) (50, 9)
// and this is the new input for broadcast_kernel whose total
// num of dimensions cannot be greater than 5(throws an error otherwise).
BroadcastReduceShapeCompact(outputs[0].shape_, small, &dst_shape, &src_shape);
Stream<xpu> *s = ctx.get_stream<xpu>();
bool isCPU = std::is_same<xpu, cpu>::value;
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, IType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, {
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> in_shape;
Expand All @@ -1091,21 +1246,38 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs,
out_shape[i] = 1;
}
}
if (dst_shape.ndim() == 2) {
struct ShapeAndStride aux_data;
PrepareAUXData(&aux_data, in_shape, out_shape, dst_shape.ndim());
if (!aux_data.shape_changed) {
// If no broadcast is required (i.e. input_shape == output_shape)
// then simply copy input to outout.
Kernel<direct_copy<mshadow_op::identity>, xpu>::Launch(
s, outputs[0].Size(), inputs[0].dptr<IType>(), outputs[0].dptr<OType>(), req[0]);
} else if (dst_shape.ndim() == 2) {
Tensor<xpu, 2, OType> out =
outputs[0].get_with_shape<xpu, 2, OType>(dst_shape.get<2>(), s);
Tensor<xpu, 2, IType> data =
inputs[0].get_with_shape<xpu, 2, IType>(src_shape.get<2>(), s);
Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch(
s, out.shape_.Size(), data.dptr_, out.dptr_, in_shape, out_shape, req[0], 2);
if (isCPU) {
Kernel<broadcast_kernel_cpu<mshadow_op::identity>, xpu>::Launch(
s, data.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], 2);
} else {
Kernel<broadcast_kernel_gpu<mshadow_op::identity>, xpu>::Launch(
s, out.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], 2);
}
} else {
const int ndim = MXNET_SPECIAL_MAX_NDIM;
Tensor<xpu, ndim, OType> out =
outputs[0].get_with_shape<xpu, ndim, OType>(dst_shape.get<ndim>(), s);
Tensor<xpu, ndim, IType> data =
inputs[0].get_with_shape<xpu, ndim, IType>(src_shape.get<ndim>(), s);
Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch(
s, out.shape_.Size(), data.dptr_, out.dptr_, in_shape, out_shape, req[0], ndim);
if (isCPU) {
Kernel<broadcast_kernel_cpu<mshadow_op::identity>, xpu>::Launch(
s, data.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], ndim);
} else {
Kernel<broadcast_kernel_gpu<mshadow_op::identity>, xpu>::Launch(
s, out.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], ndim);
}
}
});
});
Expand Down
2 changes: 2 additions & 0 deletions src/operator/tensor/la_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,8 @@ inline bool LaEigFactShape(const nnvm::NodeAttrs& attrs,
const mxnet::TShape& in_a = (*in_attrs)[0];
const mxnet::TShape& out_u = (*out_attrs)[0];
const mxnet::TShape& out_l = (*out_attrs)[1];
CHECK_LE(in_a.Size(), INT_MAX)
<< "Large tensors are not supported by Linear Algebra operator syevd";
if ( in_a.ndim() >= 2 ) {
// Forward shape inference.
const int ndim(in_a.ndim());
Expand Down

0 comments on commit e766ecc

Please sign in to comment.