diff --git a/einops/_backends.py b/einops/_backends.py index 73f75ace..ea85db04 100644 --- a/einops/_backends.py +++ b/einops/_backends.py @@ -392,6 +392,8 @@ def __len__(self): def __getitem__(self, item): return self.elements[item] + # default equality and hash is used (True only with itself, hash taken of id) + class TensorflowBackend(AbstractBackend): framework_name = "tensorflow" diff --git a/einops/einops.py b/einops/einops.py index 6e616bb2..804833b8 100644 --- a/einops/einops.py +++ b/einops/einops.py @@ -230,9 +230,14 @@ def _apply_recipe( backend, recipe: TransformRecipe, tensor: Tensor, reduction_type: Reduction, axes_lengths: HashableAxesLengths ) -> Tensor: # this method implements actual work for all backends for 3 operations - init_shapes, axes_reordering, reduced_axes, added_axes, final_shapes, n_axes_w_added = _reconstruct_from_shape( - recipe, backend.shape(tensor), axes_lengths - ) + try: + init_shapes, axes_reordering, reduced_axes, added_axes, final_shapes, n_axes_w_added = _reconstruct_from_shape( + recipe, backend.shape(tensor), axes_lengths + ) + except TypeError: + # shape or one of passed axes lengths is not hashable (i.e. they are symbols) + _result = _reconstruct_from_shape_uncached(recipe, backend.shape(tensor), axes_lengths) + (init_shapes, axes_reordering, reduced_axes, added_axes, final_shapes, n_axes_w_added) = _result if init_shapes is not None: tensor = backend.reshape(tensor, init_shapes) if axes_reordering is not None: diff --git a/tests/test_ops.py b/tests/test_ops.py index 8611f95c..46ffa1f8 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -5,7 +5,7 @@ from einops import EinopsError from einops.einops import rearrange, reduce, repeat, _enumerate_directions, _reductions -from . import collect_test_backends +from . import collect_test_backends, is_backend_tested imp_op_backends = collect_test_backends(symbolic=False, layers=False) sym_op_backends = collect_test_backends(symbolic=True, layers=False) @@ -579,3 +579,29 @@ def test_list_inputs(): repeat(list(x), "... -> b (...)", b=3), repeat(x, "... -> b (...)", b=3), ) + + +def test_torch_compile_with_dynamic_shape(): + if not is_backend_tested("torch"): + pytest.skip() + import torch + # somewhat reasonable debug messages + torch._dynamo.config.verbose = True + def func1(x): + # test contains ellipsis + a, b, c, *other = x.shape + x = rearrange(x, '(a a2) b c ... -> b (c a2) (a ...)', a2=2) + # test contains passing expression as axis length + x = reduce(x, 'b ca2 A -> b A', 'sum', ca2=c * 2) + return x + + # seems can't test static and dynamic in the same test run. + # func1_compiled_static = torch.compile(func1, dynamic=False, fullgraph=True, backend='aot_eager') + func1_compiled_dynamic = torch.compile(func1, dynamic=True, fullgraph=True, backend='aot_eager') + + x = torch.randn(size=[4, 5, 6, 3]) + assert torch.equal(func1_compiled_dynamic(x), func1(x)) + # check with input of different dimensionality, and with all shape elements changed + x = torch.randn(size=[6, 3, 4, 2, 3]) + assert torch.equal(func1_compiled_dynamic(x), func1(x)) +