Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity][IR] Purity Tracking #14394

Merged
merged 73 commits into from May 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
1168c6f
Preliminary work
slyubomirsky Mar 20, 2023
e4bae67
Won't try to infer purity for now
slyubomirsky Mar 22, 2023
8ba3041
Rename `pure` field to `purity`
slyubomirsky Mar 22, 2023
33b992d
Use attrs to annotate function purity
slyubomirsky Mar 22, 2023
b7b5685
Start implementing purity tracking
slyubomirsky Mar 24, 2023
ecd85c9
Add purity into pretty printer for FuncStructInfo
slyubomirsky Mar 26, 2023
15f864e
Process purity when parsing function declarations
slyubomirsky Mar 26, 2023
ca5aacf
Annotate purity for remaining operators
slyubomirsky Mar 26, 2023
b10f3cf
Whitespace
slyubomirsky Mar 26, 2023
ee8bf11
More whitespace and remove outdated comment
slyubomirsky Mar 26, 2023
7181b4f
More linting fixes
slyubomirsky Mar 26, 2023
d6a4800
One more fixed comment
slyubomirsky Mar 26, 2023
1c94524
Handle purity in the AST printer
slyubomirsky Mar 27, 2023
d288c31
Ensure we are parsing an Expr before checking for an attribute
slyubomirsky Mar 27, 2023
29f01fe
Mark purity for remaining ops
slyubomirsky Mar 27, 2023
34a6c80
Factor out repeated call_pure unwrapping
slyubomirsky Mar 27, 2023
e5c9da5
Add purity wrappers and annotations in test_vm_build
slyubomirsky Mar 27, 2023
82a62b3
One more WrapCallPure in vm_shape_lower
slyubomirsky Mar 27, 2023
c485ed2
Handle call_pure in memory planning
slyubomirsky Mar 27, 2023
e8b0132
One more simplification
slyubomirsky Mar 27, 2023
62d8852
Handle call_pure in one more case
slyubomirsky Mar 27, 2023
f843689
Fix struct info analysis test: cannot pass an opaque (impure) functio…
slyubomirsky Mar 27, 2023
07b55e2
Transfer over StructInfo in the call_pure wrappers
slyubomirsky Mar 27, 2023
7da2125
Make corrections to some VM shape lower tests
slyubomirsky Mar 27, 2023
4e52206
Add transformation to disable purity checking at low levels of compil…
slyubomirsky Mar 29, 2023
3932d4e
Remove purity checking before low-level passes, revert changes
slyubomirsky Mar 29, 2023
e4d9b6a
Also revert changes to VM code generation
slyubomirsky Mar 29, 2023
d005a55
Fix VMShapeLower tests
slyubomirsky Mar 29, 2023
7cc7ef7
Fix TVMScript printer test
slyubomirsky Mar 29, 2023
e8c8365
Fix TVMScript parser test
slyubomirsky Mar 29, 2023
b9fff59
Add special handling for printing call_pure, print, and assert_op
slyubomirsky Mar 29, 2023
d89ff7d
Fix tests in test_transform.py
slyubomirsky Mar 29, 2023
407547b
Remove purity checking in test_codegen_dnnl.py
slyubomirsky Mar 29, 2023
0b1f667
Add purity annotation for tensor_to_shape
slyubomirsky Mar 29, 2023
00316de
Handle call_pure in FuseTIR
slyubomirsky Mar 29, 2023
6ee5773
Be more discerning about inserting call_pure in LegalizeOps
slyubomirsky Mar 29, 2023
2a5cafe
Fix various tests to account for purity
slyubomirsky Mar 29, 2023
b51824c
Add ForcePure annotations in the static memory planning tests
slyubomirsky Mar 29, 2023
66bee84
Add purity annotation for stop_lift_params
slyubomirsky Mar 29, 2023
1239a36
Insert call_pure in LambdaLift for invoking closures
slyubomirsky Mar 29, 2023
4df0097
Fix outdated comment in test
slyubomirsky Mar 29, 2023
511dbfb
Need ToNonDataflow to keep everything consistent
slyubomirsky Mar 30, 2023
d995686
Missing call_pure for a BindParams test
slyubomirsky Mar 30, 2023
88324c4
Preserve purity in RunCodegen
slyubomirsky Mar 30, 2023
3c2eeeb
Address changes during rebase
slyubomirsky Apr 3, 2023
4326188
Rebase fixes and add purity annotations for new ops
slyubomirsky Apr 10, 2023
7d6cd1d
Fix docstring for contains_impure_call
slyubomirsky Apr 10, 2023
99102aa
Remove manipulate ops tests that are intentionally not supported
slyubomirsky Apr 10, 2023
1d42358
Address TODO: At least one builtin is impure
slyubomirsky Apr 10, 2023
ddc15db
Replace call_pure with call_pure_packed, call_pure_dps_packed, invoke…
slyubomirsky Apr 11, 2023
cc128ab
Formatting
slyubomirsky Apr 11, 2023
382e467
Correct rebase errror
slyubomirsky Apr 13, 2023
419a021
Address dynamic_strided_slice
slyubomirsky Apr 13, 2023
4c9d28f
Linting: Remove unused import
slyubomirsky Apr 14, 2023
51a5d42
Fix incorrect purity annotations
slyubomirsky Apr 17, 2023
b91878e
Fix incorrect comment in well_formed
slyubomirsky Apr 17, 2023
1387180
Add explanatory comment in unusual test case
slyubomirsky Apr 17, 2023
cd8675a
Use call_pure_packed in pipeline test
slyubomirsky Apr 18, 2023
f26eedb
Update transform_fuse_ops test
slyubomirsky Apr 24, 2023
697919d
Use ForcePure for the CUDA graph rewrite
slyubomirsky Apr 24, 2023
e00ba5c
Update comment
slyubomirsky Apr 25, 2023
fdb3370
Factor out search for purity annotation
slyubomirsky Apr 25, 2023
0fa4843
Remove call_dps_pure_packed (call_dps_packed is pure) and remove wrap…
slyubomirsky May 3, 2023
446d863
Remove unused var
slyubomirsky May 4, 2023
587ba5f
Correct rebase mistake
slyubomirsky May 4, 2023
ba78ec4
Use ForcePure for tests of low-level codegen
slyubomirsky May 4, 2023
85d5f62
lint
slyubomirsky May 4, 2023
f1982f5
Make LazyTransformParams compatible with purity tracking
slyubomirsky May 11, 2023
d83bfbb
Make is_pure and force_pure fields in Function instead of attrs
slyubomirsky May 16, 2023
3ba972f
Lint
slyubomirsky May 17, 2023
3616ec0
Unused imports
slyubomirsky May 17, 2023
e786f92
Use an attribute (relax.force_pure) to control forcing purity
slyubomirsky May 17, 2023
c920bf5
Indicate that RemovePurityChecking is also required for LazyTransform…
slyubomirsky May 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 14 additions & 0 deletions include/tvm/relax/analysis.h
Expand Up @@ -438,6 +438,20 @@ TVM_DLL relay::OpPatternKind AnalyzeOpPatternKind(const tir::PrimFunc& func);
*/
TVM_DLL bool HasReshapePattern(const tir::PrimFunc& func);

/*!
* \brief Check if the given expression (likely a function body) contains any impure calls.
* \param expr The expression to be examined. If expr is a function, we check the body.
* \param own_name (Optional.) If we are checking a recursive function body,
* the caller can pass the function's name so recursive calls
* can be ignored in the check (must be a Var or GlobalVar).
* \return A boolean indicating if the expression contains any impure calls.
* \note Relies on StructInfo annotations, so ensure that the module has been normalized first.
* Also, an impure call in a *nested* function does *not* mean that the outer expression contains
* an impure call--it only does if the nested function is *later called*.
*/
TVM_DLL bool ContainsImpureCall(const Expr& expr,
const Optional<Expr>& own_name = Optional<Expr>(nullptr));

/*!
* \brief Check if the IRModule is well formed.
*
Expand Down
22 changes: 17 additions & 5 deletions include/tvm/relax/expr.h
Expand Up @@ -920,10 +920,13 @@ class FunctionNode : public BaseFuncNode {
Expr body;
/*! \brief The return type of the function. */
StructInfo ret_struct_info;
/*! \brief Whether the function is annotated as pure or not. */
bool is_pure;
slyubomirsky marked this conversation as resolved.
Show resolved Hide resolved

void VisitAttrs(AttrVisitor* v) {
v->Visit("params", &params);
v->Visit("body", &body);
v->Visit("is_pure", &is_pure);
v->Visit("ret_struct_info", &ret_struct_info);
v->Visit("attrs", &attrs);
v->Visit("struct_info_", &struct_info_);
Expand All @@ -934,15 +937,16 @@ class FunctionNode : public BaseFuncNode {
bool SEqualReduce(const FunctionNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal.DefEqual(params, other->params) && equal(body, other->body) &&
equal(ret_struct_info, other->ret_struct_info) && equal(attrs, other->attrs) &&
equal(struct_info_, other->struct_info_);
equal(ret_struct_info, other->ret_struct_info) && equal(is_pure, other->is_pure) &&
equal(attrs, other->attrs) && equal(struct_info_, other->struct_info_);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce.DefHash(params);
hash_reduce(body);
hash_reduce(ret_struct_info);
hash_reduce(is_pure);
hash_reduce(attrs);
hash_reduce(struct_info_);
}
Expand All @@ -956,14 +960,16 @@ class FunctionNode : public BaseFuncNode {
class Function : public BaseFunc {
public:
TVM_DLL explicit Function(Array<Var> params, Expr body, Optional<StructInfo> ret_struct_info,
DictAttrs attrs = NullValue<DictAttrs>(), Span span = Span());
bool is_pure = true, DictAttrs attrs = NullValue<DictAttrs>(),
Span span = Span());

/*!
* \brief Mimics the constructor but without body Expr.
* \note ret_struct_info is required, since it can not deduced by the body
* \note ret_struct_info is required, since it can not deduced by the body.
*/
TVM_DLL static Function CreateEmpty(Array<Var> params, StructInfo ret_struct_info,
DictAttrs attrs = NullValue<DictAttrs>(), Span span = Span());
bool is_pure = true, DictAttrs attrs = NullValue<DictAttrs>(),
Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode);
Expand All @@ -985,6 +991,12 @@ constexpr const char* kComposite = "Composite";
constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern";
/*! \brief The required workspace for an external function. */
constexpr const char* kWorkspaceSize = "WorkspaceSize";

// Note: in the future, we prefer snake_case instead of CamelCase for attributes.
// Past ones will be kept for backwards compatibility.
/*! \brief Override checking purity for this function and treat as pure
* (is_pure must be set to true) */
constexpr const char* kForcePure = "relax.force_pure";
} // namespace attr

/*! \brief The extern function, which can represent packed function. */
Expand Down
24 changes: 20 additions & 4 deletions include/tvm/relax/struct_info.h
Expand Up @@ -296,6 +296,12 @@ class FuncStructInfoNode : public StructInfoNode {
* ret should be ObjectStructInfo()
*/
Optional<StructInfoDeriveFunc> derive_func;
/*!
* \brief Whether the function is pure.
* \note This parameter should be set to true only if the function is pure on all inputs.
* If the function _may_ have visible side effects, set it to false.
*/
bool purity;

/*!
* \return Whether the func struct info is opaque.
Expand All @@ -308,16 +314,18 @@ class FuncStructInfoNode : public StructInfoNode {
v->Visit("ret", &ret);
v->Visit("derive_func", &derive_func);
v->Visit("span", &span);
v->Visit("purity", &purity);
}

bool SEqualReduce(const FuncStructInfoNode* other, SEqualReducer equal) const {
return equal.DefEqual(params, other->params) && equal(ret, other->ret) &&
equal(derive_func, other->derive_func);
equal(purity, other->purity) && equal(derive_func, other->derive_func);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.DefHash(params);
hash_reduce(ret);
hash_reduce(purity);
hash_reduce(derive_func);
}

Expand All @@ -335,34 +343,42 @@ class FuncStructInfo : public StructInfo {
* \brief Constructor from parameter struct info and return value struct info.
* \param params The struct info of function parameters.
* \param ret The return value struct info.
* \param purity The purity of the function (true by default).
* \param span The span of the AST.
*
* \note If the ret contains variables(tir::Var and relax::Var), they must be deducible from
* params. If you are unsure, you can always erase ret to static.
*/
TVM_DLL FuncStructInfo(Array<StructInfo> params, StructInfo ret, Span span = Span());
TVM_DLL FuncStructInfo(Array<StructInfo> params, StructInfo ret, bool purity = true,
Span span = Span());

/*!
* \brief Constructing an opaque function struct info using derive_func.
*
* \param derive_func Derivation function.
* \param purity The purity of the function
* (false by default: most external functions are not pure).
* \param span The span of the AST.
*
* \return The FuncStructInfo for opaque packedfunc.
* \note Defaults to an derive func that always return ObjectStructInfo if not specified.
*/
TVM_DLL static FuncStructInfo OpaqueFunc(StructInfoDeriveFunc derive_func, Span span = Span());
TVM_DLL static FuncStructInfo OpaqueFunc(StructInfoDeriveFunc derive_func, bool purity = false,
Span span = Span());

/*!
* \brief Construct an opaque function using from return struct info.
*
* \param ret The struct info of the return value.
* \param purity The purity of the function
* (false by default: most external functions are not pure).
* \param span The span of the AST.
*
* \return The FuncStructInfo for opaque packedfunc.
* \note Defaults to an derive func that always return ObjectStructInfo if not specified.
*/
TVM_DLL static FuncStructInfo OpaqueFunc(StructInfo ret = ObjectStructInfo(), Span span = Span());
TVM_DLL static FuncStructInfo OpaqueFunc(StructInfo ret = ObjectStructInfo(), bool purity = false,
Span span = Span());

TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FuncStructInfo, StructInfo, FuncStructInfoNode);
};
Expand Down
13 changes: 13 additions & 0 deletions include/tvm/relax/transform.h
Expand Up @@ -83,6 +83,19 @@ TVM_DLL Pass LambdaLift();
*/
TVM_DLL Pass ToNonDataflow();

