Skip to content

[Unity][Op] Avoid indices in TIR matmul being 0 in legalization#14701

Merged
junrushao merged 1 commit intoapache:unityfrom
MasterJH5574:unity-dev/2023-04-22-matmul-legalize
Apr 22, 2023
Merged

[Unity][Op] Avoid indices in TIR matmul being 0 in legalization#14701
junrushao merged 1 commit intoapache:unityfrom
MasterJH5574:unity-dev/2023-04-22-matmul-legalize

Conversation

@MasterJH5574
Copy link
Copy Markdown
Contributor

This PR changes a behavior of the legalization of matmul, so that we do not use 0 as indices in the generated TIR in certain case.

Since the matmul op supports broadcasting and batching, previously when legalizing a matmul op, for the broadcasting dimensions, we will emit indices "0" for those broadcasting dimensions with length 1. For example, the code below is a TIR produced by matmul legalization.

@T.prim_func
def matmul(A: T.Buffer((T.int64(1), T.int64(1), T.int64(4), T.int64(5)), "float32"), B: T.Buffer((T.int64(1), T.int64(1), T.int64(5), T.int64(7)), "float32"), matmul_1: T.Buffer((T.int64(1), T.int64(1), T.int64(4), T.int64(7)), "float32")):
    T.func_attr({"tir.noalias": T.bool(True)})
    # with T.block("root"):
    for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(1), T.int64(4), T.int64(7), T.int64(5)):
        with T.block("matmul"):
            v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k])
            T.reads(A[T.int64(0), T.int64(0), v_i2, v_k], B[T.int64(0), T.int64(0), v_k, v_i3])
            T.writes(matmul_1[v_i0, v_i1, v_i2, v_i3])
            with T.init():
                matmul_1[v_i0, v_i1, v_i2, v_i3] = T.float32(0)
            matmul_1[v_i0, v_i1, v_i2, v_i3] = matmul_1[v_i0, v_i1, v_i2, v_i3] + A[T.int64(0), T.int64(0), v_i2, v_k] * B[T.int64(0), T.int64(0), v_k, v_i3]

You can see us using T.int64(0) to index the first dim of A and B.

However, when both A and B have length 1 at that dimension, it is more canonical to use a variable as the index, as this is more acceptable and detectable by analysis functions generally.

Therefore, this PR updates the behavior, so that we will emit variable as indices when both sides have length 1, just as the example above. We have a unit test to demonstrate the effect after changing.

This PR changes a behavior of the legalization of matmul, so that we
do not use 0 as indices in the generated TIR in certain case.

Since the matmul op supports broadcasting and batching, previously when
legalizing a matmul op, for the broadcasting dimensions, we will emit
indices "0" for those broadcasting dimensions with length 1. For
example, the code below is a TIR produced by matmul legalization.
```python
@T.prim_func
def matmul(A: T.Buffer((T.int64(1), T.int64(1), T.int64(4), T.int64(5)), "float32"), B: T.Buffer((T.int64(1), T.int64(1), T.int64(5), T.int64(7)), "float32"), matmul_1: T.Buffer((T.int64(1), T.int64(1), T.int64(4), T.int64(7)), "float32")):
    T.func_attr({"tir.noalias": T.bool(True)})
    # with T.block("root"):
    for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(1), T.int64(4), T.int64(7), T.int64(5)):
        with T.block("matmul"):
            v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k])
            T.reads(A[T.int64(0), T.int64(0), v_i2, v_k], B[T.int64(0), T.int64(0), v_k, v_i3])
            T.writes(matmul_1[v_i0, v_i1, v_i2, v_i3])
            with T.init():
                matmul_1[v_i0, v_i1, v_i2, v_i3] = T.float32(0)
            matmul_1[v_i0, v_i1, v_i2, v_i3] = matmul_1[v_i0, v_i1, v_i2, v_i3] + A[T.int64(0), T.int64(0), v_i2, v_k] * B[T.int64(0), T.int64(0), v_k, v_i3]
```
You can see us using `T.int64(0)` to index the first dim of `A` and `B`.

However, when both `A` and `B` have length 1 at that dimension, it is
more canonical to use a variable as the index, as this is more
acceptable and detectable by analysis functions generally.

Therefore, this PR updates the behavior, so that we will emit variable
as indices when both sides have length 1, just as the example above.
We have a unit test to demonstrate the effect after changing.
@tvm-bot
Copy link
Copy Markdown
Collaborator

tvm-bot commented Apr 22, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

1 similar comment
@tvm-bot
Copy link
Copy Markdown
Collaborator

tvm-bot commented Apr 22, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

Copy link
Copy Markdown
Member

@junrushao junrushao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@junrushao junrushao merged commit ee6e26f into apache:unity Apr 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants