The current TIR syntax printer (introduced in #9680 ) fails when there are dynamic shapes in the script:
@T.prim_func
def f(a: T.handle, b: T.handle, c: T.handle):
N = T.var("int32")
M = T.var("int32")
K = T.var("int32")
A = T.match_buffer(a, (N, K), "float32")
B = T.match_buffer(b, (K, M), "float32")
C = T.match_buffer(c, (N, M), "float32")
for i, j, k in T.grid(N, M, K):
with T.block("gemm"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
print(f.script())
Expected behavior
The output script should be the same as input.
Actual behavior
The M, N, K are used before declaration.
# from tvm.script import tir as T
@T.prim_func
def func(A: T.Buffer[(N, K), "float32"], B: T.Buffer[(K, M), "float32"], C: T.Buffer[(N, M), "float32"]) -> None:
K = T.var("int32")
M = T.var("int32")
N = T.var("int32")
# body
# with T.block("root")
for i, j, k in T.grid(N, M, K):
with T.block("gemm"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
T.reads(C[vi, vj], A[vi, vk], B[vk, vj])
T.writes(C[vi, vj])
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
The same case if I pass tensor shape as parameters:
@T.prim_func
def f(a: T.handle, b: T.handle, c: T.handle, N: T.int32, M: T.int32, K: T.int32):
A = T.match_buffer(a, (N, K), "float32")
B = T.match_buffer(b, (K, M), "float32")
C = T.match_buffer(c, (N, M), "float32")
for i, j, k in T.grid(N, M, K):
with T.block("gemm"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
The current TIR syntax printer (introduced in #9680 ) fails when there are dynamic shapes in the script:
Expected behavior
The output script should be the same as input.
Actual behavior
The
M, N, Kare used before declaration.The same case if I pass tensor shape as parameters: