From 5cd5ae4019883413c22f69c5564d6a6e7f44cddd Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Fri, 30 Jun 2017 16:47:52 +0000 Subject: [PATCH 1/6] fix bug in libsvm iter which causes mem corruption --- src/io/iter_sparse_batchloader.h | 1 + tests/python/unittest/test_io.py | 15 +++++++++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/io/iter_sparse_batchloader.h b/src/io/iter_sparse_batchloader.h index a89f21acb2a4..e8cddb9e9704 100644 --- a/src/io/iter_sparse_batchloader.h +++ b/src/io/iter_sparse_batchloader.h @@ -150,6 +150,7 @@ class SparseBatchLoader : public BatchLoader, public SparseIIterator CHECK(data_stype_ == kCSRStorage || label_stype_ == kCSRStorage); CHECK_GT(inst_cache_.size(), 0); out_.data.clear(); + data_.clear(); offsets_.clear(); size_t total_size = inst_cache_[0].data.size(); diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py index 942be2c9d818..7981a7862c1d 100644 --- a/tests/python/unittest/test_io.py +++ b/tests/python/unittest/test_io.py @@ -111,7 +111,6 @@ def test_NDArrayIter_csr(): assert_almost_equal(batch.data[0].asnumpy(), expected) begin += batch_size -''' def test_LibSVMIter(): #TODO(haibin) automatic the test instead of hard coded test cwd = os.getcwd() @@ -145,10 +144,22 @@ def test_LibSVMIter(): assert_almost_equal(data_train.getdata().asnumpy(), expected) i += 1 ''' +def test_LibSVMIter(): + kdda = os.path.join(os.getcwd(), 'news20.t') + data_train = mx.io.LibSVMIter(data_libsvm=kdda, + #label_libsvm=kdda, + data_shape=(62060, ), + #label_shape=f[3], + batch_size=512) + it = iter(data_train) + for batch in it: + print(batch.data[0]) + #break +''' if __name__ == "__main__": test_NDArrayIter() test_MNISTIter() test_Cifar10Rec() - # test_LibSVMIter() + test_LibSVMIter() test_NDArrayIter_csr() From fe47e23e0de64c0288cefd3550cb4eb456bdc2f4 Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Fri, 30 Jun 2017 19:22:30 +0000 Subject: [PATCH 2/6] add test for news dataset --- tests/python/unittest/test_io.py | 105 +++++++++++++++++++------------ 1 file changed, 64 insertions(+), 41 deletions(-) diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py index 7981a7862c1d..f9c906035395 100644 --- a/tests/python/unittest/test_io.py +++ b/tests/python/unittest/test_io.py @@ -112,50 +112,73 @@ def test_NDArrayIter_csr(): begin += batch_size def test_LibSVMIter(): - #TODO(haibin) automatic the test instead of hard coded test - cwd = os.getcwd() - data_path = os.path.join(cwd, 'data.t') - label_path = os.path.join(cwd, 'label.t') - with open(data_path, 'w') as fout: - fout.write('1.0 0:0.5 2:1.2\n') - fout.write('-2.0\n') - fout.write('-3.0 0:0.6 1:2.4 2:1.2\n') - fout.write('4 2:-1.2\n') + def get_data(data_dir, data_name, url, data_origin_name): + if not os.path.isdir(data_dir): + os.system("mkdir " + data_dir) + os.chdir(data_dir) + if (not os.path.exists(data_name)): + import urllib + zippath = os.path.join(data_dir, data_origin_name) + urllib.urlretrieve(url, zippath) + os.system("bzip2 -d %r" % data_origin_name) + os.chdir("..") - with open(label_path, 'w') as fout: - fout.write('1.0\n') - fout.write('-2.0 0:0.125\n') - fout.write('-3.0 2:1.2\n') - fout.write('4 1:1.0 2:-1.2\n') + def check_libSVMIter_synthetic(): + cwd = os.getcwd() + data_path = os.path.join(cwd, 'data.t') + label_path = os.path.join(cwd, 'label.t') + with open(data_path, 'w') as fout: + fout.write('1.0 0:0.5 2:1.2\n') + fout.write('-2.0\n') + fout.write('-3.0 0:0.6 1:2.4 2:1.2\n') + fout.write('4 2:-1.2\n') - data_dir = os.path.join(os.getcwd(), 'data') - f = (data_path, label_path, (3,), (3,), 3) - data_train = mx.io.LibSVMIter(data_libsvm=f[0], - label_libsvm=f[1], - data_shape=f[2], - label_shape=f[3], - batch_size=f[4]) + with open(label_path, 'w') as fout: + fout.write('1.0\n') + fout.write('-2.0 0:0.125\n') + fout.write('-3.0 2:1.2\n') + fout.write('4 1:1.0 2:-1.2\n') - first = mx.nd.array([[ 0.5, 0., 1.2], [ 0., 0., 0.], [ 0.6, 2.4, 1.2]]) - second = mx.nd.array([[ 0., 0., -1.2], [ 0.5, 0., 1.2], [ 0., 0., 0.]]) - i = 0 - for batch in iter(data_train): - expected = first.asnumpy() if i == 0 else second.asnumpy() - assert_almost_equal(data_train.getdata().asnumpy(), expected) - i += 1 -''' -def test_LibSVMIter(): - kdda = os.path.join(os.getcwd(), 'news20.t') - data_train = mx.io.LibSVMIter(data_libsvm=kdda, - #label_libsvm=kdda, - data_shape=(62060, ), - #label_shape=f[3], - batch_size=512) - it = iter(data_train) - for batch in it: - print(batch.data[0]) - #break -''' + data_dir = os.path.join(os.getcwd(), 'data') + f = (data_path, label_path, (3,), (3,), 3) + # TODO refactor this + data_train = mx.io.LibSVMIter(data_libsvm=f[0], + label_libsvm=f[1], + data_shape=f[2], + label_shape=f[3], + batch_size=f[4]) + + first = mx.nd.array([[ 0.5, 0., 1.2], [ 0., 0., 0.], [ 0.6, 2.4, 1.2]]) + second = mx.nd.array([[ 0., 0., -1.2], [ 0.5, 0., 1.2], [ 0., 0., 0.]]) + i = 0 + for batch in iter(data_train): + expected = first.asnumpy() if i == 0 else second.asnumpy() + assert_almost_equal(data_train.getdata().asnumpy(), expected) + i += 1 + + def check_libSVMIter_news_metadata(): + news_metadata = { + 'name': 'news20.t', + 'origin_name': 'news20.t.bz2', + 'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/news20.t.bz2", + 'shape': 62060, + 'num_classes': 20, + } + data_dir = os.path.join(os.getcwd(), 'data') + get_data(data_dir, news_metadata['name'], news_metadata['url'], + news_metadata['origin_name']) + path = os.path.join(os.getcwd(), news_metadata['name']) + data_train = mx.io.LibSVMIter(data_libsvm=path, + data_shape=(news_metadata['shape'], ), + batch_size=512) + iterator = iter(data_train) + for batch in iterator: + # check the range of labels + assert(np.sum(batch.label[0].asnumpy() > 20) == 0) + assert(np.sum(batch.label[0].asnumpy() <= 0) == 0) + + check_libSVMIter_synthetic() + check_libSVMIter_news_metadata() if __name__ == "__main__": test_NDArrayIter() From ad0993ff5c3c5e2dc29e1af578fc60294b202bef Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Fri, 30 Jun 2017 20:37:08 +0000 Subject: [PATCH 3/6] fix wrong path in test --- tests/python/unittest/test_io.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py index f9c906035395..692e950f0230 100644 --- a/tests/python/unittest/test_io.py +++ b/tests/python/unittest/test_io.py @@ -139,14 +139,9 @@ def check_libSVMIter_synthetic(): fout.write('-3.0 2:1.2\n') fout.write('4 1:1.0 2:-1.2\n') - data_dir = os.path.join(os.getcwd(), 'data') - f = (data_path, label_path, (3,), (3,), 3) - # TODO refactor this - data_train = mx.io.LibSVMIter(data_libsvm=f[0], - label_libsvm=f[1], - data_shape=f[2], - label_shape=f[3], - batch_size=f[4]) + data_dir = os.path.join(cwd, 'data') + data_train = mx.io.LibSVMIter(data_libsvm=data_path, label_libsvm=label_path, + data_shape=(3, ), label_shape=(3, ), batch_size=3) first = mx.nd.array([[ 0.5, 0., 1.2], [ 0., 0., 0.], [ 0.6, 2.4, 1.2]]) second = mx.nd.array([[ 0., 0., -1.2], [ 0.5, 0., 1.2], [ 0., 0., 0.]]) @@ -167,7 +162,7 @@ def check_libSVMIter_news_metadata(): data_dir = os.path.join(os.getcwd(), 'data') get_data(data_dir, news_metadata['name'], news_metadata['url'], news_metadata['origin_name']) - path = os.path.join(os.getcwd(), news_metadata['name']) + path = os.path.join(data_dir, news_metadata['name']) data_train = mx.io.LibSVMIter(data_libsvm=path, data_shape=(news_metadata['shape'], ), batch_size=512) From b26506616b8e6fd860bc598987ae0f99f47ace20 Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Sun, 2 Jul 2017 17:16:50 +0000 Subject: [PATCH 4/6] fix import error for urllib --- tests/python/unittest/test_io.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py index 692e950f0230..45cad0b30b37 100644 --- a/tests/python/unittest/test_io.py +++ b/tests/python/unittest/test_io.py @@ -117,9 +117,12 @@ def get_data(data_dir, data_name, url, data_origin_name): os.system("mkdir " + data_dir) os.chdir(data_dir) if (not os.path.exists(data_name)): - import urllib + if sys.version_info[0] >= 3: + from urllib.request import urlretrieve + else: + from urllib import urlretrieve zippath = os.path.join(data_dir, data_origin_name) - urllib.urlretrieve(url, zippath) + urlretrieve(url, zippath) os.system("bzip2 -d %r" % data_origin_name) os.chdir("..") From 3201b55ccf6ea5ac29fec16a594e7d2846da6e6a Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Sun, 2 Jul 2017 22:25:34 +0000 Subject: [PATCH 5/6] update url --- tests/python/unittest/test_io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py index 45cad0b30b37..c48ebbe8d334 100644 --- a/tests/python/unittest/test_io.py +++ b/tests/python/unittest/test_io.py @@ -158,7 +158,7 @@ def check_libSVMIter_news_metadata(): news_metadata = { 'name': 'news20.t', 'origin_name': 'news20.t.bz2', - 'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/news20.t.bz2", + 'url': "http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/news20.t.bz2", 'shape': 62060, 'num_classes': 20, } From 6243732279825a83261af5ec4818d41b9f1aec7a Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Mon, 3 Jul 2017 18:09:32 +0000 Subject: [PATCH 6/6] replace bz command with bz module --- tests/python/unittest/test_io.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py index c48ebbe8d334..356afc19de5e 100644 --- a/tests/python/unittest/test_io.py +++ b/tests/python/unittest/test_io.py @@ -123,7 +123,14 @@ def get_data(data_dir, data_name, url, data_origin_name): from urllib import urlretrieve zippath = os.path.join(data_dir, data_origin_name) urlretrieve(url, zippath) - os.system("bzip2 -d %r" % data_origin_name) + import bz2 + bz_file = bz2.BZ2File(data_origin_name, 'rb') + with open(data_name, 'wb') as fout: + try: + content = bz_file.read() + fout.write(content) + finally: + bz_file.close() os.chdir("..") def check_libSVMIter_synthetic():