Skip to content

Commit

Permalink
[BUGFIX] fix illegal memory access bug in reduce op schedule by const…
Browse files Browse the repository at this point in the history
…riant threadIdx.y

Signed-off-by: ziqiang.pzq <ziqiang.pzq@alibaba-inc.com>
  • Loading branch information
ziqiang.pzq committed Jul 28, 2021
1 parent 3445532 commit a86e87e
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion python/tvm/topi/cuda/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
# pylint: disable=invalid-name,unused-variable,too-many-locals,len-as-condition
"""Schedule for reduce operators"""
from __future__ import absolute_import as _abs
from operator import mul
from functools import reduce
import tvm
from tvm import te
from .. import tag
Expand Down Expand Up @@ -80,13 +82,18 @@ def _schedule_reduce(op, sch, is_idx_reduce=False):
if is_idx_reduce:
sch[temp_idx_input].compute_at(sch[real_output], outer_in)
sch[temp_val_input].compute_at(sch[real_output], outer_in)
sch[real_output].set_store_predicate(
tvm.tir.all(
thread_x.equal(0), block_x * num_thread + thread_y < reduce(mul, real_output.shape)
)
)
else:
if is_idx_reduce:
spatial_axis = sch[real_output].fuse(*(sch[real_output].op.axis))
sch[real_output].bind(spatial_axis, te.thread_axis("blockIdx.x"))
sch[temp_idx_input].compute_at(sch[real_output], spatial_axis)
sch[temp_val_input].compute_at(sch[real_output], spatial_axis)
sch[real_output].set_store_predicate(thread_x.equal(0))
sch[real_output].set_store_predicate(thread_x.equal(0))
return sch


Expand Down

0 comments on commit a86e87e

Please sign in to comment.