-
Notifications
You must be signed in to change notification settings - Fork 339
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PyTorch jit.script interoperability #115
Comments
@arogozhnikov It should be "simple" fix some meta-programming or just create axes_lengths from string or dict not using **kwargs. (but probably a big part of code to rewrite or take logical path to support torch.jit as there is a lot of custom types and meta-magic in einops) Dunno about performance hit on parsing should be same as kwargs parsing or even faster using some speedy string parsing methods. Very simple example with string parsing dict should be even easier. y1 = reduce_torchscript(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h2=2, w2=2)
# change to
y1 = reduce_torchscript(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', 'h2=2, w2=2')
# or change to
y1 = reduce_torchscript(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', 'h2:2, w2:2') and def simple_parse_torchscript(axes_lengths: str) -> List[Tuple[str, int]]:
data = axes_lengths.split(",")
container = torch.jit.annotate(List[Tuple[str, int]], [])
for i in data:
key, value = i.split("=")
value = int(value)
container.append((key, value))
return sorted(container)
# new
def reduce_torchscript(tensor, pattern: str, reduction: Reduction, axes_lengths: str = ""):
hashable_axes_lengths = simple_parse_torchscript(axes_lengths) # torch hashable or smth like this dunno but still need to change few thinks in _prepare_transformation_recipe
recipe = _prepare_transformation_recipe_torchscript(pattern, reduction, axes_lengths=hashable_axes_lengths)
return recipe.apply(tensor) And then pytorch jit should work fine but there is a lot of places that need to change types or add annotation for torchscript. And even simpler will be just adding extra few methods to create transformation_recipe without any extra kwargs and with torchscript types. All of this shouldn't be hard but will take a lot of time so im not sure if its worth a changing right now (tracing works fine with einops as long as i can tell so you can just use |
I think @machineko summarized situation pretty well. Torchscript slowly evolves, I've came thru exercise of "what would it take to support torchscript" a week ago, and the answer is - actually too much:
All these points almost demand an introduction of a separate torch-only codebase. As of now, I recommend using tracing, and if you need scripting, first trace einops-containing part and then script (this also works). I'll keep an eye on future torchscript updates |
Closing as this was implemented and available on pypi for torch layers. As for functions, still pending on torchscript improvements |
I'm not sure if this is the expected behavior, but not being able to script modules that use einops would be very unfortunate.
Describe the bug
When you try and
script
a class that uses aneinops
operation in its forward path,torch.jit.script
throws an error:Reproduction steps
Expected behavior
The scripting should work.
Your platform
python @ 3.8.5
einops @ 0.3.0
pytorch @ 1.8.1
The text was updated successfully, but these errors were encountered: