You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Problem
The torch2trt conversion for torch.Tensor.__getitem__ fails when indexing/slicing for single element; eg.
tensor = torch.rand(2, 3)
tensor[0]
The issue is in torch2trt/converters/getitem.py, where slices is assumed to be iterable. This is not necessarily true in the aforementioned use case, where slices will actually be a single element (specifically the int given as the indexing argument).
Script
Running the following script getitem-element.py using NGC 22.06-py3:
import logging
import tensorrt
import torch
import torch2trt
logging.basicConfig(level=logging.INFO)
torch.manual_seed(0)
DEVICE = 'cuda:0'
TENSOR = torch.rand(2, 3).to(DEVICE)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, tensor):
return tensor[0]
if __name__ == "__main__":
model = Model().eval().to(DEVICE)
out = model(TENSOR)
print(f'Expected model output: {out}')
model_trt = torch2trt.torch2trt(
model, [TENSOR], max_batch_size=TENSOR.shape[0], log_level=tensorrt.Logger.INFO
)
out = model_trt(TENSOR)
print(f'TRT model output: {out}')
produces the following output:
root@8f319e91dd9a:/opt# python /scripts/getitem-element.py
Expected model output: tensor([0.4963, 0.7682, 0.0885], device='cuda:0')
[07/22/2022-23:49:16] [TRT] [I] [MemUsageChange] Init CUDA: CPU +464, GPU +0, now: CPU 1299, GPU 817 (MiB)
[07/22/2022-23:49:16] [TRT] [I] [MemUsageSnapshot] Begin constructing builder kernel library: CPU 1299 MiB, GPU 817 MiB
[07/22/2022-23:49:16] [TRT] [I] [MemUsageSnapshot] End constructing builder kernel library: CPU 1453 MiB, GPU 859 MiB
Traceback (most recent call last):
File "/scripts/getitem-element.py", line 27, in <module>
model_trt = torch2trt.torch2trt(
File "/opt/conda/lib/python3.8/site-packages/torch2trt-0.4.0-py3.8.egg/torch2trt/torch2trt.py", line 736, in torch2trt
outputs = module(*inputs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1148, in _call_impl
result = forward_call(*input, **kwargs)
File "/scripts/getitem-element.py", line 19, in forward
return tensor[0]
File "/opt/conda/lib/python3.8/site-packages/torch2trt-0.4.0-py3.8.egg/torch2trt/torch2trt.py", line 307, in wrapper
converter["converter"](ctx)
File "/opt/conda/lib/python3.8/site-packages/torch2trt-0.4.0-py3.8.egg/torch2trt/converters/getitem.py", line 34, in convert_tensor_getitem
num_ellipsis = len(input.shape) - num_slice_types(slices)
File "/opt/conda/lib/python3.8/site-packages/torch2trt-0.4.0-py3.8.egg/torch2trt/converters/getitem.py", line 18, in num_slice_types
for s in slices:
TypeError: 'int' object is not iterable
The text was updated successfully, but these errors were encountered:
I believe the following should work:
We should be able to convert slices into a tuple if it is not already one, and consume that as the iterable input as previously.
I believe this follows PyTorch behaviors correctly as well; specifically,
tensor[(0,)] == tensor[0]
tensor[(0, 1)] == tensor[0][1]
tensor[(0, 1), 0] == tensor[[0, 1], 0] # This case isn't handled yet; see #755
Problem
The
torch2trt
conversion fortorch.Tensor.__getitem__
fails when indexing/slicing for single element; eg.The issue is in
torch2trt/converters/getitem.py
, whereslices
is assumed to be iterable. This is not necessarily true in the aforementioned use case, where slices will actually be a single element (specifically the int given as the indexing argument).Script
Running the following script
getitem-element.py
using NGC22.06-py3
:produces the following output:
The text was updated successfully, but these errors were encountered: