Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[WIP] Sparse Tensor (#5800)
Browse files Browse the repository at this point in the history
* squash

merge with 38f7c55

compiles on GPU

update check alloc:

Checkpoint. Pass elem-sum gpu test

bug fix for copyfromto. sparse sgd test pass on gpu

inefficient implementation for csr copy

update submodule

fix lint

Simple bind with infer storage type (#32)

* Symbol binding for sparse tensor development. (#31)

* Initial checkin

* Add init functions for simple bind in graph_executor

* Add simple_bind c_api

* Add simple bind c-api

* Assign zeros to in_args, arg_grads, and aux_states

* Add simple_bind2 python interface

* Fix python interface bugs

* Interface changes

* Fix

* Fix core dump

* Add bind_ith_exec c_api

* Change simple_bind2

* Fix seg fault

* Finish simple_bind

* Change _bind_ith_exec

* Refactor simple_bind initialization flow for bind

* Consolidate bind and simple_bind graph init flow

* Fix bug

* Clean up

* Add comments

* Clean up

* Clean up

* Minor correction

* Rename APIs in graph executor

* Refactor

* Rebase

* Delete deprecated functions

* Move more front-end work to backend

* Bug fix

* Fix failed tests

* Minor fix

* Fix lint

* Fix lint

* Revert unnecessary changes

* Revert

* Revert

* Clean up

* Fix lint

Conflicts:
	python/mxnet/symbol.py
	src/executor/graph_executor.cc

* Add inferstorage to graph executor

* re-enable tests for sparse embedding with simple_bind

* type switch fix in sparse embedding"
;

change `default` to `default_storage` for cast storage op (#33)

* change default to default_storage

* disable cpp test build temporarily

attempt to fix windows build error, and fix lint (#34)

update nnvm submodule (#37)

Scipy build (#38)

* update nnvm submodule

* add scipy pip install for dockerfile

Python3 unit tests (#39)

* change xrange to range for python3 compatiblity"

* remove more xrange from tests

replace long with int for python3 (#40)

fix the rest of TShape constructor errors (#41)

fix lint (#42)

fix wrong usage of mshadow::Shape1" (#43)

implementation for Csr slice on cpu (#36)

* CPU implementation for CSR

remove seg_len from csr slice

add some docs for slice csr

change indptr, values, etc to be private member

bug fix in sparse embedding

update nnvm submoduel

fix lint

update unit test for sparse nd"

* add const for SliceCsrIndPtr kernel

Fix sparse dot according to the new RSP definition (#35)

* Fix csr dot dns

* Fix sparse dot

* Add fallback and test cases for dot(csr, dns)=dns

* Add int type switch

* Fix

* Fix

* Fix

update mshadow submodule (#44)

Fix dns to rsp (#46)

fix lint (#47)

add runtime storage fallback detection" (#48)

* add runtime storage fallback detection"

* replace cast storage ex with cast storage impl

Fm example (#45)

* update csr slice logic to avoid confusion. add more exmaples.

* add hint to module.update

* more testcases(fallback) for sparse_nd

* add to_csr() and to_rsp() method. More unit test (fallback now)

* add fm test. fix lint

* register sparse sgd under Optim.SGD

* update dmlc-core submoduel

* change indptr to _indptr temporarily. add const ref to fname

fix lint

fix lint; (#51)

Guard gpu cast storage (#50)

* Clean up

* Fix typo

Rearrange unit test files (#52)

fix lint. add scipy for python_test. fix scipy.sparse import error. fix truediv for python3

fix travis test (#54)

* remove pyc files

* add verbose for travis nosetests

cleanup some testing code and enums (#57)

* update Makefile

* refactor test_sparse_operator

* change `default_storage` back to `default`

* remove unused cpp tests

port libsvm parser to mxnet as libsvm iter (#55)

* copied csv iter to libsvm iter

test

libsvm iter draft

handle round batch == false for csr batch loader

code refactoring

add get stype, shape interface to iiter

separate class for sparse iter

add missing file

fix mem corruption'

rename variables

add comments

also read label from libsvm

add test. update docs. update submodule

Conflicts:
	python/mxnet/sparse_ndarray.py

* update submodule

* fix lint

* update test

* revert naming change

add benchmark scritp for dot (#59)

* add benchmark scritp for dot

add gpu option for bench

add get_data funciton for benchmark

print t_sparse, too;

add comment

change nnz to dnesity

add backward

* add comment

update fm test (#62)

introduce CSRNDarray and rowsparseNDarray to python frontend api (#58)

* introduce CSRNDarray and rowsparseNDarray to python frontend api

* temporarily disable fm_module test

fix lint (#64)

fix typo. disable libsvm io test (#65)

Improve dot (#61)

* Init checkin

* Fix

* Adjust dot parallelization methods

* Set num_omp_threads for benchmark from command line

* Fix omp thread number

* Clean up

* Add scipy as dot baseline

* Fix format

sparse_retain op (#66)

* Initial checkin

* Fix bugs

* Add unit test for sparse_retain

* Add example and modify test

add storage cast for outputs that have non-default storage (#67)

fix gpu build (#69)

Fix test_sparse_retain python3 issue (#68)

revert nnvm version

* draft for sgd rsp rsp (#75)

support sgd(rsp, rsp)

support dot(csr, rsp) when rsp is full

add ref to const ndarray params

support sparse embedding with rsp weight'

fix lint

modify embedding backward to produce dense grad

remove invalid_rid for rsp->dns

remove previous embedding op changes

pass sparse embedding test

add STORAGE_TYPE_ASSIGN_CHECK

remove backward storage infer

* fix lint (#78)

* fix lint (#79)

* serial elemwise sum impl (#80)

update module kvstore interface

add other missing params and functions

revert some interface changes

revert some more changes

reomve explicit casting for gradients on kvstore

update Comm interface

update fm example

Conflicts:
	python/mxnet/model.py
	python/mxnet/ndarray.py

* bug fix for initializing module with row_sparse weight (#81)

* bug fix for initializing module with row_sparse weight

* update log message

* Sparse ndarray serialization and deserialization (#77)

* Initial checkin

* Add unit tests

* Fix lint

* Fix lint (#84)

* Sgd with row_sparse weight, dns gradient (#83)

* sgd rsp dns draft

* support sgd_mom(rsp, dns, rsp)

* update doc

* remove cast storage for kv updater

* code refactoring

* update mshadow version (#88)

* csr slice bug fix (#90)

* benchmark dot code refactor (#87)

* q^x6x add some code in benchmark

* refactor

* minor fixes

* fix

* lint fix

* Add unit test (#91)

* add unittest

* minor fix

* remove commented lines

* change test func name

* add test rsp

* kvstore push row sparse (#93)

* Add multi-thread cpu elemwise sum for rsps

* Minor fix

* Add flag to switch between serial and multi-thread kvstore push

* Fix lint in sparse_ndarray.py

* Revert "Fix lint in sparse_ndarray.py"

This reverts commit d7225ec.

* Fix ndarray init in copy(ctx)

* Add env var to control the flow of serial/parallel reduce

* Refactor

* Fix copy ndarray bug

* Fix lint

* Refactor

* Fix windows openmp build failure (#94)

* update mshadow submoduel (#95)

* Revert "update mshadow submoduel (#95)" (#96)

This reverts commit 1a129e4.

* Refactor sparse tensor code (#99)

* Initial checkin test_sparse_ndarray passes

* Fix test failure

* Clean up

* Clean up

* Move init backend op to ndarray_utils

* Fix lint

* Eliminate circular dependency on headers

* More refactor

* Fix gpu build and consolidate Slice for dense and sparse

* Clean up

* More refactor

* Clean up

* Fix gpu build

* Fix comment

* fix pylint (#100)

* Fix refactor sparse gpu test (#104)

* Fix gpu build

* Fix

* Fix gpu test failure

* change idx types from int32 to int64 (#101)

Conflicts:
	python/mxnet/test_utils.py
	tests/python/unittest/test_sparse_operator.py

update mshadow submodule

fix extra quotes in test script

change indptr type to int64

better err message for rsp"

* revert LOG(DEBUG) change (#105)

* fix undefined zeros in optimizer.py (#106)

* move init dns zeros to init_op.h for kvstore to use (#107)

* Refactor cast storage (#109)

* Refactor cast_storage

* Add cast_storage cc and cu files

* Remove redundant comments

* Replace std::accumulate with ParallelAccumulate

* Clean up

* Fix windows build

* Rowsparse kv (#111)

* update kvstore unit test

Conflicts:
	tests/python/unittest/test_kvstore.py

update model/module.py

Conflicts:
	python/mxnet/model.py
	python/mxnet/module/module.py

fix lint

resolve conflict

remove int keys in kvstore

update cast to str function

* fix failed dist_sync_kv test

* bug fix in comm to ensure merged gradient is of the right type

bug fix in comm

* row sparse dist kvstore draft (push only)

row_sparse pull

* add ndarray row sparse shared mem constructor

* code refactoring

* add test for row_sparse weight

bug fix for kv server slicing

add async support

rsolve race condition in kvstore

* resolve error after reb ase

* fix lint (#113)

* rename some python funciton (#114)

* _to_rsp

* _to_csr. raise NotImplementedError

* todense

* fix lint (#115)
  • Loading branch information
eric-haibin-lin authored and piiswrong committed Jun 26, 2017
1 parent d75ef8e commit 7b2ef68
Show file tree
Hide file tree
Showing 95 changed files with 7,928 additions and 735 deletions.
6 changes: 3 additions & 3 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -215,17 +215,17 @@ del /Q *.7z
// Python unittest for CPU
def python_ut(docker_type) {
timeout(time: max_time, unit: 'MINUTES') {
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests --with-timer --verbose tests/python/unittest"
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-2.7 --with-timer --verbose tests/python/unittest"
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-3.4 --with-timer --verbose tests/python/unittest"
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests --with-timer --verbose tests/python/train"
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-2.7 --with-timer --verbose tests/python/train"
}
}

// GPU test has two parts. 1) run unittest on GPU, 2) compare the results on
// both CPU and GPU
def python_gpu_ut(docker_type) {
timeout(time: max_time, unit: 'MINUTES') {
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests --with-timer --verbose tests/python/gpu"
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-2.7 --with-timer --verbose tests/python/gpu"
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-3.4 --with-timer --verbose tests/python/gpu"
}
}
Expand Down
228 changes: 228 additions & 0 deletions benchmark/python/sparse_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
import ctypes

from mxnet.test_utils import *
import scipy.sparse as sp
import os
import time
import argparse

from mxnet.base import check_call, _LIB
from util import get_data, estimate_density

parser = argparse.ArgumentParser(description="Benchmark sparse operators",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--num-omp-threads', type=int, default=1, help='number of omp threads to set in MXNet')
args = parser.parse_args()

# some data information
kdda = {
'data_mini': 'kdda.t.mini',
'data_name': 'kdda.t',
'data_origin_name': 'kdda.t.bz2',
'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/kdda.t.bz2",
'feature_dim': 20216830,
'm': 200,
'batch_size': [64]
}

avazu = {
'data_mini': 'avazu-app.t.mini',
'data_name': 'avazu-app.t',
'data_origin_name': 'avazu-app.t.bz2',
'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/avazu-app.t.bz2",
'feature_dim': 1000000,
'm': 500,
'batch_size': [64, 128]
}


def measure_cost(repeat, f, *args, **kwargs):
# start bench
start = time.time()
results = []
for i in range(repeat):
results.append(f(*args, **kwargs))
for result in results:
result.wait_to_read()
end = time.time()
diff = end - start
return diff / repeat


def test_dot_real(data_dict):
def get_iter(path, data_shape, batch_size):
data_train = mx.io.LibSVMIter(data_libsvm=path,
data_shape=data_shape,
batch_size=batch_size)
data_iter = iter(data_train)
return data_iter

data_dir = os.path.join(os.getcwd(), 'data')

path = os.path.join(data_dir, data_dict['data_name'])
if not os.path.exists(path):
get_data(
data_dir,
data_dict['data_name'],
data_dict['url'],
data_dict['data_origin_name']
)
assert os.path.exists(path)

k = data_dict['feature_dim']
m = data_dict['m']
density = estimate_density(path, data_dict['feature_dim'])

mini_path = os.path.join(data_dir, data_dict['data_mini'])
if not os.path.exists(mini_path):
os.system("head -n 2000 %r > %r" % (path, mini_path))
assert os.path.exists(mini_path)

print "Running Benchmarking on %r data" % data_dict['data_mini']
for batch_size in data_dict['batch_size']: # iterator through different batch size of choice
print "batch_size is %d" % batch_size
# model
data_shape = (k, )
train_iter = get_iter(mini_path, data_shape, batch_size)
weight = mx.nd.random_uniform(low=0, high=1, shape=(k, m))

csr_data = []
dns_data = []
num_batch = 0
for batch in train_iter:
data = train_iter.getdata()
csr_data.append(data)
dns_data.append(data.todense())
num_batch += 1
bag_of_data = [csr_data, dns_data]
num_repeat = 5
costs = []
for d in bag_of_data:
weight.wait_to_read()
cost = 0.
count = 0
for d_batch in d:
d_batch.wait_to_read()
cost += measure_cost(num_repeat, mx.nd.dot, d_batch, weight)
count += 1
costs.append(cost/count)
t_sparse = costs[0]
t_dense = costs[1]
ratio = t_dense / t_sparse
print('density(%)\tn\tm\tk\tt_dense/t_sparse\tt_dense\tt_sparse')
fmt = "%0.4f\t\t%d\t%d\t%d\t%0.2f\t\t\t%0.4f\t%0.6f"
print(fmt % (density * 100, batch_size, m, k, ratio, t_dense, t_sparse))


def test_dot_synthetic():
"""benchmark mx.nd.dot(sparse_ndarray, dense_ndarray) with given density.
`t_sparse` is the time cost of dot(csr, dns), while `t_dense` is the time cost
of dot(dns, dns), with the same matrix except that it is in default storage type.
"""
def measure_cost_forward_baseline(repeat, dot, lhs, rhs):
start = time.time()
for i in range(repeat):
dot(lhs, rhs)
end = time.time()
diff = end - start
return diff / repeat

def measure_cost_backward_baseline(repeat, dot, transpose, lhs, rhs):
start = time.time()
for i in range(repeat):
dot(transpose(lhs), rhs)
end = time.time()
diff = end - start
return diff / repeat

def bench_dot_forward(m, k, n, density, ctx, repeat):
set_default_context(ctx)
dns = mx.nd.random_uniform(shape=(k, n)).copyto(ctx)
data_shape = (m, k)
csr_data = rand_ndarray(data_shape, 'csr', density)
dns_data = csr_data.todense()
rhs_dns_np = dns.asnumpy()
lhs_csr_sp = sp.csr_matrix(dns_data.asnumpy()) # csr in scipy
lhs_dns_np = lhs_csr_sp.todense()

data = [dns_data, csr_data]
costs = []
for d in data:
dns.wait_to_read()
d.wait_to_read()
cost = measure_cost(repeat, mx.nd.dot, d, dns)
costs.append(cost)
ratio = costs[0] / costs[1]

costs_baseline = []
cost = measure_cost_forward_baseline(repeat, np.dot, lhs_dns_np, rhs_dns_np)
costs_baseline.append(cost)
cost = measure_cost_forward_baseline(repeat, sp.spmatrix.dot, lhs_csr_sp, rhs_dns_np)
costs_baseline.append(cost)
ratio_baseline = costs_baseline[0] / costs_baseline[1]
fmt = "%0.1f\t\t%s\t%d\t%d\t%d\t%0.2f\t\t\t%0.2f\t%0.5f\t\t%0.2f\t\t\t\t%0.6f\t%0.5f"
print(fmt % (density * 100, str(ctx), n, m, k, ratio, costs[0], costs[1],
ratio_baseline, costs_baseline[0], costs_baseline[1]))

def bench_dot_backward(m, k, n, density, ctx, repeat):
set_default_context(ctx)
dns = mx.nd.random_uniform(shape=(m, n)).copyto(ctx)
data_shape = (m, k)
csr_data = rand_ndarray(data_shape, 'csr', density)
dns_data = csr_data.todense()
rhs_dns_np = dns.asnumpy()
lhs_csr_sp = sp.csr_matrix(dns_data.asnumpy())
lhs_dns_np = lhs_csr_sp.todense()

data = [dns_data, csr_data]
costs = []
for d in data:
dns.wait_to_read()
d.wait_to_read()
cost = measure_cost(repeat, mx.nd.dot, d, dns, transpose_a=True)
costs.append(cost)
ratio = costs[0] / costs[1]

costs_baseline = []
cost = measure_cost_backward_baseline(repeat, np.dot, np.transpose, lhs_dns_np, rhs_dns_np)
costs_baseline.append(cost)
cost = measure_cost_backward_baseline(repeat, sp.spmatrix.dot, sp.spmatrix.transpose, lhs_csr_sp, rhs_dns_np)
costs_baseline.append(cost)
ratio_baseline = costs_baseline[0] / costs_baseline[1]
fmt = "%0.1f\t\t%s\t%d\t%d\t%d\t%0.2f\t\t\t%0.2f\t%0.5f\t\t%0.2f\t\t\t\t%0.6f\t%0.5f"
print(fmt % (density * 100, str(ctx), n, m, k, ratio, costs[0], costs[1],
ratio_baseline, costs_baseline[0], costs_baseline[1]))

print("A = sparse NDArray of shape(m, k)")
print("B = dense NDArray of shape(k, n)")
print("dot_forward\tdot(csr, dns)")
print('density(%)\tcontext\tn\tm\tk\tt_dense/t_sparse\tt_dense\tt_sparse'
'\tt_scipy_dense/t_scipy_sparse\tt_scipy_dense\tt_scipy_sparse')

check_call(_LIB.MXSetNumOMPThreads(ctypes.c_int(args.num_omp_threads)))
# TODO(haibin) make these runtime options
m = 512
k = [50000, 100000]
n = [64, 128]
density = [1.00, 0.90, 0.70, 0.50, 0.30, 0.20, 0.10, 0.07, 0.05, 0.02, 0.01, 0.005, 0.001]
num_repeat = 10
# contexts = [mx.cpu(), mx.gpu(0)]
contexts = [mx.cpu()]
for i in range(2):
for ctx in contexts:
for den in density:
bench_dot_forward(m, k[i], n[i], den, ctx, num_repeat)

print("dot_backward\tdot(csr.T, dns)")
print('density(%)\tcontext\tn\tm\tk\tt_dense/t_sparse\tt_dense\tt_sparse'
'\tt_scipy_dense/t_scipy_sparse\tt_scipy_dense\tt_scipy_sparse')
for i in range(2):
for ctx in contexts:
for den in density:
bench_dot_backward(m, k[i], n[i], den, ctx, num_repeat)


if __name__ == "__main__":
test_dot_real(avazu)
test_dot_real(kdda)
test_dot_synthetic()
33 changes: 33 additions & 0 deletions benchmark/python/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os
import random


def get_data(data_dir, data_name, url, data_origin_name):
if not os.path.isdir(data_dir):
os.system("mkdir " + data_dir)
os.chdir(data_dir)
if (not os.path.exists(data_name)):
import urllib
zippath = os.path.join(data_dir, data_origin_name)
urllib.urlretrieve(url, zippath)
os.system("bzip2 -d %r" % data_origin_name)
os.chdir("..")


def estimate_density(DATA_PATH, feature_size):
"""sample 10 times of a size of 1000 for estimating the density of the sparse dataset"""
if not os.path.exists(DATA_PATH):
raise Exception("Data is not there!")
density = []
P = 0.01
for _ in xrange(10):
num_non_zero = 0
num_sample = 0
with open(DATA_PATH) as f:
for line in f:
if (random.random() < P):
num_non_zero += len(line.split(" ")) - 1
num_sample += 1
density.append(num_non_zero * 1.0 / (feature_size * num_sample))
return sum(density) / len(density)

Loading

0 comments on commit 7b2ef68

Please sign in to comment.