diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index 59f9e475bf93..f515ba620196 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -438,6 +438,20 @@ TVM_DLL relay::OpPatternKind AnalyzeOpPatternKind(const tir::PrimFunc& func); */ TVM_DLL bool HasReshapePattern(const tir::PrimFunc& func); +/*! + * \brief Check if the given expression (likely a function body) contains any impure calls. + * \param expr The expression to be examined. If expr is a function, we check the body. + * \param own_name (Optional.) If we are checking a recursive function body, + * the caller can pass the function's name so recursive calls + * can be ignored in the check (must be a Var or GlobalVar). + * \return A boolean indicating if the expression contains any impure calls. + * \note Relies on StructInfo annotations, so ensure that the module has been normalized first. + * Also, an impure call in a *nested* function does *not* mean that the outer expression contains + * an impure call--it only does if the nested function is *later called*. + */ +TVM_DLL bool ContainsImpureCall(const Expr& expr, + const Optional& own_name = Optional(nullptr)); + /*! * \brief Check if the IRModule is well formed. * diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index f090610019bd..36a8109c35b6 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -920,10 +920,13 @@ class FunctionNode : public BaseFuncNode { Expr body; /*! \brief The return type of the function. */ StructInfo ret_struct_info; + /*! \brief Whether the function is annotated as pure or not. */ + bool is_pure; void VisitAttrs(AttrVisitor* v) { v->Visit("params", ¶ms); v->Visit("body", &body); + v->Visit("is_pure", &is_pure); v->Visit("ret_struct_info", &ret_struct_info); v->Visit("attrs", &attrs); v->Visit("struct_info_", &struct_info_); @@ -934,8 +937,8 @@ class FunctionNode : public BaseFuncNode { bool SEqualReduce(const FunctionNode* other, SEqualReducer equal) const { equal->MarkGraphNode(); return equal.DefEqual(params, other->params) && equal(body, other->body) && - equal(ret_struct_info, other->ret_struct_info) && equal(attrs, other->attrs) && - equal(struct_info_, other->struct_info_); + equal(ret_struct_info, other->ret_struct_info) && equal(is_pure, other->is_pure) && + equal(attrs, other->attrs) && equal(struct_info_, other->struct_info_); } void SHashReduce(SHashReducer hash_reduce) const { @@ -943,6 +946,7 @@ class FunctionNode : public BaseFuncNode { hash_reduce.DefHash(params); hash_reduce(body); hash_reduce(ret_struct_info); + hash_reduce(is_pure); hash_reduce(attrs); hash_reduce(struct_info_); } @@ -956,14 +960,16 @@ class FunctionNode : public BaseFuncNode { class Function : public BaseFunc { public: TVM_DLL explicit Function(Array params, Expr body, Optional ret_struct_info, - DictAttrs attrs = NullValue(), Span span = Span()); + bool is_pure = true, DictAttrs attrs = NullValue(), + Span span = Span()); /*! * \brief Mimics the constructor but without body Expr. - * \note ret_struct_info is required, since it can not deduced by the body + * \note ret_struct_info is required, since it can not deduced by the body. */ TVM_DLL static Function CreateEmpty(Array params, StructInfo ret_struct_info, - DictAttrs attrs = NullValue(), Span span = Span()); + bool is_pure = true, DictAttrs attrs = NullValue(), + Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode); @@ -985,6 +991,12 @@ constexpr const char* kComposite = "Composite"; constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern"; /*! \brief The required workspace for an external function. */ constexpr const char* kWorkspaceSize = "WorkspaceSize"; + +// Note: in the future, we prefer snake_case instead of CamelCase for attributes. +// Past ones will be kept for backwards compatibility. +/*! \brief Override checking purity for this function and treat as pure + * (is_pure must be set to true) */ +constexpr const char* kForcePure = "relax.force_pure"; } // namespace attr /*! \brief The extern function, which can represent packed function. */ diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h index 0c1973bceac9..190174248e3e 100644 --- a/include/tvm/relax/struct_info.h +++ b/include/tvm/relax/struct_info.h @@ -296,6 +296,12 @@ class FuncStructInfoNode : public StructInfoNode { * ret should be ObjectStructInfo() */ Optional derive_func; + /*! + * \brief Whether the function is pure. + * \note This parameter should be set to true only if the function is pure on all inputs. + * If the function _may_ have visible side effects, set it to false. + */ + bool purity; /*! * \return Whether the func struct info is opaque. @@ -308,16 +314,18 @@ class FuncStructInfoNode : public StructInfoNode { v->Visit("ret", &ret); v->Visit("derive_func", &derive_func); v->Visit("span", &span); + v->Visit("purity", &purity); } bool SEqualReduce(const FuncStructInfoNode* other, SEqualReducer equal) const { return equal.DefEqual(params, other->params) && equal(ret, other->ret) && - equal(derive_func, other->derive_func); + equal(purity, other->purity) && equal(derive_func, other->derive_func); } void SHashReduce(SHashReducer hash_reduce) const { hash_reduce.DefHash(params); hash_reduce(ret); + hash_reduce(purity); hash_reduce(derive_func); } @@ -335,34 +343,42 @@ class FuncStructInfo : public StructInfo { * \brief Constructor from parameter struct info and return value struct info. * \param params The struct info of function parameters. * \param ret The return value struct info. + * \param purity The purity of the function (true by default). * \param span The span of the AST. * * \note If the ret contains variables(tir::Var and relax::Var), they must be deducible from * params. If you are unsure, you can always erase ret to static. */ - TVM_DLL FuncStructInfo(Array params, StructInfo ret, Span span = Span()); + TVM_DLL FuncStructInfo(Array params, StructInfo ret, bool purity = true, + Span span = Span()); /*! * \brief Constructing an opaque function struct info using derive_func. * * \param derive_func Derivation function. + * \param purity The purity of the function + * (false by default: most external functions are not pure). * \param span The span of the AST. * * \return The FuncStructInfo for opaque packedfunc. * \note Defaults to an derive func that always return ObjectStructInfo if not specified. */ - TVM_DLL static FuncStructInfo OpaqueFunc(StructInfoDeriveFunc derive_func, Span span = Span()); + TVM_DLL static FuncStructInfo OpaqueFunc(StructInfoDeriveFunc derive_func, bool purity = false, + Span span = Span()); /*! * \brief Construct an opaque function using from return struct info. * * \param ret The struct info of the return value. + * \param purity The purity of the function + * (false by default: most external functions are not pure). * \param span The span of the AST. * * \return The FuncStructInfo for opaque packedfunc. * \note Defaults to an derive func that always return ObjectStructInfo if not specified. */ - TVM_DLL static FuncStructInfo OpaqueFunc(StructInfo ret = ObjectStructInfo(), Span span = Span()); + TVM_DLL static FuncStructInfo OpaqueFunc(StructInfo ret = ObjectStructInfo(), bool purity = false, + Span span = Span()); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FuncStructInfo, StructInfo, FuncStructInfoNode); }; diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 0e9c42da9623..138720ec13a0 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -83,6 +83,19 @@ TVM_DLL Pass LambdaLift(); */ TVM_DLL Pass ToNonDataflow(); +/*! + * \brief Activate force_pure on all pure functions in the module + * and unwrap all pure override ops into the normal versions. + * + * This effectively means that there will be no more purity tracking, + * useful for low-level code generation. + * + * \return The Pass. + * + * \note Should be used after ToNonDataflow() + */ +TVM_DLL Pass RemovePurityChecking(); + /*! * \brief Perform explicit tensor allocation for call_tir and call_dps_packed. * diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h index d04a91f1d1d6..a1f587e14e90 100644 --- a/include/tvm/relax/utils.h +++ b/include/tvm/relax/utils.h @@ -81,6 +81,18 @@ TVM_DLL bool IsBoolStructInfo(const StructInfo& sinfo, bool permit_unknown_rank */ TVM_DLL bool IsLeafOrTuple(const Expr& expr); +/*! + * \brief Check if the given Call node is an impure operation. If the callee is a general + * expression, this simply requires checking the purity field of the FuncStructInfo. If it is an Op, + * then this checks the `fPurity` field. + * + * \param call The input call + * + * \return True iff the call is impure (definitely or possibly results in a visible side effect). + * That is, a call is considered pure only if definitely does not result in a visible side effect. + */ +TVM_DLL bool IsImpureCall(const Call& call); + /*! * \brief Copy the given function. All variables that are bound inside the original function * would be copied to satisfy the restriction in the well-formed check: Variables in diff --git a/include/tvm/script/ir_builder/relax/frame.h b/include/tvm/script/ir_builder/relax/frame.h index 0f544d3abcc2..9a8f835e819b 100644 --- a/include/tvm/script/ir_builder/relax/frame.h +++ b/include/tvm/script/ir_builder/relax/frame.h @@ -97,7 +97,8 @@ class FunctionFrameNode : public SeqExprFrameNode { * take the specified `ret_struct_info`. */ Optional ret_struct_info; - + /*! \brief Whether the function is annotated as pure */ + Optional is_pure; /*! \brief The function attributes. */ Map attrs; /*! \brief The block builder to create Relax function. */ @@ -108,6 +109,7 @@ class FunctionFrameNode : public SeqExprFrameNode { v->Visit("name", &name); v->Visit("params", ¶ms); v->Visit("ret_struct_info", &ret_struct_info); + v->Visit("is_pure", &is_pure); v->Visit("attrs", &attrs); v->Visit("binding_blocks", &binding_blocks); v->Visit("output", &output); diff --git a/include/tvm/script/ir_builder/relax/ir.h b/include/tvm/script/ir_builder/relax/ir.h index 42aa591a95b7..ca705d11dc36 100644 --- a/include/tvm/script/ir_builder/relax/ir.h +++ b/include/tvm/script/ir_builder/relax/ir.h @@ -57,6 +57,12 @@ TVM_DLL void FuncName(const String& name); */ TVM_DLL void FuncAttrs(Map attrs); +/*! + * \brief Specify the purity of the last function frame. + * \param purity Whether the function is pure. + */ +TVM_DLL void FuncIsPure(bool purity); + /*! * \brief Specify the return struct info of the last function frame. * \param ret_sinfo The return struct info. diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index 42ba452bae38..e4d91c59ac48 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -56,7 +56,7 @@ from .exec_builder import ExecBuilder # Operator -from .op.base import call_tir, call_dps_packed +from .op.base import call_tir, call_pure_packed, call_dps_packed # BlockBuilder from .block_builder import BlockBuilder diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index e3b3c288efce..4abd609437a2 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -21,7 +21,7 @@ configuring the passes and scripting them in Python. """ -from typing import Dict, List, Union, Callable +from typing import Dict, List, Optional, Union, Callable from enum import IntEnum import tvm @@ -327,6 +327,34 @@ def has_reshape_pattern(func: tir.PrimFunc) -> bool: return _ffi_api.has_reshape_pattern(func) # type: ignore +def contains_impure_call(expr: Expr, own_name: Optional[Union[Var, GlobalVar]] = None) -> bool: + """ + Check if the given expression (likely a function body) contains any impure calls. + + Parameter + --------- + expr : Expr + The expression to be examined. If expr is a function, we check the body. + + own_name : Var or GlobalVar (optional) + For a recursive function, the analysis can ignore the self-calls + for checking purity. + + Returns + ------- + ret : bool + True if there is an impure call + (call to a function that may have visible side effects). + + Notes + ----- + Relies on StructInfo annotations, so ensure that the module has been normalized first. + Also, an impure call in a *nested* function does *not* mean that the outer expression contains + an impure call--it only does if the nested function is *later called*. + """ + return _ffi_api.contains_impure_call(expr, own_name) + + def get_var2val(func: Function) -> Dict[Var, Expr]: """ Get a mapping from Var to Expr for each variable in the function. diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py index 19fc2a39ea10..2dd429d1d018 100644 --- a/python/tvm/relax/backend/contrib/cutlass.py +++ b/python/tvm/relax/backend/contrib/cutlass.py @@ -385,7 +385,7 @@ def __init__(self, mod): def visit_function_(self, f): if f.attrs is None or "Composite" not in f.attrs: body = super().visit_expr(f.body) - new_f = Function(f.params, body, f.ret_struct_info, f.attrs, f.span) + new_f = Function(f.params, body, f.ret_struct_info, f.is_pure, f.attrs, f.span) if f.attrs and "global_symbol" in f.attrs and "cutlass" in f.attrs["global_symbol"]: composite_func = body.blocks[0].bindings[0].value diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index fdf98c179b7c..6474db1775d4 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -560,6 +560,7 @@ class Function(BaseFunc, Scriptable): params: List[Var] body: Expr ret_struct_info: StructInfo + is_pure: bool attrs: Optional[tvm.ir.DictAttrs] def __init__( @@ -567,22 +568,32 @@ def __init__( params: List[Var], body: Expr, ret_struct_info: Optional[StructInfo] = None, + is_pure: Optional[bool] = True, attrs: Optional[tvm.ir.DictAttrs] = None, span: Optional[Span] = None, ) -> None: self.__init_handle_by_constructor__( - _ffi_api.Function, params, body, ret_struct_info, attrs, span # type: ignore - ) + _ffi_api.Function, + params, + body, + ret_struct_info, + is_pure, + attrs, + span, # type: ignore + ) # type: ignore @staticmethod def create_empty( params: List[Var], ret_struct_info: StructInfo, + is_pure: Optional[bool] = True, attrs: Optional[tvm.ir.DictAttrs] = None, span: Optional[Span] = None, ): """Construct a relax.Function but without body""" - return _ffi_api.FunctionCreateEmpty(params, ret_struct_info, attrs, span) # type: ignore + return _ffi_api.FunctionCreateEmpty( + params, ret_struct_info, is_pure, attrs, span + ) # type: ignore def __call__(self, *args): """Invoke the global function. diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index 67f5f5707093..a1a183003ff0 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -118,6 +118,10 @@ def call_dps_packed( """ Call a destination-passing-style packed function and return the output. + Note: The called function is assumed to be _pure_ (other than modifying the designated + output arguments). If the function _does_ result in other side effects, then the compiler + may end up removing, reordering, or repeating those effects--no guarantees can be made. + Parameters ---------- func : Union[str, Expr] @@ -217,7 +221,7 @@ def invoke_closure( closure: Expr, args: Expr, sinfo_args: Union[List[StructInfo], StructInfo], -) -> Object: +) -> Call: """ Invoke a closure. @@ -234,8 +238,8 @@ def invoke_closure( Returns ------- - ret: Object - The result. + ret: Call + A call to `invoke_closure`. """ if not isinstance(sinfo_args, (list, tuple)): @@ -466,3 +470,91 @@ def shape_to_tensor(expr: Expr) -> Expr: A relax Call, which transforms the shape values to the tensor """ return _ffi_api.shape_to_tensor(expr) # type: ignore # pylint: disable=no-member + + +@args_converter.auto +def call_pure_packed( + func: Union[str, ExternFunc, GlobalVar], + *args: Expr, + sinfo_args: Union[StructInfo, List[StructInfo]], +) -> Expr: + """ + Construct a call to a packed function that should be treated as pure, + even though packed calls are normally not treated as pure. + + The resulting call will have the same semantics as calling the packed function directly. + + Note: This should be used for cases when the user knows that calling the packed function + with these arguments will _in reality_ not cause any side effects. + If it is used for a call that _does_ result in side effects, then the compiler + may end up removing, reordering, or repeating that call, with no guarantees + made about any side effects from the callee. + + Parameters + ---------- + func : Union[str, ExternFunc] + The name (global symbol) for a PackedFunc or an ExternFunc node. + + args: Expr + The arguments for the PackedFunc. + + sinfo_args: Union[StructInfo, List[StructInfo]] + The list of structure info arguments (giving the structural info for the returned value). + + Returns + ------- + result : Expr + A Relax call, corresponding to + `call_pure_packed(ExternFunc(func), args, DictAttrs(kwargs), sinfo_args)` + """ + if isinstance(func, ExternFunc): + func = func.global_symbol + + op = ExternFunc(func) + if sinfo_args is None: + raise ValueError("R.call_pure_packed is required to have type_args") + if isinstance(sinfo_args, tuple): # type: ignore + sinfo_args = list(sinfo_args) + elif not isinstance(sinfo_args, list): + sinfo_args = [sinfo_args] + # note: if we need attributes, we can also take them here + + return _ffi_api.call_pure_packed(op, args, None, sinfo_args) # type: ignore # pylint: disable=no-member + + +@args_converter.auto +def invoke_pure_closure( + closure: Expr, + args: Expr, + sinfo_args: Union[List[StructInfo], StructInfo], +) -> Call: + """ + Invoke a closure and indicate to the compiler that it is pure. + + Note: This should be used for cases when the user knows that calling the closure + with these arguments will _in reality_ not cause any side effects. + If it is used for a call that _does_ result in side effects, then the compiler + may end up removing, reordering, or repeating that call, with no guarantees + made about any side effects from the callee. + + Parameters + ---------- + closure : Expr + The VMClosure object. + + args : Expr + The input arguments. + + type_args: Union[List[StructInfo], StructInfo] + The structure info arguments of the CallNode + + Returns + ------- + ret: Call + A call to `invoke_pure_closure`. + """ + + if not isinstance(sinfo_args, (list, tuple)): + sinfo_args = [sinfo_args] + + return _ffi_api.invoke_pure_closure(closure, args, sinfo_args) # type: ignore diff --git a/python/tvm/relax/struct_info.py b/python/tvm/relax/struct_info.py index 2ff027b22924..3dcc3dc9a04b 100644 --- a/python/tvm/relax/struct_info.py +++ b/python/tvm/relax/struct_info.py @@ -149,16 +149,25 @@ class FuncStructInfo(StructInfo): ret: StructInfo The struct info of return value + + purity: bool + Whether the function is pure (has no visible side effects). + Note: We consider a function to be pure only if it is pure on all inputs. + If a function can have visible side effects only in some cases, + we still consider it impure. """ params: Optional[List[StructInfo]] ret: StructInfo derive_func: Optional[EnvFunc] + purity: bool span: Span - def __init__(self, params: List[StructInfo], ret: StructInfo, span: Span = None) -> None: + def __init__( + self, params: List[StructInfo], ret: StructInfo, purity: bool = True, span: Span = None + ) -> None: self.__init_handle_by_constructor__( - _ffi_api.FuncStructInfo, params, ret, span # type: ignore + _ffi_api.FuncStructInfo, params, ret, purity, span # type: ignore ) @staticmethod @@ -166,6 +175,7 @@ def opaque_func( *, ret: Optional[StructInfo] = None, derive_func: Optional[EnvFunc] = None, + purity: bool = False, span: Span = None, ) -> "FuncStructInfo": """ @@ -183,6 +193,9 @@ def opaque_func( derive_func: Optional[EnvFunc] The environment function used for derivation + purity: bool + Whether the function is pure (false by default, as most opaque functions are not pure) + span: Optional[Span] Optional span information of the ast. @@ -194,4 +207,4 @@ def opaque_func( ---- We cannot specify ret and derive_func simultaneously. """ - return _ffi_api.FuncStructInfoOpaqueFunc(ret, derive_func, span) # type: ignore + return _ffi_api.FuncStructInfoOpaqueFunc(ret, derive_func, purity, span) # type: ignore diff --git a/python/tvm/relax/testing/ast_printer.py b/python/tvm/relax/testing/ast_printer.py index 6727b2429202..1ed16363b20a 100644 --- a/python/tvm/relax/testing/ast_printer.py +++ b/python/tvm/relax/testing/ast_printer.py @@ -147,6 +147,7 @@ def visit_function_(self, op: relax.Function) -> str: "params": self.build_list(map(self.visit_expr, op.params)), "body": self.visit_expr(op.body), "ret_struct_info": self.visit_struct_info_(op.ret_struct_info), + "is_pure": op.is_pure, } if op.attrs: fields["attrs"] = self.build_list( @@ -295,6 +296,7 @@ def visit_struct_info_(self, struct_info_node: relax.StructInfo) -> str: map(self.visit_struct_info_, struct_info_node.params) ) fields["ret"] = self.visit_struct_info_(struct_info_node.ret) + fields["purity"] = bool(struct_info_node.purity) return self.build_ast_node("FuncStructInfo", **fields) else: raise ValueError( diff --git a/python/tvm/relax/transform/lazy_transform_params.py b/python/tvm/relax/transform/lazy_transform_params.py index 01deee8197f9..9ce7bd003862 100644 --- a/python/tvm/relax/transform/lazy_transform_params.py +++ b/python/tvm/relax/transform/lazy_transform_params.py @@ -144,7 +144,7 @@ def transform(self, func: relax.Function) -> relax.Function: self.memory_free_insertion = liveness.var_liveness_end # Step 3. rewrite get item and set item new_body = self.visit_expr(func.body) - return relax.Function([], new_body, relax.ObjectStructInfo(), func.attrs) + return relax.Function([], new_body, relax.ObjectStructInfo(), attrs=func.attrs) def visit_tuple_getitem_(self, op: relax.TupleGetItem) -> relax.Expr: # rewrite get item @@ -191,6 +191,8 @@ class LazyTransformParams: """ Convert transform_params functions into a lazy version. (Load the input to memory on demand, and immediately free it after the last use.) + + Note: ToNonDataflow() and RemovePurityTracking() should be invoked before this pass. """ def transform_module(self, mod: IRModule, ctx: tvm.transform.PassContext) -> IRModule: diff --git a/python/tvm/relax/transform/legalize_ops/index.py b/python/tvm/relax/transform/legalize_ops/index.py index 8ee1bed9b9c7..13228c4805d5 100644 --- a/python/tvm/relax/transform/legalize_ops/index.py +++ b/python/tvm/relax/transform/legalize_ops/index.py @@ -19,8 +19,9 @@ import logging from tvm import topi, tir, te +from ...op import call_pure_packed from ...block_builder import BlockBuilder -from ...expr import Call, Expr, ExternFunc +from ...expr import Call, Expr from ...struct_info import ShapeStructInfo from .common import register_legalize @@ -105,10 +106,12 @@ def get_length(begin, end, strides, length): # Get shape length ndim = int(output_shape.struct_info.shape[0]) output_shape = bb.emit( - Call( - ExternFunc("vm.builtin.tensor_to_shape"), - [output_shape], - sinfo_args=[ShapeStructInfo(ndim=ndim)], + # TODO(@relax-team): Ideally, we should use the tensor_to_shape op here to + # address the issue with purity, but that introduces a staging issue: + # we need to apply DecomposeOpsForInference in that case + # and it's unclear when in the build it should happen + call_pure_packed( + "vm.builtin.tensor_to_shape", output_shape, sinfo_args=ShapeStructInfo(ndim=ndim) ) ) output_shape_vars = [tir.Var("s", "int64") for i in range(ndim)] diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 508e8bccba8b..a7955f754cdb 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -226,6 +226,25 @@ def ToNonDataflow() -> tvm.ir.transform.Pass: return _ffi_api.ToNonDataflow() # type: ignore +def RemovePurityChecking() -> tvm.ir.transform.Pass: + """Activate relax.force_pure on all pure functions in the module + and unwrap all pure override ops into the normal versions. + + This effectively means that there will be no more purity tracking, + useful for low-level code generation. + + Returns + ------- + ret: tvm.ir.transform.Pass + The Pass. + + Note + ---- + Should be used after ToNonDataflow() + """ + return _ffi_api.RemovePurityChecking() # type: ignore + + def LambdaLift(): """A pass that lifts local functions into global. diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py index a8339398339d..fb71f0f1f8ef 100644 --- a/python/tvm/relax/vm_build.py +++ b/python/tvm/relax/vm_build.py @@ -307,6 +307,7 @@ def foo(x: Tensor((3, 4), "float32"), y: Tensor((3, 4), "float32")): passes = [] passes.append(relax.transform.RewriteDataflowReshape()) passes.append(relax.transform.ToNonDataflow()) + passes.append(relax.transform.RemovePurityChecking()) passes.append(relax.transform.CallTIRRewrite()) passes.append(relax.transform.StaticPlanBlockMemory()) diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 39327c4b4a25..7a1ecca4d8d9 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -48,6 +48,7 @@ broadcast_to, builtin, call_builtin_with_ctx, + call_pure_packed, call_tir, call_dps_packed, ceil, @@ -76,6 +77,7 @@ greater_equal, image, invoke_closure, + invoke_pure_closure, isfinite, isinf, isnan, @@ -203,6 +205,23 @@ def func_attr(attrs: Dict[py_str, tvm_Object]) -> None: return _ffi_api.FuncAttrs(attrs) # type: ignore[attr-defined] # pylint: disable=no-member +def is_pure(purity: bool = True) -> None: + """Specify the purity of the last function frame. + + Parameters + ---------- + purity: bool + The annotated purity. + """ + return _ffi_api.FuncIsPure(purity) # type: ignore[attr-defined] # pylint: disable=no-member + + +def is_impure() -> None: + """Specify that the last function frame is annotated as impure. + (Syntactic sugar for R.is_pure(False))""" + return _ffi_api.FuncIsPure(False) # type: ignore[attr-defined] # pylint: disable=no-member + + def func_ret_struct_info(ret_sinfo: StructInfo) -> None: """Specify the return struct info of the last function frame. Parameters @@ -560,6 +579,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "broadcast_to", "builtin", "call_packed", + "call_pure_packed", "call_tir", "call_dps_packed", "call_builtin_with_ctx", @@ -601,6 +621,9 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "greater_equal", "image", "invoke_closure", + "invoke_pure_closure", + "is_impure", + "is_pure", "isfinite", "isinf", "isnan", diff --git a/python/tvm/script/parser/relax/entry.py b/python/tvm/script/parser/relax/entry.py index acb490a813b8..70e51734585d 100644 --- a/python/tvm/script/parser/relax/entry.py +++ b/python/tvm/script/parser/relax/entry.py @@ -149,12 +149,14 @@ def Tensor( class CallableProxy(StructInfoProxy): params: List[StructInfoProxy] ret: StructInfoProxy + purity: bool + """Function type. A function type consists of a list of type parameters to enable the definition of generic functions, a set of type constraints which we omit for the time being, - a sequence of argument types, and a return type. + a sequence of argument types, the purity of the function, and a return type. Parameters ---------- @@ -164,18 +166,23 @@ class CallableProxy(StructInfoProxy): ret : StructInfoProxy The return StructInfoProxy. + purity : bool + Whether the callable is pure. + """ def __init__( self, params: Union[StructInfoProxy, List[StructInfoProxy]], ret: StructInfoProxy, + purity: bool = True, ) -> None: if not isinstance(params, (list, tuple)): params = [params] # convert `R.Tensor` to `R.Tensor()` self.params = [param() if callable(param) else param for param in params] self.ret = ret() if callable(ret) else ret + self.purity = purity def get_symbolic_vars(self) -> Set[str]: return set().union(*[p.get_symbolic_vars() for p in self.params]) @@ -183,14 +190,15 @@ def get_symbolic_vars(self) -> Set[str]: def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> FuncStructInfo: params = [param.as_struct_info(dict_globals) for param in self.params] ret = self.ret.as_struct_info(dict_globals) - return FuncStructInfo(params, ret) + return FuncStructInfo(params, ret, purity=self.purity) def Callable( params: Union[StructInfoProxy, List[StructInfoProxy]], ret: StructInfoProxy, + purity: bool = True, ) -> CallableProxy: - return CallableProxy(params, ret) + return CallableProxy(params, ret, purity=purity) ############################### R.Tuple ################################ diff --git a/python/tvm/script/parser/relax/parser.py b/python/tvm/script/parser/relax/parser.py index 06fc51b7a607..3dfde96714b6 100644 --- a/python/tvm/script/parser/relax/parser.py +++ b/python/tvm/script/parser/relax/parser.py @@ -202,6 +202,25 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: self.visit_body(node.body) +def find_purity_annotation(node: doc.FunctionDef) -> bool: + """ + Check if is_pure is specified in the function body. + Returns the annotated purity if present, otherwise defaulting to True. + This allows for specifying the purity in the function signature. + """ + for item in node.body: + if ( + isinstance(item, doc.Expr) + and isinstance(item.value, doc.Call) + and isinstance(item.value.func, doc.Attribute) + and item.value.func.attr == "is_pure" + and len(item.value.args) == 1 + and isinstance(item.value.args[0], doc.Constant) + ): + return bool(item.value.args[0].value) + return True + + @dispatch.register(token="relax", type_name="tvm_declare_function") def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar: with self.var_table.with_frame(): @@ -220,7 +239,9 @@ def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar param_sinfo = eval_struct_info(self, arg.annotation, eval_str=True) params.append(relax.Var(arg.arg, param_sinfo)) - func_signature = relax.Function.create_empty(params, ret_sinfo) + is_pure = find_purity_annotation(node) + + func_signature = relax.Function.create_empty(params, ret_sinfo, is_pure=is_pure) return I.decl_function(node.name, func_signature) diff --git a/src/relax/analysis/analysis.cc b/src/relax/analysis/analysis.cc index 4132039a5e34..108fe69372b6 100644 --- a/src/relax/analysis/analysis.cc +++ b/src/relax/analysis/analysis.cc @@ -141,6 +141,48 @@ tvm::Array AllVars(const Expr& expr) { return VarVisitor().All(expr); } tvm::Array AllGlobalVars(const Expr& expr) { return VarVisitor().AllGlobalVars(expr); } +bool ContainsImpureCall(const Expr& expr, const Optional& own_name) { + class ImpureCallChecker : public ExprVisitor { + public: + explicit ImpureCallChecker(const Optional& own_name) : own_name_(own_name) {} + + bool Check(const Expr& expr) { + contains_impure_ = false; + VisitExpr(expr); + return contains_impure_; + } + + void VisitExpr_(const FunctionNode* func) override { + // we don't visit inner functions because an impure call in an inner function + // does *not* mean the outer function contains an impure call + } + + void VisitExpr_(const CallNode* call) override { + // ignore recursive calls if we find one + if (!(own_name_ && own_name_.value().same_as(call->op))) { + if (IsImpureCall(GetRef(call))) { + contains_impure_ = true; + } + } + ExprVisitor::VisitExpr_(call); + } + + private: + const Optional& own_name_; + bool contains_impure_ = false; + }; + + if (own_name) { + ICHECK(own_name.value().as() || own_name.value().as()) + << "Must pass a Var or GlobalVar for own_name"; + } + ImpureCallChecker checker(own_name); + if (auto func = expr.as()) { + return checker.Check(func->body); + } + return checker.Check(expr); +} + TVM_REGISTER_GLOBAL("relax.analysis.free_vars").set_body_typed(FreeVars); TVM_REGISTER_GLOBAL("relax.analysis.bound_vars").set_body_typed(BoundVars); @@ -149,5 +191,7 @@ TVM_REGISTER_GLOBAL("relax.analysis.all_vars").set_body_typed(AllVars); TVM_REGISTER_GLOBAL("relax.analysis.all_global_vars").set_body_typed(AllGlobalVars); +TVM_REGISTER_GLOBAL("relax.analysis.contains_impure_call").set_body_typed(ContainsImpureCall); + } // namespace relax } // namespace tvm diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc index d2ef8c4e73ac..19e93f36a439 100644 --- a/src/relax/analysis/struct_info_analysis.cc +++ b/src/relax/analysis/struct_info_analysis.cc @@ -95,7 +95,8 @@ StructInfo StructInfoFromType(const Type& type) { Array params = func_type->arg_types.Map([](const Type& param) { return StructInfoFromType(param); }); StructInfo ret = StructInfoFromType(func_type->ret_type); - return FuncStructInfo(params, ret, func_type->span); + // TODO(relax-team): Maybe add purity into the type as well + return FuncStructInfo(params, ret, true, func_type->span); } else { LOG(FATAL) << "Unsupported type: " << type; return StructInfo(); @@ -362,6 +363,11 @@ class StructInfoBaseChecker return BaseCheckResult::kFailL0; } + // Check purity: Pure functions are a subtype of impure functions + if (lhs->purity && !rhs->purity) { + return BaseCheckResult::kFailL0; + } + // lhs opaque handling if (lhs->IsOpaque()) { if (lhs->derive_func.defined()) { @@ -774,6 +780,9 @@ class StructInfoLCAFinder auto* rhs = other.as(); if (rhs == nullptr) return ObjectStructInfo(lhs->span); + // the unified function is pure only if both are pure + bool purity = lhs->purity && rhs->purity; + // lhs opaque handling if (lhs->IsOpaque()) { if (lhs->derive_func.defined()) { @@ -781,13 +790,13 @@ class StructInfoLCAFinder return GetRef(lhs); } else { // Create a new opaque with object return - return FuncStructInfo::OpaqueFunc(ObjectStructInfo(), lhs->span); + return FuncStructInfo::OpaqueFunc(ObjectStructInfo(), purity, lhs->span); } } else { // no derivation function, only depends on ret StructInfo ret = this->VisitStructInfo(lhs->ret, rhs->ret); if (ret.same_as(lhs->ret)) return GetRef(lhs); - return FuncStructInfo::OpaqueFunc(ret, lhs->span); + return FuncStructInfo::OpaqueFunc(ret, purity, lhs->span); } } // rhs is opaque, lhs is not @@ -795,7 +804,7 @@ class StructInfoLCAFinder // unify ret value, note that rhs's ret is context free(because it is opaque) // so result of the unify is also context-free. StructInfo ret = this->VisitStructInfo(lhs->ret, rhs->ret); - return FuncStructInfo::OpaqueFunc(ret, lhs->span); + return FuncStructInfo::OpaqueFunc(ret, purity, lhs->span); } // Both lhs and rhs are not opaque @@ -825,9 +834,9 @@ class StructInfoLCAFinder } else { // fail to unify the params if (!params.defined()) { - return FuncStructInfo::OpaqueFunc(ret, lhs->span); + return FuncStructInfo::OpaqueFunc(ret, purity, lhs->span); } else { - return FuncStructInfo(params.value(), ret, lhs->span); + return FuncStructInfo(params.value(), ret, purity, lhs->span); } } } diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index aeae975bf53e..b37662af858b 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -56,6 +56,13 @@ * * The op or args fields of Call nodes * * Inside the fields of Tuple nodes * 13. Expr always has checked_type_ (with the exception of Op). + * 14. DataflowBlocks may not contain If nodes. + * 15. DataflowBlocks may not contain calls to impure functions or operators + * (only checked if check_struct_info is true). + * 16. If a function has is_pure set to true and the kForcePure attribute is not set, + * the body may not contain any impure call (only checked if check_struct_info is true). + * 17. If the kForcePure attribute is set for a function, + * that function's is_pure field must be true. */ #include #include @@ -220,6 +227,14 @@ class WellFormedChecker : public relax::ExprVisitor, } }); + // ensure the purity attributes are valid + if (op->GetAttr(relax::attr::kForcePure).value_or(Bool(false))->value && !op->is_pure) { + Malformed(Diagnostic::Error(op->span) + << "Function " << op << " has true for " << relax::attr::kForcePure + << " but false for is_pure; " << relax::attr::kForcePure + << " should be true only if is_pure is also true."); + } + // check all expr are well defined. for (Var param : op->params) { this->VisitVarDef(param); @@ -239,6 +254,18 @@ class WellFormedChecker : public relax::ExprVisitor, Malformed(Diagnostic::Error(op) << "Function must have defined ret_struct_info"); } + // if we are not forcing purity and the function is annotated as pure, it must not contain an + // impure call + if (check_struct_info_ && + !op->GetAttr(relax::attr::kForcePure).value_or(Bool(false))->value && op->is_pure && + ContainsImpureCall(op->body)) { + Malformed(Diagnostic::Error(op) + << "Function " << op << " is annotated as pure but contains an impure call; " + << "please set " << relax::attr::kForcePure << " to true " + << "or use a pure operator variant (e.g., call_pure_packed) " + << "if it is necessary to override this judgment."); + } + if (auto seq = op->body.as()) { this->VisitSeqExpr(seq); } else { @@ -279,9 +306,15 @@ class WellFormedChecker : public relax::ExprVisitor, } CheckStructInfo(op); + if (is_dataflow_ && check_struct_info_ && IsImpureCall(GetRef(op))) { + Malformed(Diagnostic::Error(op) << "There cannot be an impure call inside a dataflow block."); + } } void VisitExpr_(const IfNode* op) final { + if (is_dataflow_) { + Malformed(Diagnostic::Error(op) << "If nodes are not allowed to appear in dataflow blocks."); + } if (IsLeafOrTuple(op->cond)) { this->VisitExpr(op->cond); } else { @@ -346,6 +379,7 @@ class WellFormedChecker : public relax::ExprVisitor, } else { this->VisitExpr(binding->value); } + this->VisitVarDef(binding->var); if (is_lambda) { recur_vars_.erase(binding->var); diff --git a/src/relax/backend/vm/vm_shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc index f4b272979bb6..694bcd40d6e1 100644 --- a/src/relax/backend/vm/vm_shape_lower.cc +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -277,7 +277,7 @@ class VMShapeLowerMutator auto new_body = builder_->Normalize(SeqExpr(blocks, body_seq->body)); // create a new function - return Function(func->params, new_body, func->ret_struct_info, func->attrs); + return Function(func->params, new_body, func->ret_struct_info, func->is_pure, func->attrs); } //------------------------------------------------------- diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 5f9ce63c97dc..6d5448f49924 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -572,7 +572,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorbody)) { return GetRef(op); } else { - return Function(op->params, new_body, op->ret_struct_info, op->attrs); + return Function(op->params, new_body, op->ret_struct_info, op->is_pure, op->attrs); } } diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 5392be7cb69b..7cd356e0cae3 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -417,7 +417,7 @@ TVM_REGISTER_GLOBAL("relax.SeqExpr") TVM_REGISTER_NODE_TYPE(FunctionNode); -Function::Function(Array params, Expr body, Optional ret_struct_info, +Function::Function(Array params, Expr body, Optional ret_struct_info, bool is_pure, DictAttrs attrs, Span span) { // Set the function type. // For function, we take a conservative approach and require the function type @@ -449,13 +449,14 @@ Function::Function(Array params, Expr body, Optional ret_struct ret_struct_info = body_sinfo; } - FuncStructInfo func_sinfo(param_sinfo, ret_struct_info.value()); + FuncStructInfo func_sinfo(param_sinfo, ret_struct_info.value(), is_pure); // set the fields ObjectPtr n = make_object(); n->params = std::move(params); n->body = std::move(body); n->ret_struct_info = std::move(ret_struct_info.value()); + n->is_pure = is_pure; n->checked_type_ = GetStaticType(func_sinfo); n->struct_info_ = std::move(func_sinfo); n->attrs = std::move(attrs); @@ -465,23 +466,26 @@ Function::Function(Array params, Expr body, Optional ret_struct TVM_REGISTER_GLOBAL("relax.Function") .set_body_typed([](Array params, Expr body, Optional ret_struct_info, - DictAttrs attrs, - Span span) { return Function(params, body, ret_struct_info, attrs, span); }); + bool is_pure, DictAttrs attrs, Span span) { + return Function(params, body, ret_struct_info, is_pure, attrs, span); + }); -Function Function::CreateEmpty(Array params, StructInfo ret_struct_info, DictAttrs attrs, - Span span) { +Function Function::CreateEmpty(Array params, StructInfo ret_struct_info, bool is_pure, + DictAttrs attrs, Span span) { Array param_sinfo; for (const Var& param : params) { ICHECK(param->checked_type_.defined()) << "relax.Function requires params to contain checked_type_."; param_sinfo.push_back(GetStructInfo(param)); } - FuncStructInfo finfo(param_sinfo, ret_struct_info); + + FuncStructInfo finfo(param_sinfo, ret_struct_info, is_pure); // set the fields ObjectPtr n = make_object(); n->params = std::move(params); n->body = Expr(); + n->is_pure = is_pure; n->checked_type_ = GetStaticType(finfo); n->struct_info_ = std::move(finfo); n->ret_struct_info = std::move(ret_struct_info); @@ -491,8 +495,9 @@ Function Function::CreateEmpty(Array params, StructInfo ret_struct_info, Di } TVM_REGISTER_GLOBAL("relax.FunctionCreateEmpty") - .set_body_typed([](Array params, StructInfo ret_struct_info, DictAttrs attrs, Span span) { - return Function::CreateEmpty(params, ret_struct_info, attrs, span); + .set_body_typed([](Array params, StructInfo ret_struct_info, bool is_pure, DictAttrs attrs, + Span span) { + return Function::CreateEmpty(params, ret_struct_info, is_pure, attrs, span); }); // Special opaque derivation function for ExternFunc diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index 3f0fc86a2a37..cb74400d7a19 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -410,7 +410,7 @@ Expr ExprMutatorBase::VisitExpr_(const FunctionNode* op) { if (body.same_as(op->body)) { return GetRef(op); } else { - return Function(op->params, body, op->ret_struct_info, op->attrs); + return Function(op->params, body, op->ret_struct_info, op->is_pure, op->attrs); } } @@ -589,7 +589,7 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) { if (all_params_unchanged && body.same_as(op->body)) { return GetRef(op); } else { - return Function(params, body, op->ret_struct_info, op->attrs); + return Function(params, body, op->ret_struct_info, op->is_pure, op->attrs); } } diff --git a/src/relax/ir/struct_info.cc b/src/relax/ir/struct_info.cc index 4004ad28d560..c290711dcdad 100644 --- a/src/relax/ir/struct_info.cc +++ b/src/relax/ir/struct_info.cc @@ -145,25 +145,29 @@ TVM_REGISTER_GLOBAL("relax.TupleStructInfo") }); // Func -FuncStructInfo::FuncStructInfo(Array params, StructInfo ret, Span span) { +FuncStructInfo::FuncStructInfo(Array params, StructInfo ret, bool purity, Span span) { ObjectPtr n = make_object(); n->params = std::move(params); n->ret = std::move(ret); + n->purity = std::move(purity); n->span = span; data_ = std::move(n); } -FuncStructInfo FuncStructInfo::OpaqueFunc(StructInfoDeriveFunc derive_func, Span span) { +FuncStructInfo FuncStructInfo::OpaqueFunc(StructInfoDeriveFunc derive_func, bool purity, + Span span) { ObjectPtr n = make_object(); n->derive_func = std::move(derive_func); n->ret = ObjectStructInfo(); + n->purity = std::move(purity); n->span = span; return FuncStructInfo(n); } -FuncStructInfo FuncStructInfo::OpaqueFunc(StructInfo ret, Span span) { +FuncStructInfo FuncStructInfo::OpaqueFunc(StructInfo ret, bool purity, Span span) { ObjectPtr n = make_object(); n->ret = std::move(ret); + n->purity = std::move(purity); n->span = span; return FuncStructInfo(n); } @@ -171,18 +175,18 @@ FuncStructInfo FuncStructInfo::OpaqueFunc(StructInfo ret, Span span) { TVM_REGISTER_NODE_TYPE(FuncStructInfoNode); TVM_REGISTER_GLOBAL("relax.FuncStructInfo") - .set_body_typed([](Array params, StructInfo ret, Span span) { - return FuncStructInfo(params, ret, span); + .set_body_typed([](Array params, StructInfo ret, bool purity, Span span) { + return FuncStructInfo(params, ret, purity, span); }); TVM_REGISTER_GLOBAL("relax.FuncStructInfoOpaqueFunc") .set_body_typed([](Optional ret, Optional derive_func, - Span span) { + bool purity, Span span) { if (derive_func.defined()) { ICHECK(!ret.defined()) << "ValueError: Cannot specify both ret and derive_func"; - return FuncStructInfo::OpaqueFunc(derive_func.value(), span); + return FuncStructInfo::OpaqueFunc(derive_func.value(), purity, span); } else { - return FuncStructInfo::OpaqueFunc(ret.value_or(ObjectStructInfo()), span); + return FuncStructInfo::OpaqueFunc(ret.value_or(ObjectStructInfo()), purity, span); } }); diff --git a/src/relax/ir/struct_info_functor.cc b/src/relax/ir/struct_info_functor.cc index 199491e3c63f..d7b9bef3dd3c 100644 --- a/src/relax/ir/struct_info_functor.cc +++ b/src/relax/ir/struct_info_functor.cc @@ -122,7 +122,7 @@ StructInfo StructInfoMutator::VisitStructInfo_(const FuncStructInfoNode* op) { return GetRef(op); } else { ICHECK(ret.defined()) << "FuncStructInfo that contains params must contain ret"; - return FuncStructInfo(params.value(), ret, op->span); + return FuncStructInfo(params.value(), ret, op->purity, op->span); } } diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc index 6d49bea6b656..3c3bb151366e 100644 --- a/src/relax/op/image/resize.cc +++ b/src/relax/op/image/resize.cc @@ -122,7 +122,8 @@ TVM_REGISTER_OP("relax.image.resize2d") .add_argument("size", "Shape", "The output image shape.") .set_attr("FInferStructInfo", InferStructInfoResize2D) .set_attr("FRelaxInferLayout", InferLayoutResize2d) - .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); } // namespace relax } // namespace tvm diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc index 56e5a04e123d..c83e49c70c57 100644 --- a/src/relax/op/nn/attention.cc +++ b/src/relax/op/nn/attention.cc @@ -121,7 +121,8 @@ TVM_REGISTER_OP("relax.nn.attention") .add_argument("value", "Tensor", "The input values tensor.") .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) .set_attr("FInferMixedPrecision", InferMixedPrecisionAttention) - .set_attr("FInferStructInfo", InferStructInfoAttention); + .set_attr("FInferStructInfo", InferStructInfoAttention) + .set_attr("FPurity", Bool(true)); TVM_REGISTER_OP("relax.nn.attention_bias") .set_attrs_type() @@ -132,7 +133,8 @@ TVM_REGISTER_OP("relax.nn.attention_bias") .add_argument("bias", "Tensor", "The input bias tensor.") .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) .set_attr("FInferMixedPrecision", InferMixedPrecisionAttention) - .set_attr("FInferStructInfo", InferStructInfoAttention); + .set_attr("FInferStructInfo", InferStructInfoAttention) + .set_attr("FPurity", Bool(true)); } // namespace relax } // namespace tvm diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc index ae84409c2a14..d698cf9757d3 100644 --- a/src/relax/op/nn/convolution.cc +++ b/src/relax/op/nn/convolution.cc @@ -182,7 +182,8 @@ TVM_REGISTER_OP("relax.nn.conv1d") .set_attr("FInferStructInfo", InferStructInfoConv1d) .set_attr("FRelaxInferLayout", InferLayoutConv1d) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) - .set_attr("FInferMixedPrecision", InferMixedPrecisionConv1d); + .set_attr("FInferMixedPrecision", InferMixedPrecisionConv1d) + .set_attr("FPurity", Bool(true)); /* relax.nn.conv2d */ TVM_REGISTER_NODE_TYPE(Conv2DAttrs); @@ -348,7 +349,8 @@ TVM_REGISTER_OP("relax.nn.conv2d") .set_attr("FInferStructInfo", InferStructInfoConv2d) .set_attr("FRelaxInferLayout", InferLayoutConv2d) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) - .set_attr("FInferMixedPrecision", InferMixedPrecisionConv2d); + .set_attr("FInferMixedPrecision", InferMixedPrecisionConv2d) + .set_attr("FPurity", Bool(true)); /* relax.nn.conv2d_transpose */ TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs); @@ -492,7 +494,8 @@ TVM_REGISTER_OP("relax.nn.conv2d_transpose") .add_argument("data", "Tensor", "The input tensor.") .add_argument("weight", "Tensor", "The weight tensor.") .set_attrs_type() - .set_attr("FInferStructInfo", InferStructInfoConv2dTranspose); + .set_attr("FInferStructInfo", InferStructInfoConv2dTranspose) + .set_attr("FPurity", Bool(true)); } // namespace relax } // namespace tvm diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index ec2205d1b739..215c9ead8110 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -80,7 +80,8 @@ TVM_REGISTER_OP("relax.nn.softmax") .add_argument("data", "Tensor", "The input tensor.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoSoftmax) - .set_attr("FRelaxInferLayout", InferLayoutSoftmax); + .set_attr("FRelaxInferLayout", InferLayoutSoftmax) + .set_attr("FPurity", Bool(true)); /* relax.nn.log_softmax */ Expr log_softmax(Expr data, int axis) { @@ -96,7 +97,8 @@ TVM_REGISTER_OP("relax.nn.log_softmax") .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_attrs_type() - .set_attr("FInferStructInfo", InferStructInfoSoftmax); + .set_attr("FInferStructInfo", InferStructInfoSoftmax) + .set_attr("FPurity", Bool(true)); bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx, const Array& input_sinfo, Array axes) { @@ -234,7 +236,8 @@ TVM_REGISTER_OP("relax.nn.batch_norm") .add_argument("moving_mean", "Tensor", "Running mean of input.") .add_argument("moving_var", "Tensor", "Running variance of input.") .set_attr("FInferStructInfo", InferStructInfoBatchNorm) - .set_attr("FRelaxInferLayout", InferLayoutBatchNorm); + .set_attr("FRelaxInferLayout", InferLayoutBatchNorm) + .set_attr("FPurity", Bool(true)); /* relax.nn.layer_norm */ TVM_REGISTER_NODE_TYPE(LayerNormAttrs); @@ -296,7 +299,8 @@ TVM_REGISTER_OP("relax.nn.layer_norm") .add_argument("beta", "Tensor", "The beta offset factor.") .set_attr("FInferStructInfo", InferStructInfoLayerNorm) .set_attr("FRelaxInferLayout", InferLayoutLayerNorm) - .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); /* relax.nn.group_norm */ TVM_REGISTER_NODE_TYPE(GroupNormAttrs); @@ -407,7 +411,8 @@ TVM_REGISTER_OP("relax.nn.group_norm") .add_argument("beta", "Tensor", "The beta offset factor.") .set_attr("FInferStructInfo", InferStructInfoGroupNorm) .set_attr("FRelaxInferLayout", InferLayoutGroupNorm) - .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); /* relax.nn.dropout */ TVM_REGISTER_NODE_TYPE(DropoutAttrs); @@ -433,7 +438,8 @@ TVM_REGISTER_OP("relax.nn.dropout") .add_argument("data", "Tensor", "Input to which dropout will be applied.") .set_attr("FInferStructInfo", InferStructInfoDropout) .set_attr("FRelaxInferLayout", InferLayoutUnaryEwise) - .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); /* relax.nn.cross_entropy_with_logits */ StructInfo InferStructInfoCrossEntropy(const Call& call, const BlockBuilder& ctx) { @@ -491,7 +497,8 @@ TVM_REGISTER_OP("relax.nn.cross_entropy_with_logits") .set_num_inputs(2) .add_argument("predictions", "Tensor", "The predictions.") .add_argument("labels", "Tensor", "The labels.") - .set_attr("FInferStructInfo", InferStructInfoCrossEntropy); + .set_attr("FInferStructInfo", InferStructInfoCrossEntropy) + .set_attr("FPurity", Bool(true)); /* relax.nn.nll_loss */ TVM_REGISTER_NODE_TYPE(NLLLossAttrs); @@ -716,7 +723,8 @@ TVM_REGISTER_OP("relax.nn.nll_loss") .add_argument("predictions", "Tensor", "The prediction tensor.") .add_argument("targets", "Tensor", "The target tensor.") .add_argument("weights", "Optional", "The weight of each target values.") - .set_attr("FInferStructInfo", InferStructInfoNLLLoss); + .set_attr("FInferStructInfo", InferStructInfoNLLLoss) + .set_attr("FPurity", Bool(true)); } // namespace relax } // namespace tvm diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc index c31ce3dd0ba6..bfbb4b4284de 100644 --- a/src/relax/op/nn/pooling.cc +++ b/src/relax/op/nn/pooling.cc @@ -140,7 +140,8 @@ TVM_REGISTER_OP("relax.nn.max_pool2d") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoPool2D) .set_attr("FRelaxInferLayout", InferLayoutPool2d) - .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); Expr avg_pool2d(Expr data, Array pool_size, Array strides, Array padding, Array dilation, bool ceil_mode, String layout, @@ -157,7 +158,8 @@ TVM_REGISTER_OP("relax.nn.avg_pool2d") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoPool2D) .set_attr("FRelaxInferLayout", InferLayoutPool2d) - .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); /* relax.nn.adaptive_avg_pool2d */ TVM_REGISTER_NODE_TYPE(AdaptivePool2DAttrs); @@ -240,7 +242,8 @@ TVM_REGISTER_OP("relax.nn.adaptive_avg_pool2d") .add_argument("data", "Tensor", "The input tensor") .set_attr("FInferStructInfo", InferStructInfoAdaptiveAvgPool2D) .set_attr("FRelaxInferLayout", InferLayoutAdaptiveAvgPool2D) - .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); } // namespace relax } // namespace tvm diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index f2106f155023..f1fb5c52bd86 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -73,6 +73,53 @@ StructInfo InferStructInfoShapeOf(const Call& call, const BlockBuilder& ctx) { return ShapeStructInfo(tensor_shape->values); } +// call_pure_packed + +StructInfo InferStructInfoCallPurePacked(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() < 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "call_pure_packed must be called with at least one argument"); + } + + // the callee must be an opaque function + auto callee = call->args[0]; + ICHECK(!callee.as()) << "call_pure_packed cannot be used with an op node"; + auto opt = MatchStructInfo(callee); + ICHECK(opt) << "Callee must have a function struct info"; + FuncStructInfo finfo = opt.value(); + ICHECK(finfo->IsOpaque()) << "call_pure_packed must be called with an opaque function, but " + << callee << " is not opaque"; + + // same logic as from DeriveCallRetStructInfo for ordinary calls + if (finfo->derive_func.defined()) { + // derive using custom derivation function. + return finfo->derive_func.value()(call, ctx); + } else { + // directly return the normal value. + return finfo->ret; + } +} + +RELAY_REGISTER_OP("relax.call_pure_packed") + .set_num_inputs(-1) + .add_argument("args", "Array", + "The first argument is the function being called. The rest are the " + "arguments to that function.") + .set_attr("FInferStructInfo", InferStructInfoCallPurePacked) + .set_attr("FPurity", Bool(true)); + +Expr MakeCallPurePacked(const Expr& callee, Array args, const Attrs& attrs, + Array sinfo_args) { + static const Op& op = Op::Get("relax.call_pure_packed"); + Array call_args = {callee}; + for (auto arg : args) { + call_args.push_back(arg); + } + return Call(op, call_args, attrs, sinfo_args); +} + +TVM_REGISTER_GLOBAL("relax.op.call_pure_packed").set_body_typed(MakeCallPurePacked); + // call_tir StructInfo InferStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) { @@ -93,7 +140,8 @@ RELAY_REGISTER_OP("relax.call_tir") .add_argument("packed_ints", "Expr", "ShapeExpr representing a tuple of ints to unpack during runtime. Omitted from " "args if unused") - .set_attr("FInferStructInfo", InferStructInfoCallTIR); + .set_attr("FInferStructInfo", InferStructInfoCallTIR) + .set_attr("FPurity", Bool(true)); Expr MakeCallTIR(Expr func, Tuple args, Array out_sinfo_list, Optional packed_ints) { @@ -138,7 +186,10 @@ RELAY_REGISTER_OP("relax.call_dps_packed") .set_num_inputs(2) .add_argument("func", "Expr", "The destination-passing-style function.") .add_argument("args", "Tuple", "The input arguments.") - .set_attr("FInferStructInfo", InferStructInfoCallDPSPacked); + .set_attr("FInferStructInfo", InferStructInfoCallDPSPacked) + // technically, an impure op could be used with this, but there is + // little reason to use DPS with an impure op + .set_attr("FPurity", Bool(true)); Expr MakeCallDPSPacked(Expr func, Tuple args, Array out_sinfo_list) { for (const TensorStructInfo& sinfo : out_sinfo_list) { @@ -177,7 +228,9 @@ TVM_REGISTER_OP("relax.call_builtin_with_ctx") .set_num_inputs(4) .add_argument("func", "Expr", "The builtin packed func.") .add_argument("args", "Tuple", "The input arguments.") - .set_attr("FInferStructInfo", InferStructInfoCallBuiltinWithCtx); + .set_attr("FInferStructInfo", InferStructInfoCallBuiltinWithCtx) + // Most builtins are pure, but some are not, like `vm.builtin.attention_kv_cache_append` + .set_attr("FPurity", Bool(false)); Expr MakeCallBuiltinWithCtx(Expr func, Tuple args, Array sinfo_args) { static const Op& op = Op::Get("relax.call_builtin_with_ctx"); @@ -188,7 +241,8 @@ TVM_REGISTER_GLOBAL("relax.op.call_builtin_with_ctx").set_body_typed(MakeCallBui TVM_REGISTER_OP("relax.null_value") .set_num_inputs(0) - .set_attr("FInferStructInfo", ReturnObjectStructInfo); + .set_attr("FInferStructInfo", ReturnObjectStructInfo) + .set_attr("FPurity", Bool(true)); Expr MakeCallNullValue() { static const Op& op = Op::Get("relax.null_value"); @@ -205,7 +259,8 @@ RELAY_REGISTER_OP("relax.print") "The first value is Python-style format string to use to print. The others " "are values to print") .set_attr("FInferStructInfo", ReturnVoidStructInfo) - .set_attr("FCallPacked", "relax.run.print"); + .set_attr("FCallPacked", "relax.run.print") + .set_attr("FPurity", Bool(false)); Expr MakePrint(Array vals, StringImm format) { Array params; @@ -247,7 +302,8 @@ RELAY_REGISTER_OP("relax.assert_op") "Python-style format string to use for displaying an error message, if the " "assert fails. The others are used as format arguments if there is an error.") .set_attr("FInferStructInfo", InferAssertStructInfo) - .set_attr("FCallPacked", "relax.run.assert_op"); + .set_attr("FCallPacked", "relax.run.assert_op") + .set_attr("FPurity", Bool(false)); Expr MakeAssertOp(Expr condition, Array vals, StringImm format) { static const Op& op = Op::Get("relax.assert_op"); @@ -267,7 +323,8 @@ RELAY_REGISTER_OP("relax.make_closure") .set_num_inputs(2) .add_argument("func", "Expr", "The closure.") .add_argument("args", "Tuple", "The captured variables.") - .set_attr("FInferStructInfo", ReturnObjectStructInfo); + .set_attr("FInferStructInfo", ReturnObjectStructInfo) + .set_attr("FPurity", Bool(true)); Expr MakeClosure(Expr func, Tuple args) { static const Op& op = Op::Get("relax.make_closure"); @@ -292,7 +349,9 @@ RELAY_REGISTER_OP("relax.invoke_closure") .set_num_inputs(2) .add_argument("closure", "Expr", "The VMClosure.") .add_argument("args", "Tuple", "The captured variables.") - .set_attr("FInferStructInfo", InferStructInfoInvokeClosure); + .set_attr("FInferStructInfo", InferStructInfoInvokeClosure) + // Not all closures are pure. Use invoke_pure_closure for specifying purity + .set_attr("FPurity", Bool(false)); Expr InvokeClosure(Expr closure, Tuple args, Array sinfo_args) { static const Op& op = Op::Get("relax.invoke_closure"); @@ -301,12 +360,29 @@ Expr InvokeClosure(Expr closure, Tuple args, Array sinfo_args) { TVM_REGISTER_GLOBAL("relax.op.invoke_closure").set_body_typed(InvokeClosure); +// invoke_pure_closure + +RELAY_REGISTER_OP("relax.invoke_pure_closure") + .set_num_inputs(2) + .add_argument("closure", "Expr", "The VMClosure.") + .add_argument("args", "Tuple", "The captured variables.") + .set_attr("FInferStructInfo", InferStructInfoInvokeClosure) + .set_attr("FPurity", Bool(true)); + +Expr InvokePureClosure(Expr closure, Tuple args, Array sinfo_args) { + static const Op& op = Op::Get("relax.invoke_pure_closure"); + return Call(op, {closure, args}, {}, sinfo_args); +} + +TVM_REGISTER_GLOBAL("relax.op.invoke_pure_closure").set_body_typed(InvokePureClosure); + // shape_of RELAY_REGISTER_OP("relax.shape_of") .set_num_inputs(1) .add_argument("input", "Expr", "The input expression") - .set_attr("FInferStructInfo", InferStructInfoShapeOf); + .set_attr("FInferStructInfo", InferStructInfoShapeOf) + .set_attr("FPurity", Bool(true)); Expr MakeShapeOf(Expr expr) { static const Op& op = Op::Get("relax.shape_of"); @@ -332,7 +408,8 @@ StructInfo ReturnTensorToShapeStructInfo(const Call& call, const BlockBuilder& c RELAY_REGISTER_OP("relax.tensor_to_shape") .set_num_inputs(1) .add_argument("input", "Expr", "The input expression") - .set_attr("FInferStructInfo", ReturnTensorToShapeStructInfo); + .set_attr("FInferStructInfo", ReturnTensorToShapeStructInfo) + .set_attr("FPurity", Bool(true)); Expr MakeTensorToShape(Expr expr) { static const Op& op = Op::Get("relax.tensor_to_shape"); @@ -355,7 +432,8 @@ RELAY_REGISTER_OP("relax.shape_to_tensor") .set_num_inputs(1) .add_argument("input", "Expr", "The input expression") .set_attr("FInferStructInfo", ReturnShapeToTensorStructInfo) - .set_attr("FCallPacked", "relax.run.shape_to_tensor"); + .set_attr("FCallPacked", "relax.run.shape_to_tensor") + .set_attr("FPurity", Bool(true)); Expr MakeShapeToTensor(Expr expr) { static const Op& op = Op::Get("relax.shape_to_tensor"); @@ -386,7 +464,9 @@ RELAY_REGISTER_OP("relax.builtin.alloc_tensor") .add_argument("runtime_device_index", "PrimValue", "The device index indicating on which device the tensor is to be " "allocated at runtime. Index -1 is reserved for the host device.") - .set_attr("FInferStructInfo", InferStructInfoAllocateTensor); + .set_attr("FInferStructInfo", InferStructInfoAllocateTensor) + // memory allocation isn't considered a "visible effect" as far as purity is concerned + .set_attr("FPurity", Bool(true)); Expr MakeAllocTensor(Expr shape, DataTypeImm dtype, PrimValue runtime_device_index) { static const Op& op = Op::Get("relax.builtin.alloc_tensor"); @@ -407,7 +487,9 @@ RELAY_REGISTER_OP("relax.memory.alloc_storage") .add_argument("storage_scope", "StringImm", "The storage scope of the storage to allocate. Default is global.") .add_argument("dtype", "DataTypeImm", "The dtype of the tensor to allocate.") - .set_attr("FInferStructInfo", ReturnObjectStructInfo); + .set_attr("FInferStructInfo", ReturnObjectStructInfo) + // memory allocation isn't considered a "visible effect" as far as purity is concerned + .set_attr("FPurity", Bool(true)); Expr MakeAllocStorage(Expr size, PrimValue virtual_device_index, StringImm storage_scope, DataTypeImm dtype) { @@ -436,7 +518,9 @@ RELAY_REGISTER_OP("relax.memory.alloc_tensor") .add_argument("offset", "PrimValue", "Storage offset to allocate the tensor.") .add_argument("shape", "Expr", "The shape of the tensor to allocate.") .add_argument("dtype", "DataTypeImm", "The dtype of the tensor to allocate.") - .set_attr("FInferStructInfo", InferStructInfoMemAllocTensor); + .set_attr("FInferStructInfo", InferStructInfoMemAllocTensor) + // memory allocation isn't considered a "visible effect" as far as purity is concerned + .set_attr("FPurity", Bool(true)); Expr MakeMemAllocTensor(Expr storage, PrimValue offset, Expr shape, DataTypeImm dtype) { static const Op& op = Op::Get("relax.memory.alloc_tensor"); @@ -450,7 +534,9 @@ TVM_REGISTER_GLOBAL("relax.op.memory.alloc_tensor").set_body_typed(MakeMemAllocT RELAY_REGISTER_OP("relax.memory.kill_storage") .set_num_inputs(1) .add_argument("storage", "Expr", "The storage to be killed.") - .set_attr("FInferStructInfo", ReturnVoidStructInfo); + .set_attr("FInferStructInfo", ReturnVoidStructInfo) + // deallocation also isn't considered a "visible effect" as far as purity is concerned + .set_attr("FPurity", Bool(true)); Expr MakeMemKillStorage(Expr storage) { static const Op& op = Op::Get("relax.memory.kill_storage"); @@ -464,7 +550,9 @@ TVM_REGISTER_GLOBAL("relax.op.memory.kill_storage").set_body_typed(MakeMemKillSt RELAY_REGISTER_OP("relax.memory.kill_tensor") .set_num_inputs(1) .add_argument("tensor", "Expr", "The tensor to be killed.") - .set_attr("FInferStructInfo", ReturnVoidStructInfo); + .set_attr("FInferStructInfo", ReturnVoidStructInfo) + // memory deallocation also isn't considered a "visible effect" as far as purity is concerned + .set_attr("FPurity", Bool(true)); Expr MakeMemKillTensor(Expr tensor) { static const Op& op = Op::Get("relax.memory.kill_tensor"); @@ -482,7 +570,9 @@ RELAY_REGISTER_OP("relax.vm.alloc_storage") .add_argument("runtime_device_index", "PrimValue", "The device index indicating on which device the tensor is " "to be allocated at runtime.") - .set_attr("FInferStructInfo", ReturnObjectStructInfo); + .set_attr("FInferStructInfo", ReturnObjectStructInfo) + // memory allocation isn't considered a "visible effect" as far as purity is concerned + .set_attr("FPurity", Bool(true)); Expr MakeVMAllocStorage(Expr size, PrimValue runtime_device_index, DataTypeImm dtype) { static const Op& op = Op::Get("relax.vm.alloc_storage"); @@ -517,7 +607,9 @@ RELAY_REGISTER_OP("relax.vm.alloc_tensor") .add_argument("offset", "PrimValue", "Storage offset to allocate the tensor.") .add_argument("shape", "Expr", "The shape of the tensor to allocate.") .add_argument("dtype", "DataTypeImm", "The dtype of the tensor to allocate.") - .set_attr("FInferStructInfo", InferStructInfoVMAllocTensor); + .set_attr("FInferStructInfo", InferStructInfoVMAllocTensor) + // memory allocation isn't considered a "visible effect" as far as purity is concerned + .set_attr("FPurity", Bool(true)); Expr MakeVMAllocTensor(Expr storage, PrimValue offset, Expr shape, DataTypeImm dtype) { static const Op& op = Op::Get("relax.vm.alloc_tensor"); @@ -531,7 +623,9 @@ TVM_REGISTER_GLOBAL("relax.op.vm.alloc_tensor").set_body_typed(MakeVMAllocTensor TVM_REGISTER_OP("relax.vm.kill_object") .set_num_inputs(1) .add_argument("obj", "Expr", "The object to be killed.") - .set_attr("FInferStructInfo", ReturnVoidStructInfo); + .set_attr("FInferStructInfo", ReturnVoidStructInfo) + // deallocation also isn't considered a "visible effect" as far as purity is concerned + .set_attr("FPurity", Bool(true)); Expr MakeVMKillObject(Expr obj) { static const Op& op = Op::Get("relax.vm.kill_object"); @@ -547,7 +641,8 @@ RELAY_REGISTER_OP("relax.vm.call_tir_dyn") .add_argument("func", "Expr", "The destination-passing-style function.") .add_argument("args", "Tuple", "The input arguments (list of tensors and last argument is ShapeExpr)") - .set_attr("FInferStructInfo", ReturnVoidStructInfo); + .set_attr("FInferStructInfo", ReturnVoidStructInfo) + .set_attr("FPurity", Bool(true)); Expr MakeCallTIRDyn(Expr func, Tuple args) { static const Op& op = Op::Get("relax.vm.call_tir_dyn"); @@ -564,7 +659,8 @@ StructInfo InferStructInfoStopLiftParams(const Call& call, const BlockBuilder& c RELAY_REGISTER_OP("relax.builtin.stop_lift_params") .set_num_inputs(1) .add_argument("x", "Expr", "The input data") - .set_attr("FInferStructInfo", InferStructInfoStopLiftParams); + .set_attr("FInferStructInfo", InferStructInfoStopLiftParams) + .set_attr("FPurity", Bool(true)); Expr MakeStopLiftParams(Expr x) { static const Op& op = Op::Get("relax.builtin.stop_lift_params"); diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index f7cff638cd98..a6b437111b46 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -82,12 +82,13 @@ Array GetTensorStructInfoFromTuple(const Call& call, const Blo * \param OpRegName The name of operator to register. The name passed in will * be prepended with a prefix "relax." as the identifier string in the operator registry. */ -#define RELAX_REGISTER_UNARY_OP(OpRegName) \ - TVM_REGISTER_OP("relax." OpRegName) \ - .set_num_inputs(1) \ - .add_argument("x", "Tensor", "The input tensor.") \ - .set_attr("FRelaxInferLayout", InferLayoutUnaryEwise) \ - .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) +#define RELAX_REGISTER_UNARY_OP(OpRegName) \ + TVM_REGISTER_OP("relax." OpRegName) \ + .set_num_inputs(1) \ + .add_argument("x", "Tensor", "The input tensor.") \ + .set_attr("FRelaxInferLayout", InferLayoutUnaryEwise) \ + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) \ + .set_attr("FPurity", Bool(true)) /*! * \brief Quick helper macro to expose a make-function to construct the operator. diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h index e386f9019fd4..06f3944d8543 100644 --- a/src/relax/op/tensor/binary.h +++ b/src/relax/op/tensor/binary.h @@ -37,18 +37,19 @@ namespace relax { * 1. be prepended with a prefix "relax.op." as the FFI identifier string for the make function, * 2. be prepended with a prefix "relax." as the identifier string in the operator registry. */ -#define RELAX_REGISTER_BINARY_OP_AND_IMPL(OpName) \ - Expr OpName(Expr x1, Expr x2) { \ - static const Op& op = Op::Get("relax." #OpName); \ - return Call(op, {x1, x2}, Attrs(), {}); \ - } \ - TVM_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName); \ - TVM_REGISTER_OP("relax." #OpName) \ - .set_num_inputs(2) \ - .add_argument("x1", "Tensor", "The first input tensor.") \ - .add_argument("x2", "Tensor", "The second input tensor.") \ - .set_attr("FRelaxInferLayout", InferLayoutBinaryEwise) \ - .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) +#define RELAX_REGISTER_BINARY_OP_AND_IMPL(OpName) \ + Expr OpName(Expr x1, Expr x2) { \ + static const Op& op = Op::Get("relax." #OpName); \ + return Call(op, {x1, x2}, Attrs(), {}); \ + } \ + TVM_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName); \ + TVM_REGISTER_OP("relax." #OpName) \ + .set_num_inputs(2) \ + .add_argument("x1", "Tensor", "The first input tensor.") \ + .add_argument("x2", "Tensor", "The second input tensor.") \ + .set_attr("FRelaxInferLayout", InferLayoutBinaryEwise) \ + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) \ + .set_attr("FPurity", Bool(true)) #define RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(OpName) \ RELAX_REGISTER_BINARY_OP_AND_IMPL(OpName).set_attr( \ diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc index 053ca28a6c8d..dabf3155f0f8 100644 --- a/src/relax/op/tensor/create.cc +++ b/src/relax/op/tensor/create.cc @@ -86,7 +86,8 @@ TVM_REGISTER_OP("relax.full") .add_argument("shape", "Shape", "The shape of the created tensor.") .add_argument("fill_value", "Tensor", "The scalar tensor, denoting the value to fill.") .set_attr("FInferStructInfo", InferStructInfoFull) - .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); /* relax.full_like */ Expr full_like(Expr x, Expr fill_value, DataType dtype) { @@ -124,7 +125,8 @@ TVM_REGISTER_OP("relax.full_like") .add_argument("x", "Tensor", "The input tensor.") .add_argument("fill_value", "Tensor", "The scalar value to fill.") .set_attr("FInferStructInfo", InferStructInfoFullLike) - .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); // Structure info inference for ones and zeros StructInfo InferStructInfoOnesZeros(const Call& call, const BlockBuilder& ctx) { @@ -181,13 +183,15 @@ TVM_REGISTER_OP("relax.ones") .set_num_inputs(1) .add_argument("shape", "Shape", "The shape of the created tensor.") .set_attr("FInferStructInfo", InferStructInfoOnesZeros) - .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); TVM_REGISTER_OP("relax.ones_like") .set_attrs_type() .set_num_inputs(1) .add_argument("x", "Tensor", "The input tensor.") - .set_attr("FInferStructInfo", InferStructInfoOnesLikeZerosLike); + .set_attr("FInferStructInfo", InferStructInfoOnesLikeZerosLike) + .set_attr("FPurity", Bool(true)); /* relax.zeros & relax.zeros_like */ Expr zeros(Expr shape, DataType dtype) { @@ -214,13 +218,15 @@ TVM_REGISTER_OP("relax.zeros") .set_num_inputs(1) .add_argument("shape", "Shape", "The shape of the created tensor.") .set_attr("FInferStructInfo", InferStructInfoOnesZeros) - .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); TVM_REGISTER_OP("relax.zeros_like") .set_attrs_type() .set_num_inputs(1) .add_argument("x", "Tensor", "The input tensor.") - .set_attr("FInferStructInfo", InferStructInfoOnesLikeZerosLike); + .set_attr("FInferStructInfo", InferStructInfoOnesLikeZerosLike) + .set_attr("FPurity", Bool(true)); /* relax.arange */ Expr arange(PrimValue start, PrimValue stop, PrimValue step, DataType dtype) { @@ -271,7 +277,8 @@ TVM_REGISTER_OP("relax.arange") .add_argument("end", "PrimValue", "The ending value for the set of points.") .add_argument("step", "PrimValue", "The gap between each pair of adjacent points.") .set_attr("FInferStructInfo", InferStructInfoArange) - .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); /* relax.tril & relax.triu */ TVM_REGISTER_NODE_TYPE(TriluAttrs); @@ -310,13 +317,15 @@ TVM_REGISTER_OP("relax.tril") .set_attrs_type() .set_num_inputs(1) .add_argument("x", "Tensor", "The input tensor.") - .set_attr("FInferStructInfo", InferStructInfoTrilTriu); + .set_attr("FInferStructInfo", InferStructInfoTrilTriu) + .set_attr("FPurity", Bool(true)); TVM_REGISTER_OP("relax.triu") .set_attrs_type() .set_num_inputs(1) .add_argument("x", "Tensor", "The input tensor.") - .set_attr("FInferStructInfo", InferStructInfoTrilTriu); + .set_attr("FInferStructInfo", InferStructInfoTrilTriu) + .set_attr("FPurity", Bool(true)); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/datatype.cc b/src/relax/op/tensor/datatype.cc index 18747fedcda0..bc24285cf9c7 100644 --- a/src/relax/op/tensor/datatype.cc +++ b/src/relax/op/tensor/datatype.cc @@ -56,7 +56,8 @@ TVM_REGISTER_OP("relax.astype") .add_argument("x", "Tensor", "The input tensor") .set_attr("FInferStructInfo", InferStructInfoAstype) .set_attr("FRelaxInferLayout", InferLayoutUnaryEwise) - .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); /* relax.wrap_param */ TVM_REGISTER_NODE_TYPE(WrapParamAttrs); @@ -83,7 +84,8 @@ TVM_REGISTER_OP("relax.wrap_param") .set_attrs_type() .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor") - .set_attr("FInferStructInfo", InferStructInfoWrapParam); + .set_attr("FInferStructInfo", InferStructInfoWrapParam) + .set_attr("FPurity", Bool(true)); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/grad.cc b/src/relax/op/tensor/grad.cc index a3bddd951ba5..2fef2d09b9ec 100644 --- a/src/relax/op/tensor/grad.cc +++ b/src/relax/op/tensor/grad.cc @@ -43,7 +43,8 @@ StructInfo InferStructInfoNoGrad(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.grad.no_grad") .set_num_inputs(0) - .set_attr("FInferStructInfo", InferStructInfoNoGrad); + .set_attr("FInferStructInfo", InferStructInfoNoGrad) + .set_attr("FPurity", Bool(true)); /* relax.grad.nll_loss_backward */ Expr nll_loss_backward(Expr output_grad, Expr predictions, Expr targets, Optional weights, @@ -78,7 +79,8 @@ TVM_REGISTER_OP("relax.grad.nll_loss_backward") .add_argument("predictions", "Tensor", "The prediction tensor.") .add_argument("targets", "Tensor", "The target tensor.") .add_argument("weights", "Optional", "The weight of each target values.") - .set_attr("FInferStructInfo", InferStructInfoNLLLossBackward); + .set_attr("FInferStructInfo", InferStructInfoNLLLossBackward) + .set_attr("FPurity", Bool(true)); /* relax.grad.max_pool2d_backward */ Expr max_pool2d_backward(Expr output_grad, Expr data, Array pool_size, @@ -107,7 +109,8 @@ TVM_REGISTER_OP("relax.grad.max_pool2d_backward") .add_argument("output_grad", "Tensor", "The output gradient.") .add_argument("data", "Tensor", "The input tensor") .set_attrs_type() - .set_attr("FInferStructInfo", InferStructInfoMaxPool2DBackward); + .set_attr("FInferStructInfo", InferStructInfoMaxPool2DBackward) + .set_attr("FPurity", Bool(true)); /* relax.grad.avg_pool2d_backward */ Expr avg_pool2d_backward(Expr output_grad, Expr data, Array pool_size, @@ -136,7 +139,8 @@ TVM_REGISTER_OP("relax.grad.avg_pool2d_backward") .add_argument("output_grad", "Tensor", "The output gradient.") .add_argument("data", "Tensor", "The input tensor") .set_attrs_type() - .set_attr("FInferStructInfo", InferStructInfoAvgPool2DBackward); + .set_attr("FInferStructInfo", InferStructInfoAvgPool2DBackward) + .set_attr("FPurity", Bool(true)); /* relax.grad.take_backward */ TVM_REGISTER_NODE_TYPE(TakeAttrs); @@ -161,7 +165,8 @@ TVM_REGISTER_OP("relax.grad.take_backward") .add_argument("output_grad", "Tensor", "The output gradient.") .add_argument("x", "Tensor", "The source tensor.") .add_argument("indices", "Tensor", "The indices of the values to extract.") - .set_attr("FInferStructInfo", InferStructInfoTakeBackward); + .set_attr("FInferStructInfo", InferStructInfoTakeBackward) + .set_attr("FPurity", Bool(true)); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc index c3d38db4e194..a9c61bb56a35 100644 --- a/src/relax/op/tensor/index.cc +++ b/src/relax/op/tensor/index.cc @@ -95,7 +95,8 @@ TVM_REGISTER_OP("relax.take") .set_num_inputs(2) .add_argument("x", "Tensor", "The source tensor.") .add_argument("indices", "Tensor", "The indices of the values to extract.") - .set_attr("FInferStructInfo", InferStructInfoTake); + .set_attr("FInferStructInfo", InferStructInfoTake) + .set_attr("FPurity", Bool(true)); /* relax.strided_slice */ TVM_REGISTER_NODE_TYPE(StridedSliceAttrs); @@ -237,7 +238,8 @@ TVM_REGISTER_OP("relax.strided_slice") .add_argument("x", "Tensor", "The source tensor to be sliced.") .set_attr("FInferStructInfo", InferStructInfoStridedSlice) .set_attr("FRelaxInferLayout", InferLayoutStridedSlice) - .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); /* relax.dynamic_strided_slice */ Expr dynamic_strided_slice(Expr x, // @@ -311,7 +313,8 @@ TVM_REGISTER_OP("relax.dynamic_strided_slice") .add_argument("begin", "Tensor", "The indices to begin with in the slicing.") .add_argument("end", "Tensor", "Indices indicating end of the slice.") .add_argument("strides", "Tensor", "The stride values.") - .set_attr("FInferStructInfo", InferStructInfoDynStridedSlice); + .set_attr("FInferStructInfo", InferStructInfoDynStridedSlice) + .set_attr("FPurity", Bool(true)); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/linear_algebra.cc b/src/relax/op/tensor/linear_algebra.cc index 5f47e366c43f..b05fbaa5d3a9 100644 --- a/src/relax/op/tensor/linear_algebra.cc +++ b/src/relax/op/tensor/linear_algebra.cc @@ -126,7 +126,8 @@ TVM_REGISTER_OP("relax.matmul") .add_argument("x2", "Tensor", "The second input tensor.") .set_attr("FInferStructInfo", InferStructInfoMatmul) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) - .set_attr("FInferMixedPrecision", InferMixedPrecisionMatmul); + .set_attr("FInferMixedPrecision", InferMixedPrecisionMatmul) + .set_attr("FPurity", Bool(true)); /* relax.einsum */ TVM_REGISTER_NODE_TYPE(EinsumAttrs); @@ -188,7 +189,8 @@ TVM_REGISTER_OP("relax.einsum") .set_attrs_type() .set_num_inputs(1) .add_argument("operands", "Tensor", "The input tensors.") - .set_attr("FInferStructInfo", InferStructInfoEinsum); + .set_attr("FInferStructInfo", InferStructInfoEinsum) + .set_attr("FPurity", Bool(true)); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index d66388c34979..5b298110be55 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -108,7 +108,8 @@ TVM_REGISTER_OP("relax.broadcast_to") .add_argument("x", "Tensor", "The input tensor.") .add_argument("shape", "Shape", "The target shape.") .set_attr("FInferStructInfo", InferStructInfoBroadcastTo) - .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); /* relax.concat */ TVM_REGISTER_NODE_TYPE(ConcatAttrs); @@ -278,7 +279,8 @@ TVM_REGISTER_OP("relax.concat") .add_argument("tensors", "Tuple of Tensors", "The input list of tensors.") .set_attr("FInferStructInfo", InferStructInfoConcat) .set_attr("FRelaxInferLayout", InferLayoutConcat) - .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); /* relax.expand_dims */ TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs); @@ -375,7 +377,8 @@ TVM_REGISTER_OP("relax.expand_dims") .add_argument("x", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoExpandDims) .set_attr("FRelaxInferLayout", InferLayoutExpandDims) - .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); // Helper function for flatten and reshape. PrimExpr ComputeShapeProduct(const Array& shape_values) { @@ -416,7 +419,8 @@ TVM_REGISTER_OP("relax.flatten") .set_num_inputs(1) .add_argument("x", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoFlatten) - .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); /* relax.layout_transform */ TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs); @@ -479,7 +483,8 @@ TVM_REGISTER_OP("relax.layout_transform") .set_attrs_type() .add_argument("x", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoLayoutTransform) - .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); /* relax.permute_dims */ TVM_REGISTER_NODE_TYPE(PermuteDimsAttrs); @@ -591,7 +596,8 @@ TVM_REGISTER_OP("relax.permute_dims") .add_argument("x", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoPermuteDims) .set_attr("FRelaxInferLayout", InferLayoutPermuteDims) - .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); /* relax.reshape */ Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) { @@ -739,7 +745,8 @@ TVM_REGISTER_OP("relax.reshape") .add_argument("x", "Tensor", "The input tensor.") .add_argument("shape", "Shape", "The input new shape.") .set_attr("FInferStructInfo", InferStructInfoReshape) - .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); /* relax.split */ TVM_REGISTER_NODE_TYPE(SplitAttrs); @@ -873,7 +880,8 @@ TVM_REGISTER_OP("relax.split") .add_argument("x", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoSplit) .set_attr("FRelaxInferLayout", InferLayoutSplit) - .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); /* relax.squeeze */ TVM_REGISTER_NODE_TYPE(SqueezeAttrs); @@ -1029,7 +1037,8 @@ TVM_REGISTER_OP("relax.squeeze") .add_argument("x", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoSqueeze) .set_attr("FRelaxInferLayout", InferLayoutSqueeze) - .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); void CheckCollapseShape(const Call& call, const BlockBuilder& ctx, const Array& data_shape, const Array& target_shape) { @@ -1110,7 +1119,8 @@ TVM_REGISTER_OP("relax.collapse_sum_like") .add_argument("data", "Tensor", "The input tensor.") .add_argument("collapse_target", "Tensor", "The tensor whose shape is the shape to collapse to.") - .set_attr("FInferStructInfo", InferStructInfoCollapseSumLike); + .set_attr("FInferStructInfo", InferStructInfoCollapseSumLike) + .set_attr("FPurity", Bool(true)); /* relax.collapse_sum_to */ Expr collapse_sum_to(Expr data, Expr shape) { @@ -1159,7 +1169,8 @@ TVM_REGISTER_OP("relax.collapse_sum_to") .set_num_inputs(2) .add_argument("data", "Tensor", "The input tensor.") .add_argument("shape", "Shape", "The shape to collapse to.") - .set_attr("FInferStructInfo", InferStructInfoCollapseSumTo); + .set_attr("FInferStructInfo", InferStructInfoCollapseSumTo) + .set_attr("FPurity", Bool(true)); /* relax.repeat */ TVM_REGISTER_NODE_TYPE(RepeatAttrs); @@ -1223,7 +1234,8 @@ TVM_REGISTER_OP("relax.repeat") .set_attrs_type() .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") - .set_attr("FInferStructInfo", InferStructInfoRepeat); + .set_attr("FInferStructInfo", InferStructInfoRepeat) + .set_attr("FPurity", Bool(true)); /* relax.tile */ TVM_REGISTER_NODE_TYPE(TileAttrs); @@ -1285,7 +1297,8 @@ TVM_REGISTER_OP("relax.tile") .set_attrs_type() .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") - .set_attr("FInferStructInfo", InferStructInfoTile); + .set_attr("FInferStructInfo", InferStructInfoTile) + .set_attr("FPurity", Bool(true)); /* relax.flip */ TVM_REGISTER_NODE_TYPE(FlipAttrs); @@ -1321,7 +1334,8 @@ TVM_REGISTER_OP("relax.flip") .set_attrs_type() .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") - .set_attr("FInferStructInfo", InferStructInfoFlip); + .set_attr("FInferStructInfo", InferStructInfoFlip) + .set_attr("FPurity", Bool(true)); /* relax.scatter_elements */ TVM_REGISTER_NODE_TYPE(ScatterElementsAttrs); @@ -1435,7 +1449,8 @@ TVM_REGISTER_OP("relax.scatter_elements") .add_argument("data", "Tensor", "The input tensor.") .add_argument("indices", "Tensor", "The indices tensor.") .add_argument("updates", "Tensor", "The input tensor of updates.") - .set_attr("FInferStructInfo", InferStructInfoScatterElements); + .set_attr("FInferStructInfo", InferStructInfoScatterElements) + .set_attr("FPurity", Bool(true)); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/search.cc b/src/relax/op/tensor/search.cc index 71f37c743ff2..e1d684916cd2 100644 --- a/src/relax/op/tensor/search.cc +++ b/src/relax/op/tensor/search.cc @@ -93,7 +93,8 @@ TVM_REGISTER_OP("relax.where") .add_argument("condition", "Tensor", "When True, yield `x1`; otherwise, yield `x2`.") .add_argument("x1", "Tensor", "The first input tensor.") .add_argument("x2", "Tensor", "The second input tensor.") - .set_attr("FInferStructInfo", InferStructInfoWhere); + .set_attr("FInferStructInfo", InferStructInfoWhere) + .set_attr("FPurity", Bool(true)); /* relax.argmax & relax.argmin */ TVM_REGISTER_NODE_TYPE(ArgmaxArgminAttrs); @@ -155,19 +156,20 @@ StructInfo InferStructInfoArgmaxArgmin(const Call& call, const BlockBuilder& ctx return TensorStructInfo(ShapeExpr(out_shape), out_dtype); } -#define RELAX_REGISTER_ARGMAX_ARGMIN_OP(OpName) \ - Expr OpName(Expr x, Optional axis, bool keepdims) { \ - ObjectPtr attrs = make_object(); \ - attrs->axis = std::move(axis); \ - attrs->keepdims = std::move(keepdims); \ - static const Op& op = Op::Get("relax." #OpName); \ - return Call(op, {std::move(x)}, Attrs(attrs)); \ - } \ - TVM_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName); \ - TVM_REGISTER_OP("relax." #OpName) \ - .set_num_inputs(1) \ - .add_argument("x", "Tensor", "The input data tensor") \ - .set_attr("FInferStructInfo", InferStructInfoArgmaxArgmin); +#define RELAX_REGISTER_ARGMAX_ARGMIN_OP(OpName) \ + Expr OpName(Expr x, Optional axis, bool keepdims) { \ + ObjectPtr attrs = make_object(); \ + attrs->axis = std::move(axis); \ + attrs->keepdims = std::move(keepdims); \ + static const Op& op = Op::Get("relax." #OpName); \ + return Call(op, {std::move(x)}, Attrs(attrs)); \ + } \ + TVM_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName); \ + TVM_REGISTER_OP("relax." #OpName) \ + .set_num_inputs(1) \ + .add_argument("x", "Tensor", "The input data tensor") \ + .set_attr("FInferStructInfo", InferStructInfoArgmaxArgmin) \ + .set_attr("FPurity", Bool(true)); RELAX_REGISTER_ARGMAX_ARGMIN_OP(argmax); RELAX_REGISTER_ARGMAX_ARGMIN_OP(argmin); diff --git a/src/relax/op/tensor/set.cc b/src/relax/op/tensor/set.cc index 8df0813ed2b5..cb6a332d49eb 100644 --- a/src/relax/op/tensor/set.cc +++ b/src/relax/op/tensor/set.cc @@ -133,7 +133,8 @@ TVM_REGISTER_OP("relax.unique") "The dimension to apply unique. If it is NullOpt, the unique values of the flattened input " "are returned.") .set_attr("FInferStructInfo", InferStructInfoUnique) - .set_attr("FCallPacked", "relax.run.unique"); + .set_attr("FCallPacked", "relax.run.unique") + .set_attr("FPurity", Bool(true)); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/statistical.cc b/src/relax/op/tensor/statistical.cc index 6e67a2fdc28c..6d1cc86f0a5b 100644 --- a/src/relax/op/tensor/statistical.cc +++ b/src/relax/op/tensor/statistical.cc @@ -179,7 +179,8 @@ TVM_REGISTER_OP("relax.cumsum") .set_attrs_type() .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") - .set_attr("FInferStructInfo", InferStructInfoCumsum); + .set_attr("FInferStructInfo", InferStructInfoCumsum) + .set_attr("FPurity", Bool(true)); TVM_REGISTER_NODE_TYPE(StatisticalAttrs); diff --git a/src/relax/op/tensor/statistical.h b/src/relax/op/tensor/statistical.h index 29b7da5d6b70..23a6da99f142 100644 --- a/src/relax/op/tensor/statistical.h +++ b/src/relax/op/tensor/statistical.h @@ -55,7 +55,8 @@ namespace relax { .set_num_inputs(1) \ .add_argument("x", "Tensor", "The input data tensor") \ .set_attr("FInferStructInfo", InferStructInfoStatistical) \ - .set_attr("FRelaxInferLayout", InferLayoutStatistical) + .set_attr("FRelaxInferLayout", InferLayoutStatistical) \ + .set_attr("FPurity", Bool(true)) /*! * \brief Computes the maximum value of tensor elements over given axes. diff --git a/src/relax/op/tensor/ternary.cc b/src/relax/op/tensor/ternary.cc index 940192bd8e45..d1ff5b78635b 100644 --- a/src/relax/op/tensor/ternary.cc +++ b/src/relax/op/tensor/ternary.cc @@ -112,7 +112,8 @@ TVM_REGISTER_OP("relax.ewise_fma") .add_argument("x3", "Tensor", "The operand of the addition") .set_attr("FInferStructInfo", InferStructInfoEwiseFMA) .set_attr("FRelaxInferLayout", InferLayoutEwiseFMA) - .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); Expr ewise_fma(Expr x1, Expr x2, Expr x3) { static const Op& op = Op::Get("relax.ewise_fma"); diff --git a/src/relax/op/tensor/unary.cc b/src/relax/op/tensor/unary.cc index 5d4a39067f58..6713c4e31af6 100644 --- a/src/relax/op/tensor/unary.cc +++ b/src/relax/op/tensor/unary.cc @@ -67,7 +67,8 @@ TVM_REGISTER_OP("relax.clip") .add_argument("x", "Tensor", "The input tensor.") .add_argument("min", "PrimValue", "The lower-bound of the range to be clipped to") .add_argument("max", "PrimValue", "The upper-bound of the range to be clipped to") - .set_attr("FInferStructInfo", ReturnStructInfoFromArg<0>); + .set_attr("FInferStructInfo", ReturnStructInfoFromArg<0>) + .set_attr("FPurity", Bool(true)); Expr clip(Expr x, Expr min, Expr max) { CHECK(min->IsInstance()) diff --git a/src/relax/training/utils.cc b/src/relax/training/utils.cc index 0a88e4569fa8..37582e301550 100644 --- a/src/relax/training/utils.cc +++ b/src/relax/training/utils.cc @@ -78,7 +78,7 @@ class AppendLossMutator : private ExprMutator { loss_function_->params.end()); Expr new_body = this->VisitExpr(func->body); - return Function(new_params, new_body, NullOpt, func->attrs); + return Function(new_params, new_body, NullOpt, func->is_pure, func->attrs); } Expr VisitExpr_(const SeqExprNode* seq_expr) final { diff --git a/src/relax/transform/allocate_workspace.cc b/src/relax/transform/allocate_workspace.cc index b20f982efb01..95bbfbee7ca8 100644 --- a/src/relax/transform/allocate_workspace.cc +++ b/src/relax/transform/allocate_workspace.cc @@ -67,7 +67,7 @@ class ExternFunctionRewriter : ExprMutator { new_params.push_back(workspace_param); return Function(new_params, VisitExpr(func_node->body), func_node->ret_struct_info, - func_node->attrs); + func_node->is_pure, func_node->attrs); } return ExprMutator::VisitExpr_(func_node); } @@ -127,8 +127,8 @@ class WorkspaceProvider : ExprMutator { auto gvar = mod_->GetGlobalVar("main"); auto func = Downcast(mod_->Lookup(gvar)); - auto new_func = - Function(func->params, VisitExpr(func->body), func->ret_struct_info, func->attrs); + auto new_func = Function(func->params, VisitExpr(func->body), func->ret_struct_info, + func->is_pure, func->attrs); builder_->UpdateFunction(gvar, new_func); return builder_->GetContextIRModule(); } diff --git a/src/relax/transform/decompose_ops.cc b/src/relax/transform/decompose_ops.cc index 5c7bfaf96297..899c80c1c454 100644 --- a/src/relax/transform/decompose_ops.cc +++ b/src/relax/transform/decompose_ops.cc @@ -154,8 +154,10 @@ Expr TensorToShape(const Call& call_node, const BlockBuilder& builder) { ICHECK(sinfo); // call builtin function that converts tensor to shape tuple // TODO(@sunggg): Register operator for "vm.builtin.tensor_to_shape" - Var call = builder->Emit( - Call(ExternFunc("vm.builtin.tensor_to_shape"), {expr}, {}, {GetRef(sinfo)})); + static const Op& call_pure_packed_op = Op::Get("relax.call_pure_packed"); + Var call = + builder->Emit(Call(call_pure_packed_op, {ExternFunc("vm.builtin.tensor_to_shape"), expr}, {}, + {GetRef(sinfo)})); // Operators like reshape take the output of `TensorToShape` as their output shape. // Because TOPI expects to have such output shape in symbolic shape at least (i.e., diff --git a/src/relax/transform/eliminate_common_subexpr.cc b/src/relax/transform/eliminate_common_subexpr.cc index 9c9252ddfa72..6c772d2e204e 100644 --- a/src/relax/transform/eliminate_common_subexpr.cc +++ b/src/relax/transform/eliminate_common_subexpr.cc @@ -105,7 +105,8 @@ class CommonSubexprEliminator : public ExprMutator { if (new_body.same_as(func->body)) { return GetRef(func); } - return Function(func->params, new_body, func->ret_struct_info, func->attrs, func->span); + return Function(func->params, new_body, func->ret_struct_info, func->is_pure, func->attrs, + func->span); } // this should happen only for the inner function case diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index ad1dc3eb9814..8940768ced13 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -473,6 +473,7 @@ class FunctionCreator : public ExprMutator { Function function = Function(/*params=*/params_, // /*body=*/body, // /*ret_struct_info=*/NullOpt, // + /*is_pure=*/true, // /*attrs=*/DictAttrs(group_attrs)); Array free_vars = FreeSymbolicVars(function).Map([](const tir::Var& var) -> PrimExpr { return var; }); @@ -482,6 +483,7 @@ class FunctionCreator : public ExprMutator { function = Function(/*params=*/params_, // /*body=*/body, // /*ret_struct_info=*/NullOpt, // + /*is_pure=*/true, // /*attrs=*/DictAttrs(group_attrs)); } function_ = SymbolicVarRenewMutator::Renew(function); @@ -1088,7 +1090,7 @@ class CompositeFunctionAnnotator : public ExprMutator { auto new_body = VisitExpr(func->body); if (!new_body.same_as(func->body)) { auto new_func = Function(func->params, VisitExpr(func->body), func->ret_struct_info, - func->attrs, func->span); + func->is_pure, func->attrs, func->span); builder_->UpdateFunction(entry.first, new_func); } } @@ -1131,7 +1133,9 @@ class CompositeFunctionAnnotator : public ExprMutator { params.push_back(new_v); } - return Function(param_vars, Call(f_inner, params), func_node->ret_struct_info); + // pure if the inner func is pure (no need to force purity if it's forced for the inner func) + return Function(param_vars, Call(f_inner, params), func_node->ret_struct_info, + Downcast(f_inner)->is_pure); } private: diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 432ddca0a751..d5dcd64cc726 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -835,6 +835,7 @@ class TIRFuseMutator : public ExprMutator { Expr VisitExpr_(const CallNode* op) final { static const Op& call_tir_op_ = Op::Get("relax.call_tir"); + Call call = Downcast(builder_->Normalize(ExprMutator::VisitExpr_(op))); if (call->op->IsInstance()) { diff --git a/src/relax/transform/gradient.cc b/src/relax/transform/gradient.cc index e7bdea603663..7645ae8cb6c6 100644 --- a/src/relax/transform/gradient.cc +++ b/src/relax/transform/gradient.cc @@ -349,7 +349,7 @@ class GradientMutator : private ExprMutator { Expr new_body = this->VisitExpr(func->body); - return Function(func->params, new_body, NullOpt, func->attrs); + return Function(func->params, new_body, NullOpt, func->is_pure, func->attrs); } Expr VisitExpr_(const SeqExprNode* seq_expr) final { diff --git a/src/relax/transform/lambda_lift.cc b/src/relax/transform/lambda_lift.cc index 74920823100a..e3ed24cd9ed7 100644 --- a/src/relax/transform/lambda_lift.cc +++ b/src/relax/transform/lambda_lift.cc @@ -87,7 +87,20 @@ class LambdaLifter : public ExprMutator { if (this->var_remap_.find(var->vid) != this->var_remap_.end()) { clo_arg = this->var_remap_.at(var->vid); } - return Call(invoke_closure_op_, {clo_arg, Tuple(call_node->args)}, {}, + + // if the original op was pure, we should use invoke_pure_closure + Call orig_call = Downcast(val); + bool purity; + if (orig_call->op.as()) { + auto orig_op = Downcast(orig_call->op); + static const auto& purity_map = Op::GetAttrMap("FPurity"); + purity = purity_map.count(orig_op) && purity_map[orig_op]->value; + } else { + purity = GetStructInfoAs(orig_call->op)->purity; + } + + return Call(purity ? invoke_pure_closure_op_ : invoke_closure_op_, + {clo_arg, Tuple(call_node->args)}, {}, {GetStructInfo(GetRef(call_node))}); } auto it = lambda_map_.find(var); @@ -177,9 +190,11 @@ class LambdaLifter : public ExprMutator { if (all_params_unchanged && body.same_as(func_node->body)) { visited_func = GetRef(func_node); } else if (const auto& body_sinfo = MatchStructInfo(body)) { - visited_func = Function(params, body, body_sinfo.value(), func_node->attrs); + visited_func = + Function(params, body, body_sinfo.value(), func_node->is_pure, func_node->attrs); } else { - visited_func = Function(params, body, func_node->ret_struct_info, func_node->attrs); + visited_func = + Function(params, body, func_node->ret_struct_info, func_node->is_pure, func_node->attrs); } auto new_func = Downcast(visited_func); @@ -190,6 +205,7 @@ class LambdaLifter : public ExprMutator { /*params=*/new_func->params, /*body=*/new_func->body, /*ret_struct_info=*/new_func->ret_struct_info, + /*is_pure=*/new_func->is_pure, /*attrs=*/new_func->attrs, /*span=*/new_func->span); } else { @@ -206,6 +222,7 @@ class LambdaLifter : public ExprMutator { lifted_func = Function(/*params=*/closure_params, /*body=*/Bind(new_func->body, rebinding_map), /*ret_struct_info=*/new_func->ret_struct_info, + /*is_pure=*/new_func->is_pure, /*attrs=*/new_func->attrs, /*span=*/func->span); @@ -280,7 +297,8 @@ class LambdaLifter : public ExprMutator { for (auto pair : glob_funcs) { if (auto* n = pair.second.as()) { auto func = GetRef(n); - func = Function(func->params, VisitExpr(func->body), func->ret_struct_info, func->attrs); + func = Function(func->params, VisitExpr(func->body), func->ret_struct_info, func->is_pure, + func->attrs); builder_->UpdateFunction(pair.first, func); } } @@ -295,6 +313,7 @@ class LambdaLifter : public ExprMutator { /*! \brief Cache ops that would be used later to reduce lookup overhead. */ const Op& make_closure_op_ = Op::Get("relax.make_closure"); const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure"); + const Op& invoke_pure_closure_op_ = Op::Get("relax.invoke_pure_closure"); }; namespace transform { diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index 0953a8dacf0c..4469f3558593 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -75,9 +75,41 @@ class LegalizeMutator : public ExprMutator { private: using ExprMutator::VisitExpr_; + bool WrapPureCondition(const Op& op, const Expr& legalized) { + static const auto& purity_map = Op::GetAttrMap("FPurity"); + + // unlikely for this condition not to be met + if (const CallNode* call = legalized.as()) { + // if the original op is not pure, don't wrap + if (!(purity_map.count(op) && purity_map[op]->value)) { + return false; + } + if (const OpNode* call_op = call->op.as()) { + auto res_op = GetRef(call_op); + if (purity_map.count(res_op)) { + // if the legalized op is already pure, we *don't* need a wrapper + return !purity_map[res_op]->value; + } + } + // simplest case: wrap if the original op was pure and the result is somehow not + return true; + } + return false; + } + + Call WrapPureCall(const Call& ret) { + static const Op& call_pure_packed_op = Op::Get("relax.call_pure_packed"); + Array ret_args = {ret->op}; + for (auto arg : ret->args) { + ret_args.push_back(arg); + } + return Call(call_pure_packed_op, ret_args, ret->attrs, ret->sinfo_args); + } + Expr VisitExpr_(const CallNode* call) final { Call visited_call = Downcast(this->VisitExprPostOrder_(call)); static const auto& legalize_map = Op::GetAttrMap("FLegalize"); + static const Op& call_pure_packed_op = Op::Get("relax.call_pure_packed"); static const Op& call_tir_op = Op::Get("relax.call_tir"); static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed"); auto* op_node = visited_call->op.as(); @@ -103,15 +135,24 @@ class LegalizeMutator : public ExprMutator { // Priority: customize > default. // Check if it has customize legalization registered. if (cmap_.defined() && cmap_.value().count(op->name)) { - return cmap_.value()[op->name](this->builder_, visited_call); + auto ret = cmap_.value()[op->name](this->builder_, visited_call); + if (ret.IsObjectRef() && WrapPureCondition(op, ret.AsObjectRef())) { + return WrapPureCall(Downcast(ret.AsObjectRef())); + } + return ret; } // Check if it has default legalization registered. if (legalize_map.count(op)) { - return legalize_map[op](this->builder_, visited_call); + auto ret = legalize_map[op](this->builder_, visited_call); + if (WrapPureCondition(op, ret)) { + return WrapPureCall(Downcast(ret)); + } + return ret; } // No legalization. - if (enable_warning_ && op != call_tir_op && op != call_dps_packed_op) { + if (enable_warning_ && op != call_tir_op && op != call_dps_packed_op && + op != call_pure_packed_op) { LOG(WARNING) << "No legalization func for " << op->name << " is found."; } return visited_call; diff --git a/src/relax/transform/lift_transform_params.cc b/src/relax/transform/lift_transform_params.cc index cbeedb66944b..f7c9a4189dbb 100644 --- a/src/relax/transform/lift_transform_params.cc +++ b/src/relax/transform/lift_transform_params.cc @@ -277,7 +277,7 @@ class TransformParamsLifter : public ExprMutator { new_attrs = NullValue(); } - Function new_func(new_params, new_body, func->ret_struct_info, new_attrs); + Function new_func(new_params, new_body, func->ret_struct_info, func->is_pure, new_attrs); return new_func; } diff --git a/src/relax/transform/merge_composite_functions.cc b/src/relax/transform/merge_composite_functions.cc index f444d5c4f63f..81ee2ac7a124 100644 --- a/src/relax/transform/merge_composite_functions.cc +++ b/src/relax/transform/merge_composite_functions.cc @@ -288,8 +288,8 @@ class CompositeInliner : public ExprMutator { Function Run(Function func) { inlined_functions_ = Map(); auto new_body = VisitExpr(func->body); - auto new_func = - Function(func->params, new_body, func->ret_struct_info, func->attrs, func->span); + auto new_func = Function(func->params, new_body, func->ret_struct_info, func->is_pure, + func->attrs, func->span); return new_func; } diff --git a/src/relax/transform/normalize.cc b/src/relax/transform/normalize.cc index 915498178f0f..fdd2ccc17e4f 100644 --- a/src/relax/transform/normalize.cc +++ b/src/relax/transform/normalize.cc @@ -47,7 +47,7 @@ class NormalizeMutator : public ExprMutatorBase { if (body.same_as(op->body)) { return GetRef(op); } else { - return Function(op->params, body, op->ret_struct_info, op->attrs); + return Function(op->params, body, op->ret_struct_info, op->is_pure, op->attrs); } } diff --git a/src/relax/transform/remove_purity_checking.cc b/src/relax/transform/remove_purity_checking.cc new file mode 100644 index 000000000000..a8719c9d90f6 --- /dev/null +++ b/src/relax/transform/remove_purity_checking.cc @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/remove_purity_checking.cc + * \brief Apply kForcePure in all pure functions and unwrap all calls to pure overrides + */ +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +class PurityRemover : public ExprMutator { + public: + Function RemovePurity(Function func) { + bool purity = func->is_pure; + auto ret = func; + if (purity) { + ret = std::move(WithAttr(func, relax::attr::kForcePure, Bool(true))); + } + auto new_body = VisitExpr(ret->body); + if (!new_body.same_as(ret->body)) { + return Function(ret->params, new_body, ret->ret_struct_info, ret->is_pure, ret->attrs, + ret->span); + } + return ret; + } + + Expr VisitExpr_(const CallNode* call) override { + if (call->op == call_pure_packed_op_) { + auto ret = Call(call->args[0], Array(call->args.begin() + 1, call->args.end()), + call->attrs, call->sinfo_args); + return VisitExpr(ret); + } + if (call->op == invoke_pure_closure_op_) { + auto ret = Call(invoke_closure_op_, call->args, call->attrs, call->sinfo_args); + return VisitExpr(ret); + } + return ExprMutator::VisitExpr_(call); + } + + Expr VisitExpr_(const FunctionNode* func) override { + // handling inner functions: we will remove purity annotations from them too + return RemovePurity(GetRef(func)); + } + + private: + const Op& call_pure_packed_op_ = Op::Get("relax.call_pure_packed"); + const Op& invoke_pure_closure_op_ = Op::Get("relax.invoke_pure_closure"); + const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure"); +}; + +Function RemovePurityChecking(const Function& f) { return PurityRemover().RemovePurity(f); } + +namespace transform { + +Pass RemovePurityChecking() { + runtime::TypedPackedFunc pass_func = + [=](const Function& f, IRModule mod, PassContext pc) { + return relax::RemovePurityChecking(f); + }; + return CreateFunctionPass(pass_func, 0, "RemovePurityChecking", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.RemovePurityChecking").set_body_typed(RemovePurityChecking); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/rewrite_cuda_graph.cc b/src/relax/transform/rewrite_cuda_graph.cc index 42ec5fca9d08..2839060ce134 100644 --- a/src/relax/transform/rewrite_cuda_graph.cc +++ b/src/relax/transform/rewrite_cuda_graph.cc @@ -130,7 +130,10 @@ class FuncBuilder : public ExprMutator { auto output = builder_->Emit(Tuple(outputs)); auto block = builder_->EndBlock(); auto body = builder_->Normalize(SeqExpr({block}, output)); - auto func = Function(params, body, Downcast(output->struct_info_.value())); + Map attrs; + attrs.Set(relax::attr::kForcePure, Bool(true)); + auto func = Function(params, body, Downcast(output->struct_info_.value()), + /*is_pure=*/true, /*attrs=*/DictAttrs(attrs)); return func; } diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc index 05cf498a4b8a..e6aa450ff8e8 100644 --- a/src/relax/transform/static_plan_block_memory.cc +++ b/src/relax/transform/static_plan_block_memory.cc @@ -347,6 +347,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { void VisitExpr_(const CallNode* call) final { static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor"); static const Op& call_tir_dyn_op = Op::Get("relax.vm.call_tir_dyn"); + if (call->op == alloc_tensor_op) { // Create a storage token for builtin alloc_tensor. this->CreateToken(call); diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h index 4a348123ce27..489a36a5a413 100644 --- a/src/relax/transform/utils.h +++ b/src/relax/transform/utils.h @@ -276,7 +276,7 @@ class SymbolicVarRenewMutator : public ExprMutator, tir::ExprMutator { return GetRef(op); } else { auto new_ret_sinfo = this->VisitExprDepStructInfoField(op->ret_struct_info); - return Function(params, body, new_ret_sinfo, op->attrs); + return Function(params, body, new_ret_sinfo, op->is_pure, op->attrs); } } diff --git a/src/relax/utils.cc b/src/relax/utils.cc index 3b1364c6010b..b0816b0eda5c 100644 --- a/src/relax/utils.cc +++ b/src/relax/utils.cc @@ -19,6 +19,7 @@ #include "transform/utils.h" +#include #include namespace tvm { @@ -54,7 +55,9 @@ class ExprBinder : public ExprMutator { if (all_params_unchanged && body.same_as(op->body)) { return GetRef(op); } else { - return Function(params, body, VisitExprDepStructInfoField(op->ret_struct_info), op->attrs); + // purity won't be affected, no need to update annotation + return Function(params, body, VisitExprDepStructInfoField(op->ret_struct_info), op->is_pure, + op->attrs); } } @@ -111,6 +114,18 @@ bool IsLeafOrTuple(const Expr& expr) { expr.as() || expr.as(); } +bool IsImpureCall(const Call& call) { + if (auto op_ptr = call->op.as()) { + auto op = GetRef(op_ptr); + static auto purity_map = Op::GetAttrMap("FPurity"); + ICHECK(purity_map.count(op)) << "Cannot find the registered purity of this op: " << op->name; + return !(purity_map[op]->value); + } + // the StructInfo must be FuncStructInfo + auto func_struct_info = GetStructInfoAs(call->op); + return !func_struct_info->purity; +} + /*! * \brief Copy a new Relax function with new remapped vars and symbolic vars. * To get the var mapping from old vars to new vars, see FuncCopier in src/relax/transform/utils.h. diff --git a/src/script/ir_builder/relax/frame.cc b/src/script/ir_builder/relax/frame.cc index c78b9e73c534..00bbd2a551a6 100644 --- a/src/script/ir_builder/relax/frame.cc +++ b/src/script/ir_builder/relax/frame.cc @@ -61,6 +61,7 @@ void FunctionFrameNode::ExitWithScope() { tvm::relax::Function func(/*params=*/params, /*body=*/body, /*ret_struct_info=*/ret_struct_info, + /*is_pure=*/is_pure.value_or(Bool(true))->value, /*attrs=*/dict_attrs); // Step 2: Update IRModule. if (builder->frames.empty()) { diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc index 71a0651de859..5c39bedd4379 100644 --- a/src/script/ir_builder/relax/ir.cc +++ b/src/script/ir_builder/relax/ir.cc @@ -87,6 +87,15 @@ void FuncAttrs(Map attrs) { frame->attrs = attrs; } +void FuncIsPure(bool purity) { + FunctionFrame frame = FindFunctionFrame("R.is_pure"); + if (frame->is_pure.defined()) { + LOG(FATAL) << "ValueError: Duplicate function purity annotations, previous one is:\n" + << frame->is_pure.value(); + } + frame->is_pure = Bool(purity); +} + void FuncRetStructInfo(const tvm::relax::StructInfo& ret_sinfo) { FunctionFrame frame = FindFunctionFrame("R.func_ret_struct_info"); if (frame->ret_struct_info.defined()) { @@ -123,6 +132,7 @@ TVM_REGISTER_GLOBAL("script.ir_builder.relax.Function").set_body_typed(Function) TVM_REGISTER_GLOBAL("script.ir_builder.relax.Arg").set_body_typed(Arg); TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncName").set_body_typed(FuncName); TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncAttrs").set_body_typed(FuncAttrs); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncIsPure").set_body_typed(FuncIsPure); TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncRetStructInfo").set_body_typed(FuncRetStructInfo); TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncRetValue").set_body_typed(FuncRetValue); diff --git a/src/script/printer/relax/call.cc b/src/script/printer/relax/call.cc index c32ab8be2f0e..9bf9e50ee857 100644 --- a/src/script/printer/relax/call.cc +++ b/src/script/printer/relax/call.cc @@ -132,6 +132,44 @@ Optional PrintCallTIRDPSPacked(const relax::Call& n, const ObjectPath& return Relax(d, "call_tir")->Call(args, kwargs_keys, kwargs_values); } +Optional PrintAssertOp(const relax::Call& n, const ObjectPath& n_p, const IRDocsifier& d) { + static const Op& assert_op = Op::Get("relax.assert_op"); + if (!n->op.same_as(assert_op)) { + return NullOpt; + } + ICHECK(n->args.size() >= 2); + // special handling: it is important to indicate that the format string (second argument) + // is the _format_ string, or else roundtripping will fail + // (the format string will be interpreted as an argument and there will be a new default format + // string given) + Array args; + args.push_back(d->AsDoc(n->args[0], n_p->Attr("args")->ArrayIndex(0))); + ExprDoc second_arg = d->AsDoc(n->args[1], n_p->Attr("args")->ArrayIndex(1)); + for (size_t i = 2; i < n->args.size(); i++) { + args.push_back(d->AsDoc(n->args[i], n_p->Attr("args")->ArrayIndex(i))); + } + return Relax(d, "assert_op")->Call(args, {"format"}, {second_arg}); +} + +Optional PrintRelaxPrint(const relax::Call& n, const ObjectPath& n_p, + const IRDocsifier& d) { + static const Op& print_op = Op::Get("relax.print"); + if (!n->op.same_as(print_op)) { + return NullOpt; + } + ICHECK(n->args.size() >= 1); + // special handling: it is important to indicate that the format string (first argument) + // is the _format_ string, or else roundtripping will fail + // (the format string will be interpreted as an argument and there will be a new default format + // string given) + ExprDoc first_arg = d->AsDoc(n->args[0], n_p->Attr("args")->ArrayIndex(0)); + Array args; + for (size_t i = 1; i < n->args.size(); i++) { + args.push_back(d->AsDoc(n->args[i], n_p->Attr("args")->ArrayIndex(i))); + } + return Relax(d, "print")->Call(args, {"format"}, {first_arg}); +} + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](relax::Call n, ObjectPath n_p, IRDocsifier d) -> Doc { @@ -139,6 +177,14 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (Optional doc = PrintCallTIRDPSPacked(n, n_p, d)) { return doc.value(); } + // Special case: assert_op + if (Optional doc = PrintAssertOp(n, n_p, d)) { + return doc.value(); + } + // Special case: print + if (Optional doc = PrintRelaxPrint(n, n_p, d)) { + return doc.value(); + } ExprDoc prefix{nullptr}; Array args; Array kwargs_keys; diff --git a/src/script/printer/relax/function.cc b/src/script/printer/relax/function.cc index fd7bdddfcaf5..95169712d9a0 100644 --- a/src/script/printer/relax/function.cc +++ b/src/script/printer/relax/function.cc @@ -56,7 +56,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ExprStmtDoc(Relax(d, "func_attr") // ->Call({d->AsDoc(n->attrs, n_p->Attr("attrs"))}))); } - // Step 5. Print body + // Step 5. Print purity attributes (only include if it's impure) + if (!n->is_pure) { + (*f)->stmts.push_back(ExprStmtDoc(Relax(d, "is_impure")->Call({}))); + } + // Step 6. Print body Array body = PrintSeqExpr(Downcast(n->body), n_p->Attr("body"), d, /*use_ret=*/true); (*f)->stmts.insert((*f)->stmts.end(), body.begin(), body.end()); diff --git a/src/script/printer/relax/struct_info.cc b/src/script/printer/relax/struct_info.cc index c541619ec887..49162bb8242b 100644 --- a/src/script/printer/relax/struct_info.cc +++ b/src/script/printer/relax/struct_info.cc @@ -145,8 +145,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) params_doc.push_back(d->AsDoc(params[i], params_p->ArrayIndex(i))); } return Relax(d, "Callable") - ->Call({TupleDoc(params_doc), // - d->AsDoc(n->ret, n_p->Attr("ret"))}); + ->Call({TupleDoc(params_doc), // + d->AsDoc(n->ret, n_p->Attr("ret")), // + LiteralDoc::Boolean(n->purity, n_p->Attr("purity"))}); }); TVM_SCRIPT_REPR(relax::ObjectStructInfoNode, ReprPrintRelax); diff --git a/src/script/printer/relax/utils.h b/src/script/printer/relax/utils.h index 97acb79c3d24..88fc7491c2d4 100644 --- a/src/script/printer/relax/utils.h +++ b/src/script/printer/relax/utils.h @@ -21,6 +21,7 @@ #include #include +#include #include #include diff --git a/tests/python/relax/test_analysis_contains_impure_call.py b/tests/python/relax/test_analysis_contains_impure_call.py new file mode 100644 index 000000000000..bc7d663517eb --- /dev/null +++ b/tests/python/relax/test_analysis_contains_impure_call.py @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm import relax as rx +from tvm.relax.analysis import contains_impure_call +from tvm.script import relax as R + + +def test_simple_pure_case(): + @tvm.script.ir_module + class PureTest: + @R.function + def pure_func(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + y = R.add(x, x) + z = R.multiply(x, y) + return R.add(z, R.const(1, "int32")) + + assert not contains_impure_call(PureTest["pure_func"]) + + +def test_simple_impure_case(): + @tvm.script.ir_module + class ImpureTest: + @R.function + def impure_func() -> R.Object: + R.is_impure() + y = R.print(format="I am a message") + return y + + assert contains_impure_call(ImpureTest["impure_func"]) + + +def test_nested_function(): + @tvm.script.ir_module + class NestedTest: + @R.function + def pure_with_impure_nested() -> R.Tensor((), "int32"): + # unused + @R.function + def impure_inner() -> R.Object: + R.is_impure() + y = R.print(format="Another, worse, message") + return y + + x = R.const(0, dtype="int32") + return R.add(x, x) + + assert not contains_impure_call(NestedTest["pure_with_impure_nested"]) + assert contains_impure_call( + NestedTest["pure_with_impure_nested"].body.blocks[0].bindings[0].value + ) + + +def test_ignoring_recursive_call(): + # Ignoring a recursive call. This can be useful if some transformation + # removes an impure operation and the compiler needs to check if the impure + # function has become pure + @tvm.script.ir_module + class RecursiveTest: + @R.function + def recursive_impure() -> R.Object: + R.is_impure() + x = R.const(1, "int32") + y = R.add(x, x) + z = R.print(x, y, format="{} {}") + w = RecursiveTest.recursive_impure() + return w + + assert contains_impure_call(RecursiveTest["recursive_impure"]) + # but if we remove the impure call... + body = RecursiveTest["recursive_impure"].body + own_name = body.blocks[0].bindings[-1].value.op + # skipping the call to print... + new_bindings = [ + body.blocks[0].bindings[0], + body.blocks[0].bindings[1], + body.blocks[0].bindings[-1], + ] + # Note: we construct the function in this way so that we keep the old vars + # with their current StructInfo. That would get fixed during normalization. + # However, this situation is meant to correspond to an intermediate state + # that might arise within a pass. + new_body = rx.SeqExpr([rx.BindingBlock(new_bindings)], body.body) + + # if we didn't ignore the recursive call, the fact the var's StructInfo + # calls it impure would throw it off + assert not contains_impure_call(new_body, own_name=own_name) + assert contains_impure_call(new_body) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_analysis_struct_info_analysis.py b/tests/python/relax/test_analysis_struct_info_analysis.py index 85136d803bdb..d279b60b541c 100644 --- a/tests/python/relax/test_analysis_struct_info_analysis.py +++ b/tests/python/relax/test_analysis_struct_info_analysis.py @@ -458,7 +458,9 @@ def func_shape_mixed(c): func_shape_mixed(3), [ rx.ShapeStructInfo([10, 20]), - rx.FuncStructInfo.opaque_func(ret=rx.ShapeStructInfo(ndim=2)), + # have to specify purity because an impure function cannot be passed + # where a pure one is expected + rx.FuncStructInfo.opaque_func(ret=rx.ShapeStructInfo(ndim=2), purity=True), ], rx.ShapeStructInfo([30, 3]), ) diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py index 97f076dc6ce1..4c815b9bb4ea 100644 --- a/tests/python/relax/test_analysis_well_formed.py +++ b/tests/python/relax/test_analysis_well_formed.py @@ -547,5 +547,77 @@ def local(x: R.Tensor(["m", "n"], "float32")): assert rx.analysis.well_formed(mod) +def test_conditional_in_dataflow_block(): + # error: not allowed to have a conditional inside a dataflow block + x = rx.Var("x", rx.TensorStructInfo([], dtype="int32")) + y = rx.Var("y", rx.TensorStructInfo([], dtype="int32")) + block = rx.DataflowBlock([rx.VarBinding(y, rx.If(rx.const(True, dtype="bool"), x, x))]) + func = rx.Function([x], rx.SeqExpr([block], y), R.Tensor((), dtype="int32")).with_attr( + "global_symbol", "foo" + ) + mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func)) + assert not rx.analysis.well_formed(mod) + + +def test_unlabeled_impure(): + x = rx.Var("x", R.Tensor((), dtype="int32")) + y = rx.Var("y") + block = rx.BindingBlock([rx.VarBinding(y, rx.op.print(x, format="{}"))]) + # print is impure, but the function is not labeled as impure + func = rx.Function([x], rx.SeqExpr([block], x), R.Tensor((), dtype="int32")).with_attr( + "global_symbol", "foo" + ) + mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func)) + assert not rx.analysis.well_formed(mod) + + +def test_labeled_impure(): + # the function is labeled impure so the impure operation is permitted + x = rx.Var("x", R.Tensor((), dtype="int32")) + y = rx.Var("y") + block = rx.BindingBlock([rx.VarBinding(y, rx.op.print(x, format="{}"))]) + # print is impure, but the function is not labeled as impure + func = rx.Function( + [x], rx.SeqExpr([block], x), R.Tensor((), dtype="int32"), is_pure=False + ).with_attrs({"global_symbol": "foo"}) + mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func)) + assert rx.analysis.well_formed(mod) + + +def test_force_pure(): + x = rx.Var("x", R.Tensor((), dtype="int32")) + y = rx.Var("y") + block = rx.BindingBlock([rx.VarBinding(y, rx.op.print(x, format="{}"))]) + # print is impure, but force_pure overrides the judgment + func = rx.Function([x], rx.SeqExpr([block], x), R.Tensor((), dtype="int32")).with_attrs( + {"global_symbol": "foo", "relax.force_pure": True} + ) + mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func)) + assert rx.analysis.well_formed(mod) + + +def test_force_pure_improper(): + # we require both the is_pure and force_pure flags to be set together + x = rx.Var("x", R.Tensor((), dtype="int32")) + # otherwise inoffensive, but the flags are wrong + func = rx.Function( + [x], rx.SeqExpr([], x), R.Tensor((), dtype="int32"), is_pure=False + ).with_attrs({"global_symbol": "foo", "relax.force_pure": True}) + mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func)) + assert not rx.analysis.well_formed(mod) + + +def test_impure_in_dataflow_block(): + # even if force_pure is set, an impure operation cannot appear in a dataflow block + x = rx.Var("x", R.Tensor((), dtype="int32")) + y = rx.DataflowVar("y") + block = rx.DataflowBlock([rx.VarBinding(y, rx.op.print(x, format="{}"))]) + func = rx.Function([x], rx.SeqExpr([block], x), R.Tensor((), dtype="int32")).with_attrs( + {"global_symbol": "foo", "relax.force_pure": True} + ) + mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func)) + assert not rx.analysis.well_formed(mod) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py index 84b8cb1d0930..e0ddab5c67bc 100644 --- a/tests/python/relax/test_ast_printer.py +++ b/tests/python/relax/test_ast_printer.py @@ -225,6 +225,7 @@ def test_func(): assert "params=" in func_str assert "body=" in func_str assert "ret_struct_info=" in func_str + assert "is_pure=" in func_str assert "attrs=" in func_str assert '"global_symbol": "func"' in func_str assert "SeqExpr(" in func_str @@ -350,7 +351,7 @@ def test_struct_info(): simple_func = rx.FuncStructInfo([], rx.ObjectStructInfo()) assert ( strip_whitespace(printer.visit_struct_info_(simple_func)) - == "FuncStructInfo(params=[],ret=ObjectStructInfo())" + == "FuncStructInfo(params=[],ret=ObjectStructInfo(),purity=True)" ) @@ -362,6 +363,7 @@ def f( y: R.Tensor(("m",), "float32"), r: R.Tensor(dtype="int64"), ) -> R.Object: + R.is_impure() m = T.int64() z: R.Tensor((32, m), "float32") = R.multiply(x, y) w: R.Tensor = R.multiply(z, z) @@ -385,6 +387,8 @@ def f( # the function has an annotated return type assert "ret_struct_info=ObjectStructInfo()" in f_str + # the purity attribute is set to false + assert "is_pure=False" assert isinstance(f.body, rx.SeqExpr) extern_call = f.body.blocks[0].bindings[-1].value diff --git a/tests/python/relax/test_backend_transform_shape_lower.py b/tests/python/relax/test_backend_transform_shape_lower.py index 4b194f154238..50b69a3c35b2 100644 --- a/tests/python/relax/test_backend_transform_shape_lower.py +++ b/tests/python/relax/test_backend_transform_shape_lower.py @@ -23,6 +23,8 @@ from tvm.script import relax as R from tvm.script import tir as T +# note: we expected RemovePurityChecking to be run first, so we force purity in most test cases + def test_const_shape_arg(): MS = MatchShapeCode @@ -31,6 +33,7 @@ def test_const_shape_arg(): class Before: @R.function def main(x: R.Shape([1, 2]), y: R.Shape): + R.func_attr({"relax.force_pure": True}) return x @T.prim_func @@ -42,6 +45,7 @@ def extra_func(H: T.Buffer(T.int64(4), "int64")): class Expected: @R.function def main(x: R.Shape([1, 2]), y: R.Shape): + R.func_attr({"relax.force_pure": True}) shape_heap = R.null_value() _ = R.call_packed("vm.builtin.check_shape_info", x, 2, "", sinfo_args=[R.Tuple()]) _ = R.call_packed("vm.builtin.check_shape_info", y, -1, "", sinfo_args=[R.Tuple()]) @@ -77,12 +81,14 @@ def test_static_fn_check(): class Before: @R.function def main(f: R.Callable([R.Object], R.Object), y: R.Shape([1, 2])): + R.func_attr({"relax.force_pure": True}) return y @tvm.script.ir_module class Expected: @R.function def main(f: R.Callable([R.Object], R.Object), y: R.Shape([1, 2])): + R.func_attr({"relax.force_pure": True}) shape_heap = R.null_value() _ = R.call_packed("vm.builtin.check_func_info", f, "", sinfo_args=[R.Tuple()]) _ = R.call_packed("vm.builtin.check_shape_info", y, 2, "", sinfo_args=[R.Tuple()]) @@ -113,6 +119,7 @@ def test_simple_symbolic_shape(): class Before: @R.function def main(x: R.Tensor(["n", 2, "m"], "float32")): + R.func_attr({"relax.force_pure": True}) return x sindex = { @@ -124,13 +131,19 @@ def main(x: R.Tensor(["n", 2, "m"], "float32")): class Expected: @R.function def main(x: R.Tensor(["n", 2, "m"], "float32")): + R.func_attr({"relax.force_pure": True}) shape_heap = R.call_builtin_with_ctx( "vm.builtin.alloc_shape_heap", [R.prim_value(2)], sinfo_args=[R.Tensor(ndim=1, dtype="int64")], ) _ = R.call_packed( - "vm.builtin.check_tensor_info", x, 3, R.dtype("float32"), "", sinfo_args=[R.Tuple()] + "vm.builtin.check_tensor_info", + x, + 3, + R.dtype("float32"), + "", + sinfo_args=[R.Tuple()], ) _ = R.call_packed( "vm.builtin.match_shape", @@ -164,6 +177,7 @@ class Before: def main( x: R.Tensor(["n", "m"], "float32"), y: R.Tensor(ndim=3, dtype=None) ) -> R.Shape(ndim=3): + R.func_attr({"relax.force_pure": True}) m = T.int64() k = T.int64() z = R.match_cast(y, R.Tensor([k, m, k + 1], dtype=None)) @@ -185,6 +199,7 @@ def shape_func(H: T.Buffer(T.int64(4), "int64")): def main( x: R.Tensor(["n", "m"], "float32"), y: R.Tensor(ndim=3, dtype=None) ) -> R.Shape(ndim=3): + R.func_attr({"relax.force_pure": True}) m = T.int64() k = T.int64() cls = Expected @@ -194,7 +209,12 @@ def main( sinfo_args=[R.Tensor(ndim=1, dtype="int64")], ) _ = R.call_packed( - "vm.builtin.check_tensor_info", x, 2, R.dtype("float32"), "", sinfo_args=[R.Tuple()] + "vm.builtin.check_tensor_info", + x, + 2, + R.dtype("float32"), + "", + sinfo_args=[R.Tuple()], ) _ = R.call_packed( "vm.builtin.check_tensor_info", y, 3, R.dtype(""), "", sinfo_args=[R.Tuple()] @@ -274,6 +294,7 @@ def main( R.Tensor(["n", "m"], "float32"), R.Tuple(R.Shape, R.Tensor(["n", "k"], "int32")) ) ): + R.func_attr({"relax.force_pure": True}) return x # slot assignment: @@ -287,6 +308,7 @@ def main( R.Tensor(["n", "m"], "float32"), R.Tuple(R.Shape, R.Tensor(["n", "k"], "int32")) ) ): + R.func_attr({"relax.force_pure": True}) shape_heap = R.call_builtin_with_ctx( "vm.builtin.alloc_shape_heap", [R.prim_value(3)], @@ -359,6 +381,7 @@ class Before: def main( x: R.Tensor(["n", "m"], "float32"), y: R.Object ) -> R.Tuple(R.Tensor(["n", "m"], "float32")): + R.func_attr({"relax.force_pure": True}) return y # slot assignment: @@ -373,6 +396,7 @@ class Expected: def main( x: R.Tensor(["n", "m"], "float32"), y: R.Object ) -> R.Tuple(R.Tensor(["n", "m"], "float32")): + R.func_attr({"relax.force_pure": True}) shape_heap = R.call_builtin_with_ctx( "vm.builtin.alloc_shape_heap", [R.prim_value(2)], diff --git a/tests/python/relax/test_pipeline.py b/tests/python/relax/test_pipeline.py index 2dac42b3346d..a9f0863214b4 100644 --- a/tests/python/relax/test_pipeline.py +++ b/tests/python/relax/test_pipeline.py @@ -57,12 +57,12 @@ def create_kv_cache(reserve_slots: R.Shape(["m"])): # just allocate minimum slot since it is only used to signal dtype m = T.int64() init_data = R.ones((1, 4), "float32") - kv_cache = R.call_packed( + kv_cache = R.call_pure_packed( "vm.builtin.attention_kv_cache_create", init_data, R.shape([m, 4]), 0, - sinfo_args=[R.Object], + sinfo_args=[R.Object()], ) return kv_cache @@ -73,6 +73,7 @@ def main( shape: R.Shape(["L", 4]), kv_cache: R.Object, ): + R.is_impure() L = T.int64() # computation of the current value curr_value = R.add(x, y) diff --git a/tests/python/relax/test_relax_operators.py b/tests/python/relax/test_relax_operators.py index 776abbce764d..e6c947e9ef09 100644 --- a/tests/python/relax/test_relax_operators.py +++ b/tests/python/relax/test_relax_operators.py @@ -60,6 +60,7 @@ def test_unique(): class PrintTest: @R.function def foo(x: R.Tensor((), "int32")): + R.is_impure() # results have to be bound, but we don't use them # TODO: We should allow calls whose results are not bound for side effects; # it would be easy syntactic sugar to add. @@ -91,32 +92,38 @@ def test_print(): class AssertOpTest: @R.function def passes(x: R.Tensor((), "int32")): + R.is_impure() p1 = R.assert_op(relax.const(True)) return x @R.function def pass_with_args(x: R.Tensor((), "int32")): + R.is_impure() p1 = R.assert_op(relax.const(True), x, format="You won't see me") return x @R.function def simple_fail(x: R.Tensor((), "int32")): + R.is_impure() p1 = R.assert_op(relax.const(False)) return x @R.function def fail_with_message(x: R.Tensor((), "int32")): + R.is_impure() p1 = R.assert_op(relax.const(False), format="I failed...") return x @R.function def fail_with_args(x: R.Tensor((), "int32")): + R.is_impure() # no format p1 = R.assert_op(relax.const(False), [x, x]) return x @R.function def fail_with_formatted_message(x: R.Tensor((), "int32")): + R.is_impure() p1 = R.assert_op(relax.const(False), x, format="Number: {}") return x @@ -231,5 +238,21 @@ def test_op_shape_to_tensor(): assert np.array_equal(outs.numpy(), np.array([3, 2])) +def test_op_call_pure_packed(): + @tvm.script.ir_module + class CallPureTest: + @R.function + def pure_copy(x: R.Tensor((3, 4), "float32")): + z = R.call_pure_packed( + "vm.builtin.copy", x, sinfo_args=(R.Tensor((3, 4), dtype="float32")) + ) + return z + + np.random.seed(0) # to avoid flakiness + arr = np.random.rand(3, 4).astype("float32") + copy_found = run_cpu(CallPureTest, "pure_copy", tvm.nd.array(arr)) + assert (copy_found.numpy() == arr).all() + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index 2dc06b4a9d51..2476f6e1f399 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -32,9 +32,21 @@ class TestToNonDataflow: def foo(x: R.Tensor(("m", "n"), "float32")): m, n = T.int64(), T.int64() with R.dataflow(): - lv0 = R.call_dps_packed("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) + lv0 = R.call_dps_packed( + "test.op.identity", + (x,), + R.Tensor( + (m, n), + dtype="float32", + ), + ) gv0 = R.call_dps_packed( - "test.op.identity", (lv0,), R.Tensor((m, n), dtype="float32") + "test.op.identity", + (lv0,), + R.Tensor( + (m, n), + dtype="float32", + ), ) R.output(gv0) return gv0 @@ -81,6 +93,8 @@ def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): @R.function def foo(x: R.Tensor(("m", "n"), "float32")): + # we expect RemovePurityChecking to have been used before this point + R.func_attr({"relax.force_pure": True}) m, n = T.int64(), T.int64() gv0 = R.call_tir(TestCallTIRRewrite.exp, (x,), R.Tensor((m, n), dtype="float32")) return gv0 @@ -110,11 +124,130 @@ def foo(x: R.Tensor(("m", "n"), "float32")): assert s2.op.name_hint == "exp" +def test_transform_remove_purity_checking(): + @tvm.script.ir_module + class Before: + @R.function + def base(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + y = R.add(x, x) + z = R.add(x, y) + return z + + @R.function + def use_call_pure_packed(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + y = R.add(x, x) + z = R.call_pure_packed("vm.builtin.copy", y, sinfo_args=(R.Tensor((), dtype="int32"))) + return z + + @R.function + def use_invoke_pure_closure(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + closure = R.make_closure(Before.base, ()) + res = R.invoke_pure_closure(closure, (x,), sinfo_args=R.Tensor((), "int32")) + return res + + @R.function + def impure_func() -> R.Object: + R.is_impure() + y = R.print(format="I am impure!") + return y + + @R.function + def nested_pure_func() -> R.Tensor((), "int32"): + @R.function + def nested(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + y = R.add(x, x) + q = R.call_pure_packed( + "vm.builtin.copy", y, sinfo_args=(R.Tensor((), dtype="int32")) + ) + return q + + z = R.const(1, dtype="int32") + w = nested(z) + return w + + @R.function + def nested_impure_func() -> R.Tensor((), "int32"): + R.is_impure() + + @R.function + def nested() -> R.Object: + R.is_impure() + x = R.print(format="Oops!") + return x + + y = R.const(1, dtype="int32") + z = nested() + return y + + @tvm.script.ir_module + class Expected: + @R.function + def base(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + R.func_attr({"relax.force_pure": True}) + y = R.add(x, x) + z = R.add(x, y) + return z + + @R.function + def use_call_pure_packed(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + R.func_attr({"relax.force_pure": True}) + y = R.add(x, x) + z = R.call_packed("vm.builtin.copy", y, sinfo_args=(R.Tensor((), dtype="int32"))) + return z + + @R.function + def use_invoke_pure_closure(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + R.func_attr({"relax.force_pure": True}) + closure = R.make_closure(Expected.base, ()) + res = R.invoke_closure(closure, (x,), sinfo_args=R.Tensor((), "int32")) + return res + + @R.function + def impure_func() -> R.Object: + R.is_impure() + y = R.print(format="I am impure!") + return y + + @R.function + def nested_pure_func() -> R.Tensor((), "int32"): + R.func_attr({"relax.force_pure": True}) + + @R.function + def nested(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + R.func_attr({"relax.force_pure": True}) + y = R.add(x, x) + q = R.call_packed("vm.builtin.copy", y, sinfo_args=(R.Tensor((), dtype="int32"))) + return q + + z = R.const(1, dtype="int32") + w = nested(z) + return w + + @R.function + def nested_impure_func() -> R.Tensor((), "int32"): + R.is_impure() + + @R.function + def nested() -> R.Object: + R.is_impure() + x = R.print(format="Oops!") + return x + + y = R.const(1, dtype="int32") + z = nested() + return y + + new_mod = relax.transform.RemovePurityChecking()(Before) + tvm.ir.assert_structural_equal(new_mod, Expected) + + def test_call_dps_packed_rewrite(): @tvm.script.ir_module class TestCallDPSPackedRewrite: @R.function def foo(x: R.Tensor(("m", "n"), "float32")): + # we expect RemovePurityChecking to have been used before this point + R.func_attr({"relax.force_pure": True}) m, n = T.int64(), T.int64() gv0 = R.call_dps_packed("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) return gv0 @@ -148,6 +281,8 @@ def test_vm_builtin_lower(): class TestVMBuiltinLower: @R.function def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor: + # we expected RemovePurityChecking to have been called first + R.func_attr({"relax.force_pure": True}) m, n = T.int64(), T.int64() alloc = R.builtin.alloc_tensor(R.shape([m, n]), runtime_device_index=0, dtype="float32") _ = R.call_packed( diff --git a/tests/python/relax/test_transform_decompose_ops.py b/tests/python/relax/test_transform_decompose_ops.py index dea133f9291d..85657ab245ea 100644 --- a/tests/python/relax/test_transform_decompose_ops.py +++ b/tests/python/relax/test_transform_decompose_ops.py @@ -376,7 +376,7 @@ def main(t: R.Tensor(dtype="int64", ndim=1)) -> R.Shape(ndim=3): x = T.int64() x_1 = T.int64() x_2 = T.int64() - gv: R.Shape(ndim=3) = R.call_packed( + gv: R.Shape(ndim=3) = R.call_pure_packed( "vm.builtin.tensor_to_shape", t, sinfo_args=(R.Shape(ndim=3),) ) y: R.Shape([x, x_1, x_2]) = R.match_cast(gv, R.Shape([x, x_1, x_2])) diff --git a/tests/python/relax/test_transform_fuse_ops.py b/tests/python/relax/test_transform_fuse_ops.py index 1a4af26bd8ee..169539b07243 100644 --- a/tests/python/relax/test_transform_fuse_ops.py +++ b/tests/python/relax/test_transform_fuse_ops.py @@ -1343,7 +1343,7 @@ def main(s: R.Shape(["n"]), kv_cache: R.Object): lv0 = R.emit_te(topi.full, [n, n], "float32", 0) lv1 = R.emit_te(topi.trilu, lv0, tvm.tir.const(1, "int32"), upper=True) lv2 = R.emit_te(topi.broadcast_to, lv1, [1, 1, n, n]) - gv = R.call_packed( + gv = R.call_pure_packed( "vm.builtin.attention_kv_cache_view", kv_cache, R.shape([1 + n, 32, 128]), @@ -1375,7 +1375,7 @@ def main(s: R.Shape(["n"]), kv_cache: R.Object): lv: R.Tensor([1, 1, n, n], "float32") = cls.fused_full_trilu_broadcast_to( R.shape([n]) ) - gv = R.call_packed( + gv = R.call_pure_packed( "vm.builtin.attention_kv_cache_view", kv_cache, R.shape([1 + n, 32, 128]), @@ -1393,13 +1393,13 @@ class Module: @R.function def main(inp: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((2, 2), dtype="float32"): with R.dataflow(): - lv = R.call_packed( + lv = R.call_pure_packed( "my_func1", inp, R.prim_value(0), sinfo_args=[R.Tensor((2, 2), dtype="float32")] ) - lv1 = R.call_packed( + lv1 = R.call_pure_packed( "my_func2", lv, R.str("str"), sinfo_args=[R.Tensor((2, 2), dtype="float32")] ) - gv = R.call_packed( + gv = R.call_pure_packed( "my_func3", lv1, R.dtype("float32"), diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py index c7aa7984be88..aabbd544bd7d 100644 --- a/tests/python/relax/test_transform_fuse_tir.py +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -694,7 +694,7 @@ def main(x: R.Tensor((2, 3), "float32")): R.output(y) return y - # FuseTIR should does no change to it. + # FuseTIR should do no change to it. _check(Module, Module) diff --git a/tests/python/relax/test_transform_lambda_lift.py b/tests/python/relax/test_transform_lambda_lift.py index 017a673e8fcf..98f35a4b98ec 100644 --- a/tests/python/relax/test_transform_lambda_lift.py +++ b/tests/python/relax/test_transform_lambda_lift.py @@ -92,7 +92,9 @@ def main( ) -> R.Tensor((2, 3), "float32"): outer_func = Expected.lifted_func_0 in_call = outer_func(x) - res = R.invoke_closure(in_call, (y,), sinfo_args=(R.Tensor((2, 3), dtype="float32"))) + res = R.invoke_pure_closure( + in_call, (y,), sinfo_args=(R.Tensor((2, 3), dtype="float32")) + ) return res @R.function @@ -142,7 +144,7 @@ class Expected: def lifted_func_0( i: R.Tensor((), "int32"), s: R.Tensor((2, 3), "float32"), x: R.Tensor((2, 3), "float32") ) -> R.Tensor((2, 3), "float32"): - cond: R.Tensor((), "bool") = R.call_packed( + cond: R.Tensor((), "bool") = R.call_pure_packed( "test.vm.less", i, R.const(10), sinfo_args=(R.Tensor((), dtype="bool")) ) c: R.Tensor((), "int32") = R.const(1, dtype="int32") @@ -158,7 +160,7 @@ def lifted_func_0( @R.function def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), dtype="float32"): while_loop = R.make_closure(Expected.lifted_func_0, (x,)) - gv: R.Tensor((2, 3), dtype="float32") = R.invoke_closure( + gv: R.Tensor((2, 3), dtype="float32") = R.invoke_pure_closure( while_loop, (R.const(0), x), sinfo_args=(R.Tensor((2, 3), dtype="float32")), @@ -174,7 +176,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor: def while_loop( i: R.Tensor((), "int32"), s: R.Tensor((2, 3), "float32") ) -> R.Tensor((2, 3), "float32"): - cond: R.Tensor((), "bool") = R.call_packed( + cond: R.Tensor((), "bool") = R.call_pure_packed( "test.vm.less", i, R.const(10), sinfo_args=(R.Tensor((), dtype="bool")) ) c: R.Tensor((), "int32") = R.const(1, dtype="int32") @@ -303,5 +305,44 @@ def before(c0: R.Tensor((16, 16), "float32"), x: R.Tensor(dtype="float32", ndim= _check_save_roundtrip(after) +def test_impure_function(): + @tvm.script.ir_module + class Expected: + @R.function + def lifted_func_0() -> R.Tuple: + R.is_impure() + y = R.print(format="Wow!") + return y + + @R.function + def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + R.is_impure() + inner = Expected.lifted_func_0 + gv1 = inner() + return x + + @tvm.script.ir_module + class Before: + @R.function + def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + R.is_impure() + + @R.function + def inner() -> R.Tuple: + R.is_impure() + y = R.print(format="Wow!") + return y + + gv1 = inner() + return x + + before = Before + expected = Expected + after = transform.LambdaLift()(before) + assert len(after.functions) == 2 + assert_structural_equal(after, expected, map_free_vars=True) + _check_save_roundtrip(after) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform_lazy_transform_params.py b/tests/python/relax/test_transform_lazy_transform_params.py index 0fc08d5ef487..3de4a1ff0ac8 100644 --- a/tests/python/relax/test_transform_lazy_transform_params.py +++ b/tests/python/relax/test_transform_lazy_transform_params.py @@ -44,6 +44,8 @@ def main_transform_params( ) -> R.Tuple( R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32") ): + # we expect ToNonDataflow and RemovePurityTracking to be invoked first + R.func_attr({"relax.force_pure": True}) cls = Before lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1] lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = params[0] @@ -74,6 +76,7 @@ def transform_layout_IOHW_to_OIHW( @R.function def main_transform_params() -> R.Tuple(R.Object, R.Object): + R.func_attr({"relax.force_pure": True}) cls = Expected lv: R.Object = R.call_packed("get_item", R.prim_value(1), sinfo_args=(R.Object,)) lv1: R.Object = R.call_packed("set_item", R.prim_value(0), lv, sinfo_args=(R.Object,)) diff --git a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py index 85ade3f140fa..8c10255741e3 100644 --- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py +++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py @@ -484,7 +484,7 @@ def main( (x, begin, end, strides), out_sinfo=R.Tensor((4,), dtype="int64"), ) - gv1: R.Shape(ndim=4) = R.call_packed( + gv1: R.Shape(ndim=4) = R.call_pure_packed( "vm.builtin.tensor_to_shape", gv, sinfo_args=(R.Shape(ndim=4),) ) gv2: R.Shape([s, s_1, s_2, s_3]) = R.match_cast( @@ -683,7 +683,7 @@ def main( (x, begin, end, strides), out_sinfo=R.Tensor((2,), dtype="int64"), ) - gv1: R.Shape(ndim=2) = R.call_packed( + gv1: R.Shape(ndim=2) = R.call_pure_packed( "vm.builtin.tensor_to_shape", gv, sinfo_args=(R.Shape(ndim=2),) ) gv2: R.Shape([s, s_1]) = R.match_cast(gv1, R.Shape([s, s_1])) diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py b/tests/python/relax/test_transform_legalize_ops_manipulate.py index 28e2f3ad0e22..6e4a7683324b 100644 --- a/tests/python/relax/test_transform_legalize_ops_manipulate.py +++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py @@ -722,9 +722,7 @@ def reshape( @R.function def main(x: R.Tensor((3,), dtype="int64")) -> R.Tensor((3,), dtype="int64"): x_1 = T.int64() - gv: R.Shape([3]) = R.call_packed( - "vm.builtin.tensor_to_shape", x, sinfo_args=(R.Shape([3]),) - ) + gv: R.Shape([3]) = R.call_pure_packed("vm.builtin.tensor_to_shape", x, sinfo_args=(R.Shape([3]),)) y: R.Shape([x_1]) = R.match_cast(gv, R.Shape([x_1])) lv: R.Shape([x_1]) = R.shape([x_1]) gv_1 = R.call_tir(Expected.reshape, (x,), out_sinfo=R.Tensor((x_1,), dtype="int64")) @@ -1013,52 +1011,6 @@ def collapse_sum(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), r tvm.ir.assert_structural_equal(mod, Expected) -def test_collapse_sum_like_symbolic(): - # fmt: off - @tvm.script.ir_module - class CollapseSumLike: - @R.function - def main(x: R.Tensor(("a", "b", "a"), "float32"), y: R.Tensor(("b", 1), "float32")) -> R.Tensor(("b", 1), "float32"): - b = T.int64() - gv: R.Tensor((b, 1), "float32") = R.collapse_sum_like(x, y) - return gv - - @I.ir_module - class Expected: - @T.prim_func - def collapse_sum(var_rxplaceholder: T.handle, var_rxplaceholder_red: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - a, b = T.int64(), T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, (a, b, a)) - rxplaceholder_red = T.match_buffer(var_rxplaceholder_red, (b, T.int64(1))) - # with T.block("root"): - for ax0, ax1, k0, k2 in T.grid(b, T.int64(1), a, a): - with T.block("rxplaceholder_red"): - v_ax0, v_ax1, v_k0, v_k2 = T.axis.remap("SSRR", [ax0, ax1, k0, k2]) - T.reads(rxplaceholder[v_k0, v_ax0, v_k2]) - T.writes(rxplaceholder_red[v_ax0, v_ax1]) - with T.init(): - rxplaceholder_red[v_ax0, v_ax1] = T.float32(0) - rxplaceholder_red[v_ax0, v_ax1] = (rxplaceholder_red[v_ax0, v_ax1] + rxplaceholder[v_k0, v_ax0, v_k2]) - - @R.function - def main( - x: R.Tensor(("a", "b", "a"), dtype="float32"), - y: R.Tensor(("b", 1), dtype="float32"), - ) -> R.Tensor(("b", 1), dtype="float32"): - b = T.int64() - a = T.int64() - cls = Expected - gv = R.call_tir( - cls.collapse_sum, (x,), out_sinfo=R.Tensor((b, 1), dtype="float32") - ) - return gv - # fmt: on - - mod = LegalizeOps()(CollapseSumLike) - tvm.ir.assert_structural_equal(mod, Expected) - - def test_collapse_sum_to(): # fmt: off @tvm.script.ir_module @@ -1095,54 +1047,6 @@ def collapse_sum(rxplaceholder: T.Buffer((T.int64(3), T.int64(2), T.int64(3)), " tvm.ir.assert_structural_equal(mod, Expected) -def test_collapse_sum_to_symbolic(): - # fmt: off - @tvm.script.ir_module - class CollapseSumTo: - @R.function - def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("b", 1), "float32"): - b = T.int64() - gv: R.Tensor((b, 1), "float32") = R.collapse_sum_to(x, (b, 1)) - return gv - - @I.ir_module - class Expected: - @T.prim_func - def collapse_sum(var_rxplaceholder: T.handle, var_rxplaceholder_red: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - a, b, c = T.int64(), T.int64(), T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, (a, b, c)) - rxplaceholder_red = T.match_buffer(var_rxplaceholder_red, (b, T.int64(1))) - # with T.block("root"): - for ax0, ax1, k0, k2 in T.grid(b, T.int64(1), a, c): - with T.block("rxplaceholder_red"): - v_ax0, v_ax1, v_k0, v_k2 = T.axis.remap("SSRR", [ax0, ax1, k0, k2]) - T.reads(rxplaceholder[v_k0, v_ax0, v_k2]) - T.writes(rxplaceholder_red[v_ax0, v_ax1]) - with T.init(): - rxplaceholder_red[v_ax0, v_ax1] = T.float32(0) - rxplaceholder_red[v_ax0, v_ax1] = ( - rxplaceholder_red[v_ax0, v_ax1] + rxplaceholder[v_k0, v_ax0, v_k2] - ) - - @R.function - def main( - x: R.Tensor(("a", "b", "c"), dtype="float32") - ) -> R.Tensor(("b", 1), dtype="float32"): - b = T.int64() - a = T.int64() - c = T.int64() - cls = Expected - gv = R.call_tir( - cls.collapse_sum, (x,), out_sinfo=R.Tensor((b, 1), dtype="float32") - ) - return gv - # fmt: on - - mod = LegalizeOps()(CollapseSumTo) - tvm.ir.assert_structural_equal(mod, Expected) - - def test_repeat(): # fmt: off @I.ir_module diff --git a/tests/python/relax/test_transform_rewrite_cuda_graph.py b/tests/python/relax/test_transform_rewrite_cuda_graph.py index 40c0a4a87698..931d206afbb1 100644 --- a/tests/python/relax/test_transform_rewrite_cuda_graph.py +++ b/tests/python/relax/test_transform_rewrite_cuda_graph.py @@ -39,6 +39,8 @@ def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), compute: T @R.function def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + # force_pure is expected because purity checking should be disabled before this pass + R.func_attr({"relax.force_pure": True}) cls = Before storage: R.Object = R.memory.alloc_storage(R.shape([32]), 0, "global", "float32") alloc: R.Tensor((2, 4), dtype="float32") = R.memory.alloc_tensor(storage, 0, R.shape([2, 4]), "float32") @@ -82,6 +84,7 @@ def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), compute: T @R.function def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object, R.Object): + R.func_attr({"relax.force_pure": True}) storage: R.Object = R.memory.alloc_storage(R.shape([32]), R.prim_value(0), R.str("global"), R.dtype("float32")) storage1: R.Object = R.memory.alloc_storage(R.shape([32]), R.prim_value(0), R.str("global"), R.dtype("float32")) storage2: R.Object = R.memory.alloc_storage(R.shape([32]), R.prim_value(0), R.str("global"), R.dtype("float32")) @@ -90,6 +93,7 @@ def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object, R.Object): @R.function def cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"), alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object, storage2: R.Object) -> R.Tuple(R.Tensor((2, 4), dtype="float32")): + R.func_attr({"relax.force_pure": True}) cls = Expected _2: R.Tuple = cls.exp(alloc, alloc1) _3: R.Tuple = R.memory.kill_tensor(alloc) @@ -104,6 +108,8 @@ def cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"), alloc1: R.Tenso @R.function def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + # this comes after RemovePurityChecking, so we expect purity to be forced + R.func_attr({"relax.force_pure": True}) cls = Expected gv: R.Tuple(R.Object, R.Object, R.Object) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.get_cached_alloc", (cls.cuda_graph_alloc, R.prim_value(0)), sinfo_args=(R.Tuple(R.Object, R.Object, R.Object),)) storage: R.Object = gv[0] @@ -149,6 +155,7 @@ def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), compute: T @R.function def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((2, 4), dtype="float32"): + R.func_attr({"relax.force_pure": True}) cls = Before storage: R.Object = R.memory.alloc_storage(R.shape([32]), 0, "global", "float32") alloc: R.Tensor((2, 4), dtype="float32") = R.memory.alloc_tensor(storage, 0, R.shape([2, 4]), "float32") @@ -188,6 +195,7 @@ def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), compute: T @R.function def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object): + R.func_attr({"relax.force_pure": True}) storage: R.Object = R.memory.alloc_storage(R.shape([32]), R.prim_value(0), R.str("global"), R.dtype("float32")) storage1: R.Object = R.memory.alloc_storage(R.shape([32]), R.prim_value(0), R.str("global"), R.dtype("float32")) gv: R.Tuple(R.Object, R.Object) = (storage, storage1) @@ -195,6 +203,7 @@ def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object): @R.function def cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"), alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object) -> R.Tuple(R.Tensor((2, 4), dtype="float32")): + R.func_attr({"relax.force_pure": True}) cls = Expected _: R.Tuple = cls.exp(alloc, alloc1) lv0: R.Tuple(R.Tensor((2, 4), dtype="float32")) = (alloc1,) @@ -210,6 +219,7 @@ def cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"), alloc1: R.Tenso @R.function def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((2, 4), dtype="float32"): + R.func_attr({"relax.force_pure": True}) cls = Expected gv: R.Tuple(R.Object, R.Object) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.get_cached_alloc", (cls.cuda_graph_alloc, R.prim_value(0)), sinfo_args=(R.Tuple(R.Object, R.Object),)) storage: R.Object = gv[0] diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py b/tests/python/relax/test_transform_static_plan_block_memory.py index e669f012f795..ffc0a586e569 100644 --- a/tests/python/relax/test_transform_static_plan_block_memory.py +++ b/tests/python/relax/test_transform_static_plan_block_memory.py @@ -51,6 +51,8 @@ def pad(rxplaceholder: T.Buffer(T.int64(8), "float32"), PadInput: T.Buffer(T.int @R.function def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + # we expected RemovePurityChecking to have been invoked first + R.func_attr({"relax.force_pure": True}) cls = Module alloc: R.Tensor((2, 4), dtype="float32") = R.builtin.alloc_tensor(R.shape([2, 4]), dtype="float32", runtime_device_index=0) _: R.Tuple() = cls.exp(x, alloc) @@ -98,6 +100,7 @@ def pad(rxplaceholder: T.Buffer(T.int64(8), "float32"), PadInput: T.Buffer(T.int @R.function def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"relax.force_pure": True}) cls = Expected storage: R.Object = R.memory.alloc_storage(R.shape([32]), virtual_device_index=0, storage_scope="global", dtype="float32") alloc: R.Tensor((2, 4), dtype="float32") = R.memory.alloc_tensor(storage, 0, R.shape([2, 4]), dtype="float32") @@ -154,6 +157,7 @@ def reshape(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), T_resh @R.function def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"relax.force_pure": True}) cls = ExpectedLowered storage: R.Object = R.vm.alloc_storage(R.shape([32]), R.prim_value(0), R.dtype("float32")) alloc: R.Tensor((2, 4), dtype="float32") = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape([2, 4]), R.dtype("float32")) @@ -213,6 +217,7 @@ def add1( def main( x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="int32") ) -> R.Tensor((2, 3), dtype="float32"): + R.func_attr({"relax.force_pure": True}) cls = Module alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( R.shape([2, 3]), dtype="float32", runtime_device_index=0 @@ -248,6 +253,7 @@ def add1( def main( x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="int32") ) -> R.Tensor((2, 3), dtype="float32"): + R.func_attr({"relax.force_pure": True}) cls = Expected storage: R.Object = R.memory.alloc_storage( R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" @@ -288,6 +294,7 @@ def add1( @R.function def main(y: R.Tensor((2, 3), dtype="bool")) -> R.Tensor((2, 3), dtype="bool"): + R.func_attr({"relax.force_pure": True}) cls = Module alloc: R.Tensor((2, 3), dtype="bool") = R.builtin.alloc_tensor( R.shape([2, 3]), dtype="bool", runtime_device_index=0 @@ -308,6 +315,7 @@ def add1( @R.function def main(y: R.Tensor((2, 3), dtype="bool")) -> R.Tensor((2, 3), dtype="bool"): + R.func_attr({"relax.force_pure": True}) cls = Expected storage: R.Object = R.memory.alloc_storage( R.shape([6]), virtual_device_index=0, storage_scope="global", dtype="bool" @@ -340,6 +348,7 @@ def add( def main( x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32") ) -> R.Tensor((2, 3), dtype="float32"): + R.func_attr({"relax.force_pure": True}) cls = Module alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( R.shape([2, 3]), dtype="float32", runtime_device_index=0 @@ -367,6 +376,7 @@ def add( def main( x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32") ) -> R.Tensor((2, 3), dtype="float32"): + R.func_attr({"relax.force_pure": True}) cls = Expected storage: R.Object = R.memory.alloc_storage( R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" @@ -403,6 +413,7 @@ def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): @R.function def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): + R.func_attr({"relax.force_pure": True}) cls = Module alloc: R.Tensor((), dtype="bool") = R.builtin.alloc_tensor( R.shape([]), dtype="bool", runtime_device_index=0 @@ -436,6 +447,7 @@ def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): def main( cond: R.Tensor((), dtype="bool"), x: R.Tensor((2, 3), dtype="float32") ) -> R.Tensor((2, 3), dtype="float32"): + R.func_attr({"relax.force_pure": True}) cls = Module alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( R.shape([2, 3]), dtype="float32", runtime_device_index=0 @@ -464,6 +476,7 @@ def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): def main( cond: R.Tensor((), dtype="bool"), x: R.Tensor((2, 3), dtype="float32") ) -> R.Tensor((2, 3), dtype="float32"): + R.func_attr({"relax.force_pure": True}) cls = Module alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( R.shape([2, 3]), dtype="float32", runtime_device_index=0 @@ -500,6 +513,7 @@ def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): @R.function def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): + R.func_attr({"relax.force_pure": True}) alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( R.shape([2, 3]), dtype="float32", runtime_device_index=0 ) @@ -555,6 +569,7 @@ def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): @R.function def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): + R.func_attr({"relax.force_pure": True}) cls = Expected storage: R.Object = R.memory.alloc_storage( R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" @@ -633,6 +648,7 @@ def test_call_func_other_than_primfunc(): class Module: @R.function def main(x: R.Tensor((2, 3), "float32")): + R.func_attr({"relax.force_pure": True}) alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( R.shape([2, 3]), dtype="float32", runtime_device_index=0 ) @@ -650,6 +666,8 @@ def test_call_packed_external_func(): class Module: @R.function def main(x: R.Tensor((2, 3), "float32")): + # the extern func may or may not be pure, depends on what we're calling + R.is_impure() alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( R.shape([2, 3]), dtype="float32", runtime_device_index=0 ) @@ -666,6 +684,7 @@ def main(x: R.Tensor((2, 3), "float32")): class Expected: @R.function def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): + R.is_impure() storage: R.Object = R.memory.alloc_storage( R.shape([24]), R.prim_value(0), R.str("global"), R.dtype("float32") ) @@ -700,6 +719,7 @@ def exp(var_A: T.handle, var_B: T.handle): @R.function def main(x: R.Tensor(("m", "n"), "float32")): + R.func_attr({"relax.force_pure": True}) m = T.int64() n = T.int64() alloc: R.Tensor((m, n), dtype="float32") = R.builtin.alloc_tensor( @@ -719,6 +739,7 @@ def test_zero_reference(): class Module: @R.function def main(x: R.Tensor((2, 3), "float32")): + R.func_attr({"relax.force_pure": True}) alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( R.shape([2, 3]), dtype="float32", runtime_device_index=0 ) @@ -728,6 +749,7 @@ def main(x: R.Tensor((2, 3), "float32")): class Expected: @R.function def main(x: R.Tensor((2, 3), "float32")): + R.func_attr({"relax.force_pure": True}) storage: R.Object = R.memory.alloc_storage( R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" ) @@ -756,6 +778,7 @@ def add( def main( x: R.Tensor((2, 50), dtype="float32"), y: R.Tensor((100,), dtype="float32") ) -> R.Tensor((2, 25, 2), dtype="float32"): + R.func_attr({"relax.force_pure": True}) lv: R.Tensor((2, 25, 2), dtype="float32") = R.reshape(x, (2, 25, 2)) lv1: R.Tensor((2, 25, 2), dtype="float32") = R.reshape(y, (2, 25, 2)) alloc: R.Tensor((2, 25, 2), dtype="float32") = R.builtin.alloc_tensor( @@ -793,6 +816,7 @@ def add1( def func1( x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="int32") ) -> R.Tensor((2, 3), dtype="float32"): + R.func_attr({"relax.force_pure": True}) cls = Module alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( R.shape([2, 3]), dtype="float32", runtime_device_index=0 @@ -810,6 +834,7 @@ def func1( def func2( x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32") ) -> R.Tensor((2, 3), dtype="float32"): + R.func_attr({"relax.force_pure": True}) cls = Module alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( R.shape([2, 3]), dtype="float32", runtime_device_index=0 @@ -845,6 +870,7 @@ def add1( def func1( x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="int32") ) -> R.Tensor((2, 3), dtype="float32"): + R.func_attr({"relax.force_pure": True}) cls = Expected storage: R.Object = R.memory.alloc_storage( R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" @@ -872,6 +898,7 @@ def func1( def func2( x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32") ) -> R.Tensor((2, 3), dtype="float32"): + R.func_attr({"relax.force_pure": True}) cls = Expected storage: R.Object = R.memory.alloc_storage( R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" @@ -925,7 +952,7 @@ def pad(rxplaceholder: T.handle, PadInput: T.handle): @R.function def main(x: R.Tensor((2, "n"), dtype="float32")) -> R.Tensor(("2 * n + 2",), dtype="float32"): - R.func_attr({"tir_var_upper_bound": {"n": 4}}) + R.func_attr({"tir_var_upper_bound": {"n": 4}, "relax.force_pure": True}) n = T.int64() cls = Module alloc: R.Tensor((2, n), dtype="float32") = R.builtin.alloc_tensor(R.shape([2, n]), dtype="float32", runtime_device_index=0) @@ -975,7 +1002,7 @@ def reshape(rxplaceholder: T.handle, T_reshape: T.handle): @R.function def main(x: R.Tensor((2, "n"), dtype="float32")) -> R.Tensor(("2 * n + 2",), dtype="float32"): n = T.int64() - R.func_attr({"tir_var_upper_bound": {"n": 4}}) + R.func_attr({"tir_var_upper_bound": {"n": 4}, "relax.force_pure": True}) cls = Expected storage: R.Object = R.memory.alloc_storage(R.shape([32]), R.prim_value(0), R.str("global"), R.dtype("float32")) alloc: R.Tensor((2, n), dtype="float32") = R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2, n]), R.dtype("float32")) @@ -1021,7 +1048,7 @@ def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle): def main(x: R.Tensor(("n", "m", "T.max(n - m, 1)"), dtype="float32")) -> R.Tensor(("n", "m", "T.max(n - m, 1)"), dtype="float32"): n = T.int64() m = T.int64() - R.func_attr({"tir_var_upper_bound": {"m": 5, "n": 20}}) + R.func_attr({"tir_var_upper_bound": {"m": 5, "n": 20}, "relax.force_pure": True}) cls = Module alloc: R.Tensor((n, m, T.max(n - m, 1)), dtype="float32") = R.builtin.alloc_tensor(R.shape([n, m, T.max(n - m, 1)]), R.dtype("float32"), R.prim_value(0)) _: R.Tuple = cls.tir_exp(x, alloc) @@ -1044,7 +1071,7 @@ def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle): def main(x: R.Tensor(("n", "m", "T.max(n - m, 1)"), dtype="float32")) -> R.Tensor(("n", "m", "T.max(n - m, 1)"), dtype="float32"): n = T.int64() m = T.int64() - R.func_attr({"tir_var_upper_bound": {"m": 5, "n": 20}}) + R.func_attr({"tir_var_upper_bound": {"m": 5, "n": 20}, "relax.force_pure": True}) cls = Expected storage: R.Object = R.memory.alloc_storage(R.shape([8000]), R.prim_value(0), R.str("global"), R.dtype("float32")) alloc: R.Tensor((n, m, T.max(n - m, 1)), dtype="float32") = R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([n, m, T.max(n - m, 1)]), R.dtype("float32")) @@ -1083,7 +1110,7 @@ def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle): @R.function def main(s: R.Shape(["n"])) -> R.Tensor(("n",), dtype="float32"): n = T.int64() - R.func_attr({"tir_var_upper_bound": {"n": 20}}) + R.func_attr({"tir_var_upper_bound": {"n": 20}, "relax.force_pure": True}) cls = Module alloc: R.Tensor((n,), dtype="float32") = R.builtin.alloc_tensor(R.shape([n]), R.dtype("float32"), R.prim_value(0)) _: R.Tuple = R.vm.call_tir_dyn(cls.tir_full, (alloc, R.shape([n]))) @@ -1109,7 +1136,7 @@ def tir_full(var_full: T.handle, n: T.int64): @R.function def main(s: R.Shape(["n"])) -> R.Tensor(("n",), dtype="float32"): n = T.int64() - R.func_attr({"tir_var_upper_bound": {"n": 20}}) + R.func_attr({"tir_var_upper_bound": {"n": 20}, "relax.force_pure": True}) cls = Expected storage: R.Object = R.memory.alloc_storage(R.shape([80]), R.prim_value(0), R.str("global"), R.dtype("float32")) alloc: R.Tensor((n,), dtype="float32") = R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([n]), R.dtype("float32")) diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index c924fb0a7ad2..fef13d234ec6 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -1324,6 +1324,7 @@ def add( @R.function def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"relax.force_pure": 1}) cls = Module alloc = R.builtin.alloc_tensor(R.shape([2, 4]), dtype="float32", runtime_device_index=0) _: R.Tuple() = cls.add(x, R.const(1, "float32"), alloc) @@ -1379,5 +1380,46 @@ def foo(x: R.Tensor((128, 128), "float32")): _check(Module) +def test_assert_op(): + @I.ir_module + class AssertOp: + @R.function + def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + R.is_impure() + y = R.assert_op(R.const(False, dtype="bool"), x, format="x: {}") + return x + + _check(AssertOp) + + +def test_print(): + @I.ir_module + class Print: + @R.function + def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + R.is_impure() + y = R.print(x, format="x: {}") + return x + + _check(Print) + + +def test_call_pure_packed(): + @R.function + def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + z = R.call_pure_packed("vm.builtin.copy", x, sinfo_args=R.Tensor((32, 32), "float32")) + return z + + x = relax.Var("x", R.Tensor((32, 32), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x)): + z = bb.emit( + R.call_pure_packed("vm.builtin.copy", x, sinfo_args=[R.Tensor((32, 32), "float32")]) + ) + bb.emit_func_output(z) + + _check(foo, bb.get()["foo"]) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index bffa741353a9..e76fe1d9020f 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -156,10 +156,9 @@ def test_func_struct_info(): ) _assert_print( obj, - """ -a = T.int64() -R.Callable((R.Prim("float32"), R.Object, R.Shape([1, a, 3])), R.Tensor((1, 2, 3), dtype="float32")) -""", + "a = T.int64()\n" + 'R.Callable((R.Prim("float32"), R.Object, R.Shape([1, a, 3])), ' + 'R.Tensor((1, 2, 3), dtype="float32"), True)', ) @@ -529,5 +528,57 @@ def foo(x: R.Tensor((128,), dtype="float32")) -> R.Tensor((128,), dtype="float32 ) +def test_assert_op(): + @I.ir_module + class AssertOpMod: + @R.function + def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + R.is_impure() + y = R.assert_op(R.const(False, dtype="bool"), x, format="x: {}") + return x + + _assert_print( + AssertOpMod, + """ +# from tvm.script import ir as I +# from tvm.script import relax as R + +@I.ir_module +class Module: + @R.function + def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + R.is_impure() + y: R.Tuple = R.assert_op(R.const(False, "bool"), x, format=R.str("x: {}")) + return x +""", + ) + + +def test_print(): + @I.ir_module + class PrintMod: + @R.function + def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + R.is_impure() + y = R.print(x, format="x: {}") + return x + + _assert_print( + PrintMod, + """ +# from tvm.script import ir as I +# from tvm.script import relax as R + +@I.ir_module +class Module: + @R.function + def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + R.is_impure() + y: R.Tuple = R.print(x, format=R.str("x: {}")) + return x +""", + ) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_vm_build.py b/tests/python/relax/test_vm_build.py index 9cf544515695..baf0d7c0b14f 100644 --- a/tests/python/relax/test_vm_build.py +++ b/tests/python/relax/test_vm_build.py @@ -39,7 +39,7 @@ def test_vm_compile_simple(exec_mode): class TestVMCompileStage0: @R.function def foo(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")): - z = R.call_packed( + z = R.call_pure_packed( "test.vm.identity", x, y, sinfo_args=(R.Tensor(ndim=2, dtype="float32")) ) return y @@ -286,8 +286,6 @@ def te_func(A): mod = bb.get() - new_mod = relax.transform.CallTIRRewrite()(mod) - target = tvm.target.Target("llvm", host="llvm") ex = relax.build(mod, target, exec_mode=exec_mode) @@ -488,7 +486,9 @@ def tuple_get_item( t = (x, y) a = t[0] b = t[1] - c = R.call_packed("test.vm.add", a, b, sinfo_args=(R.Tensor(ndim=2, dtype="float32"))) + c = R.call_pure_packed( + "test.vm.add", a, b, sinfo_args=(R.Tensor(ndim=2, dtype="float32")) + ) return c mod = TestVMTupleGetItem @@ -507,11 +507,13 @@ def test_lower_memory_alloc_storage_tensor(exec_mode): class TestMemoryAllocStorageTensor: @R.function def main(x: R.Tensor((2, 3), dtype="float32")): + R.func_attr({"relax.force_pure": True}) cls = TestMemoryAllocStorageTensor storage = R.memory.alloc_storage( R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" ) y = R.memory.alloc_tensor(storage, 0, R.shape([2, 3]), dtype="float32") + # this is an impure operation, but the overall function is pure so we force purity _ = cls.copy(x, y) return y @@ -566,7 +568,9 @@ def relax_matmul_tir( def relax_matmul_packed( x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32") ) -> R.Object: - gv0 = R.call_packed("test.vm.mul", x, w, sinfo_args=(R.Tensor(ndim=2, dtype="float32"))) + gv0 = R.call_pure_packed( + "test.vm.mul", x, w, sinfo_args=(R.Tensor(ndim=2, dtype="float32")) + ) return gv0 @R.function @@ -593,17 +597,17 @@ def test_recursion(exec_mode): class TestVMRecursion: @R.function def recursion(n: R.Tensor((1,), "float32")) -> R.Tensor: - cond = R.call_packed( + cond = R.call_pure_packed( "test.vm.equal_zero", n, sinfo_args=(R.Tensor(ndim=1, dtype="float32")) ) if cond: res = R.const(1.0) else: - gv0 = R.call_packed( + gv0 = R.call_pure_packed( "test.vm.subtract_one", n, sinfo_args=(R.Tensor(ndim=1, dtype="float32")) ) tmp = TestVMRecursion.recursion(gv0) - res = R.call_packed( + res = R.call_pure_packed( "test.vm.add", tmp, tmp, sinfo_args=(R.Tensor(ndim=1, dtype="float32")) ) return res @@ -626,7 +630,7 @@ def test_vm_closure(exec_mode): class TestClosure: @R.function def lifted_func_1(x: R.Tensor((2, 3), "float32"), env: R.Tensor((2, 3), "float32")): - return R.call_packed("test.vm.add", x, env, sinfo_args=(R.Tensor)) + return R.call_pure_packed("test.vm.add", x, env, sinfo_args=(R.Tensor())) @R.function def main( @@ -635,7 +639,7 @@ def main( ): cls = TestClosure clo = R.make_closure(cls.lifted_func_1, (x,)) - res = R.invoke_closure(clo, (y,), sinfo_args=(R.Tensor)) + res = R.invoke_pure_closure(clo, (y,), sinfo_args=(R.Tensor())) return res mod = TestClosure @@ -654,7 +658,7 @@ def test_time_evaluator(exec_mode): class TestTimeEvaluator: @R.function def main(x: R.Tensor((1,), "float32"), y: R.Tensor((1,), "float32")): - return R.call_packed( + return R.call_pure_packed( "test.vm.add", x, y, sinfo_args=(R.Tensor(ndim=1, dtype="float32")) )