Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

LibsvmIter fix #6898

Merged
merged 6 commits into from
Jul 4, 2017
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/io/iter_sparse_batchloader.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ class SparseBatchLoader : public BatchLoader, public SparseIIterator<TBlobBatch>
CHECK(data_stype_ == kCSRStorage || label_stype_ == kCSRStorage);
CHECK_GT(inst_cache_.size(), 0);
out_.data.clear();
data_.clear();
offsets_.clear();

size_t total_size = inst_cache_[0].data.size();
Expand Down
91 changes: 60 additions & 31 deletions tests/python/unittest/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,44 +111,73 @@ def test_NDArrayIter_csr():
assert_almost_equal(batch.data[0].asnumpy(), expected)
begin += batch_size

'''
def test_LibSVMIter():
#TODO(haibin) automatic the test instead of hard coded test
cwd = os.getcwd()
data_path = os.path.join(cwd, 'data.t')
label_path = os.path.join(cwd, 'label.t')
with open(data_path, 'w') as fout:
fout.write('1.0 0:0.5 2:1.2\n')
fout.write('-2.0\n')
fout.write('-3.0 0:0.6 1:2.4 2:1.2\n')
fout.write('4 2:-1.2\n')
def get_data(data_dir, data_name, url, data_origin_name):
if not os.path.isdir(data_dir):
os.system("mkdir " + data_dir)
os.chdir(data_dir)
if (not os.path.exists(data_name)):
import urllib
zippath = os.path.join(data_dir, data_origin_name)
urllib.urlretrieve(url, zippath)
os.system("bzip2 -d %r" % data_origin_name)
os.chdir("..")

with open(label_path, 'w') as fout:
fout.write('1.0\n')
fout.write('-2.0 0:0.125\n')
fout.write('-3.0 2:1.2\n')
fout.write('4 1:1.0 2:-1.2\n')
def check_libSVMIter_synthetic():
cwd = os.getcwd()
data_path = os.path.join(cwd, 'data.t')
label_path = os.path.join(cwd, 'label.t')
with open(data_path, 'w') as fout:
fout.write('1.0 0:0.5 2:1.2\n')
fout.write('-2.0\n')
fout.write('-3.0 0:0.6 1:2.4 2:1.2\n')
fout.write('4 2:-1.2\n')

data_dir = os.path.join(os.getcwd(), 'data')
f = (data_path, label_path, (3,), (3,), 3)
data_train = mx.io.LibSVMIter(data_libsvm=f[0],
label_libsvm=f[1],
data_shape=f[2],
label_shape=f[3],
batch_size=f[4])
with open(label_path, 'w') as fout:
fout.write('1.0\n')
fout.write('-2.0 0:0.125\n')
fout.write('-3.0 2:1.2\n')
fout.write('4 1:1.0 2:-1.2\n')

first = mx.nd.array([[ 0.5, 0., 1.2], [ 0., 0., 0.], [ 0.6, 2.4, 1.2]])
second = mx.nd.array([[ 0., 0., -1.2], [ 0.5, 0., 1.2], [ 0., 0., 0.]])
i = 0
for batch in iter(data_train):
expected = first.asnumpy() if i == 0 else second.asnumpy()
assert_almost_equal(data_train.getdata().asnumpy(), expected)
i += 1
'''
data_dir = os.path.join(cwd, 'data')
data_train = mx.io.LibSVMIter(data_libsvm=data_path, label_libsvm=label_path,
data_shape=(3, ), label_shape=(3, ), batch_size=3)

first = mx.nd.array([[ 0.5, 0., 1.2], [ 0., 0., 0.], [ 0.6, 2.4, 1.2]])
second = mx.nd.array([[ 0., 0., -1.2], [ 0.5, 0., 1.2], [ 0., 0., 0.]])
i = 0
for batch in iter(data_train):
expected = first.asnumpy() if i == 0 else second.asnumpy()
assert_almost_equal(data_train.getdata().asnumpy(), expected)
i += 1

def check_libSVMIter_news_metadata():
news_metadata = {
'name': 'news20.t',
'origin_name': 'news20.t.bz2',
'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/news20.t.bz2",
'shape': 62060,
'num_classes': 20,
}
data_dir = os.path.join(os.getcwd(), 'data')
get_data(data_dir, news_metadata['name'], news_metadata['url'],
news_metadata['origin_name'])
path = os.path.join(data_dir, news_metadata['name'])
data_train = mx.io.LibSVMIter(data_libsvm=path,
data_shape=(news_metadata['shape'], ),
batch_size=512)
iterator = iter(data_train)
for batch in iterator:
# check the range of labels
assert(np.sum(batch.label[0].asnumpy() > 20) == 0)
assert(np.sum(batch.label[0].asnumpy() <= 0) == 0)

check_libSVMIter_synthetic()
check_libSVMIter_news_metadata()

if __name__ == "__main__":
test_NDArrayIter()
test_MNISTIter()
test_Cifar10Rec()
# test_LibSVMIter()
test_LibSVMIter()
test_NDArrayIter_csr()