You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
importtorchfromeinopsimportrearrangefromtorchimportnnimportnumpyasnpfromonnx2tfimportconvertimportshutilimporttorchvisionimporttensorflowastfimportcv2importmatplotlib.pyplotaspltimportrequestsimportonnxruntimeclassdummy_network(nn.Module):
def__init__(self, output_size, spatial_scale):
super(dummy_network, self).__init__()
self.output_size=output_sizeself.spatial_scale=spatial_scaledefforward(self, x, roi):
x=torchvision.ops.roi_align(x, boxes=roi, output_size=self.output_size, spatial_scale=self.spatial_scale)
returnxdefvisualize(outputs):
plt.figure(figsize=(8, 8))
col=4row=len(outputs) //4+1fori, oinenumerate(outputs, start=1):
plt.subplot(row, col, i)
plt.axis("off")
plt.imshow(o.astype(int))
plt.show()
defmain():
model=dummy_network(output_size=(64, 64), spatial_scale=1.0)
model.eval()
dummy_name="dummy_roi_align"onnx_save_path=f"tflite/{dummy_name}.onnx"temp_tflite="tflite/model_float32.tflite"tflite_save_path=f"tflite/{dummy_name}.tflite"url="https://static01.nyt.com/images/2021/09/14/science/07CAT-STRIPES/07CAT-STRIPES-mediumSquareAt3X-v2.jpg"image_nparray=np.asarray(bytearray(requests.get(url).content), dtype=np.uint8)
image=cv2.imdecode(image_nparray, cv2.IMREAD_COLOR)
dummy_input_x=np.expand_dims(cv2.resize(cv2.cvtColor(image, cv2.COLOR_BGR2RGB), (256, 256)), axis=0)
dummy_input_x=rearrange(torch.Tensor(dummy_input_x), "n h w c -> n c h w")
# dummy_input_roi = [[0, 0, 0, 0.5, 0.5], [0, 0.5, 0, 1, 0.5], [0, 0, 0.5, 0.5, 1], [0, 0.5, 0.5, 1, 1]]dummy_input_roi= [[0, 0, 0, 128, 128], [0, 128, 0, 256, 128], [0, 0, 128, 128, 256], [0, 128, 128, 256, 256]]
# dummy_input_roi = [[0, 0, 1, 2, 3]]# dummy_input_roi = [[0, 0, 0, 256, 256]]dummy_input_roi=torch.Tensor(dummy_input_roi)
torch.onnx.export(model,
args=(dummy_input_x, dummy_input_roi),
f=onnx_save_path,
input_names=["x", "roi"],
opset_version=11)
convert(onnx_save_path, output_folder_path="tflite")
shutil.move(temp_tflite, tflite_save_path)
# get torch output# -----------------------------------------------------------------------------------------------withtorch.no_grad():
torch_output=model(dummy_input_x, dummy_input_roi)
torch_output=rearrange(torch_output, "n c h w -> n h w c")
# visualize(torch_output)# get onnx output# -----------------------------------------------------------------------------------------------onnx_session=onnxruntime.InferenceSession(onnx_save_path, providers=['CPUExecutionProvider'])
onnx_inputs=dict(x=dummy_input_x.numpy(), roi=dummy_input_roi.numpy())
onnx_output=onnx_session.run(None, onnx_inputs)[0]
onnx_output=rearrange(onnx_output, "n c h w -> n h w c")
# compare torch output and onnx outputnp.testing.assert_allclose(onnx_output, torch_output.numpy())
# get tflite output# -----------------------------------------------------------------------------------------------tflite_onnx2tf=tf.lite.Interpreter(model_path=tflite_save_path)
tflite_onnx2tf.allocate_tensors()
tflite_onnx2tf.set_tensor(tflite_onnx2tf.get_input_details()[0]['index'],
rearrange(dummy_input_x.numpy(), "n c h w -> n h w c"))
tflite_onnx2tf.set_tensor(tflite_onnx2tf.get_input_details()[1]['index'], dummy_input_roi.numpy())
tflite_onnx2tf.invoke()
tflite_output= [tflite_onnx2tf.get_tensor(i['index']) foriintflite_onnx2tf.get_output_details()][0]
total_output=np.concatenate([torch_output, onnx_output, tflite_output], axis=0)
visualize(total_output)
print("convert done")
returnif__name__=="__main__":
main()
Parameter Replacement JSON
N/A
Description
Purpose: Personal development
What: When I tested RoiAlign, it showed unexpected output as below.
pytorch output
onnx output
wrong tflite output
correct tflite output after modification
How: After some debugging, I found that wrong order of roi coordinates are fed into tflite. After changing line 124 to x0, x1, y1, y0 = tf.split(boxes, 4, axis=1), I could get correct result as shown above.
Although the output value is slightly different as below due to the implementation detail in tensorflow, it looks working fine. However, I couldn't find out why order of the roi coordinates is changed from x0, y0, x1, y1 to x0, x1, y1, y0.
The text was updated successfully, but these errors were encountered:
The operation of transposing indices in Gather from NCHW to NHWC was applied and the dimensions were shuffled. Therefore, I have added the ability to disable Gather's indices transposition operation.
Also, this problem only occurred when there was an operation with a shape change immediately after the input OP. Therefore, I have added a process to determine if the input OP has been transposed and if a subsequent operation requires re-transposition. #26
Issue Type
Others
onnx2tf version number
1.1.33
Download URL for ONNX
Please use code below to reproduce bug.
Parameter Replacement JSON
N/A
Description
Purpose: Personal development
What: When I tested RoiAlign, it showed unexpected output as below.
x0, x1, y1, y0 = tf.split(boxes, 4, axis=1)
, I could get correct result as shown above.onnx2tf/onnx2tf/ops/RoiAlign.py
Lines 116 to 126 in 0b14623
Although the output value is slightly different as below due to the implementation detail in tensorflow, it looks working fine. However, I couldn't find out why order of the roi coordinates is changed from
x0, y0, x1, y1
tox0, x1, y1, y0
.The text was updated successfully, but these errors were encountered: