Skip to content

Using torch.compile with einops

Alex Rogozhnikov edited this page Apr 24, 2023 · 2 revisions

Pytorch 2.0 introduce torch.compile, which 'compiles' python code into graphs.

  • if you use einops layers (Rearrange, Reduce, Einmix) - no action needed, they perfectly work with torch.compile, torch.jit.script, torch.jit.trace
  • if you use einops functions (rearrange, reduce, repeat, einsum, pack, unpack), you need to allow ops in graph:
from einops._torch_specific import allow_ops_in_compiled_graph  # requires einops>=0.6.1
allow_ops_in_compiled_graph()

Explanation

If you use torch.compile without calling allow_ops_in_compiled_graph first, torch.compile will break graph on einops functions. This causes significant slowdown.

In experiments with transformers (see https://github.com/arogozhnikov/einops/issues/250#issuecomment-1508138804 for details) we see that torch.compile equally well optimizes plain pytorch, einops layers and einops functions. But it has to be informed that einops functions are 'well-behaved' and can be included in the graph. That's why you need to call allow_ops_in_compiled_graph first.

Clone this wiki locally