Skip to content

Commit

Permalink
Start implementing purity tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
slyubomirsky committed Mar 24, 2023
1 parent c03e189 commit c36741f
Show file tree
Hide file tree
Showing 26 changed files with 747 additions and 113 deletions.
14 changes: 14 additions & 0 deletions include/tvm/relax/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,20 @@ TVM_DLL relay::OpPatternKind AnalyzeOpPatternKind(const tir::PrimFunc& func);
*/
TVM_DLL bool HasReshapePattern(const tir::PrimFunc& func);

/*!
* \brief Check if the given expression (likely a function body) contains any impure calls.
* \param expr The expression to be examined. If expr is a function, we check the body.
* \param own_name (Optional.) If we are checking a recursive function body,
* the caller can pass the function's name so recursive calls
* can be ignored in the check (must be a Var or GlobalVar).
* \return A boolean indicating if the expression contains any impure calls.
* \note Relies on StructInfo annotations, so ensure that the module has been normalized first.
* Also, an impure call in a *nested* function does *not* mean that the outer expression contains
* an impure call--it only does if the nested function is *later called*.
*/
TVM_DLL bool ContainsImpureCall(const Expr& expr,
const Optional<Expr>& own_name = Optional<Expr>(nullptr));

/*!
* \brief Check if the IRModule is well formed.
*
Expand Down
22 changes: 22 additions & 0 deletions include/tvm/relax/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,28 @@ TVM_DLL bool IsBoolStructInfo(const StructInfo& sinfo, bool permit_unknown_rank
*/
TVM_DLL bool IsLeafOrTuple(const Expr& expr);

/*!
* \brief Check if the given Call node is an impure operation. If the callee is a general expression,
* this simply requires checking the purity field of the FuncStructInfo. If it is an Op, then this checks
* the `fPurity` field.
*
* \param call The input call
*
* \return True iff the call is impure (definitely or possibly results in a visible side effect).
* That is, a call is considered pure only if definitely does not result in a visible side effect.
*/
TVM_DLL bool IsImpureCall(const Call& call);

/*!
* \brief Wrap the Call node in the call_pure op, transferring over the attributes and sinfo_args.
*
* \param call The input call
*
* \return A Call to the call_pure op that wraps the original call.
*/
TVM_DLL Call WrapCallPure(const Call& call);
// implementation is in op.cc

/*!
* \brief Copy the given function. All variables that are bound inside the original function
* would be copied to satisfy the restriction in the well-formed check: Variables in
Expand Down
30 changes: 29 additions & 1 deletion python/tvm/relax/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
configuring the passes and scripting them in Python.
"""

from typing import Dict, List, Union, Callable
from typing import Dict, List, Optional, Union, Callable
from enum import IntEnum

import tvm
Expand Down Expand Up @@ -276,6 +276,34 @@ def has_reshape_pattern(func: tir.PrimFunc) -> bool:
return _ffi_api.has_reshape_pattern(func) # type: ignore


def contains_impure_call(expr: Expr, own_name: Optional[Union[Var, GlobalVar]] = None) -> bool:
"""
Check if the given expression (likely a function body) contains any impure calls.
Parameter
---------
expr : Expr
The expression to be examined. If expr is a function, we check the body.
own_name : Var or GlobalVar (optional)
For a recursive function, the analysis can ignore the self-calls
for checking purity.
Returns
-------
ret : bool
True if there is an impure call
(call to a function that may have visible side effects).
Notes
-----
Relies on StructInfo annotations, so ensure that the module has been normalized first.
Also, an impure call in a *nested* function does *not* mean that the outer expression contains
an impure call--it only does if the nested function is *later called*.
"""
return _ffi_api.contains_impure_call(expr, own_name)


def get_var2val(func: Function) -> Dict[Var, Expr]:
"""
Get a mapping from Var to Expr for each variable in the function.
Expand Down
31 changes: 31 additions & 0 deletions python/tvm/relax/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,3 +402,34 @@ def shape_of(expr: Expr) -> Expr:
A relax Call, which gets the shape of the input
"""
return _ffi_api.shape_of(expr) # type: ignore # pylint: disable=no-member


def call_pure(inner_call: Call) -> Expr:
"""
Indicate to the compiler that the given Call node should be treated as pure,
even if the callee is not pure according to the StructInfo system.
The resulting call will have the same semantics as invoking the Call directly.
Note: This should be used for cases when the user knows that calling the callee
with these arguments will _in reality_ not cause any side effects.
If it is used for a call that _does_ result in side effects, then the compiler
may end up removing, reordering, or repeating that call, with no guarantees
made about any side effects from the callee.
Parameters
----------
inner_call : Call
A call that should be treated as pure
Returns
-------
result : Expr
A Relax call, corresponding to `call_pure(inner_call.op, inner_call.args)`
"""
if not isinstance(inner_call, Call):
raise ValueError(
"call_pure must take a Call node directly "
"in order to transfer over attrs and StructInfo args"
)
return _ffi_api.call_pure(inner_call) # type: ignore # pylint: disable=no-member
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
broadcast_to,
builtin,
call_builtin_with_ctx,
call_pure,
call_tir,
call_dps_packed,
ceil,
Expand Down Expand Up @@ -547,6 +548,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"broadcast_to",
"builtin",
"call_packed",
"call_pure",
"call_tir",
"call_dps_packed",
"call_builtin_with_ctx",
Expand Down
14 changes: 11 additions & 3 deletions python/tvm/script/parser/relax/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,14 @@ def Tensor(
class CallableProxy(StructInfoProxy):
params: List[StructInfoProxy]
ret: StructInfoProxy
purity: bool

"""Function type.
A function type consists of a list of type parameters to enable
the definition of generic functions,
a set of type constraints which we omit for the time being,
a sequence of argument types, and a return type.
a sequence of argument types, the purity of the function, and a return type.
Parameters
----------
Expand All @@ -162,33 +164,39 @@ class CallableProxy(StructInfoProxy):
ret : StructInfoProxy
The return StructInfoProxy.
purity : bool
Whether the callable is pure.
"""

def __init__(
self,
params: Union[StructInfoProxy, List[StructInfoProxy]],
ret: StructInfoProxy,
purity: bool = True,
) -> None:
if not isinstance(params, (list, tuple)):
params = [params]
# convert `R.Tensor` to `R.Tensor()`
self.params = [param() if callable(param) else param for param in params]
self.ret = ret() if callable(ret) else ret
self.purity = purity

def get_symbolic_vars(self) -> Set[str]:
return set().union(*[p.get_symbolic_vars() for p in self.params])

def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> FuncStructInfo:
params = [param.as_struct_info(dict_globals) for param in self.params]
ret = self.ret.as_struct_info(dict_globals)
return FuncStructInfo(params, ret)
return FuncStructInfo(params, ret, purity=self.purity)


def Callable(
params: Union[StructInfoProxy, List[StructInfoProxy]],
ret: StructInfoProxy,
purity: bool = True,
) -> CallableProxy:
return CallableProxy(params, ret)
return CallableProxy(params, ret, purity=purity)


############################### R.Tuple ################################
Expand Down
44 changes: 44 additions & 0 deletions src/relax/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,48 @@ tvm::Array<Var> AllVars(const Expr& expr) { return VarVisitor().All(expr); }

tvm::Array<GlobalVar> AllGlobalVars(const Expr& expr) { return VarVisitor().AllGlobalVars(expr); }

bool ContainsImpureCall(const Expr& expr, const Optional<Expr>& own_name) {
class ImpureCallChecker : public ExprVisitor {
public:
explicit ImpureCallChecker(const Optional<Expr>& own_name) : own_name_(own_name) {}

bool Check(const Expr& expr) {
contains_impure_ = false;
VisitExpr(expr);
return contains_impure_;
}

void VisitExpr_(const FunctionNode* func) override {
// we don't visit inner functions because an impure call in an inner function
// does *not* mean the outer function contains an impure call
}

void VisitExpr_(const CallNode* call) override {
// ignore recursive calls if we find one
if (!(own_name_ && own_name_.value().same_as(call->op))) {
if (IsImpureCall(GetRef<Call>(call))) {
contains_impure_ = true;
}
}
ExprVisitor::VisitExpr_(call);
}

private:
const Optional<Expr>& own_name_;
bool contains_impure_ = false;
};

if (own_name) {
ICHECK(own_name.value().as<VarNode>() || own_name.value().as<GlobalVarNode>())
<< "Must pass a Var or GlobalVar for own_name";
}
ImpureCallChecker checker(own_name);
if (auto func = expr.as<FunctionNode>()) {
return checker.Check(func->body);
}
return checker.Check(expr);
}

TVM_REGISTER_GLOBAL("relax.analysis.free_vars").set_body_typed(FreeVars);

TVM_REGISTER_GLOBAL("relax.analysis.bound_vars").set_body_typed(BoundVars);
Expand All @@ -149,5 +191,7 @@ TVM_REGISTER_GLOBAL("relax.analysis.all_vars").set_body_typed(AllVars);

TVM_REGISTER_GLOBAL("relax.analysis.all_global_vars").set_body_typed(AllGlobalVars);

TVM_REGISTER_GLOBAL("relax.analysis.contains_impure_call").set_body_typed(ContainsImpureCall);

} // namespace relax
} // namespace tvm
36 changes: 36 additions & 0 deletions src/relax/analysis/well_formed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@
* * The op or args fields of Call nodes
* * Inside the fields of Tuple nodes
* 12. Expr always has checked_type_ (with the exception of Op).
* 13. DataflowBlocks may not contain If nodes.
* 14. DataflowBlocks may not contain calls to impure functions or operators
* (only checked if check_struct_info is true).
* 15. If a function is annotated as pure (kIsPure is true)
* and purity is not forced (kForcePure is true),
* the body may not contain any impure call
* (only checked if check_struct_info is true).
* 16. If a function's purity is forced, kForcePure cannot be true
*/
#include <tvm/relax/analysis.h>
#include <tvm/relax/expr.h>
Expand Down Expand Up @@ -213,6 +221,15 @@ class WellFormedChecker : public relax::ExprVisitor,
}
});

