Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] Avoid adding unnecessary slicing #7479

Merged
merged 3 commits into from
Feb 26, 2021

Conversation

masahi
Copy link
Member

@masahi masahi commented Feb 19, 2021

PyTorch object detection models have many uses of slicing like arr[:, None] which is essentially nop, see for example https://github.com/pytorch/vision/blob/f7fae490980885e426fef01bb214025b9eddb832/torchvision/models/detection/roi_heads.py#L80

The current aten::slice converter does not detect such nop slicing and inserts unnecessary strided_slice. The worst of all, many of arr[:, None] converts to dynamic strided slice, which results in too much any_dim.

I simplified aten::slice conversion and fixed the fast path detection. I've updated MaskRCNN rewrite pattern to remove one of dyn.strided_slice, which serves as a test case.

please review @kevinthesun @anijain2305 @siju-samuel

@masahi
Copy link
Member Author

masahi commented Feb 24, 2021

ping @kevinthesun @anijain2305 @siju-samuel this should be easy and no brainer

Copy link
Contributor

@kevinthesun kevinthesun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@masahi masahi merged commit e664b2f into apache:main Feb 26, 2021
@apivovarov
Copy link
Contributor

Lokiiiiii pushed a commit to Lokiiiiii/tvm that referenced this pull request Mar 2, 2021
* simplyfing

* improved fast path for slice

* update rewrite pattern for maskrcnn
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Mar 2, 2021
* simplyfing

* improved fast path for slice

* update rewrite pattern for maskrcnn
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants