Skip to content

Commit

Permalink
[Unity] Implement FNormalize attribute for operators
Browse files Browse the repository at this point in the history
Some Relax operators have requirements regarding their AST that are
stronger than are checked by the C++ types being used.  These are
similar to checks that are present in the `tvm::relax::WellFormed`
utility, such as checks forbidding the use of undefined variables,
which are also stronger than required by the underlying C++ types.
However, because every operator may have unique requirements, it would
be unreasonable to expect a writer of a `relax::ExprMutator` to be
aware of and to maintain all such requirements.

This PR introduces an operation operator attribute `FNormalize`.  If
defined, this function is used to apply an operator-specific
normalization.

* If no change is required, `FNormalize` should return the input
  argument unmodified.

* `FNormalize` is only responsible for normalization of the operator
  itself.  The expression it returns may be unnormalized (e.g. contain
  nested expressions).

* `FNormalize` receives the `BlockBuilder` as an argument, to allow
  context-dependent normalization.

  For example, an operator whose normalization requires in-line
  expressions may use `BlockBuilder::LookupBinding` to perform
  variable replacement.

* `FNormalize` is applied after `FInferStructInfo`.  `FNormalize` may
  assume that the `relax::Call` passed to `FNormalize` has
  well-defined struct info.

  * Corollary: `FInferStructInfo` may not assume that its
    `relax::Call` argument has been passed through `FNormalize`.

    This is a reasonable requirement, because (1) shape inference
    should depend only on the struct info of arguments and not the
    values themselves, and (2) this only impacts operators that use
    `FNormalize`.

* `FNormalize` should not be used to apply simplifications, and should
  be limited to cases where the same computation may be expressed in
  multiple manners.

  For example, replacing a by-variable tuple with an in-line tuple in
  `R.call_tir` is a form of normalization, but replacing `R.add(arg,
  R.const(0))` with `arg` is a form of simplification.

  This separation is to ensure that `FNormalize` has minimal overhead,
  as some simplifications may have large computational costs, and
  `FNormalize` is applied as part of all `ExprMutator` usage.  A later
  PR will introduce an attribute `FSimplify`, along with a dedicated
  pass to apply simplifications.

* Use of `FNormalize` is suppressed while parsing TVMScript.
  TVMScript must be able to generate test cases that trigger specific
  failure modes, and that may include producing un-normalized relax
  IR.  In addition, TVMScript must be stable when passed through a
  round-trip from IR to text to IR.
  • Loading branch information
Lunderberg committed Nov 3, 2023
1 parent 23371ca commit f4ec8a3
Show file tree
Hide file tree
Showing 12 changed files with 436 additions and 26 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS
src/relax/distributed/*.cc
src/relax/distributed/transform/*.cc
src/relax/op/distributed/*.cc
src/relax/testing/*.cc
)

tvm_file_glob(GLOB CODEGEN_SRCS
Expand Down
32 changes: 32 additions & 0 deletions include/tvm/relax/block_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ class BlockBuilder : public ObjectRef {
* \brief Create a BlockBuilder.
*
* \param ctx_mod Optional before-transformation context module for rewriting.
*
* \return The created BlockBuilder.
*
* \note When rewriting an existing IRModule, it is important to pass it in as
Expand All @@ -231,6 +232,37 @@ class BlockBuilder : public ObjectRef {
*/
TVM_DLL static BlockBuilder Create(Optional<IRModule> ctx_mod);

/*! \brief A marker struct to disable FNormalize
*
* This struct is used as a marker to disable the use of FNormalize
* by this block builder. This should only be used for TVMScript
* parsing, which may require producing un-normalized Relax IR for
* testing purposes, and to ensure that round-trips are unchanged.
*
* The name is deliberately verbose to draw attention during a code
* review. The explicit default constructor prevents aggregate
* initialization, ensuring that the full name of the marker struct
* appears at the callsite.
*/
struct DisableOperatorSpecificNormalizationForTVMScript {
explicit DisableOperatorSpecificNormalizationForTVMScript() = default;
};
/*!
* \brief Create a BlockBuilder.
*
* \param ctx_mod Optional before-transformation context module for rewriting.
*
* \param tag An instance of DisableOperatorSpecificNormalizationForTVMScript
*
* \return The created BlockBuilder.
*
* \note When rewriting an existing IRModule, it is important to pass it in as
* ctx_mod so you can lookup the context functions for cross function
* call analysis.
*/
TVM_DLL static BlockBuilder Create(Optional<IRModule> ctx_mod,
DisableOperatorSpecificNormalizationForTVMScript tag);

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BlockBuilder, ObjectRef, BlockBuilderNode);
};

Expand Down
27 changes: 24 additions & 3 deletions include/tvm/relax/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,30 @@ using FInferStructInfo =
using FCallPacked = String;

/*!
* \brief The function type of a legalization function, which takes a
* BlockBuilder and the Call to be legalized, and outputs the legalization
* result Expr.
* \brief The function type of a normalization function.
*
* A normalization function is used when a `relax::Call` may be
* expressed in multiple syntactically valid and semantically
* equivalent forms, to normalize to a single representation.
*
* \param bb The BlockBuilder context.
*
* \param call The call to be normalized. It is provided by-value, to
* avoid copies for the common case where the call is already normalized.
*/
using FNormalize = runtime::TypedPackedFunc<Expr(const BlockBuilder& bb, Call call)>;

/*! \brief The function type of a legalization function.
*
* A legalization function is used to replace a `relax::Call` with
* more concrete implementations. For example, the operation
* `relax.op.add` may be replaced with a call to a TIR function
* implementing addition of two tensors.
*
* The purpose of `FLegalize` is to remove calls to the operator while
* lowering. Therefore, unlike `FNormalize`, the resulting expression
* may *not* contain the original operator.
*
* \param bb The BlockBuilder context.
* \param call The call to be legalized.
*/
Expand Down
19 changes: 17 additions & 2 deletions python/tvm/relax/ir/instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,29 @@
class WellFormedInstrument:
"""An instrument that checks the input/output IRModule of the Pass
is well formed. It will skip specific passes, like Normalize.
Parameters
----------
check_struct_info: bool
If True, validate the struct info in the module. If False,
skip these checks.
validate_before_transform: bool
If True (default), perform a well-formed check before running
a transform. If False, only perform the well-formed check
after running a transform.
"""

def __init__(self, check_struct_info=True):
def __init__(self, check_struct_info: bool = True, validate_before_transform: bool = True):
self.skip_pass_name = ["Normalize", "ResolveGlobals"]
self.check_struct_info = check_struct_info
self.validate_before_transform = validate_before_transform

def run_before_pass(self, mod, pass_info):
self._check(mod, pass_info.name, "Before")
if self.validate_before_transform:
self._check(mod, pass_info.name, "Before")

def run_after_pass(self, mod, pass_info):
self._check(mod, pass_info.name, "After")
Expand Down
6 changes: 6 additions & 0 deletions python/tvm/relax/testing/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=unused-argument, invalid-name, no-else-return, abstract-method, arguments-differ
"""Relax transformation passes for testing"""

import tvm
from tvm import ir, relax
from tvm.ir import transform
from tvm.ir.module import IRModule
Expand Down Expand Up @@ -122,3 +123,8 @@ def transform(self):
return new_mod

return Lowerer().transform()


def ApplyEmptyCppMutator() -> tvm.ir.transform.Pass:
packed_func = tvm.get_global_func("relax.testing.transform.ApplyEmptyCppMutator")
return packed_func()
48 changes: 38 additions & 10 deletions src/relax/analysis/well_formed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
#include <tvm/relax/analysis.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/op_attr_types.h>
#include <tvm/relax/struct_info_functor.h>
#include <tvm/relax/utils.h>
#include <tvm/tir/expr_functor.h>
Expand Down Expand Up @@ -282,17 +283,17 @@ class WellFormedChecker : public relax::ExprVisitor,
}
}

void VisitExpr_(const CallNode* op) final {
if (IsLeafOrTuple(op->op)) {
void VisitExpr_(const CallNode* call) final {
if (IsLeafOrTuple(call->op)) {
const FunctionNode* prev_visited_func = cur_visited_func_;
cur_visited_func_ = nullptr; // close the symbolic var dup check
this->VisitExpr(op->op);
this->VisitExpr(call->op);
cur_visited_func_ = prev_visited_func;
} else {
Malformed(Diagnostic::Error(op) << "The called expression must be a leaf expression");
Malformed(Diagnostic::Error(call) << "The called expression must be a leaf expression");
}
for (size_t i = 0; i < op->args.size(); i++) {
Expr arg = op->args[i];
for (size_t i = 0; i < call->args.size(); i++) {
Expr arg = call->args[i];
if (IsLeafOrTuple(arg)) {
this->VisitExpr(arg);
} else {
Expand All @@ -301,13 +302,38 @@ class WellFormedChecker : public relax::ExprVisitor,
}
}

for (const StructInfo& sinfo_arg : op->sinfo_args) {
for (const StructInfo& sinfo_arg : call->sinfo_args) {
this->VisitStructInfo(sinfo_arg);
}

CheckStructInfo(op);
if (is_dataflow_ && check_struct_info_ && IsImpureCall(GetRef<Call>(op))) {
Malformed(Diagnostic::Error(op) << "There cannot be an impure call inside a dataflow block.");
CheckStructInfo(call);
if (is_dataflow_ && check_struct_info_ && IsImpureCall(GetRef<Call>(call))) {
Malformed(Diagnostic::Error(call)
<< "There cannot be an impure call inside a dataflow block.");
}

// If the operation has defined a custom normalization function
// using the FNormalize attribute, the call node must be normalized in order to be well-formed.
// If we apply the FNormalize and it produces any change, modified the expression, re-visit in
// case it produced a nested expression.

if (auto opt_op = call->op.as<Op>()) {
auto op = opt_op.value();
if (op_map_normalize_.count(op)) {
auto func_normalize = op_map_normalize_[op];

auto dummy_builder = tvm::relax::BlockBuilder::Create(mod_);
auto before_normalize = GetRef<Call>(call);
auto after_normalize = func_normalize(dummy_builder, before_normalize);
if (!before_normalize.same_as(after_normalize)) {
Malformed(
Diagnostic::Error(call)
<< "If an operator defines an operator-specific normalization function (FNormalize), "
<< "calls to that operator must be normalized with it. "
<< "However, normalization of " << before_normalize << " resulted in "
<< after_normalize);
}
}
}
}

Expand Down Expand Up @@ -538,6 +564,8 @@ class WellFormedChecker : public relax::ExprVisitor,
std::unordered_map<Var, const FunctionNode*, ObjectPtrHash, ObjectPtrEqual> param_var_func_map_;
std::unordered_map<tir::Var, const FunctionNode*, ObjectPtrHash, ObjectPtrEqual>
symbolic_var_func_map_;

tvm::OpAttrMap<FNormalize> op_map_normalize_ = Op::GetAttrMap<FNormalize>("FNormalize");
};

bool WellFormed(IRModule m, bool check_struct_info) {
Expand Down
47 changes: 38 additions & 9 deletions src/relax/ir/block_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,9 @@ class BlockBuilderImpl : public BlockBuilderNode {
class Normalizer : public BlockBuilderImpl, private ExprFunctor<Expr(const Expr&)> {
public:
explicit Normalizer(IRModule context_mod) : BlockBuilderImpl(context_mod) {}
explicit Normalizer(IRModule context_mod,
BlockBuilder::DisableOperatorSpecificNormalizationForTVMScript)
: BlockBuilderImpl(context_mod), apply_f_normalize_(false) {}

Expr Normalize(const Expr& expr) final {
Expr normalized = this->VisitExpr(expr);
Expand Down Expand Up @@ -578,18 +581,11 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor<Expr(const Expr&

Expr VisitExpr_(const CallNode* op) final {
Expr new_op = this->NormalizeArgument(op->op);
bool unchanged = new_op.same_as(op->op);

Array<Expr> new_args;

for (Expr arg : op->args) {
Expr new_arg = this->NormalizeArgument(arg);
new_args.push_back(new_arg);
unchanged &= new_arg.same_as(arg);
}
Array<Expr> new_args = op->args.Map([this](const Expr& arg) { return NormalizeArgument(arg); });

Call call;
if (unchanged) {
if (new_op.same_as(op->op) && new_args.same_as(op->args)) {
call = GetRef<Call>(op);
} else {
call = Call(new_op, new_args, op->attrs, op->sinfo_args);
Expand All @@ -600,6 +596,27 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor<Expr(const Expr&
UpdateStructInfo(call, inferred_sinfo);
}

// In the spirit of joy and child-like whimsy, the conditional
// statement below is provided in both C++ and tongue-twister
// form. As very few C++ compilers are fluent in tongue-twister,
// only the C++ form is uncommented.
//
// How much opt could an opt op Op if an opt op could op opt?
if (auto opt_op = op->op.as<Op>()) {
auto op = opt_op.value();
if (apply_f_normalize_ && op_map_normalize_.count(op)) {
// If the operation has defined a custom normalization
// function using the FNormalize attribute, apply it. If the
// normalization modified the expression, re-visit in case it
// produced a nested expression.
auto func_normalize = op_map_normalize_[op];
Expr normalized = func_normalize(GetRef<BlockBuilder>(this), call);
if (!normalized.same_as(call)) {
return VisitExpr(normalized);
}
}
}

return call;
}

Expand Down Expand Up @@ -917,13 +934,25 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor<Expr(const Expr&
Op::GetAttrMap<FInferStructInfo>("FInferStructInfo");
tvm::OpAttrMap<FInferStructInfo> op_map_dist_infer_struct_info_ =
Op::GetAttrMap<FInferStructInfo>("dist.FInferStructInfo");
/*! \brief Operator normalization function */
tvm::OpAttrMap<FNormalize> op_map_normalize_ = Op::GetAttrMap<FNormalize>("FNormalize");

/*! \brief Whether the FNormalize function should be applied */
bool apply_f_normalize_{true};
};

BlockBuilder BlockBuilder::Create(Optional<IRModule> mod) {
ObjectPtr<BlockBuilderNode> n = make_object<Normalizer>(mod.value_or(IRModule()));
return BlockBuilder(n);
}

BlockBuilder BlockBuilder::Create(Optional<IRModule> mod,
BlockBuilder::DisableOperatorSpecificNormalizationForTVMScript) {
ObjectPtr<BlockBuilderNode> n = make_object<Normalizer>(
mod.value_or(IRModule()), BlockBuilder::DisableOperatorSpecificNormalizationForTVMScript());
return BlockBuilder(n);
}

//---------------------------------------
// User facing function registration.
//---------------------------------------
Expand Down
9 changes: 9 additions & 0 deletions src/relax/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,14 @@ StructInfo InferStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) {
return call->sinfo_args[0];
}

Expr NormalizeCallTIR(const BlockBuilder&, Call call) {
// Temporary implementation to ensure that at least one op has a
// registered value for FNormalize. This temporary implementation
// is fully implemented in follow-up PR
// https://github.com/apache/tvm/pull/16068.
return std::move(call);
}

RELAY_REGISTER_OP("relax.call_tir")
.set_num_inputs(3)
.add_argument("func", "Expr", "The destination-passing-style function.")
Expand All @@ -261,6 +269,7 @@ RELAY_REGISTER_OP("relax.call_tir")
"ShapeExpr representing a tuple of ints to unpack during runtime. Omitted from "
"args if unused")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCallTIR)
.set_attr<FNormalize>("FNormalize", NormalizeCallTIR)
.set_attr<Bool>("FPurity", Bool(true));

Expr MakeCallTIR(Expr func, Tuple args, Array<TensorStructInfo> out_sinfo_list,
Expand Down
43 changes: 43 additions & 0 deletions src/relax/testing/transform.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <tvm/relax/expr_functor.h>
#include <tvm/relax/transform.h>

namespace tvm {
namespace relax {
namespace testing {

class EmptyCppMutator : public relax::ExprMutator {};

tvm::transform::Pass ApplyEmptyCppMutator() {
auto pass_func = [](Function func, IRModule, tvm::transform::PassContext) -> Function {
EmptyCppMutator mutator;
return Downcast<Function>(mutator.VisitExpr(std::move(func)));
};
return tvm::relax::transform::CreateFunctionPass(pass_func, 0,
"relax.testing.ApplyEmptyCppMutator", {});
}

TVM_REGISTER_GLOBAL("relax.testing.transform.ApplyEmptyCppMutator")
.set_body_typed(ApplyEmptyCppMutator);

} // namespace testing
} // namespace relax
} // namespace tvm
3 changes: 2 additions & 1 deletion src/script/ir_builder/relax/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ FunctionFrame Function(const Bool& is_pure, const Bool& is_private) {
if (const Optional<ir::IRModuleFrame> mod_frame = ir_builder->GetLastFrame<ir::IRModuleFrame>()) {
mod = tvm::IRModule(mod_frame.value()->functions);
}
n->block_builder = tvm::relax::BlockBuilder::Create(/*mod=*/mod);
n->block_builder = tvm::relax::BlockBuilder::Create(
/*mod=*/mod, tvm::relax::BlockBuilder::DisableOperatorSpecificNormalizationForTVMScript());
n->is_pure = is_pure;
n->is_private = is_private;
return FunctionFrame(n);
Expand Down

0 comments on commit f4ec8a3

Please sign in to comment.