Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions python/tvm/relax/transform/legalize_ops/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,16 @@ def _meshgrid(bb: BlockBuilder, call: Call) -> Expr:
)


def _is_gpu_target():
target = tvm.target.Target.current(allow_none=True)
return target is not None and "gpu" in target.keys


@register_legalize("relax.scatter_elements")
def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr:
te_func = topi.gpu.scatter_elements if _is_gpu_target() else topi.scatter_elements
return bb.call_te(
topi.scatter_elements,
te_func,
call.args[0],
call.args[1],
call.args[2],
Expand All @@ -250,10 +256,12 @@ def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr:
@register_legalize("relax.scatter_nd")
def _scatter_nd(bb: BlockBuilder, call: Call) -> Expr:
# TODO(relax-team): Support native scatter_nd without te extern
base_te = topi.gpu.scatter_nd if _is_gpu_target() else topi.scatter_nd

def scatter_nd(data, indices, updates, reduction):
axes = list(range(len(indices.shape)))
indices = topi.transpose(indices, axes[-1:] + axes[:-1])
return topi.scatter_nd(data, indices, updates, reduction)
return base_te(data, indices, updates, reduction)

return bb.call_te(
scatter_nd,
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/topi/gpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@
"""GPU specific declaration."""

from .scan import cumsum, cumprod
from .scatter_elements import scatter_elements
from .scatter_nd import scatter_nd
from .sort import *
162 changes: 162 additions & 0 deletions python/tvm/topi/gpu/scatter_elements.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""scatter_elements related operators"""

import tvm
from tvm import te, tirx
from tvm.script.ir_builder import IRBuilder
from tvm.script.ir_builder import tirx as T

from .. import utils
from ..math import cast
from ..utils import ceil_div


def scatter_elements(data, indices, updates, axis=0, reduction="update"):
"""GPU implementation of scatter_elements with explicit thread bindings"""
if not isinstance(axis, int):
axis = utils.get_const_int(axis)

# Prepare ranges and strides
shape = data.shape
if axis < 0:
axis = len(shape) + axis
axis_range = cast(shape[axis], indices.dtype)

full_range = 1
after_axis_range = 1
for i, value in enumerate(shape, 0):
full_range *= value
if i > axis:
after_axis_range *= value
before_axis_stride = axis_range * after_axis_range

ind_shape = indices.shape
ind_axis_range = ind_shape[axis]

ind_before_axis_range = 1
ind_after_axis_range = 1
for i, value in enumerate(ind_shape, 0):
if i < axis:
ind_before_axis_range *= value
elif i > axis:
ind_after_axis_range *= value
ind_before_axis_stride = ind_axis_range * ind_after_axis_range
ind_full_range_excl_axis = ind_before_axis_range * ind_after_axis_range

def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr, reduce_func):
# pylint: disable=invalid-name
data = T.buffer_proxy(data_ptr)
indices = T.buffer_proxy(indices_ptr)
updates = T.buffer_proxy(updates_ptr)
out = T.buffer_proxy(out_ptr)

max_threads = int(tvm.target.Target.current(allow_none=False).attrs["max_num_threads"])

with IRBuilder() as ib:
with T.seq_scope():
# Init
nthread_bx_init = cast(ceil_div(full_range, max_threads), "int32")
tx_init = te.thread_axis("threadIdx.x")
bx_init = te.thread_axis("blockIdx.x")
with T.frame_scope(
[
T.attr(bx_init, "thread_extent", nthread_bx_init),
T.attr(tx_init, "thread_extent", max_threads),
]
):
tid = bx_init * max_threads + tx_init
with T.If(tid < full_range):
with T.Then():
out[tid] = data[tid]

# Scatter
nthread_bx_scat = cast(ceil_div(ind_full_range_excl_axis, max_threads), "int32")
tx_scat = te.thread_axis("threadIdx.x")
bx_scat = te.thread_axis("blockIdx.x")
with T.frame_scope(
[
T.attr(bx_scat, "thread_extent", nthread_bx_scat),
T.attr(tx_scat, "thread_extent", max_threads),
]
):
fused = bx_scat * max_threads + tx_scat
with T.If(fused < ind_full_range_excl_axis):
with T.Then():
i = fused // ind_after_axis_range
j = fused % ind_after_axis_range
pre_index1 = i * ind_before_axis_stride + j
pre_index2 = i * before_axis_stride + j
with T.serial(0, ind_axis_range) as k:
# Offset along indices or updates
index1 = pre_index1 + k * ind_after_axis_range
# Get index and shift to positive side if need
k_new = indices[index1]
shifted_index = k_new + (k_new < 0) * axis_range
# Offset along data
index2 = pre_index2 + shifted_index * after_axis_range
reduce_func(out, index2, updates[index1])

return ib.get()

def update_func(dst_ptr, dst_index, update):
dst_ptr[dst_index] = update

def add_func(dst_ptr, dst_index, update):
dst_ptr[dst_index] += update

def mul_func(dst_ptr, dst_index, update):
dst_ptr[dst_index] *= update

def mean_func(dst_ptr, dst_index, update):
dst_ptr[dst_index] = (dst_ptr[dst_index] + update) / 2

def min_func(dst_ptr, dst_index, update):
dst_ptr[dst_index] = tirx.min(dst_ptr[dst_index], update)

def max_func(dst_ptr, dst_index, update):
dst_ptr[dst_index] = tirx.max(dst_ptr[dst_index], update)

reduce_func = None
if reduction == "update":
reduce_func = update_func
elif reduction == "add":
reduce_func = add_func
elif reduction == "mul":
reduce_func = mul_func
elif reduction == "mean":
reduce_func = mean_func
elif reduction == "min":
reduce_func = min_func
elif reduction == "max":
reduce_func = max_func
else:
raise NotImplementedError(
"scatter_elements reduction not in [update, add, mul, mean, min, max]:", reduction
)

out_buf = tirx.decl_buffer(data.shape, data.dtype, "out_buf")
return te.extern(
[data.shape],
[data, indices, updates],
lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0], reduce_func),
dtype=data.dtype,
out_buffers=[out_buf],
name="scatter_elements.gpu",
tag="scatter_elements.gpu",
)
129 changes: 129 additions & 0 deletions python/tvm/topi/gpu/scatter_nd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
# ruff: noqa: E741
"""scatter_nd related operators"""

import tvm
from tvm import te, tirx # hide redefinition of min and max
from tvm.script.ir_builder import IRBuilder
from tvm.script.ir_builder import tirx as T

from ..math import cast
from ..scatter import _verify_scatter_nd_inputs
from ..utils import ceil_div


def scatter_nd(data, indices, updates, mode):
"""GPU implementation of scatter_nd with explicit thread bindings."""
_verify_scatter_nd_inputs(data, indices, updates)

def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
# pylint: disable=invalid-name
data = T.buffer_proxy(data_ptr)
indices = T.buffer_proxy(indices_ptr)
updates = T.buffer_proxy(updates_ptr)
out = T.buffer_proxy(out_ptr)

# We combine all the indices dimensions but the first one into a single
# dimension so we can iterate it in single loop instead of an arbitrary
# number of loops. We do the same thing for all the update dimensions.
fused_indices_dimension = 1
for i in indices_ptr.shape[1:]:
fused_indices_dimension *= i

fused_updates_dimension = 1
for i in updates_ptr.shape[len(indices_ptr.shape) - 1 :]:
fused_updates_dimension *= i

fused_shape = 1
for i in data_ptr.shape:
fused_shape *= i

max_threads = int(tvm.target.Target.current(allow_none=False).attrs["max_num_threads"])

with IRBuilder() as ib:
with T.seq_scope():
# Init
nthread_bx_init = cast(ceil_div(fused_shape, max_threads), "int32")
tx_init = te.thread_axis("threadIdx.x")
bx_init = te.thread_axis("blockIdx.x")
with T.frame_scope(
[
T.attr(bx_init, "thread_extent", nthread_bx_init),
T.attr(tx_init, "thread_extent", max_threads),
]
):
tid = bx_init * max_threads + tx_init
with T.If(tid < fused_shape):
with T.Then():
out[tid] = data[tid]

# Scatter
nthread_bx_scat = cast(ceil_div(fused_updates_dimension, max_threads), "int32")
tx_scat = te.thread_axis("threadIdx.x")
bx_scat = te.thread_axis("blockIdx.x")
with T.frame_scope(
[
T.attr(bx_scat, "thread_extent", nthread_bx_scat),
T.attr(tx_scat, "thread_extent", max_threads),
]
):
j = bx_scat * max_threads + tx_scat
with T.If(j < fused_updates_dimension):
with T.Then():
with T.serial(0, fused_indices_dimension) as i:
offset = fused_updates_dimension
index = j # x_M, .. x_{N-1} part of the index into out.
# Build up the indices[0, y_0, ..], ..,
# indices[M-1, y_0, ..] part of the index into out.
for l in reversed(range(indices_ptr.shape[0].value)):
# indices[l, y_0, ... y_{k-1}]
index += offset * indices[i + l * fused_indices_dimension]
offset *= data_ptr.shape[l]
if mode == "update":
out[index] = updates[i * fused_updates_dimension + j]
elif mode == "add":
out[index] += updates[i * fused_updates_dimension + j]
elif mode == "mul":
out[index] *= updates[i * fused_updates_dimension + j]
elif mode == "min":
out[index] = tirx.min(
out[index], updates[i * fused_updates_dimension + j]
)
elif mode == "max":
out[index] = tirx.max(
out[index], updates[i * fused_updates_dimension + j]
)
else:
raise NotImplementedError(
"scatter_nd mode not in [update, add, mul, min, max]:",
mode,
)

return ib.get()

out_buf = tirx.decl_buffer(data.shape, data.dtype, "out_buf")
return te.extern(
[data.shape],
[data, indices, updates],
lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]),
dtype=data.dtype,
out_buffers=[out_buf],
name="scatter_nd.gpu",
tag="scatter_nd.gpu",
)
46 changes: 46 additions & 0 deletions tests/python/relax/test_transform_legalize_ops_manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1551,6 +1551,29 @@ def main(
tvm.ir.assert_structural_equal(mod, Expected)


@tvm.testing.parametrize_targets("cuda")
def test_scatter_elements_gpu(target, dev):
"""scatter_elements lowered for GPU must build"""

@I.ir_module
class Mod:
@R.function
def main(
x: R.Tensor((4, 8), "float32"),
indices: R.Tensor((2, 8), "int64"),
updates: R.Tensor((2, 8), "float32"),
):
with R.dataflow():
lv = R.scatter_elements(x, indices, updates, axis=0)
gv = lv
R.output(gv)
return gv

with tvm.target.Target(target):
mod = LegalizeOps()(Mod)
relax.build(mod, target=target)


def test_layout_transform():
transformation = lambda a, b, c: (a, c, b // 3, b % 3)
pad_value = 2
Expand Down Expand Up @@ -1838,5 +1861,28 @@ def scatter_nd(var_data: T.handle, var_indices: T.handle, var_updates: T.handle,
tvm.ir.assert_structural_equal(After, Expected)


@tvm.testing.parametrize_targets("cuda")
def test_scatter_nd_gpu(target, dev):
"""scatter_nd lowered for GPU must build"""

@I.ir_module
class Mod:
@R.function
def main(
data: R.Tensor((4, 8), "float32"),
indices: R.Tensor((3, 2), "int64"),
updates: R.Tensor((3,), "float32"),
):
with R.dataflow():
lv = R.scatter_nd(data, indices, updates)
gv = lv
R.output(gv)
return gv

with tvm.target.Target(target):
mod = LegalizeOps()(Mod)
relax.build(mod, target=target)


if __name__ == "__main__":
tvm.testing.main()
Loading