Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1506,6 +1506,7 @@ def create_convert_map(
"triu.default": self._tril_triu(relax.op.triu),
"trunc.default": self._unary_op(relax.op.trunc),
# binary
"add": self._binary_op(relax.op.add, operator.add),
"add.Tensor": self._binary_op(relax.op.add, operator.add),
"add.Scalar": self._binary_op(relax.op.add, operator.add),
"add_.Tensor": self._binary_op(relax.op.add, operator.add),
Expand Down Expand Up @@ -1560,6 +1561,7 @@ def create_convert_map(
"pow.Scalar": self._binary_op(relax.op.power, operator.pow),
"pow.Tensor_Scalar": self._binary_op(relax.op.power, operator.pow),
"pow.Tensor_Tensor": self._binary_op(relax.op.power, operator.pow),
"sub": self._binary_op(relax.op.subtract, operator.sub),
"sub.Tensor": self._binary_op(relax.op.subtract, operator.sub),
"sub.Scalar": self._binary_op(relax.op.subtract, operator.sub),
"__and__.Tensor": self._binary_op(relax.op.bitwise_and, operator.and_),
Expand Down
52 changes: 51 additions & 1 deletion tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -1654,7 +1654,57 @@ def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tuple(
verify_model(RSub2(), example_args2, {}, expected_rsub2)


# IsIn
def test_dynamic_shape_bare_add_sub():
"""Test that bare 'add' and 'sub' ops (from operator.add/sub in dynamic shape arithmetic)."""

class AddModel(torch.nn.Module):
def forward(self, x):
# With dynamic shapes, torch.export may emit operator.add nodes
# for shape arithmetic. We test that the model imports successfully.
return x + x

class SubModel(torch.nn.Module):
def forward(self, x):
return x - x

@I.ir_module
class ExpectedAdd:
@R.function
def main(x: R.Tensor(("s0", 4), dtype="float32")) -> R.Tuple(
R.Tensor(("s0", 4), dtype="float32")
):
s0 = T.int64(is_size_var=True)
R.func_attr({"tir_var_lower_bound": {"s77": 2}})
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The hardcoded symbolic variable name "s77" in func_attr makes this test fragile, as the name generated by PyTorch's export is not guaranteed to be consistent. While map_free_vars=True can map symbolic variables in shapes, it doesn't apply to string keys in the attributes dictionary. Since the main purpose of this test is to verify the translation of bare add and sub operators, and not the range constraints, I suggest removing this line to make the test more robust.

with R.dataflow():
lv: R.Tensor((s0, 4), dtype="float32") = R.add(x, x)
gv: R.Tuple(R.Tensor((s0, 4), dtype="float32")) = (lv,)
R.output(gv)
return gv

@I.ir_module
class ExpectedSub:
@R.function
def main(x: R.Tensor(("s0", 4), dtype="float32")) -> R.Tuple(
R.Tensor(("s0", 4), dtype="float32")
):
s0 = T.int64(is_size_var=True)
R.func_attr({"tir_var_lower_bound": {"s77": 2}})
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For the same reason as in ExpectedAdd, the hardcoded symbolic variable name "s77" here makes the test fragile. To improve robustness, I recommend removing this line.

with R.dataflow():
lv: R.Tensor((s0, 4), dtype="float32") = R.subtract(x, x)
gv: R.Tuple(R.Tensor((s0, 4), dtype="float32")) = (lv,)
R.output(gv)
return gv

example_args = (torch.randn(8, 4),)
batch = torch.export.Dim("batch", min=2)
dynamic_shapes = {"x": {0: batch}}

verify_model(
AddModel(), example_args, {}, ExpectedAdd, dynamic_shapes=dynamic_shapes, map_free_vars=True
)
verify_model(
SubModel(), example_args, {}, ExpectedSub, dynamic_shapes=dynamic_shapes, map_free_vars=True
)


def test_isin():
Expand Down
Loading