New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Unity] Implement FNormalize attribute for operators #16067
[Unity] Implement FNormalize attribute for operators #16067
Conversation
82211e5
to
6b6a185
Compare
Some Relax operators have requirements regarding their AST that are stronger than are checked by the C++ types being used. These are similar to checks that are present in the `tvm::relax::WellFormed` utility, such as checks forbidding the use of undefined variables, which are also stronger than required by the underlying C++ types. However, because every operator may have unique requirements, it would be unreasonable to expect a writer of a `relax::ExprMutator` to be aware of and to maintain all such requirements. This PR introduces an operation operator attribute `FNormalize`. If defined, this function is used to apply an operator-specific normalization. * If no change is required, `FNormalize` should return the input argument unmodified. * `FNormalize` is only responsible for normalization of the operator itself. The expression it returns may be unnormalized (e.g. contain nested expressions). * `FNormalize` receives the `BlockBuilder` as an argument, to allow context-dependent normalization. For example, an operator whose normalization requires in-line expressions may use `BlockBuilder::LookupBinding` to perform variable replacement. * `FNormalize` is applied after `FInferStructInfo`. `FNormalize` may assume that the `relax::Call` passed to `FNormalize` has well-defined struct info. * Corollary: `FInferStructInfo` may not assume that its `relax::Call` argument has been passed through `FNormalize`. This is a reasonable requirement, because (1) shape inference should depend only on the struct info of arguments and not the values themselves, and (2) this only impacts operators that use `FNormalize`. * `FNormalize` should not be used to apply simplifications, and should be limited to cases where the same computation may be expressed in multiple manners. For example, replacing a by-variable tuple with an in-line tuple in `R.call_tir` is a form of normalization, but replacing `R.add(arg, R.const(0))` with `arg` is a form of simplification. This separation is to ensure that `FNormalize` has minimal overhead, as some simplifications may have large computational costs, and `FNormalize` is applied as part of all `ExprMutator` usage. A later PR will introduce an attribute `FSimplify`, along with a dedicated pass to apply simplifications. * Use of `FNormalize` is suppressed while parsing TVMScript. TVMScript must be able to generate test cases that trigger specific failure modes, and that may include producing un-normalized relax IR. In addition, TVMScript must be stable when passed through a round-trip from IR to text to IR.
6b6a185
to
f4ec8a3
Compare
Thanks for the proposed change. I like how FNormalize can help reducing overhead of creating certain operators and bring them back to normal form. I only have one comment on the wellform check side. In this case, it is useful to have an intentionally duplicated check that is different from FNormalize , e.g. have a |
Thank you, and I like the overall design. I think we still want to keep all the normalization logic in (Also, see the other comment for performance benchmarking.) |
after thinking a bit more, i now agree that we can reuse FNormalize in wellform check. thanks for proposing the change |
src/relax/ir/block_builder.cc
Outdated
// How much opt could an opt op Op if an opt op could op opt? | ||
if (auto opt_op = op->op.as<Op>()) { | ||
auto op = opt_op.value(); | ||
if (apply_f_normalize_ && op_map_normalize_.count(op)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can use this function https://github.com/apache/tvm/blob/main/include/tvm/ir/op.h#L476
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, and updated! I had checked for a single-parameter .get
, and an iterator-style .find
, but hadn't found the two-parameter .get
.
Updated to use if (auto func_normalize = op_map_normalize_.get(call->op, nullptr); func_normalize != nullptr)
, here and in well_formed.cc
.
src/relax/analysis/well_formed.cc
Outdated
// case it produced a nested expression. | ||
|
||
if (auto opt_op = call->op.as<Op>()) { | ||
auto op = opt_op.value(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/apache/tvm/blob/main/include/tvm/ir/op.h#L476 we can directly use this function to simplofy the logic
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
op_map_normalize_.get(call->op, nullptr)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, and updated to use if (auto func_normalize = op_map_normalize_.get(call->op, nullptr); func_normalize != nullptr)
.
Thank you, and changes made as suggested! |
Some Relax operators have requirements regarding their AST that are stronger than are checked by the C++ types being used. These are similar to checks that are present in the
tvm::relax::WellFormed
utility, such as checks forbidding the use of undefined variables, which are also stronger than required by the underlying C++ types. However, because every operator may have unique requirements, it would be unreasonable to expect a writer of arelax::ExprMutator
to be aware of and to maintain all such requirements.This PR introduces an operation operator attribute
FNormalize
. If defined, this function is used to apply an operator-specific normalization. The implementation ofFNormalize
has the following design decisions.If no change is required,
FNormalize
should return the input argument unmodified.FNormalize
is only responsible for normalization of the operator itself. The expression it returns may be unnormalized (e.g. contain nested expressions).FNormalize
receives theBlockBuilder
as an argument, to allow context-dependent normalization.For example, an operator whose normalization requires in-line expressions may use
BlockBuilder::LookupBinding
to perform variable replacement.FNormalize
is applied afterFInferStructInfo
.FNormalize
may assume that therelax::Call
passed toFNormalize
has well-defined struct info.Corollary:
FInferStructInfo
may not assume that itsrelax::Call
argument has been passed throughFNormalize
.This is a reasonable requirement, because (1) shape inference should depend only on the struct info of arguments and not the values themselves, and (2) this only impacts operators that use
FNormalize
.FNormalize
should not be used to apply simplifications, and should be limited to cases where the same computation may be expressed in multiple manners.For example, replacing a by-variable tuple with an in-line tuple in
R.call_tir
is a form of normalization, but replacingR.add(arg, R.const(0))
witharg
is a form of simplification.This separation is to ensure that
FNormalize
has minimal overhead, as some simplifications may have large computational costs, andFNormalize
is applied as part of allExprMutator
usage. A later PR will introduce an attributeFSimplify
, along with a dedicated pass to apply simplifications.Use of
FNormalize
is suppressed while parsing TVMScript. TVMScript must be able to generate test cases that trigger specific failure modes, and that may include producing un-normalized relax IR. In addition, TVMScript must be stable when passed through a round-trip from IR to text to IR.If an
IRModule
contains any non-normalized operators, theIRModule
is ill-formed. That is, allFNormalize
operations on a well-formed module are no-ops.