Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity][IR] Purity Tracking #14394

Merged
merged 73 commits into from
May 18, 2023
Merged

Conversation

slyubomirsky
Copy link
Contributor

@slyubomirsky slyubomirsky commented Mar 24, 2023

In this PR, I am beginning to implement the tracking of function purity as part of the StructInfo system. This will allow the compiler to enforce that no impure function (one that can possibly have visible side effects) can be called in a DataflowBlock. Tracking this requires noting which operators are pure or impure (I am presently doing this with an operator attribute called FPurity, a simple boolean), though dealing with calls to other Relax functions means that this information must also be tracked via the StructInfo system.

Additionally, it is difficult to infer the purity of a function in the general case (when there are calls to other Relax functions), so this change does require users to annotate impure functions using a new field on functions, is_pure (in TVMScript, this can be done using R.is_pure() or R.is_impure()). Since most Relax functions are likely to be pure and purity is the default assumption, this will hopefully not be a large imposition on users. We can consider inferring purity in the easier cases, since those are likely to be common.

Note that PackedFuncs are conservatively treated as impure. However, in situations where they are needed inside a dataflow block, a call to a PackedFunc that is, in reality, pure can be done via the new operator call_pure_packed or the existing operator call_dps_packed (it is assumed that any PackedFunc used with it will be pure). (Similarly, a new operator invoke_pure_closure is introduced as a counterpart to invoke_closure for dealing with closure objects, though this really should only come up with the LambdaLifting pass.)

As an "escape hatch" to the purity system, one can use the attribute relax.force_pure, which indicates to the compiler to treat the entire function as pure even if it contains an impure call. Additionally, even though PackedFuncs are normally treated as impure, a user can use call_pure_packed or call_dps_packed to call PackedFuncs in dataflow blocks when appropriate. These can be used to deal with the following situations:

  1. A function does side effects but only on a value that will not be exposed anywhere else or on a new value that will be returned. Even though the individual actions are "impure," the overall function fulfills the definition of being pure. relax.force_pure would be useful here.
  2. A PackedFunc is, in reality, pure. call_pure_packed or call_dps_packed are useful in this situation.

Changes include:

  • Enforcing that impure functions are not used in DataflowBlocks in the well-formed check.
  • Enforcing that functions that are not labeled impure do not contain impure calls (unless relax.force_pure is set).
  • Implementing the call_pure operator

Still to be done:

  • Label purity for all operators
  • Address certain tricky bugs

(Also address the design concerns below)


The process of implementing purity tracking has revealed a few issues that may require some further design discussion.

Labeling purity or impurity

Using function attributes to label purity/impurity seems very messy. It might be worth making this part of the Function node's AST to avoid having to wrangle attributes in many places in the codebase.

The treatment of call_pure

call_pure presents a dilemma because many passes look for calls to certain operators, but with call_pure, these would be "wrapped" like so: Call(Op("relay.call_pure"), [inner_op, arg1, ..., argn], attrs, sinfo_args). In the posted PR, many passes needed to be revised to look for instances of call_pure so that they could be treated the same as calls to ordinary operators. It might cause less disruption to passes to turn call_pure into a specialized AST node that literally wraps a call node. This way visitors or mutators could have an easy default case and the lower-level code generation passes could simply ignore the call_pure node and deal with the wrapped call. As painful as introducing a new AST node would be, it would likely be more maintainable than having every pass have to have a special case for the call_pure operator.

The more restricted call_pure_packed/invoke_pure_closure are less likely to come up in passes and so are less likely to need special handling.

Staging

Related to the above issue, having purity tracking in the well-formedness analysis means that even low-level passes like VMLowerBuiltin, which replaces some Relax operators with PackedFuncs (treated as impure by default, as they are in principle dangerous black boxes), have to insert calls to call_pure to avoid triggering purity errors. Perhaps it might be useful to ignore purity by the time we reach lower stages of compilation, much as the ToNonDataflow pass is used in the VM build process to eliminate DataflowBlocks.

Update: Based on the TVM Unity Community Meeting on Mar. 28, 2023, I've provisionally adopted a variant of this approach thanks to a suggestion by @tqchen. Namely, there is a pass called RemovePurityChecking that simply adds the ForcePure attribute to all pure functions and unwraps all invocations of call_pure_packed, call_pure_dps_packed, and invoke_pure_closure. This allows lower-level passes to remain unchanged and not to have to consider purity at all. It is inserted after ToNonDataflow in the build() function in vm_build.py.

Update 2: Thanks to suggestions by @tqchen, I've made is_pure and force_pure into fields on functions rather than attributes and added syntactic sugar in TVMScript to make these easy to set.

Update 3: Per @tqchen's suggestion force_pure will be kept as an attribute instead. This is because it acts more as a direction to the compiler rather than an inherent property of a function.

@tvm-bot
Copy link
Collaborator

tvm-bot commented Mar 24, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

2 similar comments
@tvm-bot
Copy link
Collaborator

tvm-bot commented Mar 24, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@tvm-bot
Copy link
Collaborator

tvm-bot commented Mar 24, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@slyubomirsky slyubomirsky marked this pull request as draft March 24, 2023 03:39
@slyubomirsky slyubomirsky changed the title [Unity][IR][WIP] Purity Tracking [Unity][IR] Purity Tracking Mar 26, 2023
@slyubomirsky slyubomirsky marked this pull request as ready for review March 26, 2023 22:56
@slyubomirsky
Copy link
Contributor Author

