Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weird bug in RoiAlign #22

Closed
Hyunseok-Kim0 opened this issue Nov 25, 2022 · 2 comments
Closed

Weird bug in RoiAlign #22

Hyunseok-Kim0 opened this issue Nov 25, 2022 · 2 comments
Labels
Bug bug OP:Gather OP:Gather OP:Input OP:Input Parameter replacement Use Parameter replacement

Comments

@Hyunseok-Kim0
Copy link
Collaborator

Hyunseok-Kim0 commented Nov 25, 2022

Issue Type

Others

onnx2tf version number

1.1.33

Download URL for ONNX

Please use code below to reproduce bug.

import torch
from einops import rearrange
from torch import nn
import numpy as np
from onnx2tf import convert
import shutil
import torchvision
import tensorflow as tf
import cv2
import matplotlib.pyplot as plt
import requests
import onnxruntime


class dummy_network(nn.Module):

    def __init__(self, output_size, spatial_scale):
        super(dummy_network, self).__init__()

        self.output_size = output_size
        self.spatial_scale = spatial_scale

    def forward(self, x, roi):
        x = torchvision.ops.roi_align(x, boxes=roi, output_size=self.output_size, spatial_scale=self.spatial_scale)

        return x


def visualize(outputs):
    plt.figure(figsize=(8, 8))

    col = 4
    row = len(outputs) // 4 + 1

    for i, o in enumerate(outputs, start=1):
        plt.subplot(row, col, i)
        plt.axis("off")
        plt.imshow(o.astype(int))

    plt.show()


def main():
    model = dummy_network(output_size=(64, 64), spatial_scale=1.0)
    model.eval()

    dummy_name = "dummy_roi_align"
    onnx_save_path = f"tflite/{dummy_name}.onnx"
    temp_tflite = "tflite/model_float32.tflite"
    tflite_save_path = f"tflite/{dummy_name}.tflite"

    url = "https://static01.nyt.com/images/2021/09/14/science/07CAT-STRIPES/07CAT-STRIPES-mediumSquareAt3X-v2.jpg"
    image_nparray = np.asarray(bytearray(requests.get(url).content), dtype=np.uint8)
    image = cv2.imdecode(image_nparray, cv2.IMREAD_COLOR)

    dummy_input_x = np.expand_dims(cv2.resize(cv2.cvtColor(image, cv2.COLOR_BGR2RGB), (256, 256)), axis=0)
    dummy_input_x = rearrange(torch.Tensor(dummy_input_x), "n h w c -> n c h w")
    # dummy_input_roi = [[0, 0, 0, 0.5, 0.5], [0, 0.5, 0, 1, 0.5], [0, 0, 0.5, 0.5, 1], [0, 0.5, 0.5, 1, 1]]
    dummy_input_roi = [[0, 0, 0, 128, 128], [0, 128, 0, 256, 128], [0, 0, 128, 128, 256], [0, 128, 128, 256, 256]]
    # dummy_input_roi = [[0, 0, 1, 2, 3]]
    # dummy_input_roi = [[0, 0, 0, 256, 256]]
    dummy_input_roi = torch.Tensor(dummy_input_roi)

    torch.onnx.export(model,
                      args=(dummy_input_x, dummy_input_roi),
                      f=onnx_save_path,
                      input_names=["x", "roi"],
                      opset_version=11)

    convert(onnx_save_path, output_folder_path="tflite")
    shutil.move(temp_tflite, tflite_save_path)


    # get torch output
    # -----------------------------------------------------------------------------------------------
    with torch.no_grad():
        torch_output = model(dummy_input_x, dummy_input_roi)
        torch_output = rearrange(torch_output, "n c h w -> n h w c")

    # visualize(torch_output)

    # get onnx output
    # -----------------------------------------------------------------------------------------------
    onnx_session = onnxruntime.InferenceSession(onnx_save_path, providers=['CPUExecutionProvider'])
    onnx_inputs = dict(x=dummy_input_x.numpy(), roi=dummy_input_roi.numpy())
    onnx_output = onnx_session.run(None, onnx_inputs)[0]
    onnx_output = rearrange(onnx_output, "n c h w -> n h w c")

    # compare torch output and onnx output
    np.testing.assert_allclose(onnx_output, torch_output.numpy())

    # get tflite output
    # -----------------------------------------------------------------------------------------------
    tflite_onnx2tf = tf.lite.Interpreter(model_path=tflite_save_path)
    tflite_onnx2tf.allocate_tensors()
    tflite_onnx2tf.set_tensor(tflite_onnx2tf.get_input_details()[0]['index'],
                              rearrange(dummy_input_x.numpy(), "n c h w -> n h w c"))
    tflite_onnx2tf.set_tensor(tflite_onnx2tf.get_input_details()[1]['index'], dummy_input_roi.numpy())
    tflite_onnx2tf.invoke()
    tflite_output = [tflite_onnx2tf.get_tensor(i['index']) for i in tflite_onnx2tf.get_output_details()][0]

    total_output = np.concatenate([torch_output, onnx_output, tflite_output], axis=0)
    visualize(total_output)

    print("convert done")

    return


if __name__ == "__main__":
    main()

Parameter Replacement JSON

N/A

Description

  1. Purpose: Personal development

  2. What: When I tested RoiAlign, it showed unexpected output as below.

pytorch output
onnx output
wrong tflite output
correct tflite output after modification
  1. How: After some debugging, I found that wrong order of roi coordinates are fed into tflite. After changing line 124 to x0, x1, y1, y0 = tf.split(boxes, 4, axis=1), I could get correct result as shown above.
    def transform_fpcoor_for_tf(
    *,
    boxes,
    image_shape,
    crop_size,
    sampling_ratio,
    adaptive_ratio,
    ):
    x0, y0, x1, y1 = tf.split(boxes, 4, axis=1)
    if not adaptive_ratio:
    crop_shape = (

    Although the output value is slightly different as below due to the implementation detail in tensorflow, it looks working fine. However, I couldn't find out why order of the roi coordinates is changed from x0, y0, x1, y1 to x0, x1, y1, y0.

@PINTO0309 PINTO0309 added the Bug bug label Nov 25, 2022
@PINTO0309
Copy link
Owner

I have also tried tracing, but so far I can't figure out why the coordinate values are inverted.

@PINTO0309
Copy link
Owner

PINTO0309 commented Nov 26, 2022

The operation of transposing indices in Gather from NCHW to NHWC was applied and the dimensions were shuffled. Therefore, I have added the ability to disable Gather's indices transposition operation.

Also, this problem only occurred when there was an operation with a shape change immediately after the input OP. Therefore, I have added a process to determine if the input OP has been transposed and if a subsequent operation requires re-transposition. #26

https://github.com/PINTO0309/onnx2tf/releases/tag/1.1.35
Added processing to determine if INPUT OP has been transposed. #26
https://github.com/PINTO0309/onnx2tf/releases/tag/1.1.36

onnx2tf -i dummy_roi_align.onnx
  • Before
    image

  • After
    image
    image

@PINTO0309 PINTO0309 added Parameter replacement Use Parameter replacement OP:Gather OP:Gather and removed Bug bug labels Nov 26, 2022
PINTO0309 added a commit that referenced this issue Nov 26, 2022
Added processing to determine if INPUT OP has been transposed. #22
@PINTO0309 PINTO0309 added Bug bug OP:Input OP:Input labels Nov 26, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Bug bug OP:Gather OP:Gather OP:Input OP:Input Parameter replacement Use Parameter replacement
Projects
None yet
Development

No branches or pull requests

2 participants