Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity] Allow FLegalize to produce Relax operations #15842

Merged

Conversation

Lunderberg
Copy link
Contributor

Prior to this commit, a FLegalize function needed to produce an implementation that can be used as input by relax.transform.AnnotateTIROpPattern, and could not lower to other relax operations. This commit allows Relax operations to be included in the output of FLegalize, with the result being further legalized if required.

Prior to this commit, a `FLegalize` function needed to produce an
implementation that can be used as input by
`relax.transform.AnnotateTIROpPattern`, and could not lower to other
relax operations.  This commit allows Relax operations to be included
in the output of `FLegalize`, with the result being further legalized
if required.
Copy link
Contributor

@sunggg sunggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, @Lunderberg! A couple questions.

):
return relax.Call(custom_op, [A, Weight, Bias])

AfterFirstIter = LegalizeOps()(Before)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does user need to perform LegalizeOps passes depending on their custom ops? For example, user needs to call twice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope. With this change, the LegalizeOps pass will continue until no additional legalization can be applied, so the user only needs to call the function once.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, sorry. I missed that you do equality check between AfterFirstIter and AfterSecondIter. Make sense to me.


return add_sinfo

def legalize(bb: relax.BlockBuilder, call: relax.Call):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a similar pass although this does not support any recursion: https://github.com/apache/tvm/blob/unity/python/tvm/relax/transform/transform.py#L994

Is there any use-case for recursion? Or is it more like a future-proof?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a couple of reasons I'd been thinking of, most of which fall somewhere between future-planning and user-friendliness. (Bit of a brain dump as follows.)

  1. User-friendliness to make it easier to write legalization steps. For example, R.nn.rms_norm could be written in terms of R.std instead of requiring a direct lowering to a TIR implementation.
  2. Future-planning for user-defined custom intrinsics. If the legalization of these custom operators is defined in terms of standard relax operators, LegalizeOps would need to recursively expand them to allow AnnotateTIROpPattern to recognize the results.
  3. Future-planning for partial legalization. If each operator has a "composite_level", then we could selectively lower operators that are above some level of complexity. This would be a generalization of the OpDecomposer, to decompose any
  4. Future-planning for defining the requirements of graph-level optimization passes. If an optimization pass handles all relax operators up to some composite_level, new operators could be added without impacting that optimization pass, so long as those operators define a partial legalization that decomposes it.
  5. Centralizing the definition of each operator. With composite operators defined in terms of lower-complexity operators, the OpDecomposer could be identical to the rules used by LegalizeOps, avoiding duplicate operator definitions.
  6. Future-planning to minimize the need for TIR pattern recognition. For example, R.nn.attention is implemented in terms of topi.transpose and topi.reshape, and would require pattern-matching similar to RewriteDataflowReshape to un-lower these back to Relax operations. If R.nn.attention were instead decomposed into R.permute_dims and R.reshape, we'd get this for free.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, @Lunderberg for kind explanation. I like the idea of "composite-level" and centralizing the definitions. Can we check if DecomposeOpsForInference and DecomposeOpsForTraining can be supported with this PR to see if we can replace them? If so, we can discuss about their deprecation as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Taking a look, the DecomposeOpsFor* passes are currently doing two distinct roles. The first role is to lower the relax.nn.batch_norm, relax.nn.layer_norm, and relax.tensor_to_shape operators into lower-level relax implementations. The second role is to mutate the relax.nn.batch_norm operator into a training-specific version.

I think the first role of lowering relax operators into less complex Relax operators will be supported by the partial lowering intended for LegalizeOps. The second role is independent to the legalization, and would be best kept as a standalone pass. The second role would become much simpler, as the relax.nn.batch_norm(data, gamma, beta, prev_mean, prev_var) could be updated to relax.nn.batch_norm(data, gamma, beta, weighted_avg(mean(data), prev_mean), weighted_avg(var(data), prev_var)), rather than needing a full definition of relax.nn.batch_norm.

Though, those are probably changes that would be best for a follow-up PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting! I did not know there is a training-specific version of batch norm. SGTM. Let's discuss about it in the follow-up PR.

@Lunderberg Lunderberg merged commit ec9e0a0 into apache:unity Oct 19, 2023
18 checks passed
@Lunderberg Lunderberg deleted the unity_flegalize_to_simpler_relax_ops branch October 19, 2023 19:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants