Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] lower_warp_memory cannot handle >1 warp buffers #5366

Closed
roastduck opened this issue Apr 18, 2020 · 0 comments · Fixed by #5368
Closed

[TIR] lower_warp_memory cannot handle >1 warp buffers #5366

roastduck opened this issue Apr 18, 2020 · 0 comments · Fixed by #5368

Comments

@roastduck
Copy link
Contributor

Pass lower_warp_memory cannot handle more than one warp buffers. Buffers except the first one cannot be correctly transformed to warp shuffles.

To reproduce:

import tvm                                                                                                                                  [8/1976]
import topi
import numpy as np

from tvm import te

dtype = "float32"
target = "cuda"
m = 32
A = te.placeholder((m,), name='A', dtype=dtype)
B = te.placeholder((m,), name='B', dtype=dtype)
C = te.compute((m,), lambda i: A[(i + 1) % m] + B[(i + 1) % m], name='C')

cuda_target = tvm.target.create("cuda")
assert m <= cuda_target.thread_warp_size
with cuda_target:
    s = te.create_schedule(C.op)
    tx = te.thread_axis("threadIdx.x")
    bx = te.thread_axis("blockIdx.x")

    AA = s.cache_read(A, "warp", [C])
    BB = s.cache_read(B, "warp", [C])
    xo, xi = s[C].split(C.op.axis[0], nparts=1)
    s[C].bind(xi, tx)
    s[C].bind(xo, bx)
    s[AA].compute_at(s[C], xo)
    s[BB].compute_at(s[C], xo)
    xo, xi = s[AA].split(s[AA].op.axis[0], nparts=1)
    s[AA].bind(xo, bx)
    s[AA].bind(xi, tx)
    xo, xi = s[BB].split(s[BB].op.axis[0], nparts=1)
    s[BB].bind(xo, bx)
    s[BB].bind(xi, tx)

print(tvm.lower(s, [A, B, C], target, simple_mode=True))
compute = tvm.build(s, [A, B, C], target, name="run")
print(compute.imported_modules[0].get_source())

I think the problem is WarpMemoryRewriter::VisitStmt_(const AllocateNode*) in lower_warp_memory.cc doesn't continue the recursion after rewriting the first buffer.

I will fix it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant