diff --git a/c10/core/SymNodeImpl.h b/c10/core/SymNodeImpl.h index 62dfcfc27d369..3b27d6ae6a6cc 100644 --- a/c10/core/SymNodeImpl.h +++ b/c10/core/SymNodeImpl.h @@ -97,6 +97,9 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target { virtual SymNode sym_not() { TORCH_CHECK(false, "NYI"); }; + virtual SymNode sym_ite(const SymNode& then_val, const SymNode& else_val) { + TORCH_CHECK(false, "NYI"); + }; // NB: self is ignored here, only the arguments are used virtual SymNode is_contiguous( ArrayRef sizes, diff --git a/docs/source/conf.py b/docs/source/conf.py index b6880b9d55996..99e6422161275 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -967,6 +967,7 @@ "parallel_and", "parallel_or", "sym_sqrt", + "sym_ite", "sympy_is_channels_last_contiguous_2d", "sympy_is_channels_last_contiguous_3d", "sympy_is_channels_last_strides_2d", diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 089c802fa6549..82fa1273f05c4 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -706,6 +706,7 @@ Symbolic Numbers sym_max sym_min sym_not + sym_ite Export Path ------------- @@ -766,4 +767,4 @@ Operator Tags .. py:module:: torch.storage .. py:module:: torch.torch_version .. py:module:: torch.types -.. py:module:: torch.version \ No newline at end of file +.. py:module:: torch.version diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 925cfea801006..07180f8210dfa 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -423,6 +423,50 @@ def test_sym_ceil(self): self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""") + def test_sym_ite(self): + shape_env = ShapeEnv() + t = create_symint(shape_env, 5) + f = create_symint(shape_env, 4) + b1 = True + r1 = torch.sym_ite(b1, t, f) + self.assertTrue(r1 is t) + b2 = False + r2 = torch.sym_ite(b2, t, f) + self.assertTrue(r2 is f) + b3 = t == 5 + r3 = torch.sym_ite(b3, t, f) + self.assertEqual(len(shape_env.guards), 0) + self.assertEqual(r3, 5) + self.assertEqual(type(t), type(r3)) + self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(Piecewise((s0, Eq(s0, 5)), (s1, True)), 5)""") + b4 = f == 5 + r4 = torch.sym_ite(b4, t, f) + self.assertEqual(len(shape_env.guards), 1) + self.assertEqual(r4, 4) + self.assertEqual(type(f), type(r4)) + self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(Piecewise((s0, Eq(s1, 5)), (s1, True)), 4)""") + + def test_tracing_sym_ite(self): + def f(x): + b = x.shape[0] == 5 + ret = torch.sym_ite(b, x.shape[0], x.shape[1]) + return ret + + gm = make_fx(f, tracing_mode="symbolic")(torch.ones(4, 5)) + self.assertEqual(len(gm.shape_env.guards), 0) + self.assertExpectedInline(gm.code.strip(), """\ +def forward(self, x_1): + sym_size = torch.ops.aten.sym_size(x_1, 0) + eq = sym_size == 5 + sym_size_1 = torch.ops.aten.sym_size(x_1, 1); x_1 = None + sym_ite = torch.sym_ite(eq, sym_size, sym_size_1); eq = sym_size = sym_size_1 = None + return sym_ite""") + r1 = gm(torch.ones(4, 5)) + self.assertIsInstance(r1, int) + self.assertEqual(r1, 5) + r2 = gm(torch.ones(5, 4)) + self.assertIsInstance(r2, int) + self.assertEqual(r2, 5) def test_int_conversion(self): shape_env = ShapeEnv() @@ -684,7 +728,8 @@ def guard_fn(v): @parametrize("fn", list(symbolic_shapes.magic_methods.keys())) def test_bool_method(self, fn): - if fn not in symbolic_shapes.bool_magic_methods: + # sym_ite has its own tests + if fn not in symbolic_shapes.bool_magic_methods or fn == "sym_ite": self.skipTest(f"{fn} is non-bool") is_unary_fn = fn in symbolic_shapes.unary_magic_methods diff --git a/torch/__init__.py b/torch/__init__.py index 969fd5d765f85..05c00c50dceb1 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -55,7 +55,7 @@ def _running_with_deploy(): 'set_float32_matmul_precision', 'get_float32_matmul_precision', 'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat', 'SymBool', 'sym_not', 'unravel_index', - 'sym_int', 'sym_float', 'sym_max', 'sym_min', 'compile', 'vmap', + 'sym_int', 'sym_float', 'sym_max', 'sym_min', 'sym_ite', 'compile', 'vmap', 'export', 'autocast', 'cond', ] @@ -390,6 +390,9 @@ def __or__(self, other) -> "SymBool": def __sym_not__(self) -> "SymBool": raise AssertionError("type stub not overridden") + def __sym_ite__(self, then_val, else_val): + raise AssertionError("type stub not overridden") + def __eq__(self, other) -> builtins.bool: raise AssertionError("type stub not overridden") @@ -456,6 +459,12 @@ def sym_min(a, b): return b.__sym_min__(a) return builtins.min(a, b) # type: ignore[operator] +def sym_ite(b, t, f): + assert isinstance(b, (SymBool, builtins.bool)) and type(t) == type(f) + if isinstance(b, SymBool): + return b.__sym_ite__(t, f) + return t if b else f + # Check to see if we can load C extensions, and if not provide some guidance # on what the problem might be. try: diff --git a/torch/csrc/utils/python_symnode.h b/torch/csrc/utils/python_symnode.h index 1c7de6d204cd9..112de773ea382 100644 --- a/torch/csrc/utils/python_symnode.h +++ b/torch/csrc/utils/python_symnode.h @@ -145,6 +145,19 @@ class PythonSymNodeImpl : public c10::SymNodeImpl { return getPyObj().attr("str")().cast(); } + c10::SymNode dispatch_sym_ite_( + const char* fname, + const c10::SymNode& other, + const c10::SymNode& third) { + auto pother = dynamic_cast(other.get()); + auto pthird = dynamic_cast(third.get()); + TORCH_CHECK(pother); + TORCH_CHECK(pthird); + py::gil_scoped_acquire acquire; + auto r = getPyObj().attr(fname)(pother->getPyObj(), pthird->getPyObj()); + return c10::make_intrusive(r); + } + c10::SymNode dispatch_common_(const char* fname, const c10::SymNode& other) { auto pother = dynamic_cast(other.get()); TORCH_CHECK(pother); @@ -226,6 +239,11 @@ class PythonSymNodeImpl : public c10::SymNodeImpl { return dispatch_common_(__func__, other); } + c10::SymNode sym_ite(const c10::SymNode& other, const c10::SymNode& third) + override { + return dispatch_sym_ite_(__func__, other, third); + } + c10::SymNode sym_not() override { return dispatch_common_(__func__); } diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 256756fc1ca7d..0a597f6e919b9 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -35,6 +35,7 @@ sym_max, sym_min, sym_not, + sym_ite, SymBool, SymFloat, SymInt, @@ -962,6 +963,9 @@ def sym_min(self, other) -> "SymNode": # noqa: F811 def sym_max(self, other) -> "SymNode": # noqa: F811 return self._sym_max(other) # type: ignore[attr-defined] + def sym_ite(self, then_val, else_val) -> "SymNode": + return self._sym_ite(then_val, else_val) + def sym_sqrt(self) -> "SymNode": # noqa: F811 return self._sym_sqrt() # type: ignore[attr-defined] @@ -1181,6 +1185,7 @@ def ceil_impl(a): 'neg': lambda a: -a, 'sym_min': lambda a, b: sympy.Min(a, b), 'sym_max': lambda a, b: sympy.Max(a, b), + 'sym_ite': lambda a, t, f: sympy.Piecewise((t, a), (f, True)), 'sym_sqrt': lambda a: sympy.sqrt(a), 'abs': lambda a: sympy.Abs(a), } @@ -1318,13 +1323,13 @@ def _eval_is_non_overlapping_and_dense(sizes, strides): # Most methods are only registered on SymInt and SymFloat # Some methods are only be registered on SymBool -only_bool_magic_methods = {"and", "or", "sym_not"} +only_bool_magic_methods = {"and", "or", "sym_not", "sym_ite"} # Methods that are also on SymBool, in addition to on SymInt and SymFloat also_bool_magic_methods = {"eq"} bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods magic_methods_on_math = {"ceil", "floor"} -magic_methods_on_submodule = {"sym_float", "sym_sqrt", "sym_min", "sym_max", "sym_not"} +magic_methods_on_submodule = {"sym_float", "sym_sqrt", "sym_min", "sym_max", "sym_not", "sym_ite"} magic_methods_on_operator_with_trailing_underscore = {"and", "or"} def method_to_operator(method): @@ -1463,6 +1468,36 @@ def unary_magic_impl(self): if method in unary_magic_methods: setattr(SymNode, f"_{method_attr}", unary_magic_impl) + elif method == "sym_ite": + + def sym_ite_impl(pred_node, then_node, else_node): + out_hint = then_node.hint if pred_node.hint else else_node.hint + if SYM_FUNCTION_MODE: + return to_node( + pred_node, + _handle_sym_dispatch( + sym_ite, + (wrap_node(pred_node), wrap_node(then_node), wrap_node(else_node)), {} + ) + ) + + try: + out = func(pred_node.expr, then_node.expr, else_node.expr) + except Exception: + log.warning("failed to eval %s(%s, %s, %s)", method, pred_node.expr, then_node.expr, else_node.expr) + raise + + out = safe_expand(out) + fx_node, _ = pred_node.shape_env.create_fx_call_function( + sym_ite, + ( + pred_node.fx_node, + then_node.fx_node, + else_node.fx_node + ) + ) + return SymNode(out, pred_node.shape_env, then_node.pytype, out_hint, fx_node=fx_node) + setattr(SymNode, f"_{method_attr}", sym_ite_impl) else: setattr(SymNode, f"_{method_attr}", binary_magic_impl) @@ -1602,6 +1637,19 @@ def rbinary_magic_impl(self, other): if method in unary_magic_methods: setattr(user_type, f"__{method}__", unary_magic_impl) + elif method == "sym_ite": + + def sym_ite_magic_impl(pred, then_val, else_val): + pred_node = pred.node + then_node = to_node(pred_node, then_val) + else_node = to_node(pred_node, else_val) + if then_node is NotImplemented or else_node is NotImplemented: + return NotImplemented + assert isinstance(then_node, SymNode) and isinstance(else_node, SymNode) and then_node.pytype == else_node.pytype + ret = wrap_node(getattr(pred.node, method_attr)(then_node, else_node)) + return get_constant(ret) if ret.node.is_constant() else ret + + setattr(user_type, f"__{method}__", sym_ite_magic_impl) else: setattr(user_type, f"__{method}__", binary_magic_impl) if method in reflectable_magic_methods: diff --git a/torch/fx/experimental/validator.py b/torch/fx/experimental/validator.py index 44c703706c78b..8ca34aaa4a089 100644 --- a/torch/fx/experimental/validator.py +++ b/torch/fx/experimental/validator.py @@ -276,6 +276,7 @@ def wrapper(*args): torch.sym_float: lift(ops.to_real), torch.sym_max: lift(ops.max), torch.sym_min: lift(ops.min), + torch.sym_ite: lift(lambda b, t, f: t if b else f), sym_sqrt: lift(ops.sqrt), # Not lifted because we only use this function as a # marker for adding the expression as validator input. diff --git a/torch/overrides.py b/torch/overrides.py index 3931d5bf51838..b94896b1b5e94 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -221,6 +221,7 @@ def get_ignored_functions() -> Set[Callable]: torch.sym_max, torch.sym_min, torch.sym_not, + torch.sym_ite, torch.sym_constrain_range, torch.sym_constrain_range_for_size, torch.tril_indices,