Skip to content

Commit

Permalink
[TIR] Fix lower_warp_memory when there are >1 warp buffers (#5368)
Browse files Browse the repository at this point in the history
* fix recursion in lower_warp_memory

* post-order mutation
  • Loading branch information
roastduck committed Apr 19, 2020
1 parent 3264895 commit a2d6fe6
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/tir/transforms/lower_warp_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -377,12 +377,13 @@ class WarpMemoryRewriter : private StmtMutator {

private:
Stmt VisitStmt_(const AllocateNode* op) {
auto ret = StmtMutator::VisitStmt_(op);
op = ret.as<AllocateNode>();
if (warp_buffer_.count(op->buffer_var.get())) {
WarpAccessRewriter rewriter(warp_size_, &analyzer_);
return rewriter.Rewrite(op);
} else {
return StmtMutator::VisitStmt_(op);
ret = rewriter.Rewrite(op);
}
return ret;
}

Stmt VisitStmt_(const AttrStmtNode* op) {
Expand Down
49 changes: 49 additions & 0 deletions tests/python/unittest/test_tir_transform_lower_warp_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,56 @@ def check_cuda(dtype):
check_cuda("float32")
check_cuda("float16")

def test_lower_warp_memory_cuda_2_buffers():
def check_cuda(dtype):
if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled..")
return
if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
print("Skip because gpu does not have fp16 support")
return

m = 32
A = te.placeholder((m,), name='A', dtype=dtype)
B = te.placeholder((m,), name='B', dtype=dtype)
C = te.compute((m,), lambda i: A[(i + 1) % m] + B[(i + 1) % m], name='C')

cuda_target = tvm.target.create("cuda")
assert m <= cuda_target.thread_warp_size
with cuda_target:
s = te.create_schedule(C.op)
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")

AA = s.cache_read(A, "warp", [C])
BB = s.cache_read(B, "warp", [C])
xo, xi = s[C].split(C.op.axis[0], nparts=1)
s[C].bind(xi, tx)
s[C].bind(xo, bx)
s[AA].compute_at(s[C], xo)
s[BB].compute_at(s[C], xo)
xo, xi = s[AA].split(s[AA].op.axis[0], nparts=1)
s[AA].bind(xo, bx)
s[AA].bind(xi, tx)
xo, xi = s[BB].split(s[BB].op.axis[0], nparts=1)
s[BB].bind(xo, bx)
s[BB].bind(xi, tx)

ctx = tvm.gpu(0)
func = tvm.build(s, [A, B, C], "cuda")
AB_np = np.array(list(range(m)), dtype=dtype)
C_np = np.array(list(range(1, m)) + [0], dtype=dtype) * 2
A_nd = tvm.nd.array(AB_np, ctx)
B_nd = tvm.nd.array(AB_np, ctx)
C_nd = tvm.nd.array(np.zeros(C_np.shape, dtype=C_np.dtype), ctx)
func(A_nd, B_nd, C_nd)
tvm.testing.assert_allclose(C_nd.asnumpy(), C_np, rtol=1e-3)

check_cuda("float32")
check_cuda("float16")

if __name__ == "__main__":
test_lower_warp_memory_local_scope()
test_lower_warp_memory_cuda_end_to_end()
test_lower_warp_memory_cuda_half_a_warp()
test_lower_warp_memory_cuda_2_buffers()

0 comments on commit a2d6fe6

Please sign in to comment.