This will probably keep failing tests because lots of them depend on syntactic matches and this has messed up a lot of it. I will start going through them, but some have proven very difficult to debug.

Comment on lines 1312 to 1329
# slight hack: normally, we would prefer to use True, but the func attrs, when printed,
# will have it as 1, so it would fail roundtripping otherwise
R.func_attr({"ForcePure": 1})
Copy link
Contributor Author

@slyubomirsky slyubomirsky Mar 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yongwww, would you happen to know how I could avoid this (print the attribute as True instead of 1)? It's not super-pressing, but it would help for roundtripping

Copy link
Member

@yongwww yongwww Apr 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the related printing should be https://github.com/apache/tvm/blob/unity/src/script/printer/relax/function.cc#L54-L58. Feel free to leave a todo for me

@slyubomirsky
Copy link
Contributor Author

More general question (maybe one to leave to a community meeting): Where should we communicate that we expect certain passes to be run before others? We already have ToNonDataflow and this PR adds RemovePurityChecking. I also saw that a line in VMShapeLower indicates that it expects LambdaLifting to have already been used.

Comment on lines +147 to +171
Optional<ExprDoc> PrintAssertOp(const relax::Call& n, const ObjectPath& n_p, const IRDocsifier& d) {
static const Op& assert_op = Op::Get("relax.assert_op");
if (!n->op.same_as(assert_op)) {
return NullOpt;
}
ICHECK(n->args.size() >= 2);
// special handling: it is important to indicate that the format string (second argument)
// is the _format_ string, or else roundtripping will fail
// (the format string will be interpreted as an argument and there will be a new default format
// string given)
Array<ExprDoc> args;
args.push_back(d->AsDoc<ExprDoc>(n->args[0], n_p->Attr("args")->ArrayIndex(0)));
ExprDoc second_arg = d->AsDoc<ExprDoc>(n->args[1], n_p->Attr("args")->ArrayIndex(1));
for (size_t i = 2; i < n->args.size(); i++) {
args.push_back(d->AsDoc<ExprDoc>(n->args[i], n_p->Attr("args")->ArrayIndex(i)));
}
return Relax(d, "assert_op")->Call(args, {"format"}, {second_arg});
}

Optional<ExprDoc> PrintRelaxPrint(const relax::Call& n, const ObjectPath& n_p,
const IRDocsifier& d) {
static const Op& print_op = Op::Get("relax.print");
if (!n->op.same_as(print_op)) {
return NullOpt;
}
ICHECK(n->args.size() >= 1);
// special handling: it is important to indicate that the format string (first argument)
// is the _format_ string, or else roundtripping will fail
// (the format string will be interpreted as an argument and there will be a new default format
// string given)
ExprDoc first_arg = d->AsDoc<ExprDoc>(n->args[0], n_p->Attr("args")->ArrayIndex(0));
Array<ExprDoc> args;
for (size_t i = 1; i < n->args.size(); i++) {
args.push_back(d->AsDoc<ExprDoc>(n->args[i], n_p->Attr("args")->ArrayIndex(i)));
}
return Relax(d, "print")->Call(args, {"format"}, {first_arg});
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One would think this isn't related to these changes at all, but in the process of testing roundtripping for call_pure, I discovered roundtripping bugs for print and assert_op (both having to do with their format strings).

Comment on lines 105 to 112
// preserve the purity: if the func was originally pure, wrap call_pure
bool purity = GetStructInfoAs<FuncStructInfoNode>(gvar)->purity;
auto ret = create_call_dps_packed(new_func, func->ret_struct_info);
if (purity) {
return WrapCallPure(ret);
}
return ret;
Copy link
Contributor Author

@slyubomirsky slyubomirsky Mar 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is another case where being clear about staging might make some of this reasoning unnecessary. We do codegen before doing the VM build in test_codegen_dnnl.py, so I wasn't sure if I should have eliminated dataflow blocks. The fusion passes seem to rely on having dataflow blocks, so I can't trivially get rid of purity checking and call_pure before fusion (unless we also disable checking purity from the well-formed check--we could consider having a third, internal-only annotation, that disables all purity checks, even inside dataflow blocks).

python/tvm/relax/analysis/analysis.py Outdated Show resolved Hide resolved
python/tvm/script/parser/relax/parser.py Outdated Show resolved Hide resolved
@@ -177,7 +227,9 @@ TVM_REGISTER_OP("relax.call_builtin_with_ctx")
.set_num_inputs(4)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just noticed this should be 2, it is not related to your change...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good observation

src/relax/op/op.cc Outdated Show resolved Hide resolved
@@ -1357,7 +1370,8 @@ TVM_REGISTER_OP("relax.cumsum")
.set_attrs_type<CumsumAttrs>()
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCumsum);
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCumsum)
.set_attr<Bool>("FPurity", Bool(true));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will we run into issue if FPurity is not specified explicitly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should specify it explicitly. I have an assert in the well-formedness checker that it's defined.

@slyubomirsky
Copy link
Contributor Author

Tests failing due to the rebase (have to update some more ops/passes), will address shortly.

@slyubomirsky
Copy link
Contributor Author

Updated to reflect the fix from #14864, thanks to @jinhongyii 😄

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants