Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Input tensors not being read torch neuronx 2.1.2 #906

Closed
PrateekAg1511 opened this issue Jun 13, 2024 · 4 comments
Closed

Input tensors not being read torch neuronx 2.1.2 #906

PrateekAg1511 opened this issue Jun 13, 2024 · 4 comments

Comments

@PrateekAg1511
Copy link

aws_neuron_venv_pytorch/lib64/python3.9/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:144: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=0, shape=torch.Size([1, 60, 184]), dtype=torch.float32)
warnings.warn(
aws_neuron_venv_pytorch/lib64/python3.9/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:144: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=1, shape=torch.Size([1, 60]), dtype=torch.uint8)
warnings.warn(

Code:

import torch
from torchcrf import CRF
num_tags = 184
model = CRF(num_tags)

emissions = torch.rand([1,60,184])
mask = torch.ones([1,60], dtype=torch.uint8)

def decode_fn(emissions , mask):
a = model.decode(emissions , mask)
a = torch.Tensor(a)
a = a.to(xm.xla_device())
return (a)

inputs_crf = emissions , mask

trace_crf = torch_neuronx.trace(decode_fn , inputs_crf)

Looking for some help with this!

@jluntamazon
Copy link
Contributor

The reason why this is not able to be traced in torch_neuronx is because this function cannot be traced in torch (See: https://pytorch.org/docs/stable/generated/torch.jit.trace.html). Tracing works by tracking operations applied to tensors throughout the compute graph.

There are 2 separate issues that prevent torchcrf from producing a graph that can be traced:

  1. The graph uses calls to Tensor.item (https://pytorch.org/docs/stable/generated/torch.Tensor.item.html). This converts the tensor type to a primitive python type which does not allow the resulting value to be tracked by the trace.

    The torchcrf calls are made here: https://github.com/kmkurn/pytorch-crf/blob/master/torchcrf/__init__.py#L326-L334

    Here is an example reproduction which shows that torch will error when attempting to trace a graph which returns a .item() value:

import torch

def item_example(tensor):
    return tensor.item()

inputs = (torch.tensor(1),)
trace = torch.jit.trace(item_example, inputs)
  1. The graph returns a List[List[int]] which is not allowed by torch tracing.

    The torchcrf output type is defined here: https://github.com/kmkurn/pytorch-crf/blob/master/torchcrf/__init__.py#L319

    Here is an example reproduction which shows that torch will error when attempting to return a list of lists:

import torch

def nested_list_example(tensor):
    return [[tensor]]

inputs = (torch.tensor(1),)
trace = torch.jit.trace(nested_list_example, inputs)

To resolve these issues, you would have to modify torchcrf to remove the .item() calls and instead return a Tuple[List[torch.Tensor]] or any another compatible type of your choice. I was able to successfully and accurately execute the model after making these changes in a local version of the torchcrf package.

In general, a good rule of thumb is to first ensure that your model can be traced using torch.jit.trace() before trying torch_neuronx.trace().

It is important to note that the compute defined in this module may not be a good candidate for Neuron hardware if it is being executed in isolation. Neuron hardware excels in scenarios with dense numerically intense compute with many matrix multiplications.

@PrateekAg1511
Copy link
Author

@jluntamazon Thanks a lot for the response and great insights!

I implemented it and it works successfully!

I will be using CRF as the last layer for NER model alongwith BERT model, hence was looking to compile the entire model with torch neuronx.

@PrateekAg1511
Copy link
Author

@jluntamazon Now I am facing issue with batch inferencing using this approach. My output is in Tuple[List[torch.Tensor]] which works well for batch_size =1. But when I try to use DataParallel on the traced model, it says inconsistent size between inputs and outputs. I looked into jit trace and found that even converting directly to torch.tensor would not work as torch.tensor is treated as a constant. I tried creating torch.zeroes(batch_size , seq_length) and then replacing the values in this tensor but that also did not work.

Any pointers on how to make DataParallel work on CRF ?

It would be of great help!

@jyang-aws
Copy link
Contributor

Close as the initial issue is resolved. opening a new one following up the support for DP.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants