Skip to content

Commit

Permalink
Allow passing in dynamic_shapes without original argument name (pytor…
Browse files Browse the repository at this point in the history
…ch#112298)

Pull Request resolved: pytorch#112298
Approved by: https://github.com/avikchaudhuri
  • Loading branch information
tugsbayasgalan authored and Skylion007 committed Nov 14, 2023
1 parent 2b3ae15 commit 8744293
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 7 deletions.
65 changes: 65 additions & 0 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torch.utils._pytree import (
LeafSpec,
tree_flatten,
tree_map,
tree_unflatten,
TreeSpec,
treespec_loads,
Expand Down Expand Up @@ -1456,5 +1457,69 @@ def forward(self, arg0_1):
zeros = torch.ops.aten.zeros.default([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None
return (zeros,)""")

def test_non_arg_name_dynamic_shapes_api(self):
def foo(a, b):
return a.sum() + b.sum()

dim = torch.export.Dim("dim")
ep = torch.export.export(foo, (torch.randn(4, 4), torch.randn(4, 4)), dynamic_shapes=(None, {0: dim}))

test_inp = (torch.randn(4, 4), torch.randn(7, 4))
self.assertEqual(ep(*test_inp), foo(*test_inp))

ep_v2 = torch.export.export(foo, (torch.randn(4, 4), torch.randn(4, 4)), dynamic_shapes=(None, None))
with self.assertRaisesRegex(RuntimeError, "Input arg1_1.shape\[0\] is specialized at 4"):
ep_v2(*test_inp)

def test_non_arg_name_dynamic_shapes_api_with_kwarg(self):
def foo(a, b, kw1, kw2):
return a.sum() + b.sum() + kw1.sum() - kw2.sum()

dim = torch.export.Dim("dim")
dim_for_kw1 = torch.export.Dim("dim_for_kw1")
ep = torch.export.export(
foo,
(torch.randn(4, 4), torch.randn(4, 4)),
{"kw2": torch.ones(4, 4), "kw1": torch.zeros(4, 4)},
# We are specifying dynamism on the first kwarg even though user passed in
# different order
dynamic_shapes=(None, {0: dim}, {0: dim_for_kw1}, None))

test_inp = (torch.randn(4, 4), torch.randn(7, 4))
test_kwargs = {"kw2": torch.ones(4, 4), "kw1": torch.zeros(9, 4)}
# This should work even if the kwarg order are flipped.
self.assertEqual(ep(*test_inp, **test_kwargs), foo(*test_inp, **test_kwargs))

def test_non_arg_name_dynamic_shapes_api_with_container_type(self):
def foo(a, b):
return a[0].sum() + a[1].sum() + b.sum()

inp_a = (torch.randn(4, 4), torch.randn(4, 4))
inp_b = torch.randn(4, 4)
inp = (inp_a, inp_b)

count = 0
def dynamify_inp(x):
# Mark the second input a[1] dynamic
nonlocal count
if count == 1:
dim = torch.export.Dim("dim", min=3)
count += 1
return {0: dim}
count += 1
return None

dynamic_shapes = tree_map(dynamify_inp, inp)

ep = torch.export.export(foo, inp, dynamic_shapes=dynamic_shapes)

test_inp = ((torch.randn(4, 4), torch.randn(2, 4)), torch.randn(4, 4))
with self.assertRaisesRegex(
RuntimeError,
"Input arg1_1.shape\[0\] is outside of specified dynamic range \[3, inf\]"
):
ep(*test_inp)


if __name__ == '__main__':
run_tests()
12 changes: 8 additions & 4 deletions torch/_export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _process_dynamic_shapes(
f: Callable,
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
dynamic_shapes: Optional[Dict[str, Any]] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
) -> Optional[List[Constraint]]:
if dynamic_shapes is None or len(dynamic_shapes) == 0:
return None
Expand Down Expand Up @@ -209,6 +209,8 @@ def update_symbols(tensor, shape):
signature = inspect.signature(f.forward) if isinstance(f, torch.nn.Module) else inspect.signature(f)
combined_args = signature.bind(*args, **kwargs).arguments

# This means user didn't specify dynamic shapes with argument names.
combined_args = combined_args if isinstance(dynamic_shapes, Mapping) else list(combined_args.values()) # type: ignore[assignment]
for tensor, shape in tree_zip(combined_args, dynamic_shapes):
update_symbols(tensor, shape)

Expand All @@ -229,15 +231,17 @@ def export__RC__(
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
*,
dynamic_shapes: Optional[Dict[str, Any]] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
preserve_module_call_signature: Tuple[str, ...] = (),
) -> ExportedProgram:
"""
API for exporting with dynamic shape specifications instead of constraints.
It should be considered "release candidate" (RC), meant to replace `export`.
Here, `dynamic_shapes` is expected to be a (possibly partial) dict from
argument names of `f` to dynamic shape specifications, as follows:
Here, `dynamic_shapes` is expected to be a dict from
argument names of `f` to dynamic shape specifications OR a tuple where each element
corresponds to the original order of the arguments defined in the function signature
,as follows:
- The dynamic shape of a tensor argument can be specified as:
- Either a dict from dynamic dimension indices to Dim types. It is not
required to include static dimension indices in this dict, but when
Expand Down
11 changes: 8 additions & 3 deletions torch/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def export(
kwargs: Optional[Dict[str, Any]] = None,
*,
constraints: Optional[List[Constraint]] = None,
dynamic_shapes: Optional[Dict[str, Any]] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
preserve_module_call_signature: Tuple[str, ...] = (),
) -> ExportedProgram:
"""
Expand Down Expand Up @@ -412,8 +412,13 @@ def export(
range of shapes. See :func:`dynamic_dim` docstring for examples on
how to use it.
dynamic_shapes: Should be a dict from argument names of ``f`` to their dynamic shape specifications,
as follows. The dynamic shape of a tensor argument can be specified as either
dynamic_shapes: Should either be:
1) a dict from argument names of ``f`` to their dynamic shape specifications,
2) a tuple that specifies dynamic shape specifications for each input in original order.
If you are specifying dynamism on keyword args, you will need to pass them in the order that
is defined in the original function signature.
The dynamic shape of a tensor argument can be specified as either
(1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
not required to include static dimension indices in this dict, but when they are,
they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
Expand Down

0 comments on commit 8744293

Please sign in to comment.