Skip to content

Commit

Permalink
cleanup memcpy ir
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Jan 18, 2021
1 parent e206c01 commit ad00c94
Showing 1 changed file with 20 additions and 41 deletions.
61 changes: 20 additions & 41 deletions python/tvm/topi/cuda/scatter.py
Expand Up @@ -22,12 +22,28 @@
from ..generic import schedule_extern
from .nms import atomic_add
from .sort import stable_sort_by_key_thrust, is_thrust_available
from ..utils import prod


def ceil_div(a, b):
return (a + b - 1) // b


def _memcpy_ir(ib, out_ptr, data_ptr, shape):
fused = prod(shape)
with ib.new_scope():
num_thread = int(tvm.target.Target.current(allow_none=False).max_num_threads)
num_blocks = ceil_div(fused, num_thread)
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(bx, "thread_extent", num_blocks)
tx = te.thread_axis("threadIdx.x")
ib.scope_attr(tx, "thread_extent", num_thread)
tid = bx * num_thread + tx

with ib.if_scope(tid < fused):
out_ptr[tid] = data_ptr[tid]


def gen_ir_1d(data, indices, updates, axis, out, update_func):
"""Generate scatter ir for 1d inputs
Expand Down Expand Up @@ -64,10 +80,7 @@ def gen_ir_1d(data, indices, updates, axis, out, update_func):
out_ptr = ib.buffer_ptr(out)
data_ptr = ib.buffer_ptr(data)

with ib.new_scope():
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(bx, "thread_extent", n)
out_ptr[bx] = data_ptr[bx]
_memcpy_ir(ib, out_ptr, data_ptr, data.shape)

indices_ptr = ib.buffer_ptr(indices)
updates_ptr = ib.buffer_ptr(updates)
Expand Down Expand Up @@ -115,8 +128,6 @@ def gen_ir_2d(data, indices, updates, axis, out, update_func):
ret : tir
The computational ir.
"""
warp_size = tvm.target.Target.current(False).thread_warp_size

n = data.shape[0]
c = data.shape[1]

Expand All @@ -125,16 +136,7 @@ def gen_ir_2d(data, indices, updates, axis, out, update_func):
out_ptr = ib.buffer_ptr(out)
data_ptr = ib.buffer_ptr(data)

with ib.new_scope():
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(bx, "thread_extent", n)
tx = te.thread_axis("threadIdx.x")
ib.scope_attr(tx, "thread_extent", warp_size)
with ib.for_range(0, ceil_div(c, warp_size), name="j") as j_:
j = j_ * warp_size + tx
with ib.if_scope(j < c):
idx = bx * c + j
out_ptr[idx] = data_ptr[idx]
_memcpy_ir(ib, out_ptr, data_ptr, data.shape)

indices_ptr = ib.buffer_ptr(indices)
updates_ptr = ib.buffer_ptr(updates)
Expand Down Expand Up @@ -206,18 +208,7 @@ def gen_ir_3d(data, indices, updates, axis, out, update_func):
out_ptr = ib.buffer_ptr(out)
data_ptr = ib.buffer_ptr(data)

with ib.new_scope():
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(bx, "thread_extent", n)
by = te.thread_axis("blockIdx.y")
ib.scope_attr(by, "thread_extent", c)
tx = te.thread_axis("threadIdx.x")
ib.scope_attr(tx, "thread_extent", warp_size)
with ib.for_range(0, ceil_div(h, warp_size), name="k") as k_:
k = k_ * warp_size + tx
with ib.if_scope(k < h):
idx = (bx * c + by) * h + k
out_ptr[idx] = data_ptr[idx]
_memcpy_ir(ib, out_ptr, data_ptr, data.shape)

indices_ptr = ib.buffer_ptr(indices)
updates_ptr = ib.buffer_ptr(updates)
Expand Down Expand Up @@ -312,19 +303,7 @@ def gen_ir_4d(data, indices, updates, axis, out, update_func):

out_ptr = ib.buffer_ptr(out)
data_ptr = ib.buffer_ptr(data)
with ib.new_scope():
fused = n * c * h * w
num_thread = int(tvm.target.Target.current(allow_none=False).max_num_threads)
num_blocks = ceil_div(fused, num_thread)

bx = te.thread_axis("blockIdx.x")
ib.scope_attr(bx, "thread_extent", num_blocks)
tx = te.thread_axis("threadIdx.x")
ib.scope_attr(tx, "thread_extent", num_thread)
tid = bx * num_thread + tx

with ib.if_scope(tid < fused):
out_ptr[tid] = data_ptr[tid]
_memcpy_ir(ib, out_ptr, data_ptr, data.shape)

indices_ptr = ib.buffer_ptr(indices)
updates_ptr = ib.buffer_ptr(updates)
Expand Down

0 comments on commit ad00c94

Please sign in to comment.