Skip to content

[BugFix][TIR] Affine-binding check should not simplify trivial iterators#13203

Merged
zxybazh merged 2 commits intoapache:mainfrom
MasterJH5574:bugfix/2022-10-26-affine-bindings
Oct 28, 2022
Merged

[BugFix][TIR] Affine-binding check should not simplify trivial iterators#13203
zxybazh merged 2 commits intoapache:mainfrom
MasterJH5574:bugfix/2022-10-26-affine-bindings

Conversation

@MasterJH5574
Copy link
Contributor

This PR fixes a bug of affine-binding check in TIR. Previously, the affine-binding check will call DetectIterMap with trivial iterators simplified. However, this will cause some blocks with affine bindings not recognized as having affine bindings (see the regression test for a concrete example). This PR specifies that we do not simplify the trivial iterators when checking affine bindings.

cc @spectrometerHBH @vinx13 @junrushao @wrongtest-intellif

@tvm-bot
Copy link
Collaborator

tvm-bot commented Oct 26, 2022

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@junrushao
Copy link
Member

CC @zxybazh would you like to validate the PR using your scripts before merging it in?

@zxybazh
Copy link
Member

zxybazh commented Oct 26, 2022

Sure, I can run a tuning locally for ResNet50 and validate the database before merging it.

@zxybazh
Copy link
Member

zxybazh commented Oct 27, 2022

My validation script is 75% done with the database scanning, it located an incorrect output issue as follows, I'll dig in a bit more to confirm if it's related

Validation failed!

Original IRModule:
------------------------------
# from tvm.script import tir as T
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(p0: T.Buffer[(1, 7, 7, 2048), "float32"], tensor: T.Buffer[(1, 1, 1, 2048), "float32"]) -> None:
        # function attr dict
        T.func_attr({"tir.noalias": True, "global_symbol": "main"})
        # body
        # with T.block("root")
        tensor_1 = T.alloc_buffer([1, 1, 1, 2048], dtype="float32")
        for i0, i1, i2, i3, i4, i5 in T.grid(1, 1, 1, 2048, 7, 7):
            with T.block("tensor"):
                ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5])
                T.reads(p0[ax0, ax1 * 7 + rv0, ax2 * 7 + rv1, ax3])
                T.writes(tensor_1[ax0, ax1, ax2, ax3])
                with T.init():
                    tensor_1[ax0, ax1, ax2, ax3] = T.float32(0)
                tensor_1[ax0, ax1, ax2, ax3] = tensor_1[ax0, ax1, ax2, ax3] + p0[ax0, ax1 * 7 + rv0, ax2 * 7 + rv1, ax3]
        for i0, i1, i2, i3 in T.grid(1, 1, 1, 2048):
            with T.block("tensor_1"):
                ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
                T.reads(tensor_1[ax0, ax1, ax2, ax3])
                T.writes(tensor[ax0, ax1, ax2, ax3])
                tensor[ax0, ax1, ax2, ax3] = tensor_1[ax0, ax1, ax2, ax3] * T.float32(0.020408163265306121)
    


Scheduled IRModule:
------------------------------
# from tvm.script import tir as T
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(p0: T.Buffer[(1, 7, 7, 2048), "float32"], tensor: T.Buffer[(1, 1, 1, 2048), "float32"]) -> None:
        # function attr dict
        T.func_attr({"tir.noalias": True, "global_symbol": "main"})
        # body
        # with T.block("root")
        tensor_shared = T.alloc_buffer([1, 1, 1, 2048], dtype="float32", scope="shared")
        for i0_i1_i2_i3_0_fused in T.thread_binding(32, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step":T.int64(16), "pragma_unroll_explicit":T.int64(1)}):
            for ax0, ax1, ax2, ax3, ax4_ax5_fused_0 in T.grid(1, 1, 1, 64, 1):
                for ax4_ax5_fused_1 in T.thread_binding(64, thread="threadIdx.x"):
                    with T.block("tensor"):
                        T.where(ax4_ax5_fused_0 * 64 + ax4_ax5_fused_1 < 49)
                        ax0_1, ax1_1, ax2_1 = T.axis.remap("SSS", [ax0, ax1, ax2])
                        ax3_1 = T.axis.spatial(2048, i0_i1_i2_i3_0_fused * 64 + ax3)
                        rv0 = T.axis.reduce(7, (ax4_ax5_fused_0 * 64 + ax4_ax5_fused_1) // 7)
                        rv1 = T.axis.reduce(7, (ax4_ax5_fused_0 * 64 + ax4_ax5_fused_1) % 7)
                        T.reads(p0[ax0_1, ax1_1 * 7 + rv0, ax2_1 * 7 + rv1, ax3_1])
                        T.writes(tensor_shared[ax0_1, ax1_1, ax2_1, ax3_1])
                        with T.init():
                            tensor_shared[ax0_1, ax1_1, ax2_1, ax3_1] = T.float32(0)
                        tensor_shared[ax0_1, ax1_1, ax2_1, ax3_1] = tensor_shared[ax0_1, ax1_1, ax2_1, ax3_1] + p0[ax0_1, ax1_1 * 7 + rv0, ax2_1 * 7 + rv1, ax3_1]
            for i3_1 in T.thread_binding(64, thread="threadIdx.x"):
                with T.block("tensor_1"):
                    ax0 = T.axis.spatial(1, 0)
                    ax1 = T.axis.spatial(1, 0)
                    ax2 = T.axis.spatial(1, 0)
                    ax3 = T.axis.spatial(2048, i0_i1_i2_i3_0_fused * 64 + i3_1)
                    T.reads(tensor_shared[ax0, ax1, ax2, ax3])
                    T.writes(tensor[ax0, ax1, ax2, ax3])
                    tensor[ax0, ax1, ax2, ax3] = tensor_shared[ax0, ax1, ax2, ax3] * T.float32(0.020408163265306121)

@MasterJH5574
Copy link
Contributor Author

@zxybazh Sorry what do you mean by “incorrect output”? From my POV, this script has no essential difference with the one in #12976. Do you mean that with this PR, the script still fails?

@MasterJH5574
Copy link
Contributor Author

I just tried and see no error. Let me know if there’s anything wrong.

@zxybazh
Copy link
Member

zxybazh commented Oct 28, 2022

Hi @MasterJH5574, I just double checked the error cases, it turns out it's not working when I cherry-picked your PR but fixed after I rebase over your PR. Will continue the verification. Current progress 15,264/20,039, should be good to merge after the rest records are verified.

@zxybazh
Copy link
Member

zxybazh commented Oct 28, 2022

The verification on 20,039 tuning records for ResNet50 end to end tuning was finished, the previously mentioned issue with cross thread reduction is fixed. I found some other incorrect output with a specific conv2d workload, but that's out of the scope and I'll follow up in a separate issue. Thanks @MasterJH5574 for the fix and @wrongtest-intellif for review!

@zxybazh zxybazh merged commit 3cce973 into apache:main Oct 28, 2022
xinetzone pushed a commit to daobook/tvm that referenced this pull request Nov 10, 2022
…ors (apache#13203)

* Fix affine bindings

* Regression test and test update
xinetzone pushed a commit to daobook/tvm that referenced this pull request Nov 25, 2022
…ors (apache#13203)

* Fix affine bindings

* Regression test and test update
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants