Skip to content

[Bug] New TIR syntax printer failed to handle dynamic shape. #9953

@yzh119

Description

@yzh119

The current TIR syntax printer (introduced in #9680 ) fails when there are dynamic shapes in the script:

@T.prim_func
def f(a: T.handle, b: T.handle, c: T.handle):
    N = T.var("int32")
    M = T.var("int32")
    K = T.var("int32")
    A = T.match_buffer(a, (N, K), "float32")
    B = T.match_buffer(b, (K, M), "float32")
    C = T.match_buffer(c, (N, M), "float32")
    for i, j, k in T.grid(N, M, K):
        with T.block("gemm"):
            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
            with T.init():
                C[vi, vj] = 0.
            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

print(f.script())

Expected behavior

The output script should be the same as input.

Actual behavior

The M, N, K are used before declaration.

# from tvm.script import tir as T
@T.prim_func
def func(A: T.Buffer[(N, K), "float32"], B: T.Buffer[(K, M), "float32"], C: T.Buffer[(N, M), "float32"]) -> None:
    K = T.var("int32")
    M = T.var("int32")
    N = T.var("int32")
    # body
    # with T.block("root")
    for i, j, k in T.grid(N, M, K):
        with T.block("gemm"):
            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
            T.reads(C[vi, vj], A[vi, vk], B[vk, vj])
            T.writes(C[vi, vj])
            with T.init():
                C[vi, vj] = T.float32(0)
            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

The same case if I pass tensor shape as parameters:

@T.prim_func
def f(a: T.handle, b: T.handle, c: T.handle, N: T.int32, M: T.int32, K: T.int32):
    A = T.match_buffer(a, (N, K), "float32")
    B = T.match_buffer(b, (K, M), "float32")
    C = T.match_buffer(c, (N, M), "float32")
    for i, j, k in T.grid(N, M, K):
        with T.block("gemm"):
            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
            with T.init():
                C[vi, vj] = 0.
            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions