-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Unity] Provide FuncStructInfo from bb.emit_te
#15026
[Unity] Provide FuncStructInfo from bb.emit_te
#15026
Conversation
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
8b1056e
to
cca8b12
Compare
output_sinfo = [te_to_sinfo(out) for out in outs] | ||
|
||
primfunc_sinfo = FuncStructInfo([*input_sinfo, *output_sinfo], PrimStructInfo("void")) | ||
_update_struct_info(tir_func, primfunc_sinfo) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tqchen Is this consistent with how we want FuncStructInfo
to work? I thought PrimFunc
s would be ObjectStructInfo
(this is what we wrote in the Relax spec). Perhaps they use a derive_func
instead? If we want them to use ordinary FuncStructInfo
, does that also mean we'll allow them to be called outside of call_tir
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question, and I had assumed this was intended, but would be interested to hear on it. I had mostly assumed that a Relax function and a TIR PrimFunc should expose the same information, so long as they have the same convention. That is, since the callsite has no distinction between a GlobalVar
representing a relax::Function
or a tir::PrimFunc
, it seemed that the struct_info_
would depend only on the call sequence, and not the implementation dialect.
Regarding call_tir
, I think the Relax-to-TIR calls are not restricted to the R.call_tir
built-in, because the LowerCallTIR
pass can output a relax::CallNode
with a GlobalVar
as its operation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, permitting TIR calls outside of call_tir
is something we're trying to figure out with respect to phase ordering in Relax (see thread). I was under the impression that we did not want direct calls to PrimFunc
s in the front end, so we should clarify that (we could put this on the agenda for a community meeting).
FWIW, I don't think it would be hard to give PrimFunc
s FuncStructInfo
, but there is the issue that they mutate their arguments, so they should be treated as impure (except when called via call_tir
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point regarding the mutation. Thinking on it, I'm also not sure what the best FuncStructInfo
would be. It could reasonably be either FuncStructInfo(params = [*input_tensors, *output_tensors]
, ret=None), which matches the TIR function's signature, or
FuncStructInfo(params=input_tensors, ret=relax.Tuple(output_tensors))`, which matches the exposed semantics in Relax.
The original issue I was running into was that the result of bb.emit_te
doesn't preserve the output struct information across mutations. If I have a TE function that accepts dynamic shapes, but which is called using static shapes, then the return type of the relax::Call
should be an inferred static shape. This works during the first usage of BlockBuilder, when a user is calling bb.emit_te
directly. However, when the module is mutated, any mutation of the call node relies on the relax::Normalizer
to regenerate the output struct info, and it doesn't have enough information to do so.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if you just call the PrimFunc
by itself, it will work by mutating the arguments, so the best signature would be the first one you suggested. call_tir
(the operator) is what's responsible for providing the nice wrapper over the mutation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Lunderberg I've put this PR on the agenda for next week's community meeting. If you can make it, that would be good
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! I probably won't be able to attend, given the timing, but I agree that discussion would be good. For now, I've converted this PR to a draft, to ensure that it can't be merged prior to discussion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Conclusion from the meeting: We think it's okay to permit direct calls to PrimFuncs as long as they're treated as impure and to give FuncStructInfo to PrimFuncs, though they should, again, be marked as impure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if unit (empty tuple) would make more sense as the return type, incidentally. Also, I do think the FuncStructInfo should have the purity set to false.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summarizing our conversation from this morning:
-
Shape propagation through
bb.emit_te
only works during the initial construction of a Relax module, when therelax::Call("relax.call_tir",...)
node is explicitly typed. Re-derivation of the output shape is not implemented, and so the shape information can be lost during lowering if the arguments tocall_tir
change. -
Annotating a PrimFunc with
FuncStructInfo
to represent the output ofcall_tir
(i.e. pure function, tensor output) wouldn't be accurate, and could cause confusion in the future. -
Annotating a PrimFunc with
FuncStructInfo
to represent the PrimFunc itself (i.e. impure function, mutates arguments) would be accurate, but insufficient forcall_tir
to propagate shapes, as input/output shapes are mixed. -
Would be useful to have a purity annotations for each parameter, dividing arguments into read-only, output, and mutate-in-place. This would allow a PrimFunc to be accurately annotated, and would be sufficient for
call_tir
to identify outputs for shape propagation.
Prior to this commit, the PrimFunc generated by `bb.call_te` had no struct info associated with it. This commit updates `gen_call_tir_inputs`, which converts from a TE expression into a TIR PrimFunc, to annotate the PrimFunc with `FuncStructInfo` representing the input and output shapes. Providing this functionality for PrimFuncs produced from TE is a simpler case than a general PrimFunc, as TE has well-defined input and output tensors.
The PrimFunc's GlobalVar is later used as the CallNode::op, and must have correct shape inference at that point.
cca8b12
to
74e207b
Compare
c95d45f
to
45eeb8c
Compare
Prior to this commit, the PrimFunc generated by
bb.call_te
had no struct info associated with it. This commit updatesgen_call_tir_inputs
, which converts from a TE expression into a TIR PrimFunc, to annotate the PrimFunc withFuncStructInfo
representing the input and output shapes.Providing this functionality for PrimFuncs produced from TE is a simpler case than a general PrimFunc, as TE has well-defined input and output tensors.