Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.Tensor.__getitem__] Add advanced indexing (GatherND) support to torch.Tensor.__getitem__ #783

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

chaoz-dev
Copy link
Contributor

@chaoz-dev chaoz-dev commented Aug 4, 2022

The following documentation also appears at the top of the file:

  Our conversion of __getitem__ needs to handle basic and advanced indexing (specifically GatherND).
  See the numpy description for more information on different types of indexing, which pytorch follows:
  https://numpy.org/doc/stable/user/basics.indexing.html

  We use the following terms to describe our algorithm:
    t, a pytorch tensor of arbitrary shape and dimensions on which we are calling __getitem__.
    s, a slice index; eg. the operators :, ..., None, ().
    g, a gather index; eg. (x,...), [x,...], torch.tensor((x,...)) for any arbitrary scalar x.
        Note that we currently only handle 1D gather indices, so g is always 1D where described.

  Our algorithm works as follows:
    For an input tensor t, we check the indices argument.
    This results in the following cases:

    1. If all of the indices are slices, eg. t[s,s,s,...], this is considered basic indexing,
    and we can trivially convert this to TRT using the slice layer (along with some supporting layers).

    2. If there are any gather indices, regardless of the presence of slice indices,
    eg. t[...,g,g,g,...], this is now considered advanced indexing
    and we are no longer just slicing, but also gathering on the input tensor.
    We convert differently depending on the composition of the indices.

    2a. If all of the indices are gather indices and there are no slice indices, eg. t[g,g,g,...],
    then we can trivially convert this to TRT using a single gather layer.

    2b. If we have a mix of slice and gather indices, eg. t[s,s,g,g,...], then the TRT conversion gets more complex.
    First, we split the indices into slice only indices and gather only indices of the same dimensions,
    using the colon operator for the axes where a gather or slice index was removed from the slice only
    or gather only indices, respectively; this allows us to process the slice and gather indices separately,
    where the colon operator allows us to ignore an axis when not processing that particular type of index.

    Consequently, we can now process t as if the indices only have slice operations, eg. t[s,s,:,:,...],
    using the same basic indexing methodology previously described in case (1) using a slice layer.
    Afterwards, all slicing operations are complete and we need only perform gather operations henceforth.

    Now using the output of the slice layer, we process all of the gather indices, eg. t[:,:,g,g,...].
    As the TRT gather layer does not handle slice indices (ie. colon operators),
    we cannot pass in all gather indices to the gather layer as in case (2a).
    This is especially problematic when the colon operator sits between two gather operations, eg. t[g,:,g].

    As a result, to account for these axes in which we have a colon operator,
    we need to continually transpose (permute) t such that each axis that we are gathering on is adjacent,
    until all axes on which we are gathering are adjacent; in other words, t[g,:,g] == transposed(t)[g,g,:]
    is a valid equivalency (we call this coalescing gather indices for brevity).
    This moves any dimensions with the colon operator out from between any two dimensions with gather operations
    and allows us to use the TRT gather layer to perform the needed gatherND operation,
    as now only gather indices are present in the indexing operation.

    The following examples using a 4D tensor of shape (3,3,3,3) shows the equivalent transpose operations needed
    so that all gather indices can be coalesced when indexing:

    t[:,g,:,:] == t.transpose(1,0)[g].transpose(0,1)
    t[:,:,g,:] == t.transpose(2,1).transpose(1,0)[g].transpose(0,1).transpose(1,2)
    t[:,:,:,g] == t.transpose(3,2).transpose(2,1).transpose(1,0)[g].transpose(0,1).transpose(1,2).transpose(2,3)
    t[g,:,g,:] == t.transpose(2,1)[g,g]
    t[g,:,:,g] == t.transpose(3,2).transpose(2,1)[g,g]
    t[:,g,g,:] == t.transpose(1,0).transpose(2,1)[g,g].transpose(0,1)
    t[:,g,:,g] == t.transpose(1,0).transpose(3,2).transpose(2,1)[g,g]
    t[:,:,g,g] == t.transpose(2,1).transpose(1,0).transpose(3,2).transpose(2,1)[g,g].transpose(0,1).transpose(1,2)
    t[g,g,:,g] == t.transpose(3,2)[g,g,g]
    t[g,:,g,g] == t.transpose(2,1).transpose(3,2)[g,g,g]
    t[:,g,g,g] == t.transpose(1,0).transpose(2,1).transpose(3,2)[g,g,g].transpose(0,1)

    Note the following from the above examples:
    - The first gather operation always transposes to dimension 0, if it is not already there.
    - Final transposes are needed after the gather operation iff gather indices are already coalesced together.

For posterity, here are some more examples of transposing different combinations of gatherND operations; the examples here are what's effectively implemented by the algorithm:

t2[:,x] == t2.transpose(1,0)[x].transpose(0,1)
t3[:,x,:] == t3.transpose(1,0)[x].transpose(0,1)
t3[:,:,x] == t3.transpose(2,1).transpose(1,0)[x].transpose(0,1).transpose(1,2)
t3[x,:,x] == t3.transpose(2,1)[x,x]
t3[:,x,x] == t3.transpose(1,0).transpose(2,1)[x,x].transpose(0,1)
t4[:,x,:,:] == t4.transpose(1,0)[x].transpose(0,1)
t4[:,:,x,:] == t4.transpose(2,1).transpose(1,0)[x].transpose(0,1).transpose(1,2)
t4[:,:,:,x] == t4.transpose(3,2).transpose(2,1).transpose(1,0)[x].transpose(0,1).transpose(1,2).transpose(2,3)
t4[x,:,x,:] == t4.transpose(1,2)[x,x]
t4[x,:,:,x] == t4.transpose(3,2).transpose(2,1)[x,x]
t4[:,x,x,:] == t4.transpose(1,0).transpose(2,1)[x,x].transpose(0,1)
t4[:,x,:,x] == t4.transpose(1,0).transpose(3,2).transpose(2,1)[x,x]
t4[:,:,x,x] == t4.transpose(2,1).transpose(1,0).transpose(3,2).transpose(2,1)[x,x].transpose(0,1).transpose(1,2)
t4[x,x,:,x] == t4.transpose(3,2)[x,x,x]
t4[x,:,x,x] == t4.transpose(2,1).transpose(3,2)[x,x,x]
t4[:,x,x,x] == t4.transpose(1,0).transpose(2,1).transpose(3,2)[x,x,x].transpose(0,1)
t5[x,:,x,:,:] == t5.transpose(2,1)[x,x]
t5[x,:,:,x,:] == t5.transpose(3,2).transpose(2,1)[x,x]
t5[x,:,:,:,x] == t5.transpose(4,3).transpose(3,2).transpose(2,1)[x,x]
t5[:,x,:,x,:] == t5.transpose(1,0).transpose(3,2).transpose(2,1)[x,x]
t5[:,:,x,:,x] == t5.transpose(2,1).transpose(1,0).transpose(4,3).transpose(3,2).transpose(2,1)[x,x]
t5[:,:,:,x,x] == t5.transpose(3,2).transpose(2,1).transpose(1,0).transpose(4,3).transpose(3,2).transpose(2,1)[x,x].transpose(0,1).transpose(1,2).transpose(2,3)
t5[x,x,:,x,:] == t5.transpose(3,2)[x,x,x]
t5[x,x,:,:,x] == t5.transpose(4,3).transpose(3,2)[x,x,x]
t5[x,:,x,x,:] == t5.transpose(2,1).transpose(3,2)[x,x,x]
t5[x,:,x,:,x] == t5.transpose(2,1).transpose(4,3).transpose(3,2)[x,x,x]
t5[x,:,:,x,x] == t5.transpose(3,2).transpose(2,1).transpose(4,3).transpose(3,2)[x,x,x]
t5[x,:,x,x,x] == t5.transpose(2,1).transpose(3,2).transpose(4,3)[x,x,x,x]
t5[x,0,x,x,:] == t5[:,0][x,x,x]
t5[x,0,:,x] == t5[:,0].transpose(2,1)[x,x]
t5[x,None,x] == t5[:,None].transpose(2,1)[x,x]
t5[x,None,:,x] == t5[:,None].transpose(3,2).transpose(2,1)[x,x]

@chaoz-dev
Copy link
Contributor Author

WIP but somewhat working advanced indexing.
Needs to be based off of #770

@chaoz-dev chaoz-dev changed the title [WIP] [WIP][torch.Tensor.__getitem__] Add advanced indexing (GatherND) support to torch.Tensor.__getitem__ Aug 4, 2022
@chaoz-dev chaoz-dev marked this pull request as ready for review August 8, 2022 21:39
@chaoz-dev
Copy link
Contributor Author

@jaybdub This is ready for review, but I may need to rebase... I started this before 0.4.0 was released, and certainly before the major API changes to master.

@chaoz-dev
Copy link
Contributor Author

Note that this is dependent on #770

@chaoz-dev chaoz-dev changed the title [WIP][torch.Tensor.__getitem__] Add advanced indexing (GatherND) support to torch.Tensor.__getitem__ [torch.Tensor.__getitem__] Add advanced indexing (GatherND) support to torch.Tensor.__getitem__ Aug 8, 2022
@chaoz-dev
Copy link
Contributor Author

Also note that I've left in a couple of TODOs... I'll leave those for future PRs

@chaoz-dev chaoz-dev force-pushed the chaoz/getitem-advanced-indexing branch 3 times, most recently from 88a1c2c to 2e2399b Compare August 9, 2022 02:30
@chaoz-dev
Copy link
Contributor Author

chaoz-dev commented Aug 9, 2022

NUM_SUCCESSFUL_CONVERSION: 88
NUM_FAILED_CONVERSION: 0
NUM_ABOVE_TOLERANCE: 0
NUM_pSNR_TOLERANCE: 0

@chaoz-dev chaoz-dev force-pushed the chaoz/getitem-advanced-indexing branch from 2e2399b to 9c6cba6 Compare November 22, 2022 00:34
@chaoz-dev
Copy link
Contributor Author

From my local testing looks like this implementation should now correctly support dynamic shapes as well.

@chaoz-dev
Copy link
Contributor Author

@jaybdub This PR is ready for review

Comment on lines +172 to +193
indices = torch.tensor((2, 0, 1), dtype=torch.int32).cuda()

module_trt = torch2trt(module, [tensor, indices], min_shapes=[(1, 1, 1, 1), (1, 1, 1)], max_shapes=[(7, 7, 7, 7), (5, 5, 5)], log_level=trt.Logger.INFO)

assert torch.allclose(module_trt(tensor, indices), module(tensor, indices), atol=1e-4, rtol=1e-4)

tensor = torch.rand(2, 4, 3, 4).cuda()
indices = torch.tensor((2, 0, 1), dtype=torch.int32).cuda()

assert torch.allclose(module_trt(tensor, indices), module(tensor, indices), atol=1e-4, rtol=1e-4)

tensor = torch.rand(4, 6, 5, 6).cuda()
indices = torch.tensor((2, 0, 1), dtype=torch.int32).cuda()

assert torch.allclose(module_trt(tensor, indices), module(tensor, indices), atol=1e-4, rtol=1e-4)


if __name__ == '__main__':
torch.manual_seed(0)
# test_getitem_dynamic()
test_getitem_dynamic_gathernd()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jaybdub Added a unit test for dynamic gatherND. Let me know if this is the right place for this test.


tensor = torch.rand(2, 4, 3, 4).cuda()
indices = torch.tensor((2, 0, 1), dtype=torch.int32).cuda()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jaybdub I think indices here could also be made to support dynamic values, but the current iteration of this isn't working... seems like the input flattener is removing this tensor from inputs before the add_inputs call for some reason (so we only see one input tensor instead of two). I'll have to dig into this more later.

@chaoz-dev chaoz-dev force-pushed the chaoz/getitem-advanced-indexing branch 2 times, most recently from 4714a08 to 7979f1f Compare December 23, 2022 08:19
@chaoz-dev chaoz-dev force-pushed the chaoz/getitem-advanced-indexing branch from 7979f1f to a9d009c Compare January 31, 2023 04:30
@chaoz-dev
Copy link
Contributor Author

chaoz-dev commented Jan 31, 2023

Updated the algorithm to remove unnecessary gather operations... now only one gather operation is needed whereas multiple were used before. Besides being more efficient, this also side steps a problem we noticed where intermediate operations were failing due to large shapes produced.

@chaoz-dev chaoz-dev force-pushed the chaoz/getitem-advanced-indexing branch from a9d009c to f927c6a Compare January 31, 2023 04:32
@chaoz-dev chaoz-dev force-pushed the chaoz/getitem-advanced-indexing branch from f927c6a to 636915e Compare January 31, 2023 17:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant