-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Unity] Added known tir.Expr to relax.PrimValue #15577
[Unity] Added known tir.Expr to relax.PrimValue #15577
Conversation
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
Prior to this commit, a `relax.PrimValue` could have a datatype, but couldn't have a corresponding `tir.PrimExpr`. As a result, it could not be used to specify tensor shapes. This makes some expressions require fallback to `R.Tensor(ndim=ndim)`, even though the shape could still be inferred. ```python @R.function def func( A: R.Tensor(16, 16), first_n_rows: R.prim("int64"), ) -> R.Tensor([first_n_rows, 16]): # ^^^^^^^^^^^^ # R.Tensor requires a PrimExpr, not relax.Expr # # Operations may require PrimExpr # vvvvvvvvvvvv out = R.op.strided_slice(axis=[0], begin=[0], end=[first_n_rows]) return out ``` This commit adds a `Optional<PrimExpr> value` field to the `PrimStructInfo`. This field acts similarly to the `PrimExpr` fields already used in `ShapeStructInfo`, and may contain symbolic variables. ```python @R.function def func( A: R.Tensor(16, 16), # TIR definitions in signature allow in-line definitions, # similar to R.Tensor and R.Shape. R.Prim takes `dtype` or # `value` kwarg to distinguish between in-line symbolic variable # and string representation of dtype. first_n_rows: R.prim(value="first_n_rows_tir"), ) -> R.Tensor(["first_n_rows_tir", 16]): # Body contains a TIR variable definition, which may be used # in function calls, inferred shape annotations. first_n_rows_tir = T.int64() out = R.op.strided_slice(axis=[0], begin=[0], end=[first_n_rows]) return out ``` Use distinct PrimStructInfo arguments for dtype/value Update TVMScript printer Parser updates, Support R.Prim(value=...) annotations in function signature
4d2eac6
to
087e636
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this change is a good addition and that it's implemented very well, thanks @Lunderberg! I will start updating the spec to account for this.
My main concern are for the cases where shape variables are used "before" they're defined (going left to right in a function). When writing the spec, I assume this wasn't a case we wanted. If that is a case we want to permit, then should our policy be to scan the arguments for binding positions first? That would be good to figure out.
# Guard against incorrect usage. For backwards compatibility, | ||
# the dtype and value are in the opposite order from most | ||
# usages. While PrimStructInfo could take a single positional | ||
# argument and check the type, this would require an API | ||
# difference from TVMScript's PrimProxy, which cannot. | ||
# (PrimProxy uses string arguments for datatype, and also for | ||
# inline variable definitions when used in a function | ||
# signature, and requires separate arguments to distinguish | ||
# the two cases.) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very user-friendly touch :)
"""The bound variable should be replaced when appearing in R.Shape""" | ||
|
||
@R.function(private=True) | ||
def before(A: R.Tensor(["M * N"]), x: R.Shape(["M", "N"])): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we sure we want to permit cases like this? I believe I wrote the specification draft to require argument vars to be in binding positions, left to right. If we permit uses before the binding position, then should the type-checker scan through the arguments, identify the bindings first, and only then check the position?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking so, mainly because it would be the same support that is provided in TIR. (ArgBinder::BindDLTensor
defines the symbolic variables prior to asserting based on the shapes.)
There's also a couple of edge cases that I think would be easier to handle by allowing it. Swapping the order of commutative operators, hoisting R.match_cast
from the body into the function signature, and handling cases with mutually-dependent shapes. (e.g. The first parameter is R.Tensor(["a * b", "b"])
and the second parameter is R.Tensor(["a * b", "a"])
.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I'll have to modify the specification to account for it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, and thank you for all your work keeping the spec up to date!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes generally look good to me. Thanks @Lunderberg for the great addition!
Only one question: should the value be general PrimExpr, or be a tir.Var
. IIUC, it maps the prim value to one symbolic var and is used in the shape deduction. A direct var should be easier than a generic expr in this case.
@Hzfengsy Thank you, and I do think the value should be a
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the clarification. LGTM now!
This test was implemented in apache#15626, but was initially disabled as it depended on functionality not introduced until apache#15577. Since that PR has landed, cleaning up and enabling the unit test.
This test was implemented in apache#15626, but was initially disabled as it depended on functionality not introduced until apache#15577. Since that PR has landed, cleaning up and enabling the unit test.
Prior to this commit, a
relax.PrimValue
could have a datatype, but couldn't have a correspondingtir.PrimExpr
. As a result, it could not be used to specify tensor shapes. This makes some expressions require fallback toR.Tensor(ndim=ndim)
, even though the shape could still be inferred.This commit adds a
Optional<PrimExpr> value
field to thePrimStructInfo
. This field acts similarly to thePrimExpr
fields already used inShapeStructInfo
, and may contain symbolic variables.