Skip to content

Commit

Permalink
Added tests to verify Large Vector Support for initial set of ops (ap…
Browse files Browse the repository at this point in the history
…ache#15943)

* Adding tests to verify support for Large Tensors in additional Ops along with new C_Apis supporting 64bit indexing

* removing skipped tests

* enabling Large Index support for slice and softmax

* removing tests not required for vector testing
  • Loading branch information
access2rohit authored and Rohit Kumar Srivastava committed Sep 25, 2019
1 parent b358ccd commit 7448111
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 20 deletions.
11 changes: 11 additions & 0 deletions python/mxnet/test_utils.py
Expand Up @@ -261,6 +261,17 @@ def assign_each2(input1, input2, function):

return output

# For testing Large Tensors having total size > 2^32 elements
def create_2d_tensor(rows, columns, dtype=np.int64):
a = mx.nd.arange(0, rows, dtype=dtype).reshape(rows, 1)
b = mx.nd.broadcast_to(a, shape=(a.shape[0], columns))
return b

# For testing Large Vectors having total size > 2^32 elements
def create_vector(size, dtype=np.int64):
a = mx.nd.arange(0, size, dtype=dtype)
return a

def rand_sparse_ndarray(shape, stype, density=None, dtype=None, distribution=None,
data_init=None, rsp_indices=None, modifier_func=None,
shuffle_csr_indices=False, ctx=None):
Expand Down
18 changes: 9 additions & 9 deletions src/operator/softmax_output-inl.h
Expand Up @@ -117,9 +117,9 @@ class SoftmaxOutputOp : public Operator {
CHECK_EQ(out_data.size(), 1U) << "SoftmaxOutput Output: [output]";
Stream<xpu> *s = ctx.get_stream<xpu>();
if (param_.multi_output) {
int n = in_data[softmaxout_enum::kData].size(0);
int k = in_data[softmaxout_enum::kData].size(1);
Shape<3> s3 = Shape3(n, k, static_cast<int>(in_data[softmaxout_enum::kData].Size()/n/k));
index_t n = in_data[softmaxout_enum::kData].size(0);
index_t k = in_data[softmaxout_enum::kData].size(1);
Shape<3> s3 = Shape3(n, k, static_cast<index_t>(in_data[softmaxout_enum::kData].Size()/n/k));
Tensor<xpu, 3, DType> data =
in_data[softmaxout_enum::kData].get_with_shape<xpu, 3, DType>(s3, s);
Tensor<xpu, 3, DType> out =
Expand All @@ -131,8 +131,8 @@ class SoftmaxOutputOp : public Operator {
Tensor<xpu, 2, DType> out = out_data[softmaxout_enum::kOut].FlatTo2D<xpu, DType>(s);
Softmax(out, data);
} else {
int n = in_data[softmaxout_enum::kData].size(0);
int k = in_data[softmaxout_enum::kData].Size()/n;
index_t n = in_data[softmaxout_enum::kData].size(0);
index_t k = in_data[softmaxout_enum::kData].Size()/n;
Shape<2> s2 = Shape2(n, k);
Tensor<xpu, 2, DType> data =
in_data[softmaxout_enum::kData].get_with_shape<xpu, 2, DType>(s2, s);
Expand Down Expand Up @@ -171,9 +171,9 @@ class SoftmaxOutputOp : public Operator {
grad = (out - label) * scalar<DType>(param_.grad_scale);
}
} else if (param_.multi_output) {
int n = out_data[softmaxout_enum::kOut].size(0);
int k = out_data[softmaxout_enum::kOut].size(1);
Shape<3> s3 = Shape3(n, k, static_cast<int>(out_data[softmaxout_enum::kOut].Size()/n/k));
index_t n = out_data[softmaxout_enum::kOut].size(0);
index_t k = out_data[softmaxout_enum::kOut].size(1);
Shape<3> s3 = Shape3(n, k, static_cast<index_t>(out_data[softmaxout_enum::kOut].Size()/n/k));
Shape<2> s2 = Shape2(s3[0], s3[2]);
Tensor<xpu, 2, DType> label =
in_data[softmaxout_enum::kLabel].get_with_shape<xpu, 2, DType>(s2, s);
Expand Down Expand Up @@ -224,7 +224,7 @@ class SoftmaxOutputOp : public Operator {
// Tensor<xpu, 2, DType> out = out_data[softmaxout_enum::kOut].FlatTo2D<xpu, DType>(s);
// Tensor<xpu, 2, DType> grad = in_grad[softmaxout_enum::kData].FlatTo2D<xpu, DType>(s);
} else {
int n = out_data[softmaxout_enum::kOut].size(0);
index_t n = out_data[softmaxout_enum::kOut].size(0);
data_shape = Shape2(n, out_data[softmaxout_enum::kOut].Size()/n);
}
Tensor<xpu, 1, DType> label = in_data[softmaxout_enum::kLabel].get_with_shape<xpu, 1, DType>(
Expand Down
6 changes: 3 additions & 3 deletions src/operator/tensor/matrix_op-inl.h
Expand Up @@ -717,8 +717,8 @@ inline void GetIndexRange(const mxnet::TShape& dshape,
}

inline void SetSliceOpOutputDimSize(const mxnet::TShape& dshape,
const index_t i, const int b,
const int e, const int s,
const index_t i, const index_t b,
const index_t e, const index_t s,
mxnet::TShape* oshape) {
if (!mxnet::dim_size_is_known(dshape, i)) {
(*oshape)[i] = -1;
Expand Down Expand Up @@ -750,7 +750,7 @@ inline bool SliceOpShape(const nnvm::NodeAttrs& attrs,
common::StaticArray<index_t, ndim> begin, end, step;
GetIndexRange(dshape, param.begin, param.end, param.step, &begin, &end, &step);
for (int i = 0; i < param.begin.ndim(); ++i) {
const int b = begin[i], e = end[i], s = step[i];
const index_t b = begin[i], e = end[i], s = step[i];
SetSliceOpOutputDimSize(dshape, i, b, e, s, &oshape);
}
})
Expand Down
8 changes: 1 addition & 7 deletions tests/nightly/test_large_array.py
Expand Up @@ -19,7 +19,7 @@
import numpy as np
import mxnet as mx

from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context, check_symbolic_forward
from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context, check_symbolic_forward, create_2d_tensor
from mxnet import gluon, nd
from tests.python.unittest.common import with_seed

Expand All @@ -31,12 +31,6 @@
LARGE_SIZE = LARGE_X * SMALL_Y


def create_2d_tensor(rows, columns, dtype=np.int64):
a = nd.arange(0, rows, dtype=dtype).reshape(rows, 1)
b = nd.broadcast_to(a, shape=(a.shape[0], columns))
return nd.array(b, dtype=dtype)


def test_gluon_embedding():
m = gluon.nn.Embedding(SMALL_Y, MEDIUM_X)
m.initialize()
Expand Down
139 changes: 138 additions & 1 deletion tests/nightly/test_large_vector.py
Expand Up @@ -18,7 +18,7 @@
import numpy as np
import mxnet as mx

from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d
from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, create_vector
from mxnet import gluon, nd
from tests.python.unittest.common import with_seed

Expand All @@ -30,9 +30,146 @@
def test_slice():
a = nd.ones(LARGE_X)
res = nd.slice(a, begin=(LARGE_X - MEDIUM_X), end=LARGE_X)
assert a[0] == 1
assert res.shape[0] == MEDIUM_X


def test_ndarray_zeros():
a = nd.zeros(shape=LARGE_X)
assert a[-1] == 0
assert a.shape == (LARGE_X,)
assert a.size == LARGE_X


def test_ndarray_ones():
a = nd.ones(shape=LARGE_X)
assert a[-1] == 1
assert nd.sum(a).asnumpy() == LARGE_X


@with_seed()
def test_ndarray_random_uniform():
a = nd.random.uniform(shape=LARGE_X)
assert a[-1] != 0


@with_seed()
def test_ndarray_random_randint():
a = nd.random.randint(100, 10000, shape=LARGE_X)
assert a.shape == (LARGE_X,)
# check if randint can generate value greater than 2**32 (large)
low_large_value = 2**32
high_large_value = 2**34
a = nd.random.randint(low_large_value, high_large_value, dtype=np.int64)
low = mx.nd.array([low_large_value], dtype='int64')
high = mx.nd.array([high_large_value], dtype='int64')
assert a > low and a < high


def test_ndarray_empty():
a = nd.empty(LARGE_X)
assert a.shape == (LARGE_X,)


def test_elementwise():
a = nd.ones(shape=LARGE_X)
b = nd.ones(shape=LARGE_X)
res = a + b
assert res[-1].asnumpy() == 2
res = a + 1
assert res[-1].asnumpy() == 2
res = nd.sqrt(a + 8)
assert res[-1].asnumpy() == 3


def test_reduce():
a = nd.ones(shape=(LARGE_X, 1))
assert nd.sum(a).asnumpy() == a.shape[0] * a.shape[1]


def test_clip():
a = create_vector(LARGE_X)
res = nd.clip(a, a_min=100, a_max=1000)
assert np.sum(res[-1].asnumpy() == 1000) == 1


def test_argmin():
a = create_vector(LARGE_X, dtype=np.float32)
assert a[0] == 0
idx = mx.nd.argmin(a, axis=0)
assert idx[0] == 0
assert idx.shape[0] == 1


def test_take():
a = nd.ones(shape=LARGE_X)
idx = nd.arange(LARGE_X - 1000, LARGE_X)
res = nd.take(a, idx)
assert np.sum(res.asnumpy() == 1) == res.shape[0]


def test_slice_assign():
a = nd.ones(shape=LARGE_X)
a[LARGE_X-1:LARGE_X] = 1000
assert np.sum(a[-1].asnumpy() == 1000) == 1


def test_expand_dims():
a = nd.ones(shape=LARGE_X)
res = nd.expand_dims(a, axis=0)
assert res[0][0] == 1
assert res.shape == (1, a.shape[0])


def test_squeeze():
a = nd.ones(shape=LARGE_X)
data = nd.expand_dims(a, axis=0)
res = nd.squeeze(data)
assert a[0] == res[0]
assert res.shape == a.shape


def test_broadcast_div():
a = nd.ones(shape=LARGE_X)
b = nd.ones(shape=LARGE_X) * 2
res = a / b
assert np.sum(res.asnumpy() == 0.5) == a.shape[0]


def test_Dense(ctx=mx.cpu(0)):
data = mx.nd.ones(shape=LARGE_X)
linear = gluon.nn.Dense(2)
linear.initialize(ctx=ctx)
res = linear(data)
res.wait_to_read()
assert res.shape == (LARGE_X, 2)


def test_argsort():
b = create_vector(size=LARGE_X)
s = nd.argsort(b, axis=0, is_ascend=False, dtype=np.int64)
mx.nd.waitall()
assert (s[0].asnumpy() == (LARGE_X - 1)).all()


def test_sort():
b = create_vector(size=LARGE_X)
s = nd.sort(b, axis=0, is_ascend=False)
assert np.sum(s[-1].asnumpy() == 0).all()
s = nd.sort(b, is_ascend=True)
assert np.sum(s[0].asnumpy() == 0).all()


def test_topk():
b = create_vector(size=LARGE_X)
ind = nd.topk(b, k=10, axis=0, dtype=np.int64)
assert np.sum(ind.asnumpy() == (LARGE_X - 1)) == 1
ind, val = mx.nd.topk(b, k=3, axis=0, dtype=np.int64, ret_typ="both", is_ascend=False)
assert np.all(ind == val)
val = nd.topk(b, k=1, axis=0, dtype=np.int64, ret_typ="value")
assert val.sum() == (LARGE_X - 1)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 7448111

Please sign in to comment.