Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Embedding layer of ViT not supported with dynamic batch size #861

Closed
fabiozappo opened this issue Apr 4, 2024 · 1 comment
Closed

Embedding layer of ViT not supported with dynamic batch size #861

fabiozappo opened this issue Apr 4, 2024 · 1 comment

Comments

@fabiozappo
Copy link

fabiozappo commented Apr 4, 2024

When trying to convert an openai clip model to neuron everything works fine if using a single batch size, but when switching to dynamic batch size it crashes when doing inference on a batch size different then the one used for compilation.

According to the error neuron fails when trying to use the Embedding layer, which is currently not supported.

Here you have an example, the model works with batch size 1 (same as compilation) but fails with batch size 2.

....../venv/lib/python3.8/site-packages/transformers/models/clip/modeling_clip.py:287: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
....../venv/lib/python3.8/site-packages/transformers/models/clip/modeling_clip.py:327: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
INFO:Neuron:The following operations are currently supported in torch-neuron for this model:
INFO:Neuron:prim::ListConstruct
INFO:Neuron:aten::slice
INFO:Neuron:aten::mul
INFO:Neuron:aten::Int
INFO:Neuron:aten::to
INFO:Neuron:aten::size
INFO:Neuron:aten::dropout
INFO:Neuron:prim::Constant
INFO:Neuron:aten::select
INFO:Neuron:aten::add
INFO:Neuron:aten::layer_norm
INFO:Neuron:aten::cat
INFO:Neuron:aten::contiguous
INFO:Neuron:aten::sigmoid
INFO:Neuron:aten::view
INFO:Neuron:aten::linear
INFO:Neuron:aten::bmm
INFO:Neuron:aten::expand
INFO:Neuron:aten::reshape
INFO:Neuron:prim::NumToTensor
INFO:Neuron:aten::_convolution
INFO:Neuron:aten::softmax
INFO:Neuron:aten::transpose
INFO:Neuron:aten::flatten
INFO:Neuron:The following operations are currently not supported in torch-neuron for this model:
INFO:Neuron:aten::embedding
INFO:Neuron:99.94% of all operations (including primitives) (1594 of 1595) are supported
INFO:Neuron:99.84% of arithmetic operations (615 of 616) are supported
....../venv/lib/python3.8/site-packages/transformers/models/clip/modeling_clip.py:287: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
....../venv/lib/python3.8/site-packages/transformers/models/clip/modeling_clip.py:327: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
INFO:Neuron:There are 1 ops of 1 different types in the TorchScript that are not compiled by neuron-cc: aten::embedding, (For more information see https://awsdocs-neuron.readthedocs-hosted.com/en/latest/release-notes/compiler/neuron-cc/neuron-cc-ops/neuron-cc-ops-pytorch.html)
INFO:Neuron:Number of arithmetic operators (pre-compilation) before = 616, fused = 615, percent fused = 99.84%
WARNING:tensorflow:From ....../venv/lib/python3.8/site-packages/torch_neuron/ops/aten.py:2413: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
INFO:Neuron:Compiling function _NeuronGraph$926 with neuron-cc
INFO:Neuron:Compiling with command line: '....../venv/bin/neuron-cc compile /tmp/tmpgktch6il/model --framework TENSORFLOW --pipeline compile SaveTemps --output /tmp/tmpgktch6il/graph_def.neff --verbose 35'
.....
Compiler status PASS
INFO:Neuron:Number of arithmetic operators (post-compilation) before = 616, compiled = 615, percent compiled = 99.84%
INFO:Neuron:The neuron partitioner created 1 sub-graphs
INFO:Neuron:Neuron successfully compiled 1 sub-graphs, Total fused subgraphs = 1, Percent of model sub-graphs successfully compiled = 100.0%
INFO:Neuron:Compiled these operators (and operator counts) to Neuron:
INFO:Neuron: => aten::Int: 145
INFO:Neuron: => aten::_convolution: 1
INFO:Neuron: => aten::add: 25
INFO:Neuron: => aten::bmm: 24
INFO:Neuron: => aten::cat: 1
INFO:Neuron: => aten::contiguous: 36
INFO:Neuron: => aten::dropout: 12
INFO:Neuron: => aten::expand: 1
INFO:Neuron: => aten::flatten: 1
INFO:Neuron: => aten::layer_norm: 26
INFO:Neuron: => aten::linear: 73
INFO:Neuron: => aten::mul: 48
INFO:Neuron: => aten::reshape: 12
INFO:Neuron: => aten::select: 1
INFO:Neuron: => aten::sigmoid: 12
INFO:Neuron: => aten::size: 37
INFO:Neuron: => aten::slice: 2
INFO:Neuron: => aten::softmax: 12
INFO:Neuron: => aten::to: 1
INFO:Neuron: => aten::transpose: 61
INFO:Neuron: => aten::view: 84
INFO:Neuron:Not compiled operators (and operator counts) to Neuron:
INFO:Neuron: => aten::embedding: 1 [not supported]
Computing visual embeddings of shape 512 with neuron, batch size 1
Traceback (most recent call last):
  File "test_resnet.py", line 70, in <module>
    visual_embeddings = neuron_image_encoder(dummy_image)
  File "....../venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
......venv/lib/python3.8/site-packages/torch/_ops.py(442): __call__
......venv/lib/python3.8/site-packages/torch_neuron/decorators.py(416): forward
....../venv/lib/python3.8/site-packages/torch/nn/modules/module.py(1182): _slow_forward
....../venv/lib/python3.8/site-packages/torch/nn/modules/module.py(1194): _call_impl
....../venv/lib/python3.8/site-packages/torch_neuron/graph.py(580): __call__
....../venv/lib/python3.8/site-packages/torch_neuron/graph.py(209): run_op
....../venv/lib/python3.8/site-packages/torch_neuron/graph.py(198): __call__
....../venv/lib/python3.8/site-packages/torch_neuron/runtime.py(69): forward
....../venv/lib/python3.8/site-packages/torch/nn/modules/module.py(1182): _slow_forward
....../venv/lib/python3.8/site-packages/torch/nn/modules/module.py(1194): _call_impl
....../venv/lib/python3.8/site-packages/torch/jit/_trace.py(976): trace_module
....../venv/lib/python3.8/site-packages/torch/jit/_trace.py(759): trace
....../venv/lib/python3.8/site-packages/torch_neuron/tensorboard.py(324): tb_parse
....../venv/lib/python3.8/site-packages/torch_neuron/tensorboard.py(550): tb_graph
....../venv/lib/python3.8/site-packages/torch_neuron/decorators.py(526): maybe_generate_tb_graph_def
....../venv/lib/python3.8/site-packages/torch_neuron/convert.py(580): maybe_determine_names_from_tensorboard
....../venv/lib/python3.8/site-packages/torch_neuron/convert.py(233): trace
test_resnet.py(31): export_to_neuron
test_resnet.py(62): <module>
RuntimeError: Inconsistent batch sizes found on inputs. All batch tensors must have the same dim 0 size.
        Input tensor #0 shape: 2 3 224 224
        Input tensor #1 shape: 1 50 768

You can replicate the error using the following script

import torch
import torch_neuron
from transformers import CLIPProcessor, CLIPModel


# Torch wrapper to isolate the image tower
class ClipImageEncoder(torch.nn.Module):
    def __init__(self, model):
        super(ClipImageEncoder, self).__init__()
        self.model = model

    def forward(self, x: torch.Tensor):
        x = self.model.get_image_features(x)
        return x

def export_to_neuron(dummy_input, model):
    if not isinstance(dummy_input, list):
        dummy_input = [dummy_input]

    torch.neuron.analyze_model(model, example_inputs=dummy_input)
    model_neuron = torch.neuron.trace(model, example_inputs=dummy_input, separate_weights=True, dynamic_batch_size=True)

    return model_neuron


if __name__ == "__main__":

    # Load from pretrained clip
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

    # Input information
    img_size = 224
    device = "cpu"

    # create our model to convert
    image_encoder = ClipImageEncoder(model)

    # Lets define inputs. Torch model expect to see torch tensors as input
    dummy_image = torch.randn(1, 3, img_size, img_size).to(device)

    # Converting to neuron the two encoders
    neuron_image_encoder = export_to_neuron(dummy_image, image_encoder)

    dummy_image = torch.randn(1, 3, img_size, img_size).to(device)
    visual_embeddings = neuron_image_encoder(dummy_image)
    print(f"Computing visual embeddings of shape {visual_embeddings.shape[1]} with neuron, batch size {visual_embeddings.shape[0]}")

    dummy_image = torch.randn(2, 3, img_size, img_size).to(device)
    visual_embeddings = neuron_image_encoder(dummy_image)
    print(f"Computing visual embeddings of shape {visual_embeddings.shape[1]} with neuron, batch size {visual_embeddings.shape[0]}")

Is there any version of neuron that support Embedding layers?

@fabiozappo fabiozappo changed the title Embedding layer of ViT not suppoer Embedding layer of ViT not supported with dynamic batch size Apr 4, 2024
@fabiozappo
Copy link
Author

I was able to solve this adding the fallback=False as suggested in #609

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant