Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions python/tvm/relax/transform/lazy_transform_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,15 @@ def visit_tuple_getitem_(self, op: relax.TupleGetItem) -> relax.Expr:
# rewrite get item
tuple_get_item = super().visit_tuple_getitem_(op)
if tuple_get_item.tuple_value == self.input_tuple_param:
return relax.Call(
relax.ExternFunc("get_item"),
[relax.PrimValue(tuple_get_item.index)],
None,
[relax.ObjectStructInfo()],
get_item_result = self.builder_.emit(
relax.Call(
relax.ExternFunc("get_item"),
[relax.PrimValue(tuple_get_item.index)],
None,
[relax.ObjectStructInfo()],
)
)
return self.builder_.match_cast(get_item_result, op.struct_info)
else:
return tuple_get_item

Expand Down
120 changes: 110 additions & 10 deletions tests/python/relax/test_transform_lazy_transform_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,23 @@ def main_transform_params() -> R.Tuple:
R.func_attr({"relax.force_pure": True})
cls = Expected
lv: R.Object = R.call_packed("get_item", R.prim_value(1), sinfo_args=(R.Object,))
_: R.Object = R.call_packed("set_item", R.prim_value(0), lv, sinfo_args=(R.Object,))
_1: R.Tuple = R.vm.kill_object(lv)
gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = R.match_cast(
lv, R.Tensor((16, 16, 3, 3), dtype="float32")
)
lv_m: R.Tensor((16, 16, 3, 3), dtype="float32") = gv1
_: R.Object = R.call_packed("set_item", R.prim_value(0), lv_m, sinfo_args=(R.Object,))
_1: R.Tuple = R.vm.kill_object(lv_m)
lv1: R.Object = R.call_packed("get_item", R.prim_value(0), sinfo_args=(R.Object,))
gv3: R.Tensor((3, 16, 3, 3), dtype="float32") = R.match_cast(
lv1, R.Tensor((3, 16, 3, 3), dtype="float32")
)
lv1_m: R.Tensor((3, 16, 3, 3), dtype="float32") = gv3
lv2 = R.call_tir(
cls.transform_layout_IOHW_to_OIHW,
(lv1,),
(lv1_m,),
out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
)
_2: R.Tuple = R.vm.kill_object(lv1)
_2: R.Tuple = R.vm.kill_object(lv1_m)
_3: R.Object = R.call_packed("set_item", R.prim_value(1), lv2, sinfo_args=(R.Object,))
gv: R.Tuple = R.tuple()
return gv
Expand Down Expand Up @@ -146,13 +154,17 @@ def main_transform_params(slice_shape_expr: R.Shape(["slice_index"])):
slice_index = T.int64()

param = R.call_packed("get_item", R.prim_value(0), sinfo_args=(R.Object,))
gv: R.Tensor((16, 16), dtype="float32") = R.match_cast(
param, R.Tensor((16, 16), dtype="float32")
)
param_m: R.Tensor((16, 16), dtype="float32") = gv
transformed = R.call_tir(
cls.slice_buffer,
(param,),
(param_m,),
tir_vars=[slice_index],
out_sinfo=R.Tensor((16,), dtype="float32"),
)
unused_1_ = R.vm.kill_object(param)
unused_1_ = R.vm.kill_object(param_m)
unused_2_ = R.call_packed(
"set_item", R.prim_value(0), transformed, sinfo_args=(R.Object,)
)
Expand All @@ -175,14 +187,100 @@ def slice_buffer(
tvm.ir.assert_structural_equal(after, Expected, map_free_vars=True)


# TODO(tvm-team): remove once regression get fixed
@pytest.mark.skip("temp disable, minor regression on read/write region in zero dim buffer")
def test_param_shape_symbolic():
@I.ir_module
class Before:
@T.prim_func
def transform_layout_IOHW_to_OIHW(var_w1: T.handle, var_out: T.handle):
ic = T.int32()
w1 = T.match_buffer(var_w1, (ic, 16, 3, 3), "float32")
out = T.match_buffer(var_out, (16, ic, 3, 3), "float32")
for ax0, ax1, ax2, ax3 in T.grid(16, ic, 3, 3):
with T.block("layout_transform"):
o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(w1[i, o, h, w])
T.writes(out[o, i, h, w])
out[o, i, h, w] = w1[i, o, h, w]

@R.function
def main_transform_params(
params: R.Tuple(
R.Tensor((3, "ic", 3, 3), dtype="float32"),
R.Tensor((16, 16, 3, 3), dtype="float32"),
)
) -> R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor(("ic", 3, 3, 3), dtype="float32")
):
ic = T.int64()
# we expect ToNonDataflow and RemovePurityTracking to be invoked first
R.func_attr({"relax.force_pure": True})
cls = Before
lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
lv1: R.Tensor((3, ic, 3, 3), dtype="float32") = params[0]
lv2 = R.call_tir(
cls.transform_layout_IOHW_to_OIHW,
(lv1,),
out_sinfo=R.Tensor((ic, 3, 3, 3), dtype="float32"),
)
gv: R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"),
R.Tensor((ic, 3, 3, 3), dtype="float32"),
) = (lv, lv2)
return gv

@I.ir_module
class Expected:
@T.prim_func
def transform_layout_IOHW_to_OIHW(var_w1: T.handle, var_out: T.handle):
ic = T.int32()
w1 = T.match_buffer(var_w1, (ic, 16, 3, 3), "float32")
out = T.match_buffer(var_out, (16, ic, 3, 3), "float32")
for ax0, ax1, ax2, ax3 in T.grid(16, ic, 3, 3):
with T.block("layout_transform"):
o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(w1[i, o, h, w])
T.writes(out[o, i, h, w])
out[o, i, h, w] = w1[i, o, h, w]

@R.function
def main_transform_params() -> R.Tuple:
R.func_attr({"relax.force_pure": True})
ic = T.int64()
cls = Expected
gv: R.Object = R.call_packed("get_item", R.prim_value(1), sinfo_args=(R.Object,))
gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = R.match_cast(
gv, R.Tensor((16, 16, 3, 3), dtype="float32")
)
lv: R.Tensor((16, 16, 3, 3), dtype="float32") = gv1
_: R.Object = R.call_packed("set_item", R.prim_value(0), lv, sinfo_args=(R.Object,))
_1: R.Tuple = R.vm.kill_object(lv)
gv2: R.Object = R.call_packed("get_item", R.prim_value(0), sinfo_args=(R.Object,))
gv3: R.Tensor((3, ic, 3, 3), dtype="float32") = R.match_cast(
gv2, R.Tensor((3, ic, 3, 3), dtype="float32")
)
lv1: R.Tensor((3, ic, 3, 3), dtype="float32") = gv3
lv2 = R.call_tir(
cls.transform_layout_IOHW_to_OIHW,
(lv1,),
out_sinfo=R.Tensor((ic, 3, 3, 3), dtype="float32"),
)
_2: R.Tuple = R.vm.kill_object(lv1)
_3: R.Object = R.call_packed("set_item", R.prim_value(1), lv2, sinfo_args=(R.Object,))
gv4: R.Tuple = R.tuple()
return gv4

after = LazyTransformParams()(Before)
tvm.ir.assert_structural_equal(after, Expected, map_free_vars=True)


def test_output_with_use_site():
@I.ir_module
class Module:
@T.prim_func
def copy(x: T.Buffer((), "float32"), y: T.Buffer((), "float32")):
with T.block("block"):
T.reads(x[()])
T.writes(y[()])
y[()] = x[()]

@R.function
Expand Down Expand Up @@ -212,8 +310,10 @@ def main_transform_params() -> R.Tuple:
R.func_attr({"relax.force_pure": True})
cls = Expected
x: R.Object = R.call_packed("get_item", R.prim_value(0), sinfo_args=(R.Object,))
y = R.call_tir(cls.copy, (x,), out_sinfo=R.Tensor((), dtype="float32"))
_: R.Tuple = R.vm.kill_object(x)
gv: R.Tensor((), dtype="float32") = R.match_cast(x, R.Tensor((), dtype="float32"))
x_m: R.Tensor((), dtype="float32") = gv
y = R.call_tir(cls.copy, (x_m,), out_sinfo=R.Tensor((), dtype="float32"))
_: R.Tuple = R.vm.kill_object(x_m)
z = R.call_tir(cls.copy, (y,), out_sinfo=R.Tensor((), dtype="float32"))
_1: R.Object = R.call_packed("set_item", R.prim_value(0), y, sinfo_args=(R.Object,))
_2: R.Object = R.call_packed("set_item", R.prim_value(1), z, sinfo_args=(R.Object,))
Expand Down