diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py index cb801ba72872..ffe4b97e33db 100644 --- a/python/tvm/autotvm/measure/measure_methods.py +++ b/python/tvm/autotvm/measure/measure_methods.py @@ -30,6 +30,7 @@ from random import getrandbits from collections import namedtuple import tempfile +import numpy as np import tvm._ffi import tvm.ir.transform @@ -560,9 +561,11 @@ def run_through_rpc( raise AttributeError( "Please make sure USE_RANDOM is ON in the config.cmake " "on the remote devices" ) - args = [nd.empty(x[0], dtype=x[1], ctx=ctx) for x in build_result.arg_info] - for arg in args: - random_fill(arg) + args = [nd.array(np.zeros(x[0], dtype=x[1]), ctx=ctx) for x in build_result.arg_info] + if "scatter" not in measure_input.task.name: + # the index tensor of scatter op cannot be randomly initialized + for arg in args: + random_fill(arg) ctx.sync() costs = time_f(*args).results diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 04c16ddd344c..3863df0fd831 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -783,10 +783,23 @@ def scatter_cuda(attrs, inputs, out_type, target): strategy = _op.OpStrategy() strategy.add_implementation( wrap_compute_scatter(topi.cuda.scatter), - wrap_topi_schedule(topi.generic.schedule_extern), + wrap_topi_schedule(topi.cuda.schedule_scatter), name="scatter.cuda", plevel=10, ) + + rank = len(inputs[0].shape) + + with SpecializedCondition(rank == 1): + if target.kind.name == "cuda" and get_global_func( + "tvm.contrib.thrust.stable_sort_by_key", allow_missing=True + ): + strategy.add_implementation( + wrap_compute_scatter(topi.cuda.scatter_via_sort), + wrap_topi_schedule(topi.cuda.schedule_scatter_via_sort), + name="scatter_via_sort.cuda", + plevel=9, # use the sequential version by default + ) return strategy diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 363832ef8b2f..8dd9dc5844dd 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1123,7 +1123,7 @@ def wrap_compute_scatter(topi_compute): """Wrap scatter topi compute""" def _compute_scatter(attrs, inputs, _): - return [topi_compute(inputs[0], inputs[1], inputs[2], axis=attrs.axis)] + return [topi_compute(inputs[0], inputs[1], inputs[2], attrs.axis)] return _compute_scatter diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index be602c8ab7a3..b34bd1df14e4 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -17,16 +17,33 @@ # pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument """Scatter operator """ import tvm -from tvm import te +from tvm import te, autotvm from ..scatter import _verify_scatter_nd_inputs +from ..generic import schedule_extern from .nms import atomic_add from .sort import stable_sort_by_key_thrust, is_thrust_available +from ..utils import prod def ceil_div(a, b): return (a + b - 1) // b +def _memcpy_ir(ib, out_ptr, data_ptr, shape): + fused = prod(shape) + with ib.new_scope(): + num_thread = int(tvm.target.Target.current(allow_none=False).max_num_threads) + num_blocks = ceil_div(fused, num_thread) + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(bx, "thread_extent", num_blocks) + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", num_thread) + tid = bx * num_thread + tx + + with ib.if_scope(tid < fused): + out_ptr[tid] = data_ptr[tid] + + def gen_ir_1d(data, indices, updates, axis, out, update_func): """Generate scatter ir for 1d inputs @@ -63,10 +80,7 @@ def gen_ir_1d(data, indices, updates, axis, out, update_func): out_ptr = ib.buffer_ptr(out) data_ptr = ib.buffer_ptr(data) - with ib.new_scope(): - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(bx, "thread_extent", n) - out_ptr[bx] = data_ptr[bx] + _memcpy_ir(ib, out_ptr, data_ptr, data.shape) indices_ptr = ib.buffer_ptr(indices) updates_ptr = ib.buffer_ptr(updates) @@ -114,8 +128,6 @@ def gen_ir_2d(data, indices, updates, axis, out, update_func): ret : tir The computational ir. """ - warp_size = tvm.target.Target.current(False).thread_warp_size - n = data.shape[0] c = data.shape[1] @@ -124,16 +136,7 @@ def gen_ir_2d(data, indices, updates, axis, out, update_func): out_ptr = ib.buffer_ptr(out) data_ptr = ib.buffer_ptr(data) - with ib.new_scope(): - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(bx, "thread_extent", n) - tx = te.thread_axis("threadIdx.x") - ib.scope_attr(tx, "thread_extent", warp_size) - with ib.for_range(0, ceil_div(c, warp_size), name="j") as j_: - j = j_ * warp_size + tx - with ib.if_scope(j < c): - idx = bx * c + j - out_ptr[idx] = data_ptr[idx] + _memcpy_ir(ib, out_ptr, data_ptr, data.shape) indices_ptr = ib.buffer_ptr(indices) updates_ptr = ib.buffer_ptr(updates) @@ -205,18 +208,7 @@ def gen_ir_3d(data, indices, updates, axis, out, update_func): out_ptr = ib.buffer_ptr(out) data_ptr = ib.buffer_ptr(data) - with ib.new_scope(): - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(bx, "thread_extent", n) - by = te.thread_axis("blockIdx.y") - ib.scope_attr(by, "thread_extent", c) - tx = te.thread_axis("threadIdx.x") - ib.scope_attr(tx, "thread_extent", warp_size) - with ib.for_range(0, ceil_div(h, warp_size), name="k") as k_: - k = k_ * warp_size + tx - with ib.if_scope(k < h): - idx = (bx * c + by) * h + k - out_ptr[idx] = data_ptr[idx] + _memcpy_ir(ib, out_ptr, data_ptr, data.shape) indices_ptr = ib.buffer_ptr(indices) updates_ptr = ib.buffer_ptr(updates) @@ -311,20 +303,7 @@ def gen_ir_4d(data, indices, updates, axis, out, update_func): out_ptr = ib.buffer_ptr(out) data_ptr = ib.buffer_ptr(data) - with ib.new_scope(): - i = te.thread_axis("blockIdx.x") - ib.scope_attr(i, "thread_extent", n) - j = te.thread_axis("blockIdx.y") - ib.scope_attr(j, "thread_extent", c) - k = te.thread_axis("blockIdx.z") - ib.scope_attr(k, "thread_extent", h) - tx = te.thread_axis("threadIdx.x") - ib.scope_attr(tx, "thread_extent", warp_size) - with ib.for_range(0, ceil_div(w, warp_size), name="l") as l_: - l = l_ * warp_size + tx - with ib.if_scope(l < w): - idx = ((i * c + j) * h + k) * w + l - out_ptr[idx] = data_ptr[idx] + _memcpy_ir(ib, out_ptr, data_ptr, data.shape) indices_ptr = ib.buffer_ptr(indices) updates_ptr = ib.buffer_ptr(updates) @@ -417,7 +396,71 @@ def gen_ir_4d(data, indices, updates, axis, out, update_func): return ib.get() -def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, axis, out, _): +@autotvm.register_topi_compute("scatter.cuda") +def scatter(cfg, data, indices, updates, axis=0): + """Update data at positions defined by indices with values in updates + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + indices : relay.Expr + The index locations to update. + + updates : relay.Expr + The values to update. + + axis : int + The axis to scatter on + + Returns + ------- + ret : relay.Expr + The computed result. + """ + if axis < 0: + axis += len(data.shape) + assert axis >= 0 + assert axis < len(data.shape) + + rank = len(data.shape) + assert 1 <= rank <= 4, "scatter only supports 1-4 dimensions" + + ir_funcs = { + 1: gen_ir_1d, + 2: gen_ir_2d, + 3: gen_ir_3d, + 4: gen_ir_4d, + } + + def update_func(dst_ptr, dst_index, update): + dst_ptr[dst_index] = update + + out_shape = data.shape + out_buf = tvm.tir.decl_buffer(out_shape, data.dtype, "out_buf") + + cfg.add_flop(1) # A dummy value to satisfy AutoTVM + + out = te.extern( + [out_shape], + [data, indices, updates], + lambda ins, outs: ir_funcs[rank](ins[0], ins[1], ins[2], axis, outs[0], update_func), + dtype=data.dtype, + out_buffers=[out_buf], + name="scatter_gpu", + tag="scatter_gpu", + ) + + return out + + +@autotvm.register_topi_schedule("scatter.cuda") +def schedule_scatter(_, outs): + return schedule_extern(outs) + + +def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, out): """Generate scatter ir for 1d inputs, using a sorting based approach. By sorting indices and comparing neighboring two indices, we can tell which of elements in the indices tensor can scatter its update value into the output. @@ -438,9 +481,6 @@ def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, axis, out, _): updates : tir.Tensor The values to update, sorted by indices. - axis : int - The axis to scatter on. It must be 0 for this function. - out : tir.Tensor The output tensor. @@ -449,7 +489,6 @@ def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, axis, out, _): ret : tir The computational ir. """ - assert axis == 0 n = data.shape[0] ib = tvm.tir.ir_builder.create() @@ -504,7 +543,8 @@ def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, axis, out, _): return ib.get() -def scatter(data, indices, updates, axis=0): +@autotvm.register_topi_compute("scatter_via_sort.cuda") +def scatter_via_sort(cfg, data, indices, updates, axis=0): """Update data at positions defined by indices with values in updates Parameters @@ -528,49 +568,34 @@ def scatter(data, indices, updates, axis=0): """ if axis < 0: axis += len(data.shape) - assert axis >= 0 - assert axis < len(data.shape) + assert axis == 0 and len(data.shape) == 1, "sorting based scatter only supported for 1d input" + assert is_thrust_available(), "Thrust is required for this op" - rank = len(data.shape) - assert 1 <= rank <= 4, "scatter only supports 1-4 dimensions" - - ir_funcs = { - 1: gen_ir_1d, - 2: gen_ir_2d, - 3: gen_ir_3d, - 4: gen_ir_4d, - } - - def update_func(dst_ptr, dst_index, update): - dst_ptr[dst_index] = update + cfg.add_flop(1) # A dummy value to satisfy AutoTVM out_shape = data.shape out_buf = tvm.tir.decl_buffer(out_shape, data.dtype, "out_buf") - in_bufs = [data] - - if rank == 1 and is_thrust_available(): - ir_funcs[1] = gen_scatter_1d_thrust - indices_sorted, updates_sorted = stable_sort_by_key_thrust( - indices, updates, for_scatter=True - ) - in_bufs += [indices_sorted, updates_sorted] - else: - in_bufs += [indices, updates] + indices_sorted, updates_sorted = stable_sort_by_key_thrust(indices, updates, for_scatter=True) out = te.extern( [out_shape], - in_bufs, - lambda ins, outs: ir_funcs[rank](ins[0], ins[1], ins[2], axis, outs[0], update_func), + [data, indices_sorted, updates_sorted], + lambda ins, outs: gen_scatter_1d_thrust(ins[0], ins[1], ins[2], outs[0]), dtype=data.dtype, out_buffers=[out_buf], - name="scatter_gpu", - tag="scatter_gpu", + name="scatter_via_sort_gpu", + tag="scatter_via_sort_gpu", ) return out +@autotvm.register_topi_schedule("scatter_via_sort.cuda") +def schedule_scatter_via_sort(_, outs): + return schedule_extern(outs) + + def gen_scatter_add_1d_atomic(data, indices, updates, axis, out, _): """Generate scatter add ir for 1d inputs, using atomic_add instruction