Skip to content

Commit

Permalink
Add support for dynamic indices
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoz-dev committed Dec 22, 2022
1 parent 88810a3 commit 4714a08
Showing 1 changed file with 92 additions and 51 deletions.
143 changes: 92 additions & 51 deletions torch2trt/converters/getitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,12 @@ def _is_list_or_tuple(obj):
return isinstance(obj, (list, tuple))


def _is_tensor(obj):
return isinstance(obj, (torch.Tensor, tensorrt.ITensor))
def _is_torch_tensor(obj):
return isinstance(obj, torch.Tensor)


def _is_trt_tensor(obj):
return isinstance(obj, trt.ITensor)


def _trt(ctx, tensor):
Expand Down Expand Up @@ -171,6 +175,22 @@ def _permute_trt(ctx, trt, start_axis, end_axis):
return shuffle_layer.get_output(0)


def _elementwise_trt(ctx, trt_a, trt_b, op):
return ctx.network.add_elementwise(trt_a, trt_b, op).get_output(0)


def _elementwise_gt_trt(ctx, trt_a, trt_b):
return _elementwise_trt(ctx, trt_a, trt_b, trt.ElementWiseOperation.GREATER)


def _select_trt(ctx, condition_trt, then_trt, else_trt):
return ctx.network.add_select(condition_trt, then_trt, else_trt).get_output(0)


def _shape_trt(ctx, trt):
return ctx.network.add_shape(trt).get_output(0)


def slice_to_trt(ctx, dim_size, dim_slice):

start = 0 if dim_slice.start is None else dim_slice.start
Expand Down Expand Up @@ -212,58 +232,91 @@ def _requires_advanced_indexing(indices):
# This is only necessary when the indices argument contains either list, tensor, or tuple index elements;
# if there are tensor elements, then at least one must either have greater than 0 dimensions or
# there must be other list or tuple elements present.
return any(((_is_list_or_tuple(index) and len(index) > 0) or (_is_tensor(index) and index.dim() > 0))
return any(((_is_list_or_tuple(index) and len(index) > 0) or (_is_torch_tensor(index) and index.dim() > 0))
for index in indices)


def _indices_to_trt(ctx, indices):
# Convert any list, tuple, tensor indices to TRT tensors.
# Convert single int inputs to tuple first for uniform processing.
# Note that this still leaves slices and Nones in the list of returned indices,
# as these cannot be converted to TRT tensors.
def _to_trt(index):
if _is_int(index):
index = (index,)

if _is_list_or_tuple(index):
index = _tensor_trt(ctx, index, dtype=torch.int32)._trt
elif _is_tensor(index):
index = _trt(ctx, index)
elif _is_torch_tensor(index):
index._trt = _trt(ctx, index)
index = index._trt

return index

return [_to_trt(index) for index in indices]


def _broadcast_indices(ctx, indices):
# All tensor indices must have the same rank for a gather operation.
# Since we iterate through each dimension individually to gather one axis at a time,
# we perform our own broadcasting and reshape all tensor indices to the same (max) rank for later use.
def _max_length_on_dim0(ctx, indices_trt):
# Get the longest length of any index TRT tensor.
# This is the length all indices will be broadcast to.
#
# Note that we currently only handle 1D tensor indices and we only broadcast single element TRT tensors.

# The rank of a 1D tensor is always given by axis 0.
max_rank = max(index.shape[0] for index in indices if _is_tensor(index))

def _broadcast(index):
if _is_tensor(index) and index.shape[0] < max_rank:
# Note that the tensor shape should be 1 here regardless,
# as otherwise this dimension is not broadcastable when running under PyTorch.
# However, we'll double check, just in case.
assert index.shape[0] == 1
index = _cat_trt(ctx, [index] * max_rank)
return index
# Note that we currently only broadcast single element 1D TRT tensors;
# consequently, we only need to examine the length on axis 0.

max_length_trt = make_int_wrapper(0)._trt
for index_trt in indices_trt:
if not _is_trt_tensor(index_trt):
continue

assert len(index_trt.shape) == 1, f"Encountered tensor index with shape {index_trt.shape} but only indices of rank 1 are currently supported."

return [_broadcast(index) for index in indices]
shape_trt = _shape_trt(ctx, index_trt)
length_trt = ctx.network.add_slice(shape_trt, [0], [1], [1]).get_output(0) # Length on axis 0.

gt_trt = _elementwise_gt_trt(ctx, length_trt, max_length_trt)
max_length_trt = _select_trt(ctx, gt_trt, length_trt, max_length_trt)

def _split_indices(indices):
return max_length_trt


def _broadcast_index(ctx, index_trt, broadcast_length_trt):
# Broadcast the index to the given length, if it's a tensor index and of length 1.
# Note that we currently only broadcast single element 1D TRT tensors.

if not _is_trt_tensor(index_trt):
return index_trt

# TODO(@chaoz): This implementation broadcasts an index even if it's unnecessary!
# ie. For an index already at max shape, we just end up slicing into an equivalent tensor!
# We should find a way to shortcut this processing; I'd like to use an if_conditional here,
# but because output shapes may be different, it is impossible to do so.
slice_layer = ctx.network.add_slice(index_trt, [0], [0], [1])
slice_layer.mode = trt.SliceMode.CLAMP
slice_layer.set_input(2, broadcast_length_trt)
slice_trt = slice_layer.get_output(0)

return slice_trt


def _broadcast_indices(ctx, indices_trt):
# All tensor indices must have the same length for a gather operation.
# Since we iterate through each dimension individually to gather one axis at a time,
# we perform our own broadcasting and reshape all tensor indices to the same max length for later use.
max_length_trt = _max_length_on_dim0(ctx, indices_trt)
return [_broadcast_index(ctx, index_trt, max_length_trt) for index_trt in indices_trt]


def _split_indices(indices_trt):
# Split indices into those used for slicing and those used for gathering.
# The colon operator (:) fills "blank" indices left by the removal of a slice or gather index,
# as we select every element in that axis if it is not being gathered or sliced, respectively.
colon = slice(None, None, None)
slices = [colon] * len(indices)
gathers = [colon] * len(indices)
slices = [colon] * len(indices_trt)
gathers = [colon] * len(indices_trt)

for axis, index in enumerate(indices):
slices_or_gathers = gathers if _is_tensor(index) else slices
slices_or_gathers[axis] = index
for axis, index_trt in enumerate(indices_trt):
slices_or_gathers = gathers if _is_trt_tensor(index_trt) else slices
slices_or_gathers[axis] = index_trt

return slices, gathers

Expand All @@ -285,8 +338,8 @@ def _requires_post_transpose(gather_indices):
def _advanced_gathernd_trt(ctx, input_trt, indices_trt):
# We only need the gather indices going forward,
# so we parse these out along with some metadata.
gather_indices = [GatherIndex(trt=index, axis=axis) for axis, index in enumerate(indices_trt)
if _is_tensor(index)]
gather_indices = [GatherIndex(trt=index_trt, axis=axis) for axis, index_trt in enumerate(indices_trt)
if _is_trt_tensor(index_trt)]
first_gather = gather_indices[0]

# An "identity" gather index, which just selects every element in the axis.
Expand Down Expand Up @@ -335,10 +388,7 @@ def _advanced_gathernd_trt(ctx, input_trt, indices_trt):
return input_trt


def _basic_indexing(ctx, input_, slices, output):
# TODO(chaoz): Need to handle tensor scalar elements, eg. torch.tensor((0)).
# These tensor slice elements are always single scalar values; otherwise, they are considered gather operations.

def _basic_indexing(ctx, input_, slices):
input_trt = input_._trt

# Step 1 - Replace ellipsis with expanded slices
Expand Down Expand Up @@ -425,22 +475,21 @@ def _basic_indexing(ctx, input_, slices, output):
final_shape = make_size_wrapper(final_shape)

layer = ctx.network.add_shuffle(output_trt)
layer.reshape_dims = tuple(output.shape) # exclude batch
layer.set_input(1, final_shape._trt)
output_trt = layer.get_output(0)

return output_trt


def _advanced_indexing(ctx, input_, indices, output):
def _advanced_indexing(ctx, input_, indices):
# Preprocess indices so that all following operations are on TRT tensors.
indices = _indices_to_trt(ctx, indices)
indices = _broadcast_indices(ctx, indices)
slices, gathers = _split_indices(indices)
indices_trt = _indices_to_trt(ctx, indices)
indices_trt = _broadcast_indices(ctx, indices_trt)
slices, gathers = _split_indices(indices_trt)

# All indices are gather operations,
# so we can trivially solve advanced indexing using a single gather layer.
if all([_is_tensor(gather) for gather in gathers]):
if all([_is_trt_tensor(gather) for gather in gathers]):
gathers = [_transpose_1d_trt(ctx, gather) for gather in gathers]
gathers = _cat_trt(ctx, gathers, axis=1)
return _gathernd_trt(ctx, input_._trt, gathers)
Expand All @@ -449,7 +498,7 @@ def _advanced_indexing(ctx, input_, indices, output):
# therefore, we can solve this by first applying a slice layer for all slice operations,
# then successively apply gather layers for each gather operation
# and transposing out any slice operations between two gather operations.
output_trt = _basic_indexing(ctx, input_, slices, input_[slices])
output_trt = _basic_indexing(ctx, input_, slices)
output_trt = _advanced_gathernd_trt(ctx, output_trt, gathers)

return output_trt
Expand Down Expand Up @@ -487,7 +536,7 @@ def convert_tensor_getitem(ctx):
# We use basic indexing when only slicing.
# Advanced indexing is only necessary when we perform gather operations.
convert_getitem = _basic_indexing if not _requires_advanced_indexing(indices) else _advanced_indexing
output._trt = convert_getitem(ctx, input_, indices, output)
output._trt = convert_getitem(ctx, input_, indices)


class LambdaModule(torch.nn.Module):
Expand Down Expand Up @@ -646,14 +695,6 @@ def test_tensor_getitem_0d_insert_dim_ellipsis():
return LambdaModule(lambda x: x[None, ...])


# TODO(chaoz): Still need to handle this case, with tensor scalar elements for indexing.
# This is actually a basic indexing case.
#
# @add_module_test(torch.float32, torch.device('cuda'), [(3, 2, 4)], max_batch_size=3)
# def test_tensor_getitem_0d_1tuple_colon():
# return LambdaModule(lambda x: x[(0), :])


@add_module_test(torch.float32, torch.device('cuda'), [(2, 5, 4, 3)])
def test_tensor_getitem_int_tuple():
return LambdaModule(lambda x: x[0, (0, 1)])
Expand Down

0 comments on commit 4714a08

Please sign in to comment.