Skip to content

Commit

Permalink
Merge pull request #275 from arogozhnikov/torch-dynamic-compile
Browse files Browse the repository at this point in the history
cover dynamic shapes in torch.compile, introduce fallback if shape was not cacheable
  • Loading branch information
arogozhnikov committed Aug 10, 2023
2 parents f656128 + 8ac8d4f commit 6fc4e09
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 4 deletions.
2 changes: 2 additions & 0 deletions einops/_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,8 @@ def __len__(self):
def __getitem__(self, item):
return self.elements[item]

# default equality and hash is used (True only with itself, hash taken of id)


class TensorflowBackend(AbstractBackend):
framework_name = "tensorflow"
Expand Down
11 changes: 8 additions & 3 deletions einops/einops.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,14 @@ def _apply_recipe(
backend, recipe: TransformRecipe, tensor: Tensor, reduction_type: Reduction, axes_lengths: HashableAxesLengths
) -> Tensor:
# this method implements actual work for all backends for 3 operations
init_shapes, axes_reordering, reduced_axes, added_axes, final_shapes, n_axes_w_added = _reconstruct_from_shape(
recipe, backend.shape(tensor), axes_lengths
)
try:
init_shapes, axes_reordering, reduced_axes, added_axes, final_shapes, n_axes_w_added = _reconstruct_from_shape(
recipe, backend.shape(tensor), axes_lengths
)
except TypeError:
# shape or one of passed axes lengths is not hashable (i.e. they are symbols)
_result = _reconstruct_from_shape_uncached(recipe, backend.shape(tensor), axes_lengths)
(init_shapes, axes_reordering, reduced_axes, added_axes, final_shapes, n_axes_w_added) = _result
if init_shapes is not None:
tensor = backend.reshape(tensor, init_shapes)
if axes_reordering is not None:
Expand Down
28 changes: 27 additions & 1 deletion tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from einops import EinopsError
from einops.einops import rearrange, reduce, repeat, _enumerate_directions, _reductions
from . import collect_test_backends
from . import collect_test_backends, is_backend_tested

imp_op_backends = collect_test_backends(symbolic=False, layers=False)
sym_op_backends = collect_test_backends(symbolic=True, layers=False)
Expand Down Expand Up @@ -579,3 +579,29 @@ def test_list_inputs():
repeat(list(x), "... -> b (...)", b=3),
repeat(x, "... -> b (...)", b=3),
)


def test_torch_compile_with_dynamic_shape():
if not is_backend_tested("torch"):
pytest.skip()
import torch
# somewhat reasonable debug messages
torch._dynamo.config.verbose = True
def func1(x):
# test contains ellipsis
a, b, c, *other = x.shape
x = rearrange(x, '(a a2) b c ... -> b (c a2) (a ...)', a2=2)
# test contains passing expression as axis length
x = reduce(x, 'b ca2 A -> b A', 'sum', ca2=c * 2)
return x

# seems can't test static and dynamic in the same test run.
# func1_compiled_static = torch.compile(func1, dynamic=False, fullgraph=True, backend='aot_eager')
func1_compiled_dynamic = torch.compile(func1, dynamic=True, fullgraph=True, backend='aot_eager')

x = torch.randn(size=[4, 5, 6, 3])
assert torch.equal(func1_compiled_dynamic(x), func1(x))
# check with input of different dimensionality, and with all shape elements changed
x = torch.randn(size=[6, 3, 4, 2, 3])
assert torch.equal(func1_compiled_dynamic(x), func1(x))

0 comments on commit 6fc4e09

Please sign in to comment.