Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity] Implement FNormalize attribute for operators #16067

Merged

Conversation

Lunderberg
Copy link
Contributor

@Lunderberg Lunderberg commented Nov 3, 2023

Some Relax operators have requirements regarding their AST that are stronger than are checked by the C++ types being used. These are similar to checks that are present in the tvm::relax::WellFormed utility, such as checks forbidding the use of undefined variables, which are also stronger than required by the underlying C++ types. However, because every operator may have unique requirements, it would be unreasonable to expect a writer of a relax::ExprMutator to be aware of and to maintain all such requirements.

This PR introduces an operation operator attribute FNormalize. If defined, this function is used to apply an operator-specific normalization. The implementation of FNormalize has the following design decisions.

  • If no change is required, FNormalize should return the input argument unmodified.

  • FNormalize is only responsible for normalization of the operator itself. The expression it returns may be unnormalized (e.g. contain nested expressions).

  • FNormalize receives the BlockBuilder as an argument, to allow context-dependent normalization.

    For example, an operator whose normalization requires in-line expressions may use BlockBuilder::LookupBinding to perform variable replacement.

  • FNormalize is applied after FInferStructInfo. FNormalize may assume that the relax::Call passed to FNormalize has well-defined struct info.

    • Corollary: FInferStructInfo may not assume that its relax::Call argument has been passed through FNormalize.

      This is a reasonable requirement, because (1) shape inference should depend only on the struct info of arguments and not the values themselves, and (2) this only impacts operators that use FNormalize.

  • FNormalize should not be used to apply simplifications, and should be limited to cases where the same computation may be expressed in multiple manners.

    For example, replacing a by-variable tuple with an in-line tuple in R.call_tir is a form of normalization, but replacing R.add(arg, R.const(0)) with arg is a form of simplification.

    This separation is to ensure that FNormalize has minimal overhead, as some simplifications may have large computational costs, and FNormalize is applied as part of all ExprMutator usage. A later PR will introduce an attribute FSimplify, along with a dedicated pass to apply simplifications.

  • Use of FNormalize is suppressed while parsing TVMScript. TVMScript must be able to generate test cases that trigger specific failure modes, and that may include producing un-normalized relax IR. In addition, TVMScript must be stable when passed through a round-trip from IR to text to IR.

  • If an IRModule contains any non-normalized operators, the IRModule is ill-formed. That is, all FNormalize operations on a well-formed module are no-ops.

Some Relax operators have requirements regarding their AST that are
stronger than are checked by the C++ types being used.  These are
similar to checks that are present in the `tvm::relax::WellFormed`
utility, such as checks forbidding the use of undefined variables,
which are also stronger than required by the underlying C++ types.
However, because every operator may have unique requirements, it would
be unreasonable to expect a writer of a `relax::ExprMutator` to be
aware of and to maintain all such requirements.

This PR introduces an operation operator attribute `FNormalize`.  If
defined, this function is used to apply an operator-specific
normalization.

* If no change is required, `FNormalize` should return the input
  argument unmodified.

* `FNormalize` is only responsible for normalization of the operator
  itself.  The expression it returns may be unnormalized (e.g. contain
  nested expressions).

* `FNormalize` receives the `BlockBuilder` as an argument, to allow
  context-dependent normalization.

  For example, an operator whose normalization requires in-line
  expressions may use `BlockBuilder::LookupBinding` to perform
  variable replacement.

* `FNormalize` is applied after `FInferStructInfo`.  `FNormalize` may
  assume that the `relax::Call` passed to `FNormalize` has
  well-defined struct info.

  * Corollary: `FInferStructInfo` may not assume that its
    `relax::Call` argument has been passed through `FNormalize`.

    This is a reasonable requirement, because (1) shape inference
    should depend only on the struct info of arguments and not the
    values themselves, and (2) this only impacts operators that use
    `FNormalize`.

* `FNormalize` should not be used to apply simplifications, and should
  be limited to cases where the same computation may be expressed in
  multiple manners.

  For example, replacing a by-variable tuple with an in-line tuple in
  `R.call_tir` is a form of normalization, but replacing `R.add(arg,
  R.const(0))` with `arg` is a form of simplification.

  This separation is to ensure that `FNormalize` has minimal overhead,
  as some simplifications may have large computational costs, and
  `FNormalize` is applied as part of all `ExprMutator` usage.  A later
  PR will introduce an attribute `FSimplify`, along with a dedicated
  pass to apply simplifications.

* Use of `FNormalize` is suppressed while parsing TVMScript.
  TVMScript must be able to generate test cases that trigger specific
  failure modes, and that may include producing un-normalized relax
  IR.  In addition, TVMScript must be stable when passed through a
  round-trip from IR to text to IR.
@Lunderberg Lunderberg force-pushed the unity_operator_specific_normalization branch from 6b6a185 to f4ec8a3 Compare November 3, 2023 20:22
@tqchen
Copy link
Member

tqchen commented Nov 6, 2023

Thanks for the proposed change. I like how FNormalize can help reducing overhead of creating certain operators and bring them back to normal form.

I only have one comment on the wellform check side. In this case, it is useful to have an intentionally duplicated check that is different from FNormalize , e.g. have a TEnforceExplicitTupleInArgs attribute that enforces the tuple argument being unpacked, and check this condition. This provides extra layer of protection, makes the intention clear and is also more efficient

@Lunderberg
Copy link
Contributor Author

Thank you, and I like the overall design. I think we still want to keep all the normalization logic in FNormalize, without adding boolean flags for specific cases. The more boolean flags we have, the more difficult it is for a developer to know the rules for all flags. For example, a developer would need to check if the TEnforceExplicitTupleInArgs is used as part of the well-formed check, whether it triggers an assert during normalization, whether it triggers a normalization step during normalization, whether the normalization step applies when parsing TVMScript, etc. By implementing each of these features on top of the same FNormalize functionality, new operator-specific normalization rules can be implemented without adding to a developer's mental overhead. A developer only needs to know that the new normalization is handled the same as all existing integrations.

(Also, see the other comment for performance benchmarking.)

@tqchen
Copy link
Member

tqchen commented Nov 6, 2023

after thinking a bit more, i now agree that we can reuse FNormalize in wellform check. thanks for proposing the change

// How much opt could an opt op Op if an opt op could op opt?
if (auto opt_op = op->op.as<Op>()) {
auto op = opt_op.value();
if (apply_f_normalize_ && op_map_normalize_.count(op)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, and updated! I had checked for a single-parameter .get, and an iterator-style .find, but hadn't found the two-parameter .get.

Updated to use if (auto func_normalize = op_map_normalize_.get(call->op, nullptr); func_normalize != nullptr), here and in well_formed.cc.

// case it produced a nested expression.

if (auto opt_op = call->op.as<Op>()) {
auto op = opt_op.value();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/apache/tvm/blob/main/include/tvm/ir/op.h#L476 we can directly use this function to simplofy the logic

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

op_map_normalize_.get(call->op, nullptr)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, and updated to use if (auto func_normalize = op_map_normalize_.get(call->op, nullptr); func_normalize != nullptr).

@Lunderberg
Copy link
Contributor Author

Thank you, and changes made as suggested!

@tqchen tqchen merged commit e506bff into apache:unity Nov 7, 2023
15 checks passed
@Lunderberg Lunderberg deleted the unity_operator_specific_normalization branch November 7, 2023 14:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants