From 2628c11a60a33f57b166ec17a45cf42913cd3bdd Mon Sep 17 00:00:00 2001 From: Adam Scott Date: Sun, 3 May 2026 02:52:18 -0400 Subject: [PATCH] [BugFix][Relax] Fix scatter_elements and scatter_nd CUDA compilation `topi.scatter_elements` and `topi.scatter_nd` emit bare `T.parallel` loops in their `te.extern` IRBuilder bodies which trips `VerifyMemory` on CUDA targets: RuntimeError: Memory verification failed ... Did you forget to bind? CPU (LLVM) is unaffected. This fix makes the IRBuilder body in both `topi/scatter_elements.py` and `topi/scatter.py` target-aware. When `Target.current()` is a GPU target it emits thread bindings instead of `T.parallel`. Fixes #19451. --- .../transform/legalize_ops/manipulate.py | 12 +- python/tvm/topi/gpu/__init__.py | 2 + python/tvm/topi/gpu/scatter_elements.py | 162 ++++++++++++++++++ python/tvm/topi/gpu/scatter_nd.py | 129 ++++++++++++++ .../test_transform_legalize_ops_manipulate.py | 46 +++++ 5 files changed, 349 insertions(+), 2 deletions(-) create mode 100644 python/tvm/topi/gpu/scatter_elements.py create mode 100644 python/tvm/topi/gpu/scatter_nd.py diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index 2a1d249ef737..fc7ee0d12eb8 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -235,10 +235,16 @@ def _meshgrid(bb: BlockBuilder, call: Call) -> Expr: ) +def _is_gpu_target(): + target = tvm.target.Target.current(allow_none=True) + return target is not None and "gpu" in target.keys + + @register_legalize("relax.scatter_elements") def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr: + te_func = topi.gpu.scatter_elements if _is_gpu_target() else topi.scatter_elements return bb.call_te( - topi.scatter_elements, + te_func, call.args[0], call.args[1], call.args[2], @@ -250,10 +256,12 @@ def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr: @register_legalize("relax.scatter_nd") def _scatter_nd(bb: BlockBuilder, call: Call) -> Expr: # TODO(relax-team): Support native scatter_nd without te extern + base_te = topi.gpu.scatter_nd if _is_gpu_target() else topi.scatter_nd + def scatter_nd(data, indices, updates, reduction): axes = list(range(len(indices.shape))) indices = topi.transpose(indices, axes[-1:] + axes[:-1]) - return topi.scatter_nd(data, indices, updates, reduction) + return base_te(data, indices, updates, reduction) return bb.call_te( scatter_nd, diff --git a/python/tvm/topi/gpu/__init__.py b/python/tvm/topi/gpu/__init__.py index e56a1d712390..69998957f39f 100644 --- a/python/tvm/topi/gpu/__init__.py +++ b/python/tvm/topi/gpu/__init__.py @@ -20,4 +20,6 @@ """GPU specific declaration.""" from .scan import cumsum, cumprod +from .scatter_elements import scatter_elements +from .scatter_nd import scatter_nd from .sort import * diff --git a/python/tvm/topi/gpu/scatter_elements.py b/python/tvm/topi/gpu/scatter_elements.py new file mode 100644 index 000000000000..a7d94218628c --- /dev/null +++ b/python/tvm/topi/gpu/scatter_elements.py @@ -0,0 +1,162 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""scatter_elements related operators""" + +import tvm +from tvm import te, tirx +from tvm.script.ir_builder import IRBuilder +from tvm.script.ir_builder import tirx as T + +from .. import utils +from ..math import cast +from ..utils import ceil_div + + +def scatter_elements(data, indices, updates, axis=0, reduction="update"): + """GPU implementation of scatter_elements with explicit thread bindings""" + if not isinstance(axis, int): + axis = utils.get_const_int(axis) + + # Prepare ranges and strides + shape = data.shape + if axis < 0: + axis = len(shape) + axis + axis_range = cast(shape[axis], indices.dtype) + + full_range = 1 + after_axis_range = 1 + for i, value in enumerate(shape, 0): + full_range *= value + if i > axis: + after_axis_range *= value + before_axis_stride = axis_range * after_axis_range + + ind_shape = indices.shape + ind_axis_range = ind_shape[axis] + + ind_before_axis_range = 1 + ind_after_axis_range = 1 + for i, value in enumerate(ind_shape, 0): + if i < axis: + ind_before_axis_range *= value + elif i > axis: + ind_after_axis_range *= value + ind_before_axis_stride = ind_axis_range * ind_after_axis_range + ind_full_range_excl_axis = ind_before_axis_range * ind_after_axis_range + + def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr, reduce_func): + # pylint: disable=invalid-name + data = T.buffer_proxy(data_ptr) + indices = T.buffer_proxy(indices_ptr) + updates = T.buffer_proxy(updates_ptr) + out = T.buffer_proxy(out_ptr) + + max_threads = int(tvm.target.Target.current(allow_none=False).attrs["max_num_threads"]) + + with IRBuilder() as ib: + with T.seq_scope(): + # Init + nthread_bx_init = cast(ceil_div(full_range, max_threads), "int32") + tx_init = te.thread_axis("threadIdx.x") + bx_init = te.thread_axis("blockIdx.x") + with T.frame_scope( + [ + T.attr(bx_init, "thread_extent", nthread_bx_init), + T.attr(tx_init, "thread_extent", max_threads), + ] + ): + tid = bx_init * max_threads + tx_init + with T.If(tid < full_range): + with T.Then(): + out[tid] = data[tid] + + # Scatter + nthread_bx_scat = cast(ceil_div(ind_full_range_excl_axis, max_threads), "int32") + tx_scat = te.thread_axis("threadIdx.x") + bx_scat = te.thread_axis("blockIdx.x") + with T.frame_scope( + [ + T.attr(bx_scat, "thread_extent", nthread_bx_scat), + T.attr(tx_scat, "thread_extent", max_threads), + ] + ): + fused = bx_scat * max_threads + tx_scat + with T.If(fused < ind_full_range_excl_axis): + with T.Then(): + i = fused // ind_after_axis_range + j = fused % ind_after_axis_range + pre_index1 = i * ind_before_axis_stride + j + pre_index2 = i * before_axis_stride + j + with T.serial(0, ind_axis_range) as k: + # Offset along indices or updates + index1 = pre_index1 + k * ind_after_axis_range + # Get index and shift to positive side if need + k_new = indices[index1] + shifted_index = k_new + (k_new < 0) * axis_range + # Offset along data + index2 = pre_index2 + shifted_index * after_axis_range + reduce_func(out, index2, updates[index1]) + + return ib.get() + + def update_func(dst_ptr, dst_index, update): + dst_ptr[dst_index] = update + + def add_func(dst_ptr, dst_index, update): + dst_ptr[dst_index] += update + + def mul_func(dst_ptr, dst_index, update): + dst_ptr[dst_index] *= update + + def mean_func(dst_ptr, dst_index, update): + dst_ptr[dst_index] = (dst_ptr[dst_index] + update) / 2 + + def min_func(dst_ptr, dst_index, update): + dst_ptr[dst_index] = tirx.min(dst_ptr[dst_index], update) + + def max_func(dst_ptr, dst_index, update): + dst_ptr[dst_index] = tirx.max(dst_ptr[dst_index], update) + + reduce_func = None + if reduction == "update": + reduce_func = update_func + elif reduction == "add": + reduce_func = add_func + elif reduction == "mul": + reduce_func = mul_func + elif reduction == "mean": + reduce_func = mean_func + elif reduction == "min": + reduce_func = min_func + elif reduction == "max": + reduce_func = max_func + else: + raise NotImplementedError( + "scatter_elements reduction not in [update, add, mul, mean, min, max]:", reduction + ) + + out_buf = tirx.decl_buffer(data.shape, data.dtype, "out_buf") + return te.extern( + [data.shape], + [data, indices, updates], + lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0], reduce_func), + dtype=data.dtype, + out_buffers=[out_buf], + name="scatter_elements.gpu", + tag="scatter_elements.gpu", + ) diff --git a/python/tvm/topi/gpu/scatter_nd.py b/python/tvm/topi/gpu/scatter_nd.py new file mode 100644 index 000000000000..a29cd68a8e37 --- /dev/null +++ b/python/tvm/topi/gpu/scatter_nd.py @@ -0,0 +1,129 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +# ruff: noqa: E741 +"""scatter_nd related operators""" + +import tvm +from tvm import te, tirx # hide redefinition of min and max +from tvm.script.ir_builder import IRBuilder +from tvm.script.ir_builder import tirx as T + +from ..math import cast +from ..scatter import _verify_scatter_nd_inputs +from ..utils import ceil_div + + +def scatter_nd(data, indices, updates, mode): + """GPU implementation of scatter_nd with explicit thread bindings.""" + _verify_scatter_nd_inputs(data, indices, updates) + + def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): + # pylint: disable=invalid-name + data = T.buffer_proxy(data_ptr) + indices = T.buffer_proxy(indices_ptr) + updates = T.buffer_proxy(updates_ptr) + out = T.buffer_proxy(out_ptr) + + # We combine all the indices dimensions but the first one into a single + # dimension so we can iterate it in single loop instead of an arbitrary + # number of loops. We do the same thing for all the update dimensions. + fused_indices_dimension = 1 + for i in indices_ptr.shape[1:]: + fused_indices_dimension *= i + + fused_updates_dimension = 1 + for i in updates_ptr.shape[len(indices_ptr.shape) - 1 :]: + fused_updates_dimension *= i + + fused_shape = 1 + for i in data_ptr.shape: + fused_shape *= i + + max_threads = int(tvm.target.Target.current(allow_none=False).attrs["max_num_threads"]) + + with IRBuilder() as ib: + with T.seq_scope(): + # Init + nthread_bx_init = cast(ceil_div(fused_shape, max_threads), "int32") + tx_init = te.thread_axis("threadIdx.x") + bx_init = te.thread_axis("blockIdx.x") + with T.frame_scope( + [ + T.attr(bx_init, "thread_extent", nthread_bx_init), + T.attr(tx_init, "thread_extent", max_threads), + ] + ): + tid = bx_init * max_threads + tx_init + with T.If(tid < fused_shape): + with T.Then(): + out[tid] = data[tid] + + # Scatter + nthread_bx_scat = cast(ceil_div(fused_updates_dimension, max_threads), "int32") + tx_scat = te.thread_axis("threadIdx.x") + bx_scat = te.thread_axis("blockIdx.x") + with T.frame_scope( + [ + T.attr(bx_scat, "thread_extent", nthread_bx_scat), + T.attr(tx_scat, "thread_extent", max_threads), + ] + ): + j = bx_scat * max_threads + tx_scat + with T.If(j < fused_updates_dimension): + with T.Then(): + with T.serial(0, fused_indices_dimension) as i: + offset = fused_updates_dimension + index = j # x_M, .. x_{N-1} part of the index into out. + # Build up the indices[0, y_0, ..], .., + # indices[M-1, y_0, ..] part of the index into out. + for l in reversed(range(indices_ptr.shape[0].value)): + # indices[l, y_0, ... y_{k-1}] + index += offset * indices[i + l * fused_indices_dimension] + offset *= data_ptr.shape[l] + if mode == "update": + out[index] = updates[i * fused_updates_dimension + j] + elif mode == "add": + out[index] += updates[i * fused_updates_dimension + j] + elif mode == "mul": + out[index] *= updates[i * fused_updates_dimension + j] + elif mode == "min": + out[index] = tirx.min( + out[index], updates[i * fused_updates_dimension + j] + ) + elif mode == "max": + out[index] = tirx.max( + out[index], updates[i * fused_updates_dimension + j] + ) + else: + raise NotImplementedError( + "scatter_nd mode not in [update, add, mul, min, max]:", + mode, + ) + + return ib.get() + + out_buf = tirx.decl_buffer(data.shape, data.dtype, "out_buf") + return te.extern( + [data.shape], + [data, indices, updates], + lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]), + dtype=data.dtype, + out_buffers=[out_buf], + name="scatter_nd.gpu", + tag="scatter_nd.gpu", + ) diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py b/tests/python/relax/test_transform_legalize_ops_manipulate.py index 05b6c50c923b..a8f1e906f50b 100644 --- a/tests/python/relax/test_transform_legalize_ops_manipulate.py +++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py @@ -1551,6 +1551,29 @@ def main( tvm.ir.assert_structural_equal(mod, Expected) +@tvm.testing.parametrize_targets("cuda") +def test_scatter_elements_gpu(target, dev): + """scatter_elements lowered for GPU must build""" + + @I.ir_module + class Mod: + @R.function + def main( + x: R.Tensor((4, 8), "float32"), + indices: R.Tensor((2, 8), "int64"), + updates: R.Tensor((2, 8), "float32"), + ): + with R.dataflow(): + lv = R.scatter_elements(x, indices, updates, axis=0) + gv = lv + R.output(gv) + return gv + + with tvm.target.Target(target): + mod = LegalizeOps()(Mod) + relax.build(mod, target=target) + + def test_layout_transform(): transformation = lambda a, b, c: (a, c, b // 3, b % 3) pad_value = 2 @@ -1838,5 +1861,28 @@ def scatter_nd(var_data: T.handle, var_indices: T.handle, var_updates: T.handle, tvm.ir.assert_structural_equal(After, Expected) +@tvm.testing.parametrize_targets("cuda") +def test_scatter_nd_gpu(target, dev): + """scatter_nd lowered for GPU must build""" + + @I.ir_module + class Mod: + @R.function + def main( + data: R.Tensor((4, 8), "float32"), + indices: R.Tensor((3, 2), "int64"), + updates: R.Tensor((3,), "float32"), + ): + with R.dataflow(): + lv = R.scatter_nd(data, indices, updates) + gv = lv + R.output(gv) + return gv + + with tvm.target.Target(target): + mod = LegalizeOps()(Mod) + relax.build(mod, target=target) + + if __name__ == "__main__": tvm.testing.main()