/*!
* \brief Activate force_pure on all pure functions in the module
* and unwrap all pure override ops into the normal versions.
*
* This effectively means that there will be no more purity tracking,
* useful for low-level code generation.
*
* \return The Pass.
*
* \note Should be used after ToNonDataflow()
*/
TVM_DLL Pass RemovePurityChecking();

/*!
* \brief Perform explicit tensor allocation for call_tir and call_dps_packed.
*
Expand Down
12 changes: 12 additions & 0 deletions include/tvm/relax/utils.h
Expand Up @@ -81,6 +81,18 @@ TVM_DLL bool IsBoolStructInfo(const StructInfo& sinfo, bool permit_unknown_rank
*/
TVM_DLL bool IsLeafOrTuple(const Expr& expr);

/*!
* \brief Check if the given Call node is an impure operation. If the callee is a general
* expression, this simply requires checking the purity field of the FuncStructInfo. If it is an Op,
* then this checks the `fPurity` field.
*
* \param call The input call
*
* \return True iff the call is impure (definitely or possibly results in a visible side effect).
* That is, a call is considered pure only if definitely does not result in a visible side effect.
*/
TVM_DLL bool IsImpureCall(const Call& call);

/*!
* \brief Copy the given function. All variables that are bound inside the original function
* would be copied to satisfy the restriction in the well-formed check: Variables in
Expand Down
4 changes: 3 additions & 1 deletion include/tvm/script/ir_builder/relax/frame.h
Expand Up @@ -97,7 +97,8 @@ class FunctionFrameNode : public SeqExprFrameNode {
* take the specified `ret_struct_info`.
*/
Optional<tvm::relax::StructInfo> ret_struct_info;

