-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Transform] Improve symbolic variable handling in FuseOps #16450
base: main
Are you sure you want to change the base?
Conversation
Ideally we don't want to change FuseOps behavior, since in cases where expressions are intermediate (e.g. intermediate compute include values that contains exprs like n * 4). This is because we should get maybe we should look into compose them? FuseOps first then rewrite signatures |
I could see having a post-processing pass to update the signature, maybe as an extension of Though, could you expand on what you mean by intermediate expressions? In either case, whether implemented in |
b34ffe9
to
b014332
Compare
b014332
to
3556c4f
Compare
Rebased onto main to resolve conflicts. For long-term, I think I agree that it would be cleaner and more general-purpose to have the functionality separated out into three distinct passes:
|
3556c4f
to
ebb6278
Compare
I've separated the first commit of this PR branch into an independent PR (#16637), as the bugfix it provides is independent of the concerns raised, and does not require the not-yet-implemented |
Prior to this commit, `FuseOps` and `FuseOpsByPattern` exposed a symbolic variable to the fused function if it was used within the fused function, but wasn't inferable from other parameter shapes. While this prevents undefined symbolic variables, it can cause issues for downstream use of `CodegenJSON`, which requires all arguments to be tensors, or tuple of tensors. Frequently, all uses of a non-inferable symbolic shape occur within a symbolic expression that can be inferred. For example, a function that takes `arg: R.Tensor([N+1])` and returns `R.add(arg, R.const(1))` cannot infer `N`. However, all occurrences of `N` occur as part of the expression `N+1`, and the value of `N+1` can be inferred. Therefore, if we replace `N+1` with `M`, the additional `ShapeTuple` argument isn't required.
ebb6278
to
db735e3
Compare
Prior to this commit,
FuseOps
andFuseOpsByPattern
exposed a symbolic variable to the fused function if it was used within the fused function, but wasn't inferable from other parameter shapes. While this prevents undefined symbolic variables, it can cause issues for downstream use ofCodegenJSON
, which requires all arguments to be tensors, or tuple of tensors.Frequently, all uses of a non-inferable symbolic shape occur within a symbolic expression that can be inferred. For example, a function that takes
arg: R.Tensor([N+1])
and returnsR.add(arg, R.const(1))
cannot inferN
. However, all occurrences ofN
occur as part of the expressionN+1
, and the value ofN+1
can be inferred. Therefore, if we replaceN+1
withM
, the additionalShapeTuple
argument isn't required.In addition, prior to this commit, the
CompositeFunctionAnnotator
visited the body of functions without the parameters being considered in-scope. As a result,EraseToWellDefined
would remove known shapes from the function body'sStructInfo
.