Skip to content

Commit

Permalink
Process purity when parsing function declarations
Browse files Browse the repository at this point in the history
  • Loading branch information
slyubomirsky committed Mar 26, 2023
1 parent 47f1c19 commit f246b75
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
27 changes: 25 additions & 2 deletions python/tvm/script/parser/relax/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import Any, Dict, Optional

from tvm import relax, tir
from tvm.ir import GlobalVar, structural_equal
from tvm.ir import make_node, GlobalVar, structural_equal
from tvm.relax import Expr, StructInfo
from tvm.relax.utils import convert_to_expr
from tvm.script.ir_builder.relax.frame import BlockFrame
Expand Down Expand Up @@ -220,7 +220,30 @@ def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar
param_sinfo = eval_struct_info(self, arg.annotation, eval_str=True)
params.append(relax.Var(arg.arg, param_sinfo))

func_signature = relax.Function.create_empty(params, ret_sinfo)
# find a call to R.func_attr to see if purity should be indicated
# namely, find a call to R.func_attr({..., "IsPure": val, ...})
# (we don't need any other attributes at the function declaration stage)
attrs = None
for item in node.body:
if (
isinstance(item.value, doc.Call)
and isinstance(item.value.func, doc.Attribute)
and item.value.func.attr == "func_attr"
and len(item.value.args) == 1
and isinstance(item.value.args[0], doc.Dict)
):
index = None
for i, key in enumerate(item.value.args[0].keys):
if isinstance(key, doc.Constant) and key.value == "IsPure":
index = i
break
if index is not None:
val = item.value.args[0].values[index]
if isinstance(val, doc.Constant):
purity = bool(val.value)
attrs = make_node("DictAttrs", IsPure=purity)

func_signature = relax.Function.create_empty(params, ret_sinfo, attrs=attrs)
return I.decl_function(node.name, func_signature)


Expand Down
4 changes: 3 additions & 1 deletion src/relax/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,9 @@ Function Function::CreateEmpty(Array<Var> params, StructInfo ret_struct_info, Di
<< "relax.Function requires params to contain checked_type_.";
param_sinfo.push_back(GetStructInfo(param));
}
FuncStructInfo finfo(param_sinfo, ret_struct_info);
// if unannotated, we assume the function is pure
bool purity = attrs.GetAttr<Bool>(relax::attr::kIsPure).value_or(Bool(true))->value;
FuncStructInfo finfo(param_sinfo, ret_struct_info, purity);

// set the fields
ObjectPtr<FunctionNode> n = make_object<FunctionNode>();
Expand Down

0 comments on commit f246b75

Please sign in to comment.