From f246b7536b997338c5e3274fcb3d703ae437f4d9 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Sun, 26 Mar 2023 18:40:13 -0400 Subject: [PATCH] Process purity when parsing function declarations --- python/tvm/script/parser/relax/parser.py | 27 ++++++++++++++++++++++-- src/relax/ir/expr.cc | 4 +++- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/python/tvm/script/parser/relax/parser.py b/python/tvm/script/parser/relax/parser.py index 06fc51b7a6072..a6d91dff4394d 100644 --- a/python/tvm/script/parser/relax/parser.py +++ b/python/tvm/script/parser/relax/parser.py @@ -21,7 +21,7 @@ from typing import Any, Dict, Optional from tvm import relax, tir -from tvm.ir import GlobalVar, structural_equal +from tvm.ir import make_node, GlobalVar, structural_equal from tvm.relax import Expr, StructInfo from tvm.relax.utils import convert_to_expr from tvm.script.ir_builder.relax.frame import BlockFrame @@ -220,7 +220,30 @@ def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar param_sinfo = eval_struct_info(self, arg.annotation, eval_str=True) params.append(relax.Var(arg.arg, param_sinfo)) - func_signature = relax.Function.create_empty(params, ret_sinfo) + # find a call to R.func_attr to see if purity should be indicated + # namely, find a call to R.func_attr({..., "IsPure": val, ...}) + # (we don't need any other attributes at the function declaration stage) + attrs = None + for item in node.body: + if ( + isinstance(item.value, doc.Call) + and isinstance(item.value.func, doc.Attribute) + and item.value.func.attr == "func_attr" + and len(item.value.args) == 1 + and isinstance(item.value.args[0], doc.Dict) + ): + index = None + for i, key in enumerate(item.value.args[0].keys): + if isinstance(key, doc.Constant) and key.value == "IsPure": + index = i + break + if index is not None: + val = item.value.args[0].values[index] + if isinstance(val, doc.Constant): + purity = bool(val.value) + attrs = make_node("DictAttrs", IsPure=purity) + + func_signature = relax.Function.create_empty(params, ret_sinfo, attrs=attrs) return I.decl_function(node.name, func_signature) diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index a77cffb507fe6..b1c2733a92ccb 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -478,7 +478,9 @@ Function Function::CreateEmpty(Array params, StructInfo ret_struct_info, Di << "relax.Function requires params to contain checked_type_."; param_sinfo.push_back(GetStructInfo(param)); } - FuncStructInfo finfo(param_sinfo, ret_struct_info); + // if unannotated, we assume the function is pure + bool purity = attrs.GetAttr(relax::attr::kIsPure).value_or(Bool(true))->value; + FuncStructInfo finfo(param_sinfo, ret_struct_info, purity); // set the fields ObjectPtr n = make_object();