// ensure the purity attributes are valid
if (op->GetAttr<Bool>(relax::attr::kForcePure).value_or(Bool(false))->value &&
!op->GetAttr<Bool>(relax::attr::kIsPure).value_or(Bool(true))->value) {
Malformed(Diagnostic::Error(op->span)
<< "Function " << op
<< " has a ForcePure annotation but its IsPure annotation is false;"
<< " ForcePure should be used only if IsPure is annotated as true.");
}

// check all expr are well defined.
for (Var param : op->params) {
this->VisitVarDef(param);
Expand All @@ -232,6 +249,18 @@ class WellFormedChecker : public relax::ExprVisitor,
Malformed(Diagnostic::Error(op) << "Function must have defined ret_struct_info");
}

// if we are not forcing purity and the function is annotated as pure, it must not contain an
// impure call
if (check_struct_info_ &&
!op->GetAttr<Bool>(relax::attr::kForcePure).value_or(Bool(false))->value &&
op->GetAttr<Bool>(relax::attr::kIsPure).value_or(Bool(true))->value &&
ContainsImpureCall(op->body)) {
Malformed(Diagnostic::Error(op)
<< "Function " << op << " is annotated as pure but contains an impure call; "
<< "please use the ForcePure attribute or wrap the call with call_pure "
<< "if it should be considered pure despite containing an impure call.");
}

if (auto seq = op->body.as<SeqExprNode>()) {
this->VisitSeqExpr(seq);
} else {
Expand Down Expand Up @@ -265,9 +294,15 @@ class WellFormedChecker : public relax::ExprVisitor,
}

CheckStructInfo(op);
if (is_dataflow_ && check_struct_info_ && IsImpureCall(GetRef<Call>(op))) {
Malformed(Diagnostic::Error(op) << "There cannot be an impure call inside a dataflow block.");
}
}

void VisitExpr_(const IfNode* op) final {
if (is_dataflow_) {
Malformed(Diagnostic::Error(op) << "If nodes are not allowed to appear in dataflow blocks.");
}
if (IsLeafOrTuple(op->cond)) {
this->VisitExpr(op->cond);
} else {
Expand Down Expand Up @@ -332,6 +367,7 @@ class WellFormedChecker : public relax::ExprVisitor,
} else {
this->VisitExpr(binding->value);
}

this->VisitVarDef(binding->var);
if (is_lambda) {
recur_vars_.erase(binding->var);
Expand Down
Loading

0 comments on commit c36741f

Please sign in to comment.