From 87442935180629484acde7a2c208fe8ad8facc02 Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Wed, 1 Nov 2023 13:42:25 -0700 Subject: [PATCH] Allow passing in dynamic_shapes without original argument name (#112298) Pull Request resolved: https://github.com/pytorch/pytorch/pull/112298 Approved by: https://github.com/avikchaudhuri --- test/export/test_export.py | 65 ++++++++++++++++++++++++++++++++++++++ torch/_export/__init__.py | 12 ++++--- torch/export/__init__.py | 11 +++++-- 3 files changed, 81 insertions(+), 7 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index cec9dcf685c2..01dd28a9aa77 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -24,6 +24,7 @@ from torch.utils._pytree import ( LeafSpec, tree_flatten, + tree_map, tree_unflatten, TreeSpec, treespec_loads, @@ -1456,5 +1457,69 @@ def forward(self, arg0_1): zeros = torch.ops.aten.zeros.default([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None return (zeros,)""") + def test_non_arg_name_dynamic_shapes_api(self): + def foo(a, b): + return a.sum() + b.sum() + + dim = torch.export.Dim("dim") + ep = torch.export.export(foo, (torch.randn(4, 4), torch.randn(4, 4)), dynamic_shapes=(None, {0: dim})) + + test_inp = (torch.randn(4, 4), torch.randn(7, 4)) + self.assertEqual(ep(*test_inp), foo(*test_inp)) + + ep_v2 = torch.export.export(foo, (torch.randn(4, 4), torch.randn(4, 4)), dynamic_shapes=(None, None)) + with self.assertRaisesRegex(RuntimeError, "Input arg1_1.shape\[0\] is specialized at 4"): + ep_v2(*test_inp) + + def test_non_arg_name_dynamic_shapes_api_with_kwarg(self): + def foo(a, b, kw1, kw2): + return a.sum() + b.sum() + kw1.sum() - kw2.sum() + + dim = torch.export.Dim("dim") + dim_for_kw1 = torch.export.Dim("dim_for_kw1") + ep = torch.export.export( + foo, + (torch.randn(4, 4), torch.randn(4, 4)), + {"kw2": torch.ones(4, 4), "kw1": torch.zeros(4, 4)}, + # We are specifying dynamism on the first kwarg even though user passed in + # different order + dynamic_shapes=(None, {0: dim}, {0: dim_for_kw1}, None)) + + test_inp = (torch.randn(4, 4), torch.randn(7, 4)) + test_kwargs = {"kw2": torch.ones(4, 4), "kw1": torch.zeros(9, 4)} + # This should work even if the kwarg order are flipped. + self.assertEqual(ep(*test_inp, **test_kwargs), foo(*test_inp, **test_kwargs)) + + def test_non_arg_name_dynamic_shapes_api_with_container_type(self): + def foo(a, b): + return a[0].sum() + a[1].sum() + b.sum() + + inp_a = (torch.randn(4, 4), torch.randn(4, 4)) + inp_b = torch.randn(4, 4) + inp = (inp_a, inp_b) + + count = 0 + def dynamify_inp(x): + # Mark the second input a[1] dynamic + nonlocal count + if count == 1: + dim = torch.export.Dim("dim", min=3) + count += 1 + return {0: dim} + count += 1 + return None + + dynamic_shapes = tree_map(dynamify_inp, inp) + + ep = torch.export.export(foo, inp, dynamic_shapes=dynamic_shapes) + + test_inp = ((torch.randn(4, 4), torch.randn(2, 4)), torch.randn(4, 4)) + with self.assertRaisesRegex( + RuntimeError, + "Input arg1_1.shape\[0\] is outside of specified dynamic range \[3, inf\]" + ): + ep(*test_inp) + + if __name__ == '__main__': run_tests() diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py index 891290f8fe4f..872824269dfd 100644 --- a/torch/_export/__init__.py +++ b/torch/_export/__init__.py @@ -82,7 +82,7 @@ def _process_dynamic_shapes( f: Callable, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None, - dynamic_shapes: Optional[Dict[str, Any]] = None, + dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, ) -> Optional[List[Constraint]]: if dynamic_shapes is None or len(dynamic_shapes) == 0: return None @@ -209,6 +209,8 @@ def update_symbols(tensor, shape): signature = inspect.signature(f.forward) if isinstance(f, torch.nn.Module) else inspect.signature(f) combined_args = signature.bind(*args, **kwargs).arguments + # This means user didn't specify dynamic shapes with argument names. + combined_args = combined_args if isinstance(dynamic_shapes, Mapping) else list(combined_args.values()) # type: ignore[assignment] for tensor, shape in tree_zip(combined_args, dynamic_shapes): update_symbols(tensor, shape) @@ -229,15 +231,17 @@ def export__RC__( args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None, *, - dynamic_shapes: Optional[Dict[str, Any]] = None, + dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, preserve_module_call_signature: Tuple[str, ...] = (), ) -> ExportedProgram: """ API for exporting with dynamic shape specifications instead of constraints. It should be considered "release candidate" (RC), meant to replace `export`. - Here, `dynamic_shapes` is expected to be a (possibly partial) dict from - argument names of `f` to dynamic shape specifications, as follows: + Here, `dynamic_shapes` is expected to be a dict from + argument names of `f` to dynamic shape specifications OR a tuple where each element + corresponds to the original order of the arguments defined in the function signature + ,as follows: - The dynamic shape of a tensor argument can be specified as: - Either a dict from dynamic dimension indices to Dim types. It is not required to include static dimension indices in this dict, but when diff --git a/torch/export/__init__.py b/torch/export/__init__.py index 19fe75346fff..599a3264e54a 100644 --- a/torch/export/__init__.py +++ b/torch/export/__init__.py @@ -353,7 +353,7 @@ def export( kwargs: Optional[Dict[str, Any]] = None, *, constraints: Optional[List[Constraint]] = None, - dynamic_shapes: Optional[Dict[str, Any]] = None, + dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, preserve_module_call_signature: Tuple[str, ...] = (), ) -> ExportedProgram: """ @@ -412,8 +412,13 @@ def export( range of shapes. See :func:`dynamic_dim` docstring for examples on how to use it. - dynamic_shapes: Should be a dict from argument names of ``f`` to their dynamic shape specifications, - as follows. The dynamic shape of a tensor argument can be specified as either + dynamic_shapes: Should either be: + 1) a dict from argument names of ``f`` to their dynamic shape specifications, + 2) a tuple that specifies dynamic shape specifications for each input in original order. + If you are specifying dynamism on keyword args, you will need to pass them in the order that + is defined in the original function signature. + + The dynamic shape of a tensor argument can be specified as either (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is not required to include static dimension indices in this dict, but when they are, they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,