-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove transpose before reshape #193
Comments
HI @PINTO0309 , For simplification I only included the output part of my model, the real model has many convs layers and the output of these conv layers is what I fake in the example above to be the input image. If I use |
Frankly, I don't understand the nature of the problem you are having because only part of the model has been shared. Even if there is a realistic solution, it is often impossible to answer if there is insufficient information. I am not an esper. |
Yes I understand, I'll update the model to include some convs and ping here |
Here is the updated model: https://drive.google.com/file/d/1_jGqloclOR-hfKToyAUm-3a4czsrxqcx/view?usp=sharing It is still simplified (There are many more convs than one) but should better reflect the problem: The input is (1, 3, 256, 192) and after all the conv layers, output is (1, 23, 64, 48). At this point, onnx2tf introduces the transpose layer that I need to get rid of, while keeping the channel last dim format on all the conv layers
|
Just to be sure, is there no What is the library you are using, to begin with? A framework that cannot use Transpose is quite critical. For example, Hailo-8. |
Yes, no transpose in the ONNX model other than here I am using pytorch to develop the models, and on android devices, I am using the |
Your request is too device specific to be addressed by tool side behavior. Reshaping a tensor with NCW, NCHW, or NCDHW geometry will always break all models except yours unless you extrapolate the Transpose to NWC, NHWC, or NDHWC. In other words, if you want to remove Transpose by any means, just separate it into two dimensions or less. import torch
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv = torch.nn.Conv2d(3, 23, (5,5), stride=4, padding=2)
def forward(self, x): # (B, C, H, W)
y = self.conv(x)
zs = torch.split(y, split_size_or_sections=1, dim=1)
aaa = []
for i in zs:
bbb = torch.squeeze(i)
ccc = bbb.reshape([3072])
ddd = torch.unsqueeze(ccc, 0)
eee = torch.unsqueeze(ddd, 0)
aaa.append(eee)
fff = torch.cat(aaa, dim=1)
flat = fff - torch.max(fff, dim=-1, keepdim=True).values
flat = torch.nn.functional.softmax(flat, -1)
return flat.sum(axis=-1)
model = Model()
x = torch.randn([1,3,256,192])
onnx_file = f'aaa.onnx'
torch.onnx.export(
model,
args=(x),
f=onnx_file,
opset_version=11,
input_names=[
'input',
],
output_names=[
'output',
],
)
import onnx
from onnxsim import simplify
model_onnx2 = onnx.load(onnx_file)
model_simp, check = simplify(model_onnx2)
onnx.save(model_simp, onnx_file) import torch
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv = torch.nn.Conv2d(3, 23, (5,5), stride=4, padding=2)
def forward(self, x): # (B, C, H, W)
y = self.conv(x)
zs = []
for i in range(y.shape[1]):
zs.append(y[:,i:i+1,...])
aaa = []
for i in zs:
bbb = torch.squeeze(i)
ccc = bbb.reshape([3072])
ddd = torch.unsqueeze(ccc, 0)
eee = torch.unsqueeze(ddd, 0)
aaa.append(eee)
fff = torch.cat(aaa, dim=1)
flat = fff - torch.max(fff, dim=-1, keepdim=True).values
flat = torch.nn.functional.softmax(flat, -1)
return flat.sum(axis=-1)
model = Model()
x = torch.randn([1,3,256,192])
onnx_file = f'aaa.onnx'
torch.onnx.export(
model,
args=(x),
f=onnx_file,
opset_version=11,
input_names=[
'input',
],
output_names=[
'output',
],
)
import onnx
from onnxsim import simplify
model_onnx2 = onnx.load(onnx_file)
model_simp, check = simplify(model_onnx2)
onnx.save(model_simp, onnx_file) |
|
Issue Type
Feature Request
onnx2tf version number
1.6.6
onnx version number
1.13.0
tensorflow version number
2.10.0
Download URL for ONNX
model without transpose: https://drive.google.com/file/d/1cOIdoZJ1qS-pvVrihk1EumT6tQMI24PD/view?usp=sharing
With extra last dim: https://drive.google.com/file/d/1uT9xqnY7iY9ML7OIaaouETuLQXkdMmCV/view?usp=sharing
Parameter Replacement JSON
Description
Long story short, I want to get rid of the single transpose layer introduced by the conversion from onnx to tflite.
Code for model:
I tried to add an extra dummy dimension to fool the converter, but this confuses the converter that does the reshaping incorrectly:
The text was updated successfully, but these errors were encountered: