Skip to content

Commit

Permalink
Indicate that RemovePurityChecking is also required for LazyTransform…
Browse files Browse the repository at this point in the history
…Params
  • Loading branch information
slyubomirsky committed May 18, 2023
1 parent e786f92 commit c920bf5
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
21 changes: 10 additions & 11 deletions python/tvm/relax/transform/lazy_transform_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,11 @@ def visit_tuple_getitem_(self, op: relax.TupleGetItem) -> relax.Expr:
# rewrite get item
tuple_get_item = super().visit_tuple_getitem_(op)
if tuple_get_item.tuple_value == self.input_tuple_param:
return relax.call_pure_packed(
return relax.Call(
relax.ExternFunc("get_item"),
relax.PrimValue(tuple_get_item.index),
sinfo_args=(relax.ObjectStructInfo(),),
[relax.PrimValue(tuple_get_item.index)],
None,
[relax.ObjectStructInfo()],
)
else:
return tuple_get_item
Expand All @@ -165,15 +166,11 @@ def visit_var_binding_(self, binding: relax.VarBinding) -> None:
var_before_setitem = self.builder_.emit(value)
# rewrite set item
new_var = self.builder_.emit(
# TODO(@relax-team): This is wrong! This is not pure,
# but there is no other way to allow this inside a dataflow block.
# Properly speaking, this pass should require ToNonDataflow first,
# but the liveness analysis requires dataflow blocks. This should be refactored
relax.call_pure_packed(
relax.Call(
relax.ExternFunc("set_item"),
index,
var_before_setitem,
sinfo_args=(relax.ObjectStructInfo(),),
[index, var_before_setitem],
None,
[relax.ObjectStructInfo()],
)
)
self.set_var_remap(binding.var.vid, new_var)
Expand All @@ -194,6 +191,8 @@ class LazyTransformParams:
"""
Convert transform_params functions into a lazy version.
(Load the input to memory on demand, and immediately free it after the last use.)
Note: ToNonDataflow() and RemovePurityTracking() should be invoked before this pass.
"""

def transform_module(self, mod: IRModule, ctx: tvm.transform.PassContext) -> IRModule:
Expand Down
3 changes: 3 additions & 0 deletions tests/python/relax/test_transform_lazy_transform_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def main_transform_params(
) -> R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32")
):
# we expect ToNonDataflow and RemovePurityTracking to be invoked first
R.func_attr({"relax.force_pure": True})
cls = Before
lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = params[0]
Expand Down Expand Up @@ -74,6 +76,7 @@ def transform_layout_IOHW_to_OIHW(

@R.function
def main_transform_params() -> R.Tuple(R.Object, R.Object):
R.func_attr({"relax.force_pure": True})
cls = Expected
lv: R.Object = R.call_packed("get_item", R.prim_value(1), sinfo_args=(R.Object,))
lv1: R.Object = R.call_packed("set_item", R.prim_value(0), lv, sinfo_args=(R.Object,))
Expand Down

0 comments on commit c920bf5

Please sign in to comment.