Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity] Provide FuncStructInfo from bb.emit_te #15026

Closed

Conversation

Lunderberg
Copy link
Contributor

Prior to this commit, the PrimFunc generated by bb.call_te had no struct info associated with it. This commit updates gen_call_tir_inputs, which converts from a TE expression into a TIR PrimFunc, to annotate the PrimFunc with FuncStructInfo representing the input and output shapes.

Providing this functionality for PrimFuncs produced from TE is a simpler case than a general PrimFunc, as TE has well-defined input and output tensors.

@tvm-bot
Copy link
Collaborator

tvm-bot commented Jun 5, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@Lunderberg Lunderberg force-pushed the relax_funcstructinfo_from_call_te branch from 8b1056e to cca8b12 Compare June 6, 2023 15:23
output_sinfo = [te_to_sinfo(out) for out in outs]

primfunc_sinfo = FuncStructInfo([*input_sinfo, *output_sinfo], PrimStructInfo("void"))
_update_struct_info(tir_func, primfunc_sinfo)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tqchen Is this consistent with how we want FuncStructInfo to work? I thought PrimFuncs would be ObjectStructInfo (this is what we wrote in the Relax spec). Perhaps they use a derive_func instead? If we want them to use ordinary FuncStructInfo, does that also mean we'll allow them to be called outside of call_tir?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, and I had assumed this was intended, but would be interested to hear on it. I had mostly assumed that a Relax function and a TIR PrimFunc should expose the same information, so long as they have the same convention. That is, since the callsite has no distinction between a GlobalVar representing a relax::Function or a tir::PrimFunc, it seemed that the struct_info_ would depend only on the call sequence, and not the implementation dialect.

Regarding call_tir, I think the Relax-to-TIR calls are not restricted to the R.call_tir built-in, because the LowerCallTIR pass can output a relax::CallNode with a GlobalVar as its operation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, permitting TIR calls outside of call_tir is something we're trying to figure out with respect to phase ordering in Relax (see thread). I was under the impression that we did not want direct calls to PrimFuncs in the front end, so we should clarify that (we could put this on the agenda for a community meeting).

FWIW, I don't think it would be hard to give PrimFuncs FuncStructInfo, but there is the issue that they mutate their arguments, so they should be treated as impure (except when called via call_tir).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point regarding the mutation. Thinking on it, I'm also not sure what the best FuncStructInfo would be. It could reasonably be either FuncStructInfo(params = [*input_tensors, *output_tensors], ret=None), which matches the TIR function's signature, or FuncStructInfo(params=input_tensors, ret=relax.Tuple(output_tensors))`, which matches the exposed semantics in Relax.

The original issue I was running into was that the result of bb.emit_te doesn't preserve the output struct information across mutations. If I have a TE function that accepts dynamic shapes, but which is called using static shapes, then the return type of the relax::Call should be an inferred static shape. This works during the first usage of BlockBuilder, when a user is calling bb.emit_te directly. However, when the module is mutated, any mutation of the call node relies on the relax::Normalizer to regenerate the output struct info, and it doesn't have enough information to do so.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if you just call the PrimFunc by itself, it will work by mutating the arguments, so the best signature would be the first one you suggested. call_tir (the operator) is what's responsible for providing the nice wrapper over the mutation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Lunderberg I've put this PR on the agenda for next week's community meeting. If you can make it, that would be good

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! I probably won't be able to attend, given the timing, but I agree that discussion would be good. For now, I've converted this PR to a draft, to ensure that it can't be merged prior to discussion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conclusion from the meeting: We think it's okay to permit direct calls to PrimFuncs as long as they're treated as impure and to give FuncStructInfo to PrimFuncs, though they should, again, be marked as impure.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if unit (empty tuple) would make more sense as the return type, incidentally. Also, I do think the FuncStructInfo should have the purity set to false.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summarizing our conversation from this morning:

  • Shape propagation through bb.emit_te only works during the initial construction of a Relax module, when the relax::Call("relax.call_tir",...) node is explicitly typed. Re-derivation of the output shape is not implemented, and so the shape information can be lost during lowering if the arguments to call_tir change.

  • Annotating a PrimFunc with FuncStructInfo to represent the output of call_tir (i.e. pure function, tensor output) wouldn't be accurate, and could cause confusion in the future.

  • Annotating a PrimFunc with FuncStructInfo to represent the PrimFunc itself (i.e. impure function, mutates arguments) would be accurate, but insufficient for call_tir to propagate shapes, as input/output shapes are mixed.

  • Would be useful to have a purity annotations for each parameter, dividing arguments into read-only, output, and mutate-in-place. This would allow a PrimFunc to be accurately annotated, and would be sufficient for call_tir to identify outputs for shape propagation.

@Lunderberg Lunderberg marked this pull request as draft June 12, 2023 15:42
Prior to this commit, the PrimFunc generated by `bb.call_te`
had no struct info associated with it.  This commit updates
`gen_call_tir_inputs`, which converts from a TE expression into a TIR
PrimFunc, to annotate the PrimFunc with `FuncStructInfo` representing
the input and output shapes.

Providing this functionality for PrimFuncs produced from TE is a
simpler case than a general PrimFunc, as TE has well-defined input and
output tensors.
The PrimFunc's GlobalVar is later used as the CallNode::op, and must
have correct shape inference at that point.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants