Skip to content

Commit

Permalink
Add support for sym_ite (pytorch#111440)
Browse files Browse the repository at this point in the history
This PR supports sym_ite. This is useful for converting SymBool to SymInt in e.g. pytorch#109916. Internally, it uses sympy.Piecewise. We cannot use sympy.ITE because it expects the arguments and output all to be boolean type but we want return SymInt type when converting a SymBool to SymInt. So we use sympy.Piecewise to denote the symbolic relationship.

Note that this pr uses the range analysis for sympy.Piecewise implemented in https://github.com/pytorch/pytorch/blob/main/torch/utils/_sympy/value_ranges.py.

Test Plan:
See added test.

Pull Request resolved: pytorch#111440
Approved by: https://github.com/ezyang
  • Loading branch information
ydwu4 authored and andreigh committed Oct 26, 2023
1 parent a6cd595 commit 0df6ae6
Show file tree
Hide file tree
Showing 9 changed files with 132 additions and 5 deletions.
3 changes: 3 additions & 0 deletions c10/core/SymNodeImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
virtual SymNode sym_not() {
TORCH_CHECK(false, "NYI");
};
virtual SymNode sym_ite(const SymNode& then_val, const SymNode& else_val) {
TORCH_CHECK(false, "NYI");
};
// NB: self is ignored here, only the arguments are used
virtual SymNode is_contiguous(
ArrayRef<SymNode> sizes,
Expand Down
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,7 @@
"parallel_and",
"parallel_or",
"sym_sqrt",
"sym_ite",
"sympy_is_channels_last_contiguous_2d",
"sympy_is_channels_last_contiguous_3d",
"sympy_is_channels_last_strides_2d",
Expand Down
3 changes: 2 additions & 1 deletion docs/source/torch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,7 @@ Symbolic Numbers
sym_max
sym_min
sym_not
sym_ite

Export Path
-------------
Expand Down Expand Up @@ -766,4 +767,4 @@ Operator Tags
.. py:module:: torch.storage
.. py:module:: torch.torch_version
.. py:module:: torch.types
.. py:module:: torch.version
.. py:module:: torch.version
47 changes: 46 additions & 1 deletion test/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,50 @@ def test_sym_ceil(self):
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""")

def test_sym_ite(self):
shape_env = ShapeEnv()
t = create_symint(shape_env, 5)
f = create_symint(shape_env, 4)
b1 = True
r1 = torch.sym_ite(b1, t, f)
self.assertTrue(r1 is t)
b2 = False
r2 = torch.sym_ite(b2, t, f)
self.assertTrue(r2 is f)
b3 = t == 5
r3 = torch.sym_ite(b3, t, f)
self.assertEqual(len(shape_env.guards), 0)
self.assertEqual(r3, 5)
self.assertEqual(type(t), type(r3))
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(Piecewise((s0, Eq(s0, 5)), (s1, True)), 5)""")
b4 = f == 5
r4 = torch.sym_ite(b4, t, f)
self.assertEqual(len(shape_env.guards), 1)
self.assertEqual(r4, 4)
self.assertEqual(type(f), type(r4))
self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(Piecewise((s0, Eq(s1, 5)), (s1, True)), 4)""")

def test_tracing_sym_ite(self):
def f(x):
b = x.shape[0] == 5
ret = torch.sym_ite(b, x.shape[0], x.shape[1])
return ret

gm = make_fx(f, tracing_mode="symbolic")(torch.ones(4, 5))
self.assertEqual(len(gm.shape_env.guards), 0)
self.assertExpectedInline(gm.code.strip(), """\
def forward(self, x_1):
sym_size = torch.ops.aten.sym_size(x_1, 0)
eq = sym_size == 5
sym_size_1 = torch.ops.aten.sym_size(x_1, 1); x_1 = None
sym_ite = torch.sym_ite(eq, sym_size, sym_size_1); eq = sym_size = sym_size_1 = None
return sym_ite""")
r1 = gm(torch.ones(4, 5))
self.assertIsInstance(r1, int)
self.assertEqual(r1, 5)
r2 = gm(torch.ones(5, 4))
self.assertIsInstance(r2, int)
self.assertEqual(r2, 5)

def test_int_conversion(self):
shape_env = ShapeEnv()
Expand Down Expand Up @@ -684,7 +728,8 @@ def guard_fn(v):

@parametrize("fn", list(symbolic_shapes.magic_methods.keys()))
def test_bool_method(self, fn):
if fn not in symbolic_shapes.bool_magic_methods:
# sym_ite has its own tests
if fn not in symbolic_shapes.bool_magic_methods or fn == "sym_ite":
self.skipTest(f"{fn} is non-bool")

is_unary_fn = fn in symbolic_shapes.unary_magic_methods
Expand Down
11 changes: 10 additions & 1 deletion torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _running_with_deploy():
'set_float32_matmul_precision', 'get_float32_matmul_precision',
'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat',
'SymBool', 'sym_not', 'unravel_index',
'sym_int', 'sym_float', 'sym_max', 'sym_min', 'compile', 'vmap',
'sym_int', 'sym_float', 'sym_max', 'sym_min', 'sym_ite', 'compile', 'vmap',
'export', 'autocast', 'cond',
]

Expand Down Expand Up @@ -390,6 +390,9 @@ def __or__(self, other) -> "SymBool":
def __sym_not__(self) -> "SymBool":
raise AssertionError("type stub not overridden")

def __sym_ite__(self, then_val, else_val):
raise AssertionError("type stub not overridden")

def __eq__(self, other) -> builtins.bool:
raise AssertionError("type stub not overridden")

Expand Down Expand Up @@ -456,6 +459,12 @@ def sym_min(a, b):
return b.__sym_min__(a)
return builtins.min(a, b) # type: ignore[operator]

def sym_ite(b, t, f):
assert isinstance(b, (SymBool, builtins.bool)) and type(t) == type(f)
if isinstance(b, SymBool):
return b.__sym_ite__(t, f)
return t if b else f

# Check to see if we can load C extensions, and if not provide some guidance
# on what the problem might be.
try:
Expand Down
18 changes: 18 additions & 0 deletions torch/csrc/utils/python_symnode.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,19 @@ class PythonSymNodeImpl : public c10::SymNodeImpl {
return getPyObj().attr("str")().cast<std::string>();
}

c10::SymNode dispatch_sym_ite_(
const char* fname,
const c10::SymNode& other,
const c10::SymNode& third) {
auto pother = dynamic_cast<PythonSymNodeImpl*>(other.get());
auto pthird = dynamic_cast<PythonSymNodeImpl*>(third.get());
TORCH_CHECK(pother);
TORCH_CHECK(pthird);
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr(fname)(pother->getPyObj(), pthird->getPyObj());
return c10::make_intrusive<PythonSymNodeImpl>(r);
}

c10::SymNode dispatch_common_(const char* fname, const c10::SymNode& other) {
auto pother = dynamic_cast<PythonSymNodeImpl*>(other.get());
TORCH_CHECK(pother);
Expand Down Expand Up @@ -226,6 +239,11 @@ class PythonSymNodeImpl : public c10::SymNodeImpl {
return dispatch_common_(__func__, other);
}

c10::SymNode sym_ite(const c10::SymNode& other, const c10::SymNode& third)
override {
return dispatch_sym_ite_(__func__, other, third);
}

c10::SymNode sym_not() override {
return dispatch_common_(__func__);
}
Expand Down
52 changes: 50 additions & 2 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
sym_max,
sym_min,
sym_not,
sym_ite,
SymBool,
SymFloat,
SymInt,
Expand Down Expand Up @@ -962,6 +963,9 @@ def sym_min(self, other) -> "SymNode": # noqa: F811
def sym_max(self, other) -> "SymNode": # noqa: F811
return self._sym_max(other) # type: ignore[attr-defined]

def sym_ite(self, then_val, else_val) -> "SymNode":
return self._sym_ite(then_val, else_val)

def sym_sqrt(self) -> "SymNode": # noqa: F811
return self._sym_sqrt() # type: ignore[attr-defined]

Expand Down Expand Up @@ -1181,6 +1185,7 @@ def ceil_impl(a):
'neg': lambda a: -a,
'sym_min': lambda a, b: sympy.Min(a, b),
'sym_max': lambda a, b: sympy.Max(a, b),
'sym_ite': lambda a, t, f: sympy.Piecewise((t, a), (f, True)),
'sym_sqrt': lambda a: sympy.sqrt(a),
'abs': lambda a: sympy.Abs(a),
}
Expand Down Expand Up @@ -1318,13 +1323,13 @@ def _eval_is_non_overlapping_and_dense(sizes, strides):

# Most methods are only registered on SymInt and SymFloat
# Some methods are only be registered on SymBool
only_bool_magic_methods = {"and", "or", "sym_not"}
only_bool_magic_methods = {"and", "or", "sym_not", "sym_ite"}
# Methods that are also on SymBool, in addition to on SymInt and SymFloat
also_bool_magic_methods = {"eq"}
bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods

magic_methods_on_math = {"ceil", "floor"}
magic_methods_on_submodule = {"sym_float", "sym_sqrt", "sym_min", "sym_max", "sym_not"}
magic_methods_on_submodule = {"sym_float", "sym_sqrt", "sym_min", "sym_max", "sym_not", "sym_ite"}
magic_methods_on_operator_with_trailing_underscore = {"and", "or"}

def method_to_operator(method):
Expand Down Expand Up @@ -1463,6 +1468,36 @@ def unary_magic_impl(self):

if method in unary_magic_methods:
setattr(SymNode, f"_{method_attr}", unary_magic_impl)
elif method == "sym_ite":

def sym_ite_impl(pred_node, then_node, else_node):
out_hint = then_node.hint if pred_node.hint else else_node.hint
if SYM_FUNCTION_MODE:
return to_node(
pred_node,
_handle_sym_dispatch(
sym_ite,
(wrap_node(pred_node), wrap_node(then_node), wrap_node(else_node)), {}
)
)

try:
out = func(pred_node.expr, then_node.expr, else_node.expr)
except Exception:
log.warning("failed to eval %s(%s, %s, %s)", method, pred_node.expr, then_node.expr, else_node.expr)
raise

out = safe_expand(out)
fx_node, _ = pred_node.shape_env.create_fx_call_function(
sym_ite,
(
pred_node.fx_node,
then_node.fx_node,
else_node.fx_node
)
)
return SymNode(out, pred_node.shape_env, then_node.pytype, out_hint, fx_node=fx_node)
setattr(SymNode, f"_{method_attr}", sym_ite_impl)
else:
setattr(SymNode, f"_{method_attr}", binary_magic_impl)

Expand Down Expand Up @@ -1602,6 +1637,19 @@ def rbinary_magic_impl(self, other):

if method in unary_magic_methods:
setattr(user_type, f"__{method}__", unary_magic_impl)
elif method == "sym_ite":

def sym_ite_magic_impl(pred, then_val, else_val):
pred_node = pred.node
then_node = to_node(pred_node, then_val)
else_node = to_node(pred_node, else_val)
if then_node is NotImplemented or else_node is NotImplemented:
return NotImplemented
assert isinstance(then_node, SymNode) and isinstance(else_node, SymNode) and then_node.pytype == else_node.pytype
ret = wrap_node(getattr(pred.node, method_attr)(then_node, else_node))
return get_constant(ret) if ret.node.is_constant() else ret

setattr(user_type, f"__{method}__", sym_ite_magic_impl)
else:
setattr(user_type, f"__{method}__", binary_magic_impl)
if method in reflectable_magic_methods:
Expand Down
1 change: 1 addition & 0 deletions torch/fx/experimental/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def wrapper(*args):
torch.sym_float: lift(ops.to_real),
torch.sym_max: lift(ops.max),
torch.sym_min: lift(ops.min),
torch.sym_ite: lift(lambda b, t, f: t if b else f),
sym_sqrt: lift(ops.sqrt),
# Not lifted because we only use this function as a
# marker for adding the expression as validator input.
Expand Down
1 change: 1 addition & 0 deletions torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def get_ignored_functions() -> Set[Callable]:
torch.sym_max,
torch.sym_min,
torch.sym_not,
torch.sym_ite,
torch.sym_constrain_range,
torch.sym_constrain_range_for_size,
torch.tril_indices,
Expand Down

0 comments on commit 0df6ae6

Please sign in to comment.