Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Transform] Improve symbolic variable handling in FuseOps #16450

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Lunderberg
Copy link
Contributor

Prior to this commit, FuseOps and FuseOpsByPattern exposed a symbolic variable to the fused function if it was used within the fused function, but wasn't inferable from other parameter shapes. While this prevents undefined symbolic variables, it can cause issues for downstream use of CodegenJSON, which requires all arguments to be tensors, or tuple of tensors.

Frequently, all uses of a non-inferable symbolic shape occur within a symbolic expression that can be inferred. For example, a function that takes arg: R.Tensor([N+1]) and returns R.add(arg, R.const(1)) cannot infer N. However, all occurrences of N occur as part of the expression N+1, and the value of N+1 can be inferred. Therefore, if we replace N+1 with M, the additional ShapeTuple argument isn't required.

In addition, prior to this commit, the CompositeFunctionAnnotator visited the body of functions without the parameters being considered in-scope. As a result, EraseToWellDefined would remove known shapes from the function body's StructInfo.

@tqchen
Copy link
Member

tqchen commented Jan 22, 2024

Ideally we don't want to change FuseOps behavior, since in cases where expressions are intermediate (e.g. intermediate compute include values that contains exprs like n * 4).

This is because we should get maybe we should look into compose them? FuseOps first then rewrite signatures

@Lunderberg
Copy link
Contributor Author

I could see having a post-processing pass to update the signature, maybe as an extension of RemoveUnusedParameters. There would still need to be an update to FuseOps to have the fused functions marked as private, since the post-processing step would only be allowed to update the signature of internal functions.

Though, could you expand on what you mean by intermediate expressions? In either case, whether implemented in FuseOps or in a post-processing pass, I think intermediate expressions would be handled correctly. If an expression n*4 can be inferred from the tensor shapes, but n+42 also appears in the fused function, then there would still be a shape expr used to expose n to the fused function.

@Lunderberg
Copy link
Contributor Author

Rebased onto main to resolve conflicts.

For long-term, I think I agree that it would be cleaner and more general-purpose to have the functionality separated out into three distinct passes:

  1. FuseOps, with the first commit in this PR to preserve symbolic variables in the ret_struct_info.
  2. A not-yet-existing HoistCommonSubexpressions, which would recognize that a symbolic variable is always used within a specific expression, and would hoist it to the calling scope.
  3. Applying the RemoveUnusedParameters to remove the no-longer-required R.shape param.

@Lunderberg
Copy link
Contributor Author

I've separated the first commit of this PR branch into an independent PR (#16637), as the bugfix it provides is independent of the concerns raised, and does not require the not-yet-implemented HoistCommonSubexpressions transform.

Prior to this commit, `FuseOps` and `FuseOpsByPattern` exposed
a symbolic variable to the fused function if it was used within the
fused function, but wasn't inferable from other parameter shapes.
While this prevents undefined symbolic variables, it can cause issues
for downstream use of `CodegenJSON`, which requires all arguments to
be tensors, or tuple of tensors.

Frequently, all uses of a non-inferable symbolic shape occur within a
symbolic expression that can be inferred.  For example, a function
that takes `arg: R.Tensor([N+1])` and returns `R.add(arg, R.const(1))`
cannot infer `N`.  However, all occurrences of `N` occur as part
of the expression `N+1`, and the value of `N+1` can be inferred.
Therefore, if we replace `N+1` with `M`, the additional `ShapeTuple`
argument isn't required.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants