From 7385118ab896b1f9602db2211f50860c5e0170ce Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Sun, 14 Jun 2020 18:54:15 -0700 Subject: [PATCH] [topi] fix strategy for sparse dense cuda (#5782) --- python/tvm/relay/op/nn/_nn.py | 2 +- python/tvm/relay/op/strategy/cuda.py | 13 +++++++++++++ python/tvm/relay/op/strategy/generic.py | 22 ++++++++++++++++------ python/tvm/relay/op/strategy/x86.py | 15 ++++++++++----- topi/python/topi/cuda/sparse.py | 1 - topi/python/topi/x86/sparse.py | 2 -- 6 files changed, 40 insertions(+), 15 deletions(-) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index c09b873d42e3..1c76f57a6343 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -69,7 +69,7 @@ def compute_sparse_dense(attrs, inputs, out_type): """Compute definition of sparse_dense""" return [topi.nn.sparse_dense(inputs[0], inputs[1], inputs[2], inputs[3])] -reg.register_schedule("nn.sparse_dense", strategy.schedule_sparse_dense) +reg.register_strategy("nn.sparse_dense", strategy.sparse_dense_strategy) reg.register_pattern("nn.sparse_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 4b019cfcbccc..e0091a18de72 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -493,6 +493,19 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target): plevel=15) return strategy + +@sparse_dense_strategy.register(["cuda", "gpu"]) +def sparse_dense_strategy_cuda(attrs, inputs, out_type, target): + """sparse dense cuda strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_sparse_dense(topi.cuda.sparse_dense), + wrap_topi_schedule(topi.cuda.schedule_sparse_dense), + name="sparse_dense.cuda", + plevel=10) + return strategy + + @argsort_strategy.register(["cuda", "gpu"]) def argsort_strategy_cuda(attrs, inputs, out_type, target): """argsort cuda strategy""" diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 4fa2b11d554d..b1fb421c3e2e 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -599,12 +599,22 @@ def batch_matmul_strategy(attrs, inputs, out_type, target): name="batch_matmul.generic") return strategy -# sparse_dense -@generic_func -def schedule_sparse_dense(attrs, outs, target): - """schedule sparse_dense""" - with target: - return topi.generic.schedule_sparse_dense(outs) +# sparse dense +def wrap_compute_sparse_dense(topi_compute): + """wrap sparse dense topi compute""" + def _compute_sparse_dense(attrs, inputs, out_type): + return [topi_compute(inputs[0], inputs[1], inputs[2], inputs[3])] + return _compute_sparse_dense + +@override_native_generic_func("sparse_dense_strategy") +def sparse_dense_strategy(attrs, inputs, out_type, target): + """sparse dense generic strategy""" + logger.warning("sparse dense is not optimized for this platform.") + strategy = _op.OpStrategy() + strategy.add_implementation(wrap_compute_sparse_dense(topi.nn.sparse_dense), + wrap_topi_schedule(topi.generic.schedule_sparse_dense), + name="sparse_dense.generic") + return strategy # sparse_transpose @generic_func diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 0984e400c6c6..b02db416bdc8 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -294,11 +294,16 @@ def batch_matmul_strategy_cpu(attrs, inputs, out_type, target): plevel=15) return strategy -@schedule_sparse_dense.register("cpu") -def schedule_sparse_dense_cpu(attrs, outs, target): - """schedule sparse_dense for x86""" - with target: - return topi.x86.schedule_sparse_dense(outs) +@sparse_dense_strategy.register("cpu") +def sparse_dense_strategy_cpu(attrs, inputs, out_type, target): + """sparse dense x86 strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation(wrap_compute_sparse_dense(topi.nn.sparse_dense), + wrap_topi_schedule(topi.x86.schedule_sparse_dense), + name="sparse_dense.x86", + plevel=10) + return strategy + @roi_align_strategy.register("cpu") def roi_align_strategy_cpu(attrs, inputs, out_type, target): diff --git a/topi/python/topi/cuda/sparse.py b/topi/python/topi/cuda/sparse.py index fb875b749750..5b57000f403a 100644 --- a/topi/python/topi/cuda/sparse.py +++ b/topi/python/topi/cuda/sparse.py @@ -63,7 +63,6 @@ def schedule_sparse_dense(cfg, outs): """Create schedule for sparse dense""" # pylint:disable=invalid-name s = te.create_schedule([x.op for x in outs]) - def _callback(op): if op.tag == "sparse_dense_bsrmm": y_bsrmm = op.input_tensors[0] diff --git a/topi/python/topi/x86/sparse.py b/topi/python/topi/x86/sparse.py index 54a5af9ca9f0..02cbd2d76ed3 100644 --- a/topi/python/topi/x86/sparse.py +++ b/topi/python/topi/x86/sparse.py @@ -21,11 +21,9 @@ from ..util import traverse_inline, get_const_int from .util import get_fp32_len - def schedule_sparse_dense(outs): """Create schedule for sparse dense""" s = te.create_schedule([x.op for x in outs]) - def _callback(op): simd_width = get_fp32_len() if op.tag == "sparse_dense_csrmm" and op != outs[0].op: