Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove transpose before reshape #193

Closed
H4dr1en opened this issue Feb 17, 2023 · 10 comments
Closed

Remove transpose before reshape #193

H4dr1en opened this issue Feb 17, 2023 · 10 comments
Labels
GPU Delegate GPU Delegate third party Third-party tool issues

Comments

@H4dr1en
Copy link

H4dr1en commented Feb 17, 2023

Issue Type

Feature Request

onnx2tf version number

1.6.6

onnx version number

1.13.0

tensorflow version number

2.10.0

Download URL for ONNX

model without transpose: https://drive.google.com/file/d/1cOIdoZJ1qS-pvVrihk1EumT6tQMI24PD/view?usp=sharing

With extra last dim: https://drive.google.com/file/d/1uT9xqnY7iY9ML7OIaaouETuLQXkdMmCV/view?usp=sharing

Parameter Replacement JSON

{}

Description

  1. Product Development. Fixing this bug will allow me to run the model on prod
  2. I am running tflite models on android devices without openCL and the app crashes with the error that the transpose layer cannot be run on GPU, so it falls back to openCL implementation and crashes because this library is not available.
    Long story short, I want to get rid of the single transpose layer introduced by the conversion from onnx to tflite.
Onnx model without transpose Tflite model with transpose
Screenshot 2023-02-17 at 14 59 28 Screenshot 2023-02-17 at 14 59 42

Code for model:

class Model(torch.nn.Module):
    def forward(x):   # (B, C, H, W)
        flat = x.view(1, -1, 64 * 48)  # (B, C, H * W)
        flat = flat - torch.max(flat, dim=-1, keepdim=True).values
        flat = torch.nn.functional.softmax(flat, -1)
        return flat.sum(axis=-1)  # (B, C)
  1. I tried various workarounds to prevent a transpose to be added, but could not succeed so far. It's not clear to my why onnx2tf doesn't do channel last and introduce a transpose layer. Is there a way to overcome that?

I tried to add an extra dummy dimension to fool the converter, but this confuses the converter that does the reshaping incorrectly:

class Model(torch.nn.Module):
    def forward(x):   # (B, C, H, W)
        flat = x[..., None]:   # (B, C, H, W, 1)
        flat = flat.view(1, -1, 64 * 48, 1)
        flat = flat - torch.max(flat, dim=-2, keepdim=True).values
        flat = torch.nn.functional.softmax(flat, -2)
        return flat[..., 0]
Onnx model Tflite model
Screenshot 2023-02-17 at 15 13 49 Screenshot 2023-02-17 at 15 15 49
  1. I need the model to run on these android devices
@PINTO0309
Copy link
Owner

onnx2tf -i model.onnx -kat image

image

@H4dr1en
Copy link
Author

H4dr1en commented Feb 17, 2023

HI @PINTO0309 ,

For simplification I only included the output part of my model, the real model has many convs layers and the output of these conv layers is what I fake in the example above to be the input image. If I use -kat, the whole model will not run with channels last, is it possible to limit this effect to the part shown above?

@PINTO0309
Copy link
Owner

Frankly, I don't understand the nature of the problem you are having because only part of the model has been shared. Even if there is a realistic solution, it is often impossible to answer if there is insufficient information. I am not an esper.

@H4dr1en
Copy link
Author

H4dr1en commented Feb 17, 2023

Yes I understand, I'll update the model to include some convs and ping here

@H4dr1en
Copy link
Author

H4dr1en commented Feb 17, 2023

Here is the updated model: https://drive.google.com/file/d/1_jGqloclOR-hfKToyAUm-3a4czsrxqcx/view?usp=sharing

It is still simplified (There are many more convs than one) but should better reflect the problem: The input is (1, 3, 256, 192) and after all the conv layers, output is (1, 23, 64, 48). At this point, onnx2tf introduces the transpose layer that I need to get rid of, while keeping the channel last dim format on all the conv layers

Onnx TFLITE
Screenshot 2023-02-17 at 15 50 39 Screenshot 2023-02-17 at 15 50 58

@PINTO0309
Copy link
Owner

PINTO0309 commented Feb 17, 2023

Just to be sure, is there no Transpose other than this part? I am referring to the ONNX model.

What is the library you are using, to begin with? A framework that cannot use Transpose is quite critical. For example, Hailo-8.

@H4dr1en
Copy link
Author

H4dr1en commented Feb 17, 2023

Yes, no transpose in the ONNX model other than here

I am using pytorch to develop the models, and on android devices, I am using the org.tensorflow:tensorflow-lite / org.tensorflow:tensorflow-lite-gpu modules to run the inference on the tflite model. It can run transpose, only not on all devices, such as nexus phones. This is because tflite tries to fallback to the openCL implementation of transpose when it's not available otherwise, and openCL is not available on all phones

@PINTO0309
Copy link
Owner

PINTO0309 commented Feb 17, 2023

Your request is too device specific to be addressed by tool side behavior. Reshaping a tensor with NCW, NCHW, or NCDHW geometry will always break all models except yours unless you extrapolate the Transpose to NWC, NHWC, or NDHWC. In other words, if you want to remove Transpose by any means, just separate it into two dimensions or less.

import torch

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        self.conv = torch.nn.Conv2d(3, 23, (5,5), stride=4, padding=2)

    def forward(self, x):   # (B, C, H, W)
        y = self.conv(x)
        zs = torch.split(y, split_size_or_sections=1, dim=1)
        aaa = []
        for i in zs:
            bbb = torch.squeeze(i)
            ccc = bbb.reshape([3072])
            ddd = torch.unsqueeze(ccc, 0)
            eee = torch.unsqueeze(ddd, 0)
            aaa.append(eee)
        fff = torch.cat(aaa, dim=1)
        flat = fff - torch.max(fff, dim=-1, keepdim=True).values
        flat = torch.nn.functional.softmax(flat, -1)
        return flat.sum(axis=-1)

model = Model()

x = torch.randn([1,3,256,192])
onnx_file = f'aaa.onnx'
torch.onnx.export(
    model,
    args=(x),
    f=onnx_file,
    opset_version=11,
    input_names=[
        'input',
    ],
    output_names=[
        'output',
    ],
)
import onnx
from onnxsim import simplify
model_onnx2 = onnx.load(onnx_file)
model_simp, check = simplify(model_onnx2)
onnx.save(model_simp, onnx_file)

image

import torch

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        self.conv = torch.nn.Conv2d(3, 23, (5,5), stride=4, padding=2)

    def forward(self, x):   # (B, C, H, W)
        y = self.conv(x)
        zs = []
        for i in range(y.shape[1]):
            zs.append(y[:,i:i+1,...])
        aaa = []
        for i in zs:
            bbb = torch.squeeze(i)
            ccc = bbb.reshape([3072])
            ddd = torch.unsqueeze(ccc, 0)
            eee = torch.unsqueeze(ddd, 0)
            aaa.append(eee)
        fff = torch.cat(aaa, dim=1)
        flat = fff - torch.max(fff, dim=-1, keepdim=True).values
        flat = torch.nn.functional.softmax(flat, -1)
        return flat.sum(axis=-1)

model = Model()

x = torch.randn([1,3,256,192])
onnx_file = f'aaa.onnx'
torch.onnx.export(
    model,
    args=(x),
    f=onnx_file,
    opset_version=11,
    input_names=[
        'input',
    ],
    output_names=[
        'output',
    ],
)
import onnx
from onnxsim import simplify
model_onnx2 = onnx.load(onnx_file)
model_simp, check = simplify(model_onnx2)
onnx.save(model_simp, onnx_file)

image

@PINTO0309
Copy link
Owner

onnx2tf -i model_2.onnx -nodafc 2

image

https://github.com/PINTO0309/onnx2tf/releases/tag/1.7.1

@PINTO0309
Copy link
Owner

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
GPU Delegate GPU Delegate third party Third-party tool issues
Projects
None yet
Development

No branches or pull requests

2 participants