Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] torch.Tensor.__getitem__ fails when indexing/slicing with a single element, with error: TypeError: 'int' object is not iterable #768

Open
chaoz-dev opened this issue Jul 22, 2022 · 3 comments

Comments

@chaoz-dev
Copy link
Contributor

Problem
The torch2trt conversion for torch.Tensor.__getitem__ fails when indexing/slicing for single element; eg.

tensor = torch.rand(2, 3)
tensor[0]

The issue is in torch2trt/converters/getitem.py, where slices is assumed to be iterable. This is not necessarily true in the aforementioned use case, where slices will actually be a single element (specifically the int given as the indexing argument).

Script
Running the following script getitem-element.py using NGC 22.06-py3:

  import logging    
  import tensorrt    
  import torch    
  import torch2trt    
      
      
  logging.basicConfig(level=logging.INFO)    
  torch.manual_seed(0)    
      
  DEVICE = 'cuda:0'    
  TENSOR = torch.rand(2, 3).to(DEVICE)    
      
      
  class Model(torch.nn.Module):    
      def __init__(self):    
          super().__init__()    
      
      def forward(self, tensor):    
          return tensor[0]    
      
      
  if __name__ == "__main__":    
      model = Model().eval().to(DEVICE)    
      out = model(TENSOR)    
      print(f'Expected model output: {out}')    
      
      model_trt = torch2trt.torch2trt(    
          model, [TENSOR], max_batch_size=TENSOR.shape[0], log_level=tensorrt.Logger.INFO    
      )    
      out = model_trt(TENSOR)    
      print(f'TRT model output: {out}')    

produces the following output:

root@8f319e91dd9a:/opt# python /scripts/getitem-element.py
Expected model output: tensor([0.4963, 0.7682, 0.0885], device='cuda:0')
[07/22/2022-23:49:16] [TRT] [I] [MemUsageChange] Init CUDA: CPU +464, GPU +0, now: CPU 1299, GPU 817 (MiB)
[07/22/2022-23:49:16] [TRT] [I] [MemUsageSnapshot] Begin constructing builder kernel library: CPU 1299 MiB, GPU 817 MiB
[07/22/2022-23:49:16] [TRT] [I] [MemUsageSnapshot] End constructing builder kernel library: CPU 1453 MiB, GPU 859 MiB
Traceback (most recent call last):
  File "/scripts/getitem-element.py", line 27, in <module>
    model_trt = torch2trt.torch2trt(
  File "/opt/conda/lib/python3.8/site-packages/torch2trt-0.4.0-py3.8.egg/torch2trt/torch2trt.py", line 736, in torch2trt
    outputs = module(*inputs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1148, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/scripts/getitem-element.py", line 19, in forward
    return tensor[0]
  File "/opt/conda/lib/python3.8/site-packages/torch2trt-0.4.0-py3.8.egg/torch2trt/torch2trt.py", line 307, in wrapper
    converter["converter"](ctx)
  File "/opt/conda/lib/python3.8/site-packages/torch2trt-0.4.0-py3.8.egg/torch2trt/converters/getitem.py", line 34, in convert_tensor_getitem
    num_ellipsis = len(input.shape) - num_slice_types(slices)
  File "/opt/conda/lib/python3.8/site-packages/torch2trt-0.4.0-py3.8.egg/torch2trt/converters/getitem.py", line 18, in num_slice_types
    for s in slices:
TypeError: 'int' object is not iterable
@chaoz-dev
Copy link
Contributor Author

This is likely the same issue presented in #247.

@chaoz-dev
Copy link
Contributor Author

I'll post a solution shortly.

I believe the following should work:
We should be able to convert slices into a tuple if it is not already one, and consume that as the iterable input as previously.

I believe this follows PyTorch behaviors correctly as well; specifically,

tensor[(0,)] == tensor[0]
tensor[(0, 1)] == tensor[0][1]
tensor[(0, 1), 0] == tensor[[0, 1], 0] # This case isn't handled yet; see #755 

@chaoz-dev
Copy link
Contributor Author

chaoz-dev commented Jul 23, 2022

Looks like there might be issues with the : and ... arguments as well, when used on the first dim of the tensor.

Caused by #769

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant