Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax][BlockBuilder] Use PrimValue to provide tir_vars #17087

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

Lunderberg
Copy link
Contributor

Prior to this commit, if a TIR variable was required to compute the output of BlockBuilder.call_te, but that TIR variable could not be inferred from the shape of any tensor arguments, it would be provided in an optional tir_vars argument to R.call_tir. In C++, this would be then be accessed as an optional
call->args[2].as<ShapeExprNode>().

This extra argument can cause unexpected bugs. For example, the bug that was fixed in #17086 was caused by RewriteDataflowReshape identifying the output buffer using prim_func->buffer_map.Get(prim_func->params.back()), which is only correct if tir_vars is empty. Rather than fixing these issues as they come up, it would be better to make the general Relax guarantees stronger by removing the tir_vars argument altogether.

Use of extra R.shape parameter to specify additional tir_vars predates the existence of relax::PrimValue, and is no longer required. This commit updates BlockBuilder.call_te to use additional relax.PrimValue arguments to handle symbolic values that cannot be inferred from tensor shapes, rather than tir_vars.

Prior to this commit, if a TIR variable was required to compute the
output of `BlockBuilder.call_te`, but that TIR variable could not be
inferred from the shape of any tensor arguments, it would be provided
in an optional `tir_vars` argument to `R.call_tir`.  In C++, this
would be then be accessed as an optional
`call->args[2].as<ShapeExprNode>()`.

This extra argument can cause unexpected bugs.  For example,
`RewriteDataflowReshape` identifies the output buffer using
`prim_func->buffer_map.Get(prim_func->params.back())`, which is only
correct if `tir_vars` is empty.  Rather than fixing these issues as
they come up, it would be better to make the general Relax guarantees
stronger by removing the `tir_vars` argument altogether.

Use of extra `R.shape` parameter to specify additional `tir_vars`
predates the existence of `relax::PrimValue`, and is no longer
required.  This commit updates `BlockBuilder.call_te` to use
additional `relax.PrimValue` arguments to handle symbolic values that
cannot be inferred from tensor shapes, rather than `tir_vars`.
@Lunderberg Lunderberg requested a review from masahi June 12, 2024 19:33
@masahi
Copy link
Member

masahi commented Jun 12, 2024

cc @tqchen @Hzfengsy

@Lunderberg Lunderberg requested a review from sunggg June 18, 2024 18:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants