Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity] Added known tir.Expr to relax.PrimValue #15577

Merged
merged 5 commits into from
Sep 7, 2023

Conversation

Lunderberg
Copy link
Contributor

@Lunderberg Lunderberg commented Aug 16, 2023

Prior to this commit, a relax.PrimValue could have a datatype, but couldn't have a corresponding tir.PrimExpr. As a result, it could not be used to specify tensor shapes. This makes some expressions require fallback to R.Tensor(ndim=ndim), even though the shape could still be inferred.

# Prior to this commit, attempts to use `R.Prim` when 
# a `tir.PrimExpr` is required would result in an error.
@R.function
def func(
    A: R.Tensor(16, 16),
    first_n_rows: R.Prim("int64"),
) -> R.Tensor([first_n_rows, 16]):
    #          ^^^^^^^^^^^^
    #          R.Tensor requires a PrimExpr, not relax.Expr
    #
    #                               Operations may require PrimExpr
    #                                                  vvvvvvvvvvvv
    out = R.op.strided_slice(axis=[0], begin=[0], end=[first_n_rows])
    return out

This commit adds a Optional<PrimExpr> value field to the PrimStructInfo. This field acts similarly to the PrimExpr fields already used in ShapeStructInfo, and may contain symbolic variables.

@R.function
def func(
    A: R.Tensor(16, 16),

    # TIR definitions in signature allow in-line definitions,
    # similar to R.Tensor and R.Shape.  R.Prim takes `dtype` or
    # `value` kwarg to distinguish between in-line symbolic variable
    # and string representation of dtype.
    first_n_rows: R.prim(value="first_n_rows_tir"),
) -> R.Tensor(["first_n_rows_tir", 16]):

    # Body contains a TIR variable definition, which may be used
    # in function calls, inferred shape annotations.
    first_n_rows_tir = T.int64()
    out = R.op.strided_slice(axis=[0], begin=[0], end=[first_n_rows])
    return out

@tvm-bot
Copy link
Collaborator

tvm-bot commented Aug 16, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

Prior to this commit, a `relax.PrimValue` could have a datatype, but
couldn't have a corresponding `tir.PrimExpr`.  As a result, it could
not be used to specify tensor shapes.  This makes some expressions
require fallback to `R.Tensor(ndim=ndim)`, even though the shape could
still be inferred.

```python
@R.function
def func(
    A: R.Tensor(16, 16),
    first_n_rows: R.prim("int64"),
) -> R.Tensor([first_n_rows, 16]):
    #          ^^^^^^^^^^^^
    #          R.Tensor requires a PrimExpr, not relax.Expr
    #
    #                               Operations may require PrimExpr
    #                                                  vvvvvvvvvvvv
    out = R.op.strided_slice(axis=[0], begin=[0], end=[first_n_rows])
    return out
```

This commit adds a `Optional<PrimExpr> value` field to the
`PrimStructInfo`.  This field acts similarly to the `PrimExpr` fields
already used in `ShapeStructInfo`, and may contain symbolic variables.

```python
@R.function
def func(
    A: R.Tensor(16, 16),

    # TIR definitions in signature allow in-line definitions,
    # similar to R.Tensor and R.Shape.  R.Prim takes `dtype` or
    # `value` kwarg to distinguish between in-line symbolic variable
    # and string representation of dtype.
    first_n_rows: R.prim(value="first_n_rows_tir"),
) -> R.Tensor(["first_n_rows_tir", 16]):

    # Body contains a TIR variable definition, which may be used
    # in function calls, inferred shape annotations.
    first_n_rows_tir = T.int64()
    out = R.op.strided_slice(axis=[0], begin=[0], end=[first_n_rows])
    return out
```

Use distinct PrimStructInfo arguments for dtype/value

Update TVMScript printer

Parser updates, Support R.Prim(value=...) annotations in function signature
Copy link
Contributor

@slyubomirsky slyubomirsky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this change is a good addition and that it's implemented very well, thanks @Lunderberg! I will start updating the spec to account for this.

My main concern are for the cases where shape variables are used "before" they're defined (going left to right in a function). When writing the spec, I assume this wasn't a case we wanted. If that is a case we want to permit, then should our policy be to scan the arguments for binding positions first? That would be good to figure out.

Comment on lines +61 to +69
# Guard against incorrect usage. For backwards compatibility,
# the dtype and value are in the opposite order from most
# usages. While PrimStructInfo could take a single positional
# argument and check the type, this would require an API
# difference from TVMScript's PrimProxy, which cannot.
# (PrimProxy uses string arguments for datatype, and also for
# inline variable definitions when used in a function
# signature, and requires separate arguments to distinguish
# the two cases.)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very user-friendly touch :)

"""The bound variable should be replaced when appearing in R.Shape"""

@R.function(private=True)
def before(A: R.Tensor(["M * N"]), x: R.Shape(["M", "N"])):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure we want to permit cases like this? I believe I wrote the specification draft to require argument vars to be in binding positions, left to right. If we permit uses before the binding position, then should the type-checker scan through the arguments, identify the bindings first, and only then check the position?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking so, mainly because it would be the same support that is provided in TIR. (ArgBinder::BindDLTensor defines the symbolic variables prior to asserting based on the shapes.)

There's also a couple of edge cases that I think would be easier to handle by allowing it. Swapping the order of commutative operators, hoisting R.match_cast from the body into the function signature, and handling cases with mutually-dependent shapes. (e.g. The first parameter is R.Tensor(["a * b", "b"]) and the second parameter is R.Tensor(["a * b", "a"]).)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I'll have to modify the specification to account for it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, and thank you for all your work keeping the spec up to date!

@tqchen
Copy link
Member

tqchen commented Sep 6, 2023

cc @junrushao @Hzfengsy

Copy link
Member

@Hzfengsy Hzfengsy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes generally look good to me. Thanks @Lunderberg for the great addition!

Only one question: should the value be general PrimExpr, or be a tir.Var. IIUC, it maps the prim value to one symbolic var and is used in the shape deduction. A direct var should be easier than a generic expr in this case.

@Lunderberg
Copy link
Contributor Author

@Hzfengsy Thank you, and I do think the value should be a tir.PrimExpr and not just a tir.Var for a few reasons.

  • API consistency when applying BindSymbolicVars. When a function is updated using BindSymbolicVars, the allowed values for each parameter are restricted, but the calling convention is otherwise kept identical. For example, binding {n: 16} would reduce the set of allowed values for R.Tensor([n, n]) (the set of all square tensors) to R.Tensor([16, 16]) (the set of all tensors of shape [16,16]), and remains part of the function signature. Similarly, a PrimValue parameter would be reduced from R.Prim(dtype='int64', value=n) (the set of all 64-bit integers) to R.Prim(dtype='int64', value=16) (the set containing only the integer 16), and remains part of the function signature.

    If R.Prim were instead restricted to a tir.Var, it wouldn't be possible to write R.Prim(dtype='int64', value=16), and the parameter would need to be removed entirely. As a result, the calling scope would need to be updated (e.g. from my_func(tensor, 16) to my_func(tensor)).

  • API consistency with R.Shape. Similar to a R.Prim, a R.Shape can contain only static parameters. However, for similar reasons, parameters of type R.Shape([n, n]) are specialized to R.Shape([16,16]) and retained in the function signature, even though they only have a single legal value of ShapeTuple([16,16]).

  • Usefulness as constituents in other StructInfo objects. Currently, many types of StructInfo may internally contain a PrimExpr (e.g. TensorStructInfo, ShapeStructInfo), and so visitors that inspect the struct info for PrimExpr instances must handle each case independently, performing the handoff from StructInfoFunctor to tir::ExprFunctor. If these were instead expressed in terms of PrimStructInfo, this would only need to occur in a single location. Because TensorStructInfo and ShapeStructInfo may contain either static values or expressions in terms of other symbolic variables, this potential usage would require PrimStructInfo to also contain these values.

Copy link
Member

@Hzfengsy Hzfengsy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification. LGTM now!

@Lunderberg Lunderberg merged commit 755af1f into apache:unity Sep 7, 2023
8 checks passed
@Lunderberg Lunderberg deleted the unity_primstructinfo_to_tir branch September 7, 2023 15:04
Lunderberg added a commit to Lunderberg/tvm that referenced this pull request Oct 24, 2023
This test was implemented in apache#15626,
but was initially disabled as it depended on functionality not
introduced until apache#15577.  Since that
PR has landed, cleaning up and enabling the unit test.
Lunderberg added a commit to Lunderberg/tvm that referenced this pull request Oct 25, 2023
This test was implemented in apache#15626,
but was initially disabled as it depended on functionality not
introduced until apache#15577.  Since that
PR has landed, cleaning up and enabling the unit test.
masahi pushed a commit that referenced this pull request Oct 30, 2023
This test was implemented in #15626,
but was initially disabled as it depended on functionality not
introduced until #15577.  Since that
PR has landed, cleaning up and enabling the unit test.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants