Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot use converted model with dynamic input shape #521

Closed
sklum opened this issue Sep 28, 2023 · 6 comments
Closed

Cannot use converted model with dynamic input shape #521

sklum opened this issue Sep 28, 2023 · 6 comments
Labels
Dynamic batch / Dynamic shape Dynamic batch / Dynamic shape OP:Concat OP:Concat OP:Gather OP:Gather OP:Resize OP:Resize OP:Unsqueeze OP:Unsqueeze Parameter replacement Use Parameter replacement TFJS TFJS

Comments

@sklum
Copy link

sklum commented Sep 28, 2023

Issue Type

Others

OS

Linux

onnx2tf version number

1.18.1

onnx version number

1.14.1

onnxruntime version number

1.16.0

onnxsim (onnx_simplifier) version number

0.4.31

tensorflow version number

2.14.0

Download URL for ONNX

https://drive.google.com/file/d/1XIRHjWYzWHwsZXOgcT4RLOJ6kxXE1BVT/view?usp=share_link

Parameter Replacement JSON

Unclear which parameters need replacement.

Description

Hi @PINTO0309, thanks for the great tool and all of your hard work.

I'm trying to convert a custom model from onnx to tensorflow to tfjs with a dynamic input shape and am having problems.

As an example, take the mobilenet0.25_Final.pth model from https://github.com/biubug6/Pytorch_Retinaface.

I'm converting from pytorch to onnx using the following (where the RetinaFace definition comes from here):

import torch
import onnx
import onnxruntime
import numpy as np
from model import RetinaFace

torch_model = RetinaFace(
    {
        "name": "mobilenet0.25",
        "min_sizes": [[16, 32], [64, 128], [256, 512]],
        "steps": [8, 16, 32],
        "variance": [0.1, 0.2],
        "clip": False,
        "loc_weight": 2.0,
        "gpu_train": True,
        "batch_size": 32,
        "ngpu": 1,
        "epoch": 250,
        "decay1": 190,
        "decay2": 220,
        "image_size": 640,
        "pretrain": False,
        "return_layers": {"stage1": 1, "stage2": 2, "stage3": 3},
        "in_channel": 32,
        "out_channel": 64,
    },
    "test",
)

torch_model.load_state_dict(
    torch.load(
        "./mobilenet0.25_Final.pth", map_location=torch.device("cpu")
    )
)

torch_model.eval()

output_file = "./retinaface.onnx"
example_input = torch.randn(1, 3, 448, 448, requires_grad=False)

torch_out = torch_model(example_input)

torch.onnx.export(
    torch_model,
    example_input,
    output_file,
    input_names=["input_1"],
    opset_version=18,
    dynamic_axes={"input_1": [2, 3]}
)

I check that the onnx model works on the python side with the following:

def onnx_forward(onnx_file, example_input):

    sess_options = onnxruntime.SessionOptions()
    session = onnxruntime.InferenceSession(onnx_file, sess_options, providers=['AzureExecutionProvider', 'CPUExecutionProvider'])
    input_name = session.get_inputs()[0].name
    output = session.run([], {input_name: example_input.numpy()})
    output = output[0]
    return output

example_dynamic_input = torch.randn(1, 3, 480, 640, requires_grad=False)
torch_dynamic_out = torch_model(example_dynamic_input)

onnx_model = onnx.load(output_file)
onnx.checker.check_model(onnx_model, full_check=True)

onnx_out = onnx_forward(output_file, example_input)
np.testing.assert_almost_equal(torch_out[0].data.numpy(), onnx_out, decimal=3)

onnx_dynamic_out = onnx_forward(output_file, example_dynamic_input)
np.testing.assert_almost_equal(torch_dynamic_out[0].data.numpy(), onnx_dynamic_out, decimal=3)

I then convert the onnx model to tf using:

onnx2tf -i retinaface.onnx -osd -o retinaface_tf -cotof

When running the above, the following is the first shape issue I have:

INFO: onnx_output_name: wa/fpn/Shape_3_output_0 tf_output_name: tf.compat.v1.shape/wa/fpn/Shape_3:0 shape: (4,) dtype: int64 validate_result: Skipped (Deleted or Shape Unmatched)

Beyond a number of errors / warnings like these, the model converts successfully, but when using it in my tfjs-based system (after converting with tensorflowjs_converter) I get the following shape mismatch at inference time:

Invalid TF_Status: 3
Message: Incompatible shapes: [1,28,28,64] vs. [1,64,28,64]

If remove the dynamic_axes, things work fine at the fixed input size. There are a number of layers with the [1,28,28,64] shape, so it's been challenging to track down which is the problematic layer.

FWIW, I've also tried wit the -kat and -nuo options.

This error doesn't happen during this workflow, but does the Alternatively, if the input OP has a dynamic dimension, use the -bor-ois option to rewrite it to a static shape and try again. error message appearing in other places mean that dynamic shapes are not supported? Based on your recent commits I assume that they are indeed supported in some way.

Is there simply a need for a parameter replacement in my case or am I hitting an edge case in dynamic inputs somehow? Any guidance would be appreciated. Please let me know if more information / models would be helpful - I will provide what I can.

@PINTO0309 PINTO0309 added Dynamic batch / Dynamic shape Dynamic batch / Dynamic shape Parameter replacement Use Parameter replacement labels Sep 29, 2023
@PINTO0309
Copy link
Owner

PINTO0309 commented Sep 29, 2023

The price of onnx2tf's considerably higher model optimization efficiency compared to onnx-tensorflow is that the optimization operation may fail if there are two or more undefined dimensions in the input tensor. If there is a series of tensors with axis size None, the correct axis position will be lost in the process of model transformation.

image

Although the conversion of models with multiple undefined dimensions is originally supported, the probability of model conversion failure is higher, and the user must compensate for conversion errors. JSON files can be used to compensate for the axis transposition behavior of onnx2tf.

In the case of RetinaFace, there was an error in the axis correction of Gather and Concat. When the correction was instructed in JSON, the conversion was successful, and the inference operation could be performed without any problems and with variable axes.

If you do not understand what I am saying, I would not recommend dealing with a model that has a high conversion difficulty involving undefined dimensions.

When running the above, the following is the first shape issue I have:

INFO: onnx_output_name: wa/fpn/Shape_3_output_0 tf_output_name: tf.compat.v1.shape/wa/fpn/Shape_3:0 shape: (4,) dtype: int64 validate_result: Skipped (Deleted or Shape Unmatched)

Skipped (Deleted or Shape Unmatched) appears in all 1D OP output, so most can safely be ignored. If it appears in more than two dimensions of OP output, it suggests that the OP transposition operation has failed somewhere prior to that OP. In the case of your RetinaFace, I was getting a lot of these warning messages for all outputs above 2 dimensions immediately after Gather and Concat. The Gather and Concat operations are used to derive the tensor size for the Resize immediately following it. Thus, if onnx2tf misunderstands the Gather axis and the Concat axis, as you have posted this time, an inconsistent tensor will be generated by onnx2tf, resulting in an infeasible model. Wrong: [1,64,28,64]

image

https://github.com/PINTO0309/onnx2tf/releases/tag/1.18.2

pip install onnx2tf -U

wget https://github.com/PINTO0309/onnx2tf/releases/download/1.16.31/flatc.tar.gz \
  && tar -zxvf flatc.tar.gz \
  && sudo chmod +x flatc \
  && sudo mv flatc /usr/bin/
  • replace_retinaface_dynamic.json

    {
      "format_version": 1,
      "operations": [
        {
          "op_name": "/fpn/Gather",
          "param_target": "inputs",
          "param_name": "/fpn/Constant_output_0",
          "values": 1
        },
        {
          "op_name": "/fpn/Gather_1",
          "param_target": "inputs",
          "param_name": "/fpn/Constant_1_output_0",
          "values": 2
        },
        {
          "op_name": "/fpn/Concat_1",
          "param_target": "outputs",
          "param_name": "/fpn/Concat_1_output_0",
          "post_process_transpose_perm": [0,2,3,1]
        },
    
    
        {
          "op_name": "/fpn/Gather_2",
          "param_target": "inputs",
          "param_name": "/fpn/Constant_output_0",
          "values": 1
        },
        {
          "op_name": "/fpn/Gather_3",
          "param_target": "inputs",
          "param_name": "/fpn/Constant_1_output_0",
          "values": 2
        },
        {
          "op_name": "/fpn/Concat_3",
          "param_target": "outputs",
          "param_name": "/fpn/Concat_3_output_0",
          "post_process_transpose_perm": [0,2,3,1]
        }
      ]
    }
  • convert

    onnx2tf \
    -i retinaface_onnx_dynamic.onnx \
    -prf replace_retinaface_dynamic.json \
    -osd \
    -coion \
    -cotof

    image

    tensorflowjs_converter \
    --input_format tf_saved_model \
    --output_format tfjs_graph_model \
    saved_model \
    tfjs_model
  • converted models

    1. saved_model.zip
    2. tfjs_model.zip

    image

    saved_model_cli show --dir saved_model/ --all
    
    MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
    
    signature_def['__saved_model_init_op']:
      The given SavedModel SignatureDef contains the following input(s):
      The given SavedModel SignatureDef contains the following output(s):
        outputs['__saved_model_init_op'] tensor_info:
            dtype: DT_INVALID
            shape: unknown_rank
            name: NoOp
      Method name is: 
    
    signature_def['serving_default']:
      The given SavedModel SignatureDef contains the following input(s):
        inputs['input_1'] tensor_info:
            dtype: DT_FLOAT
            shape: (1, -1, -1, 3)
            name: serving_default_input_1:0
      The given SavedModel SignatureDef contains the following output(s):
        outputs['525'] tensor_info:
            dtype: DT_FLOAT
            shape: (1, -1, 4)
            name: PartitionedCall:0
        outputs['606'] tensor_info:
            dtype: DT_FLOAT
            shape: (1, -1, 10)
            name: PartitionedCall:1
        outputs['607'] tensor_info:
            dtype: DT_FLOAT
            shape: (1, -1, 2)
            name: PartitionedCall:2
      Method name is: tensorflow/serving/predict
    
  • tflite tests - Assumption that tflite and tfjs behave the same unless there is a bug in the tfjs converter

    import numpy as np
    import tensorflow as tf
    from pprint import pprint
    
    interpreter = tf.lite.Interpreter(model_path="retinaface_onnx_dynamic_float32.tflite")
    tf_lite_model = interpreter.get_signature_runner()
    inputs = {
        'input_1': np.ones([1,480,640,3], dtype=np.float32),
    }
    tf_lite_output = tf_lite_model(**inputs)
    print(f"[TFLite] Model Predictions shape: {tf_lite_output['525'].shape}")
    print(f"[TFLite] Model Predictions shape: {tf_lite_output['606'].shape}")
    print(f"[TFLite] Model Predictions shape: {tf_lite_output['607'].shape}")
    print(f"[TFLite] Model Predictions:")
    pprint(tf_lite_output)
    [TFLite] Model Predictions shape: (1, 12600, 4)
    [TFLite] Model Predictions shape: (1, 12600, 10)
    [TFLite] Model Predictions shape: (1, 12600, 2)
    [TFLite] Model Predictions:
    {'525': array([[[ 0.5037499 , -0.53452146,  0.28746593, -0.1060854 ],
            [ 0.0795919 ,  0.11518952,  0.2050548 ,  0.18294893],
            [ 0.14049551,  0.08239099,  0.09347238,  0.07261422],
            ...,
            [ 0.1204751 ,  0.4000509 , -0.9219683 , -1.0032947 ],
            [ 0.6076698 ,  1.0467663 , -1.3310252 , -0.77822804],
            [ 0.93253464,  0.9066995 , -1.2800018 , -0.8087649 ]]],
          dtype=float32),
     '606': array([[[-0.779792  , -1.3634459 , -0.9822485 , ..., -1.5434346 ,
             -1.5022398 , -1.5166504 ],
            [-1.4359287 , -1.416552  , -1.53112   , ..., -1.5689636 ,
             -1.5718539 , -1.5501782 ],
            [-1.5408219 , -1.5390366 , -1.5441248 , ..., -1.5436357 ,
             -1.5436357 , -1.5436357 ],
            ...,
            [-0.22368906, -0.33528268,  0.43932757, ...,  0.62453854,
              0.64479786,  0.5371737 ],
            [ 0.25273618, -0.30643156,  1.4769835 , ...,  1.2033877 ,
              0.99543476,  1.2796272 ],
            [ 0.49980396, -0.35802573,  1.6034117 , ...,  1.1993165 ,
              1.3116333 ,  1.3412124 ]]], dtype=float32),
     '607': array([[[4.9240157e-01, 5.0759840e-01],
            [4.8854157e-01, 5.1145840e-01],
            [5.3600717e-01, 4.6399277e-01],
            ...,
            [9.9783522e-01, 2.1648051e-03],
            [9.9984205e-01, 1.5789027e-04],
            [9.9987757e-01, 1.2243164e-04]]], dtype=float32)}
    
    import numpy as np
    import tensorflow as tf
    from pprint import pprint
    
    interpreter = tf.lite.Interpreter(model_path="retinaface_onnx_dynamic_float32.tflite")
    tf_lite_model = interpreter.get_signature_runner()
    inputs = {
        'input_1': np.ones([1,192,320,3], dtype=np.float32),
    }
    tf_lite_output = tf_lite_model(**inputs)
    print(f"[TFLite] Model Predictions shape: {tf_lite_output['525'].shape}")
    print(f"[TFLite] Model Predictions shape: {tf_lite_output['606'].shape}")
    print(f"[TFLite] Model Predictions shape: {tf_lite_output['607'].shape}")
    print(f"[TFLite] Model Predictions:")
    pprint(tf_lite_output)
    [TFLite] Model Predictions shape: (1, 2520, 4)
    [TFLite] Model Predictions shape: (1, 2520, 10)
    [TFLite] Model Predictions shape: (1, 2520, 2)
    [TFLite] Model Predictions:
    {'525': array([[[ 0.5037502 , -0.534521  ,  0.287466  , -0.1060854 ],
            [ 0.07959143,  0.11518921,  0.20505464,  0.18294904],
            [ 0.14049558,  0.08239092,  0.09347239,  0.07261468],
            ...,
            [ 0.06551935,  0.4618321 , -0.94103575, -1.0230165 ],
            [ 0.49285072,  0.9942304 , -1.325173  , -0.7776669 ],
            [ 0.8382109 ,  0.94311976, -1.3364108 , -0.869154  ]]],
          dtype=float32),
     '606': array([[[-0.7797919 , -1.3634459 , -0.9822487 , ..., -1.5434344 ,
             -1.5022393 , -1.5166501 ],
            [-1.4359276 , -1.4165508 , -1.5311191 , ..., -1.5689638 ,
             -1.5718228 , -1.5446154 ],
            [-1.5317626 , -1.5467222 , -1.5688723 , ..., -1.330126  ,
             -1.4217492 , -1.3949665 ],
            ...,
            [-0.254573  , -0.24012746,  0.38229224, ...,  0.6902082 ,
              0.59906906,  0.5935865 ],
            [ 0.20296118, -0.2921868 ,  1.4164804 , ...,  1.1968634 ,
              0.98056495,  1.2392025 ],
            [ 0.4711579 , -0.26603216,  1.5369333 , ...,  1.2316816 ,
              1.2790531 ,  1.3474693 ]]], dtype=float32),
     '607': array([[[4.9240148e-01, 5.0759852e-01],
            [4.8854169e-01, 5.1145828e-01],
            [5.3600734e-01, 4.6399271e-01],
            ...,
            [9.9757117e-01, 2.4287710e-03],
            [9.9981421e-01, 1.8576751e-04],
            [9.9986351e-01, 1.3642981e-04]]], dtype=float32)}
    

@PINTO0309
Copy link
Owner

PINTO0309 commented Sep 30, 2023

Redundant ONNX output from PyTorch was improved by performing a proprietary optimization to eliminate the need for JSON creation.

https://github.com/PINTO0309/onnx2tf/releases/tag/1.18.3

  • Improved stability of the overall model transformation when the input tensor contains two or more undefined dimensions.
  • Resize
    • Added optimization process for Resize with undefined dimensions.
    • Valid only if the input tensor is fixed in NHWC.
    • Axis transposition correction by JSON in this pattern is no longer necessary.
    • https://github.com/PINTO0309/onnx2tf/files/12758261/retinaface_onnx_dynamic.onnx.zip
      image
    • convert - For your use, you do not need to specify -coion.
      onnx2tf \
      -i retinaface_onnx_dynamic.onnx \
      -osd \
      -coion \
      -cotof
      onnx Before tflite After tflite
      image image image
    image

@PINTO0309
Copy link
Owner

Good luck.

@sklum
Copy link
Author

sklum commented Sep 30, 2023

Thanks for all your help @PINTO0309. Everything here makes sense, but I wasn't able to get to this work today. I'll review in detail on Monday and let you know if I run into any additional issues with the updates.

@sklum
Copy link
Author

sklum commented Oct 4, 2023

Alright @PINTO0309, we're making great progress here. I can now run the dynamic input model without any shape issues on the tfjs side. However, I am having issues with the "correctness" of the output that I'm hoping you can help me with.

Circling back to the pytorch -> onnx conversion code above, I've changed the process to be as follows:

img = np.float32(np.ones([375,448,3]))
img -= (104, 117, 123)
img = img.transpose(2, 0, 1)
img = np.expand_dims(img, 0)
example_input = torch.from_numpy(img)
torch_out = torch_model(example_input)

This matches the general functionality in detect.py from here.

I have updated the onnx_forward function to return all the output tensors instead of the first. Then, as before, I check the model after conversion with:

onnx_model = onnx.load(output_file)
onnx.checker.check_model(onnx_model, full_check=True)

onnx_out = onnx_forward(output_file, example_input)
np.testing.assert_almost_equal(torch_out[1].data.numpy(), onnx_out[1], decimal=3)

This passes successfully. Printing the confidence outputs of these models (torch_out[1] and onnx_out[1]) at this point we get a (1, 6944, 2) tensor from both, where the values in the final dim of the tensor are [~1, ~0] across all 6944 elements. This makes sense as this model is trained and so it shouldn't be detecting anything in an input of ones.

Now, setting aside tfjs, if we run the tflite code you provided using the same input as above with a NHWC input instead of NCHW because of onnx2tf:

import numpy as np
import tensorflow as tf
from pprint import pprint
interpreter = tf.lite.Interpreter(model_path="retinaface_onnx_dynamic_float32.tflite")
tf_lite_model = interpreter.get_signature_runner()

img = np.float32(np.ones([375,448,3]))
img -= (104, 117, 123)
img = np.expand_dims(img, 0)

inputs = {
    'input_1': img,
}

tf_lite_output = tf_lite_model(**inputs)
print(f"[TFLite] Model Predictions shape: {tf_lite_output['525'].shape}")
print(f"[TFLite] Model Predictions shape: {tf_lite_output['606'].shape}")
print(f"[TFLite] Model Predictions shape: {tf_lite_output['607'].shape}")
print(f"[TFLite] Model Predictions:")
pprint(tf_lite_output)

The tf_lite_output for the confidences is:

 '607': array([[[4.6728298e-01, 5.3271705e-01],
        [4.7566253e-01, 5.2433747e-01],
        [5.1309991e-01, 4.8690012e-01],
        ...,
        [9.9740821e-01, 2.5917361e-03],
        [9.9948138e-01, 5.1866425e-04],
        [9.9974173e-01, 2.5822222e-04]]], dtype=float32)

Note that multiple elements that are [~.5, ~.5]. This is basically the same as the output you were showing in your comment above (and this aligns with what I'm getting on the tfjs side).

So, am I doing something wrong with the input shapes or transpositions here? What am I doing wrong such that the -cotof option isn't catching this? I assume something is going wrong with the expectation of transposition in the model, but it's not clear to me at the moment.

@PINTO0309
Copy link
Owner

PINTO0309 commented Oct 7, 2023

I was late checking the issue because I was training other models.

Models containing more than one None have a non-zero chance of making a transposition error. The checks performed in -cotof force a comparison between the NCHW tensor and the NHWC tensor.

Forced meaning compares, for example, NCHW: [1,9,128,9] with NHWC: [1,128,9,9]. To begin with, it is necessary to compare tensor values based on the assumption that the tensor shapes of ONNX and PyTorch are completely different from those of TensorFlow, so a brute force check is used to replace all combinations of each axis to find the arrangement with the smallest error before calculating the error.

Thus, if you are unlucky enough to have a model with a structure like [1,9,9,9,9] in the middle of the model, the check itself will succeed correctly, but the axis of the model transformation itself may still be wrong.

Structural checking of a model with multiple None is quite difficult even with the human eye, but for when such a situation arises, we have a function to check where we have made a mistake in transforming the model.

The -onimc option stops the conversion halfway through the model and outputs the model converted halfway through. It is a bit tedious work, but you need to try several transformations up to the midpoint of the model and run an inference test each time to see what part of the model you are mis-transposing.

For example,

# Split the model at the middle position for debugging
# Specify the output name of the OP
$ wget https://github.com/PINTO0309/onnx2tf/releases/download/0.0.2/resnet18-v1-7.onnx
$ onnx2tf -i resnet18-v1-7.onnx -onimc resnetv15_stage2_conv1_fwd resnetv15_stage2_conv2_fwd

Once you know where you have made a transposition error on an axis, you can use JSON to correct the transposition error.

For example,

https://github.com/PINTO0309/onnx2tf#parameter-replacement

https://github.com/PINTO0309/onnx2tf/blob/main/replace.json

# Parameter replacement (Resize,Transpose,Softmax)
$ rm replace.json
$ wget https://github.com/PINTO0309/onnx2tf/releases/download/1.1.27/human_segmentation_pphumanseg_2021oct.onnx
$ wget https://github.com/PINTO0309/onnx2tf/releases/download/1.1.27/replace.json
$ onnx2tf -i human_segmentation_pphumanseg_2021oct.onnx -prf replace.json

The reason why the -cotof check is likely to be Matches even though the axis is in the wrong position is OP, which does not rewrite the value itself, such as Gather. (This is a possibility, not a definitive list of problem areas for RetinaFace.)

It would be hard to blindly examine the wrong areas, so if I were you, I would venture to generate a fixed-resolution RetinaFace model and compare its structure with the model with None. This will make it easier to understand to some extent where Transpose is lacking, or conversely, where useless Transpose is extrapolated.

If I get enough time during the holidays I will check out the model too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Dynamic batch / Dynamic shape Dynamic batch / Dynamic shape OP:Concat OP:Concat OP:Gather OP:Gather OP:Resize OP:Resize OP:Unsqueeze OP:Unsqueeze Parameter replacement Use Parameter replacement TFJS TFJS
Projects
None yet
Development

No branches or pull requests

2 participants