-
Notifications
You must be signed in to change notification settings - Fork 668
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[torch.Tensor.__getitem__
] Add advanced indexing (GatherND) support to torch.Tensor.__getitem__
#783
base: master
Are you sure you want to change the base?
[torch.Tensor.__getitem__
] Add advanced indexing (GatherND) support to torch.Tensor.__getitem__
#783
Conversation
WIP but somewhat working advanced indexing. |
torch.Tensor.__getitem__
] Add advanced indexing (GatherND) support to torch.Tensor.__getitem__
@jaybdub This is ready for review, but I may need to rebase... I started this before 0.4.0 was released, and certainly before the major API changes to master. |
Note that this is dependent on #770 |
torch.Tensor.__getitem__
] Add advanced indexing (GatherND) support to torch.Tensor.__getitem__
torch.Tensor.__getitem__
] Add advanced indexing (GatherND) support to torch.Tensor.__getitem__
Also note that I've left in a couple of TODOs... I'll leave those for future PRs |
88a1c2c
to
2e2399b
Compare
|
2e2399b
to
9c6cba6
Compare
From my local testing looks like this implementation should now correctly support dynamic shapes as well. |
@jaybdub This PR is ready for review |
indices = torch.tensor((2, 0, 1), dtype=torch.int32).cuda() | ||
|
||
module_trt = torch2trt(module, [tensor, indices], min_shapes=[(1, 1, 1, 1), (1, 1, 1)], max_shapes=[(7, 7, 7, 7), (5, 5, 5)], log_level=trt.Logger.INFO) | ||
|
||
assert torch.allclose(module_trt(tensor, indices), module(tensor, indices), atol=1e-4, rtol=1e-4) | ||
|
||
tensor = torch.rand(2, 4, 3, 4).cuda() | ||
indices = torch.tensor((2, 0, 1), dtype=torch.int32).cuda() | ||
|
||
assert torch.allclose(module_trt(tensor, indices), module(tensor, indices), atol=1e-4, rtol=1e-4) | ||
|
||
tensor = torch.rand(4, 6, 5, 6).cuda() | ||
indices = torch.tensor((2, 0, 1), dtype=torch.int32).cuda() | ||
|
||
assert torch.allclose(module_trt(tensor, indices), module(tensor, indices), atol=1e-4, rtol=1e-4) | ||
|
||
|
||
if __name__ == '__main__': | ||
torch.manual_seed(0) | ||
# test_getitem_dynamic() | ||
test_getitem_dynamic_gathernd() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jaybdub Added a unit test for dynamic gatherND. Let me know if this is the right place for this test.
|
||
tensor = torch.rand(2, 4, 3, 4).cuda() | ||
indices = torch.tensor((2, 0, 1), dtype=torch.int32).cuda() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jaybdub I think indices
here could also be made to support dynamic values, but the current iteration of this isn't working... seems like the input flattener is removing this tensor from inputs
before the add_inputs
call for some reason (so we only see one input tensor instead of two). I'll have to dig into this more later.
4714a08
to
7979f1f
Compare
7979f1f
to
a9d009c
Compare
Updated the algorithm to remove unnecessary gather operations... now only one gather operation is needed whereas multiple were used before. Besides being more efficient, this also side steps a problem we noticed where intermediate operations were failing due to large shapes produced. |
a9d009c
to
f927c6a
Compare
…etitem__ converter.
…Tensor.__getitem__ converter.
f927c6a
to
636915e
Compare
The following documentation also appears at the top of the file:
For posterity, here are some more examples of transposing different combinations of gatherND operations; the examples here are what's effectively implemented by the algorithm: