[Unity][Op] Avoid indices in TIR matmul being 0 in legalization#14701
Merged
junrushao merged 1 commit intoapache:unityfrom Apr 22, 2023
Merged
Conversation
This PR changes a behavior of the legalization of matmul, so that we
do not use 0 as indices in the generated TIR in certain case.
Since the matmul op supports broadcasting and batching, previously when
legalizing a matmul op, for the broadcasting dimensions, we will emit
indices "0" for those broadcasting dimensions with length 1. For
example, the code below is a TIR produced by matmul legalization.
```python
@T.prim_func
def matmul(A: T.Buffer((T.int64(1), T.int64(1), T.int64(4), T.int64(5)), "float32"), B: T.Buffer((T.int64(1), T.int64(1), T.int64(5), T.int64(7)), "float32"), matmul_1: T.Buffer((T.int64(1), T.int64(1), T.int64(4), T.int64(7)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(1), T.int64(4), T.int64(7), T.int64(5)):
with T.block("matmul"):
v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k])
T.reads(A[T.int64(0), T.int64(0), v_i2, v_k], B[T.int64(0), T.int64(0), v_k, v_i3])
T.writes(matmul_1[v_i0, v_i1, v_i2, v_i3])
with T.init():
matmul_1[v_i0, v_i1, v_i2, v_i3] = T.float32(0)
matmul_1[v_i0, v_i1, v_i2, v_i3] = matmul_1[v_i0, v_i1, v_i2, v_i3] + A[T.int64(0), T.int64(0), v_i2, v_k] * B[T.int64(0), T.int64(0), v_k, v_i3]
```
You can see us using `T.int64(0)` to index the first dim of `A` and `B`.
However, when both `A` and `B` have length 1 at that dimension, it is
more canonical to use a variable as the index, as this is more
acceptable and detectable by analysis functions generally.
Therefore, this PR updates the behavior, so that we will emit variable
as indices when both sides have length 1, just as the example above.
We have a unit test to demonstrate the effect after changing.
Collaborator
|
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
1 similar comment
Collaborator
|
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR changes a behavior of the legalization of matmul, so that we do not use 0 as indices in the generated TIR in certain case.
Since the matmul op supports broadcasting and batching, previously when legalizing a matmul op, for the broadcasting dimensions, we will emit indices "0" for those broadcasting dimensions with length 1. For example, the code below is a TIR produced by matmul legalization.
You can see us using
T.int64(0)to index the first dim ofAandB.However, when both
AandBhave length 1 at that dimension, it is more canonical to use a variable as the index, as this is more acceptable and detectable by analysis functions generally.Therefore, this PR updates the behavior, so that we will emit variable as indices when both sides have length 1, just as the example above. We have a unit test to demonstrate the effect after changing.