New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Unity] Allow FLegalize to produce Relax operations #15842
[Unity] Allow FLegalize to produce Relax operations #15842
Conversation
Prior to this commit, a `FLegalize` function needed to produce an implementation that can be used as input by `relax.transform.AnnotateTIROpPattern`, and could not lower to other relax operations. This commit allows Relax operations to be included in the output of `FLegalize`, with the result being further legalized if required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, @Lunderberg! A couple questions.
): | ||
return relax.Call(custom_op, [A, Weight, Bias]) | ||
|
||
AfterFirstIter = LegalizeOps()(Before) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does user need to perform LegalizeOps
passes depending on their custom ops? For example, user needs to call twice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope. With this change, the LegalizeOps
pass will continue until no additional legalization can be applied, so the user only needs to call the function once.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, sorry. I missed that you do equality check between AfterFirstIter
and AfterSecondIter
. Make sense to me.
|
||
return add_sinfo | ||
|
||
def legalize(bb: relax.BlockBuilder, call: relax.Call): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have a similar pass although this does not support any recursion: https://github.com/apache/tvm/blob/unity/python/tvm/relax/transform/transform.py#L994
Is there any use-case for recursion? Or is it more like a future-proof?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a couple of reasons I'd been thinking of, most of which fall somewhere between future-planning and user-friendliness. (Bit of a brain dump as follows.)
- User-friendliness to make it easier to write legalization steps. For example,
R.nn.rms_norm
could be written in terms ofR.std
instead of requiring a direct lowering to a TIR implementation. - Future-planning for user-defined custom intrinsics. If the legalization of these custom operators is defined in terms of standard relax operators,
LegalizeOps
would need to recursively expand them to allowAnnotateTIROpPattern
to recognize the results. - Future-planning for partial legalization. If each operator has a "composite_level", then we could selectively lower operators that are above some level of complexity. This would be a generalization of the
OpDecomposer
, to decompose any - Future-planning for defining the requirements of graph-level optimization passes. If an optimization pass handles all relax operators up to some
composite_level
, new operators could be added without impacting that optimization pass, so long as those operators define a partial legalization that decomposes it. - Centralizing the definition of each operator. With composite operators defined in terms of lower-complexity operators, the
OpDecomposer
could be identical to the rules used byLegalizeOps
, avoiding duplicate operator definitions. - Future-planning to minimize the need for TIR pattern recognition. For example,
R.nn.attention
is implemented in terms oftopi.transpose
andtopi.reshape
, and would require pattern-matching similar toRewriteDataflowReshape
to un-lower these back to Relax operations. IfR.nn.attention
were instead decomposed intoR.permute_dims
andR.reshape
, we'd get this for free.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, @Lunderberg for kind explanation. I like the idea of "composite-level" and centralizing the definitions. Can we check if DecomposeOpsForInference
and DecomposeOpsForTraining
can be supported with this PR to see if we can replace them? If so, we can discuss about their deprecation as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Taking a look, the DecomposeOpsFor*
passes are currently doing two distinct roles. The first role is to lower the relax.nn.batch_norm
, relax.nn.layer_norm
, and relax.tensor_to_shape
operators into lower-level relax implementations. The second role is to mutate the relax.nn.batch_norm
operator into a training-specific version.
I think the first role of lowering relax operators into less complex Relax operators will be supported by the partial lowering intended for LegalizeOps
. The second role is independent to the legalization, and would be best kept as a standalone pass. The second role would become much simpler, as the relax.nn.batch_norm(data, gamma, beta, prev_mean, prev_var)
could be updated to relax.nn.batch_norm(data, gamma, beta, weighted_avg(mean(data), prev_mean), weighted_avg(var(data), prev_var))
, rather than needing a full definition of relax.nn.batch_norm
.
Though, those are probably changes that would be best for a follow-up PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting! I did not know there is a training-specific version of batch norm. SGTM. Let's discuss about it in the follow-up PR.
Prior to this commit, a
FLegalize
function needed to produce an implementation that can be used as input byrelax.transform.AnnotateTIROpPattern
, and could not lower to other relax operations. This commit allows Relax operations to be included in the output ofFLegalize
, with the result being further legalized if required.