Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch jit.script interoperability #115

Closed
boris-kuz opened this issue May 11, 2021 · 3 comments
Closed

PyTorch jit.script interoperability #115

boris-kuz opened this issue May 11, 2021 · 3 comments
Labels
enhancement New feature or request

Comments

@boris-kuz
Copy link

I'm not sure if this is the expected behavior, but not being able to script modules that use einops would be very unfortunate.

Describe the bug
When you try and script a class that uses an einops operation in its forward path, torch.jit.script throws an error:

NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
  File "~/.local/lib/python3.8/site-packages/einops/einops.py", line 379
def rearrange(tensor, pattern, **axes_lengths):
                                ~~~~~~~~~~~~~ <--- HERE

Reproduction steps

In [1]: from torch import nn
In [2]: import torch
In [3]: from einops import rearrange

In [4]: class Foo(nn.Module):
   ...:     def __init__(self):
   ...:         super(Foo, self).__init__()
   ...:     def forward(self, x):
   ...:         return rearrange(x, 'a b -> b a')
   ...:
In [5]: torch.jit.script(Foo())

Expected behavior
The scripting should work.

Your platform
python @ 3.8.5
einops @ 0.3.0
pytorch @ 1.8.1

@boris-kuz boris-kuz added the bug Something isn't working label May 11, 2021
@machineko
Copy link

machineko commented May 21, 2021

@arogozhnikov It should be "simple" fix some meta-programming or just create axes_lengths from string or dict not using **kwargs. (but probably a big part of code to rewrite or take logical path to support torch.jit as there is a lot of custom types and meta-magic in einops)

Dunno about performance hit on parsing should be same as kwargs parsing or even faster using some speedy string parsing methods.

Very simple example with string parsing dict should be even easier.

y1 = reduce_torchscript(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h2=2, w2=2)
# change to
y1 = reduce_torchscript(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', 'h2=2, w2=2')
# or change to
y1 = reduce_torchscript(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', 'h2:2, w2:2')

and

def simple_parse_torchscript(axes_lengths: str) -> List[Tuple[str, int]]:
    data = axes_lengths.split(",")
    container = torch.jit.annotate(List[Tuple[str, int]], [])
    for i in data:
        key, value = i.split("=")
        value = int(value)
        container.append((key, value))
    return sorted(container)

# new
def reduce_torchscript(tensor, pattern: str, reduction: Reduction, axes_lengths: str = ""):
      hashable_axes_lengths = simple_parse_torchscript(axes_lengths)  # torch hashable or smth like this dunno but still need to change few thinks in _prepare_transformation_recipe

      recipe = _prepare_transformation_recipe_torchscript(pattern, reduction, axes_lengths=hashable_axes_lengths)
      return recipe.apply(tensor)

And then pytorch jit should work fine but there is a lot of places that need to change types or add annotation for torchscript.

And even simpler will be just adding extra few methods to create transformation_recipe without any extra kwargs and with torchscript types.

All of this shouldn't be hard but will take a lot of time so im not sure if its worth a changing right now (tracing works fine with einops as long as i can tell so you can just use jit.trace(method, input) as this will return torch.jit.ScriptFunction)

@arogozhnikov
Copy link
Owner

arogozhnikov commented May 27, 2021

I think @machineko summarized situation pretty well.

Torchscript slowly evolves, I've came thru exercise of "what would it take to support torchscript" a week ago, and the answer is - actually too much:

  • **args are not supported (I see no good reason for this, it's only a minor exception in type annotations), but no way current interface would work until torchscript implements this. So this creates a separate entrance point, and same functions would not work for torch and other frameworks at the same time
  • code inside also requires complete reannotation (narrowing down types to torch-recognizable would drive to incorrect annotation for other frameworks) and moving away from some current templates (latter is fine)
  • automated backend resolution - does not seem to be supportable within torchscript at all
  • callbacks support should be dropped

All these points almost demand an introduction of a separate torch-only codebase.

As of now, I recommend using tracing, and if you need scripting, first trace einops-containing part and then script (this also works). I'll keep an eye on future torchscript updates

@arogozhnikov
Copy link
Owner

Closing as this was implemented and available on pypi for torch layers. As for functions, still pending on torchscript improvements

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants