diff --git a/src/io/iter_sparse_batchloader.h b/src/io/iter_sparse_batchloader.h index a89f21acb2a4..e8cddb9e9704 100644 --- a/src/io/iter_sparse_batchloader.h +++ b/src/io/iter_sparse_batchloader.h @@ -150,6 +150,7 @@ class SparseBatchLoader : public BatchLoader, public SparseIIterator CHECK(data_stype_ == kCSRStorage || label_stype_ == kCSRStorage); CHECK_GT(inst_cache_.size(), 0); out_.data.clear(); + data_.clear(); offsets_.clear(); size_t total_size = inst_cache_[0].data.size(); diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py index 942be2c9d818..356afc19de5e 100644 --- a/tests/python/unittest/test_io.py +++ b/tests/python/unittest/test_io.py @@ -111,44 +111,83 @@ def test_NDArrayIter_csr(): assert_almost_equal(batch.data[0].asnumpy(), expected) begin += batch_size -''' def test_LibSVMIter(): - #TODO(haibin) automatic the test instead of hard coded test - cwd = os.getcwd() - data_path = os.path.join(cwd, 'data.t') - label_path = os.path.join(cwd, 'label.t') - with open(data_path, 'w') as fout: - fout.write('1.0 0:0.5 2:1.2\n') - fout.write('-2.0\n') - fout.write('-3.0 0:0.6 1:2.4 2:1.2\n') - fout.write('4 2:-1.2\n') + def get_data(data_dir, data_name, url, data_origin_name): + if not os.path.isdir(data_dir): + os.system("mkdir " + data_dir) + os.chdir(data_dir) + if (not os.path.exists(data_name)): + if sys.version_info[0] >= 3: + from urllib.request import urlretrieve + else: + from urllib import urlretrieve + zippath = os.path.join(data_dir, data_origin_name) + urlretrieve(url, zippath) + import bz2 + bz_file = bz2.BZ2File(data_origin_name, 'rb') + with open(data_name, 'wb') as fout: + try: + content = bz_file.read() + fout.write(content) + finally: + bz_file.close() + os.chdir("..") - with open(label_path, 'w') as fout: - fout.write('1.0\n') - fout.write('-2.0 0:0.125\n') - fout.write('-3.0 2:1.2\n') - fout.write('4 1:1.0 2:-1.2\n') + def check_libSVMIter_synthetic(): + cwd = os.getcwd() + data_path = os.path.join(cwd, 'data.t') + label_path = os.path.join(cwd, 'label.t') + with open(data_path, 'w') as fout: + fout.write('1.0 0:0.5 2:1.2\n') + fout.write('-2.0\n') + fout.write('-3.0 0:0.6 1:2.4 2:1.2\n') + fout.write('4 2:-1.2\n') - data_dir = os.path.join(os.getcwd(), 'data') - f = (data_path, label_path, (3,), (3,), 3) - data_train = mx.io.LibSVMIter(data_libsvm=f[0], - label_libsvm=f[1], - data_shape=f[2], - label_shape=f[3], - batch_size=f[4]) + with open(label_path, 'w') as fout: + fout.write('1.0\n') + fout.write('-2.0 0:0.125\n') + fout.write('-3.0 2:1.2\n') + fout.write('4 1:1.0 2:-1.2\n') - first = mx.nd.array([[ 0.5, 0., 1.2], [ 0., 0., 0.], [ 0.6, 2.4, 1.2]]) - second = mx.nd.array([[ 0., 0., -1.2], [ 0.5, 0., 1.2], [ 0., 0., 0.]]) - i = 0 - for batch in iter(data_train): - expected = first.asnumpy() if i == 0 else second.asnumpy() - assert_almost_equal(data_train.getdata().asnumpy(), expected) - i += 1 -''' + data_dir = os.path.join(cwd, 'data') + data_train = mx.io.LibSVMIter(data_libsvm=data_path, label_libsvm=label_path, + data_shape=(3, ), label_shape=(3, ), batch_size=3) + + first = mx.nd.array([[ 0.5, 0., 1.2], [ 0., 0., 0.], [ 0.6, 2.4, 1.2]]) + second = mx.nd.array([[ 0., 0., -1.2], [ 0.5, 0., 1.2], [ 0., 0., 0.]]) + i = 0 + for batch in iter(data_train): + expected = first.asnumpy() if i == 0 else second.asnumpy() + assert_almost_equal(data_train.getdata().asnumpy(), expected) + i += 1 + + def check_libSVMIter_news_metadata(): + news_metadata = { + 'name': 'news20.t', + 'origin_name': 'news20.t.bz2', + 'url': "http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/news20.t.bz2", + 'shape': 62060, + 'num_classes': 20, + } + data_dir = os.path.join(os.getcwd(), 'data') + get_data(data_dir, news_metadata['name'], news_metadata['url'], + news_metadata['origin_name']) + path = os.path.join(data_dir, news_metadata['name']) + data_train = mx.io.LibSVMIter(data_libsvm=path, + data_shape=(news_metadata['shape'], ), + batch_size=512) + iterator = iter(data_train) + for batch in iterator: + # check the range of labels + assert(np.sum(batch.label[0].asnumpy() > 20) == 0) + assert(np.sum(batch.label[0].asnumpy() <= 0) == 0) + + check_libSVMIter_synthetic() + check_libSVMIter_news_metadata() if __name__ == "__main__": test_NDArrayIter() test_MNISTIter() test_Cifar10Rec() - # test_LibSVMIter() + test_LibSVMIter() test_NDArrayIter_csr()