[BugFix][TIR] Affine-binding check should not simplify trivial iterators#13203
Conversation
|
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
|
CC @zxybazh would you like to validate the PR using your scripts before merging it in? |
|
Sure, I can run a tuning locally for ResNet50 and validate the database before merging it. |
|
My validation script is 75% done with the database scanning, it located an incorrect output issue as follows, I'll dig in a bit more to confirm if it's related Validation failed!
Original IRModule:
------------------------------
# from tvm.script import tir as T
@tvm.script.ir_module
class Module:
@T.prim_func
def main(p0: T.Buffer[(1, 7, 7, 2048), "float32"], tensor: T.Buffer[(1, 1, 1, 2048), "float32"]) -> None:
# function attr dict
T.func_attr({"tir.noalias": True, "global_symbol": "main"})
# body
# with T.block("root")
tensor_1 = T.alloc_buffer([1, 1, 1, 2048], dtype="float32")
for i0, i1, i2, i3, i4, i5 in T.grid(1, 1, 1, 2048, 7, 7):
with T.block("tensor"):
ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5])
T.reads(p0[ax0, ax1 * 7 + rv0, ax2 * 7 + rv1, ax3])
T.writes(tensor_1[ax0, ax1, ax2, ax3])
with T.init():
tensor_1[ax0, ax1, ax2, ax3] = T.float32(0)
tensor_1[ax0, ax1, ax2, ax3] = tensor_1[ax0, ax1, ax2, ax3] + p0[ax0, ax1 * 7 + rv0, ax2 * 7 + rv1, ax3]
for i0, i1, i2, i3 in T.grid(1, 1, 1, 2048):
with T.block("tensor_1"):
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(tensor_1[ax0, ax1, ax2, ax3])
T.writes(tensor[ax0, ax1, ax2, ax3])
tensor[ax0, ax1, ax2, ax3] = tensor_1[ax0, ax1, ax2, ax3] * T.float32(0.020408163265306121)
Scheduled IRModule:
------------------------------
# from tvm.script import tir as T
@tvm.script.ir_module
class Module:
@T.prim_func
def main(p0: T.Buffer[(1, 7, 7, 2048), "float32"], tensor: T.Buffer[(1, 1, 1, 2048), "float32"]) -> None:
# function attr dict
T.func_attr({"tir.noalias": True, "global_symbol": "main"})
# body
# with T.block("root")
tensor_shared = T.alloc_buffer([1, 1, 1, 2048], dtype="float32", scope="shared")
for i0_i1_i2_i3_0_fused in T.thread_binding(32, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step":T.int64(16), "pragma_unroll_explicit":T.int64(1)}):
for ax0, ax1, ax2, ax3, ax4_ax5_fused_0 in T.grid(1, 1, 1, 64, 1):
for ax4_ax5_fused_1 in T.thread_binding(64, thread="threadIdx.x"):
with T.block("tensor"):
T.where(ax4_ax5_fused_0 * 64 + ax4_ax5_fused_1 < 49)
ax0_1, ax1_1, ax2_1 = T.axis.remap("SSS", [ax0, ax1, ax2])
ax3_1 = T.axis.spatial(2048, i0_i1_i2_i3_0_fused * 64 + ax3)
rv0 = T.axis.reduce(7, (ax4_ax5_fused_0 * 64 + ax4_ax5_fused_1) // 7)
rv1 = T.axis.reduce(7, (ax4_ax5_fused_0 * 64 + ax4_ax5_fused_1) % 7)
T.reads(p0[ax0_1, ax1_1 * 7 + rv0, ax2_1 * 7 + rv1, ax3_1])
T.writes(tensor_shared[ax0_1, ax1_1, ax2_1, ax3_1])
with T.init():
tensor_shared[ax0_1, ax1_1, ax2_1, ax3_1] = T.float32(0)
tensor_shared[ax0_1, ax1_1, ax2_1, ax3_1] = tensor_shared[ax0_1, ax1_1, ax2_1, ax3_1] + p0[ax0_1, ax1_1 * 7 + rv0, ax2_1 * 7 + rv1, ax3_1]
for i3_1 in T.thread_binding(64, thread="threadIdx.x"):
with T.block("tensor_1"):
ax0 = T.axis.spatial(1, 0)
ax1 = T.axis.spatial(1, 0)
ax2 = T.axis.spatial(1, 0)
ax3 = T.axis.spatial(2048, i0_i1_i2_i3_0_fused * 64 + i3_1)
T.reads(tensor_shared[ax0, ax1, ax2, ax3])
T.writes(tensor[ax0, ax1, ax2, ax3])
tensor[ax0, ax1, ax2, ax3] = tensor_shared[ax0, ax1, ax2, ax3] * T.float32(0.020408163265306121) |
|
I just tried and see no error. Let me know if there’s anything wrong. |
|
Hi @MasterJH5574, I just double checked the error cases, it turns out it's not working when I cherry-picked your PR but fixed after I rebase over your PR. Will continue the verification. Current progress 15,264/20,039, should be good to merge after the rest records are verified. |
|
The verification on 20,039 tuning records for ResNet50 end to end tuning was finished, the previously mentioned issue with cross thread reduction is fixed. I found some other incorrect output with a specific conv2d workload, but that's out of the scope and I'll follow up in a separate issue. Thanks @MasterJH5574 for the fix and @wrongtest-intellif for review! |
…ors (apache#13203) * Fix affine bindings * Regression test and test update
…ors (apache#13203) * Fix affine bindings * Regression test and test update
This PR fixes a bug of affine-binding check in TIR. Previously, the affine-binding check will call
DetectIterMapwith trivial iterators simplified. However, this will cause some blocks with affine bindings not recognized as having affine bindings (see the regression test for a concrete example). This PR specifies that we do not simplify the trivial iterators when checking affine bindings.cc @spectrometerHBH @vinx13 @junrushao @wrongtest-intellif