/*! \brief Whether the function is annotated as pure */
Optional<Bool> is_pure;
/*! \brief The function attributes. */
Map<String, ObjectRef> attrs;
/*! \brief The block builder to create Relax function. */
Expand All @@ -108,6 +109,7 @@ class FunctionFrameNode : public SeqExprFrameNode {
v->Visit("name", &name);
v->Visit("params", &params);
v->Visit("ret_struct_info", &ret_struct_info);
v->Visit("is_pure", &is_pure);
v->Visit("attrs", &attrs);
v->Visit("binding_blocks", &binding_blocks);
v->Visit("output", &output);
Expand Down
6 changes: 6 additions & 0 deletions include/tvm/script/ir_builder/relax/ir.h
Expand Up @@ -57,6 +57,12 @@ TVM_DLL void FuncName(const String& name);
*/
TVM_DLL void FuncAttrs(Map<String, ObjectRef> attrs);

/*!
* \brief Specify the purity of the last function frame.
* \param purity Whether the function is pure.
*/
TVM_DLL void FuncIsPure(bool purity);

/*!
* \brief Specify the return struct info of the last function frame.
* \param ret_sinfo The return struct info.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/__init__.py
Expand Up @@ -56,7 +56,7 @@
from .exec_builder import ExecBuilder

# Operator
from .op.base import call_tir, call_dps_packed
from .op.base import call_tir, call_pure_packed, call_dps_packed

# BlockBuilder
from .block_builder import BlockBuilder
Expand Down
30 changes: 29 additions & 1 deletion python/tvm/relax/analysis/analysis.py
Expand Up @@ -21,7 +21,7 @@
configuring the passes and scripting them in Python.
"""

from typing import Dict, List, Union, Callable
from typing import Dict, List, Optional, Union, Callable
from enum import IntEnum

import tvm
Expand Down Expand Up @@ -327,6 +327,34 @@ def has_reshape_pattern(func: tir.PrimFunc) -> bool:
return _ffi_api.has_reshape_pattern(func) # type: ignore


def contains_impure_call(expr: Expr, own_name: Optional[Union[Var, GlobalVar]] = None) -> bool:
"""
Check if the given expression (likely a function body) contains any impure calls.

Parameter
---------
expr : Expr
The expression to be examined. If expr is a function, we check the body.

own_name : Var or GlobalVar (optional)
For a recursive function, the analysis can ignore the self-calls
for checking purity.

Returns
-------
ret : bool
True if there is an impure call
(call to a function that may have visible side effects).

Notes
-----
Relies on StructInfo annotations, so ensure that the module has been normalized first.
Also, an impure call in a *nested* function does *not* mean that the outer expression contains
an impure call--it only does if the nested function is *later called*.
"""
return _ffi_api.contains_impure_call(expr, own_name)


def get_var2val(func: Function) -> Dict[Var, Expr]:
"""
Get a mapping from Var to Expr for each variable in the function.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/backend/contrib/cutlass.py
Expand Up @@ -385,7 +385,7 @@ def __init__(self, mod):
def visit_function_(self, f):
if f.attrs is None or "Composite" not in f.attrs:
body = super().visit_expr(f.body)
new_f = Function(f.params, body, f.ret_struct_info, f.attrs, f.span)
new_f = Function(f.params, body, f.ret_struct_info, f.is_pure, f.attrs, f.span)

if f.attrs and "global_symbol" in f.attrs and "cutlass" in f.attrs["global_symbol"]:
composite_func = body.blocks[0].bindings[0].value
Expand Down
17 changes: 14 additions & 3 deletions python/tvm/relax/expr.py
Expand Up @@ -560,29 +560,40 @@ class Function(BaseFunc, Scriptable):
params: List[Var]
body: Expr
ret_struct_info: StructInfo
is_pure: bool
attrs: Optional[tvm.ir.DictAttrs]

def __init__(
self,
params: List[Var],
body: Expr,
ret_struct_info: Optional[StructInfo] = None,
is_pure: Optional[bool] = True,
attrs: Optional[tvm.ir.DictAttrs] = None,
span: Optional[Span] = None,
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.Function, params, body, ret_struct_info, attrs, span # type: ignore
)
_ffi_api.Function,
params,
body,
ret_struct_info,
is_pure,
attrs,
span, # type: ignore
) # type: ignore

@staticmethod
def create_empty(
params: List[Var],
ret_struct_info: StructInfo,
is_pure: Optional[bool] = True,
attrs: Optional[tvm.ir.DictAttrs] = None,
span: Optional[Span] = None,
):
"""Construct a relax.Function but without body"""
return _ffi_api.FunctionCreateEmpty(params, ret_struct_info, attrs, span) # type: ignore
return _ffi_api.FunctionCreateEmpty(
params, ret_struct_info, is_pure, attrs, span
) # type: ignore

def __call__(self, *args):
"""Invoke the global function.
Expand Down