Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

aten::index nodes take multiple indices in PyTorch model but cause an error when trying to convert to TFLite #282

Closed
MariosHavWaller opened this issue Mar 22, 2024 · 1 comment · Fixed by #286
Labels
enhancement New feature or request work/small work that can be done within 6 hour

Comments

@MariosHavWaller
Copy link

MariosHavWaller commented Mar 22, 2024

I'm currently using TinyNN to convert a PyTorch model to TFLite format. Currently this model is able to get converted to ONNX completely fine, but when trying to convert it to TFLite I run into the error:

assert len(filtered_dims) == 1, "Multiple indices for aten::index is not supported"
           ^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: Multiple indices for aten::index is not supported

This is linked, I'm assuming, to 2 nodes of type aten::index.Tensor in the machine learning model (Currently working with MobileSAMv2) where, in PyTorch format, take three inputs and output a single value. Of the three inputs, two of them are indices. I'm currently unsure whether this is caused by an issue with the TinyNN converter or whether there is something I myself am missing. Is this something TinyNN could support or is there a workaround?

This is the conversion code that's been written so far:

model = SamOnnxModel(
    model=sam,
    return_single_mask=return_single_mask,
)

embed_dim = sam.prompt_encoder.embed_dim
embed_size = sam.prompt_encoder.image_embedding_size
mask_input_size = [4 * x for x in embed_size]

image_embeddings = torch.randn(1, embed_dim, *embed_size, dtype=torch.float32)
point_coords = torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.int32)
point_labels = torch.randint(low=0, high=4, size=(1, 5), dtype=torch.int32)
mask_input = torch.randn(1, 1, *mask_input_size, dtype=torch.float32)
has_mask_input = torch.tensor([1], dtype=torch.float32)
orig_im_size = torch.tensor([1500, 2250], dtype=torch.float32)

dummy_input_actual = (image_embeddings, point_coords, point_labels, mask_input, has_mask_input, orig_im_size)

name = 'mobileSAM.tflite'
    converter = TFLiteConverter(model, dummy_input_actual, name)
    converter.convert()
@peterjc123 peterjc123 added enhancement New feature or request work/small work that can be done within 6 hour labels Mar 23, 2024
@peterjc123
Copy link
Collaborator

This can be supported by TinyNN. Stay tuned.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request work/small work that can be done within 6 hour
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants