Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Gridsample] num_input_elements != num_output_elements (2223936 != 2235392)Node number 2 (RESHAPE) failed to prepare.Failed to apply the default TensorFlow Lite delegate indexed at 0. #308

Closed
On-JungWoan opened this issue Apr 11, 2023 · 14 comments
Labels
third party Third-party tool issues

Comments

@On-JungWoan
Copy link
Contributor

On-JungWoan commented Apr 11, 2023

Issue Type

Others

onnx2tf version number

1.9.1

onnx version number

1.13.1

tensorflow version number

2.12.0

Download URL for ONNX

https://drive.google.com/file/d/1UZPbL5h6GJUwZTHPpab54TFeJNuaJRfS/view?usp=sharing

Parameter Replacement JSON

None

Description

I have noticed that an unknown error occurs when using the Gridsample function from Torch. So, I'm trying to use the custom Gridsample function that you mentioned in #274.

  • Script
class Model(torch.nn.Module):
    def forward(self, image, grid):
        n, c, h, w = image.shape
        _, gh, gw, _ = grid.shape

        x, y = torch.split(grid, split_size_or_sections=1, dim=3)

        x = ((x + 1) / 2) * (w - 1)
        y = ((y + 1) / 2) * (h - 1)

        x = x.view(n, -1)
        y = y.view(n, -1)

        x0 = torch.floor(x).long()
        y0 = torch.floor(y).long()
        x1 = x0 + 1
        y1 = y0 + 1

        wa = ((x1 - x) * (y1 - y)).unsqueeze(1)
        wb = ((x1 - x) * (y - y0)).unsqueeze(1)
        wc = ((x - x0) * (y1 - y)).unsqueeze(1)
        wd = ((x - x0) * (y - y0)).unsqueeze(1)

        # Apply default for grid_sample function zero padding
        im_padded = torch.nn.functional.pad(image, pad=[1, 1, 1, 1], mode='constant', value=0)
        padded_h = h + 2
        padded_w = w + 2
        # save points positions after padding
        x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1

        # Clip coordinates to padded image size
        x0 = torch.where(x0 < 0, torch.tensor(0), x0)
        x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1), x0)
        x1 = torch.where(x1 < 0, torch.tensor(0), x1)
        x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1), x1)
        y0 = torch.where(y0 < 0, torch.tensor(0), y0)
        y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1), y0)
        y1 = torch.where(y1 < 0, torch.tensor(0), y1)
        y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1), y1)

        im_padded = im_padded.view(n, c, -1)

        x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
        x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
        x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
        x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)

        Ia = torch.gather(im_padded, 2, x0_y0)
        Ib = torch.gather(im_padded, 2, x0_y1)
        Ic = torch.gather(im_padded, 2, x1_y0)
        Id = torch.gather(im_padded, 2, x1_y1)

        return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)

However, when I tried to infer the tflite model that includes the converted custom Gridsample function, the following error occurred.

Exception has occurred: RuntimeError
tensorflow/lite/kernels/reshape.cc:85 num_input_elements != num_output_elements (2223936 != 2235392)Node number 2 (RESHAPE) failed to prepare.Failed to apply the default TensorFlow Lite delegate indexed at 0.
  File "/data/ojw/convert/convert/split/grid_sample/inference_tflite.py", line 23, in <module>
    interpreter.allocate_tensors()
RuntimeError: tensorflow/lite/kernels/reshape.cc:85 num_input_elements != num_output_elements (2223936 != 2235392)Node number 2 (RESHAPE) failed to prepare.Failed to apply the default TensorFlow Lite delegate indexed at 0.

Is this bug?

image

@PINTO0309
Copy link
Owner

PINTO0309 commented Apr 11, 2023

image

To me, it looks like a bug in the TensorFlow runtime. The place you are focusing on is where you ignored the error message. The error message says there is a problem with operation #2, but you are looking at operation #3, which is misplaced.

Does the model look broken? To me, it looks like the error message is lying.

import os
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=Warning)
warnings.simplefilter(action='ignore', category=DeprecationWarning)
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=RuntimeWarning)
import random
random.seed(0)
import numpy as np
np.set_printoptions(
    precision=6,
    floatmode='fixed',
    suppress=True,
    edgeitems=3,
    linewidth=100,
)
np.random.seed(0)
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
from tensorflow.lite.python import interpreter as iw

TFLITE_PATH = 'gird_sample_float32.tflite'
interpreter = iw.Interpreter(
    model_path=TFLITE_PATH,
    num_threads=4,
)
input_details = interpreter.get_input_details()
input_shape_1 = input_details[0]['shape']
input_shape_2 = input_details[1]['shape']
output_details = interpreter.get_output_details()

interpreter.allocate_tensors()

@PINTO0309 PINTO0309 added the third party Third-party tool issues label Apr 11, 2023
@On-JungWoan
Copy link
Contributor Author

I still get the same error even when I use this script. Is it because the error is with TensorFlow, and there is no current solution to fix it?

@PINTO0309
Copy link
Owner

I still get the same error even when I use this script.

I only reproduced the error because you did not post the code to reproduce it.

Did you seriously look at the images I posted?

@On-JungWoan
Copy link
Contributor Author

On-JungWoan commented Apr 11, 2023

Yes, I'm seriously check #2's input and output, but I can't find any errors.

  • Tensor details

    image

  • Input of #2 operation

    image

    8*32*74*118 = 2235392

  • Output of #2 operation

    image

    8*32*8732 = 2235392

  • Inference script

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '4'

import sys
sys.path.append('/data/ojw/convert')

import time
import pickle
from os import path

dir = '/data/ojw/convert/convert/split/grid_sample'
TFLITE_PATH = path.join(dir, 'tflite/gird_sample_float32.tflite')

with open(path.join(dir, 'input.pkl'), 'rb') as f:
    ipt = pickle.load(f)
with open(path.join(dir, 'output.pkl'), 'rb') as f:
    opt = pickle.load(f)


import tensorflow as tf
# Load the TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=TFLITE_PATH, num_threads=4,)
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()    


interpreter.set_tensor(
    input_details[0]['index'],
    ipt[0].cpu().detach().numpy()
)
interpreter.set_tensor(
    input_details[1]['index'],
    ipt[1].cpu().detach().numpy()
)


start = time.time()
interpreter.invoke()
print(time.time() - start)

@PINTO0309
Copy link
Owner

PINTO0309 commented Apr 11, 2023

If so, then you see no problem with the structure of the model?

Isn't the TFLite runtime error message lying?

num_input_elements != num_output_elements (2223936 != 2235392) Node number 2 (RESHAPE)

The number 2223936 cannot be derived by any multiplication. Thus, it is clearly a bug in the TFLite runtime.

@On-JungWoan
Copy link
Contributor Author

On-JungWoan commented Apr 11, 2023

Does the sentence "Isn't the TFLite runtime error message lying?" mean that the error message is incorrect? I am asking to confirm if the message is properly conveyed in English. If that's the case, it seems to me that the error message is lying. Then, Would it be necessary for me to raise an issue directly with TFLite runtime in order to resolve the error?

  • Whole model structure

    image

@PINTO0309
Copy link
Owner

PINTO0309 commented Apr 11, 2023

If that's the case, it seems to me that the error message is lying. Then, Would it be necessary for me to raise an issue directly with TFLite runtime in order to resolve the error?

That's right. I assure you. It is the same conclusion for everyone.

This error has nothing to do with the overall structure of the model. It is a bug in Reshape.

@On-JungWoan
Copy link
Contributor Author

Thank you for your answer. In tomorrow, I'll raise an issue and get back to you.

@PINTO0309
Copy link
Owner

The only thing runtime users like us can do about this symptom is to locate the bug at runtime and submit a pull request, or submit an issue and wait a year or more.

@PINTO0309
Copy link
Owner

Very interesting bug. It seems that dimensions other than batch size are degenerating.

8×33×72×117=2223936

@On-JungWoan
Copy link
Contributor Author

On-JungWoan commented Apr 11, 2023

Oh, you're right.After I ran allocate_tensors() once, I checked the 29th output again and found that the shape had changed as you said.

image

In that case, can I temporarily solve this problem by manually reshaping the input tensor to 8x31x76x119?

@On-JungWoan
Copy link
Contributor Author

Thus, why is there unnecessary multiple Transposes for a single task? This error also occurs right after Transpose. Is there a way to prevent the input image from being unnecessarily transposed?

image

@PINTO0309
Copy link
Owner

PINTO0309 commented Apr 11, 2023

It is odd that a bug in the TFLite runtime is addressed on the onnx2tf side. I would never implement some trick if the structure of the model is broken, but not if it is not broken.

You should probably convert the GridSample (opset=16) before spending time on such a non-essential investigation.

@On-JungWoan
Copy link
Contributor Author

I will close this issue, and I will reopen it when I receive a response from TensorFlow.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
third party Third-party tool issues
Projects
None yet
Development

No branches or pull requests

2 participants