Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DN-DAB-DETR] The output of ONNX's Mul OP is different from the TFLite's output. #327

Closed
On-JungWoan opened this issue Apr 25, 2023 · 78 comments
Labels
OP:BatchMatMul OP:BatchMatMul OP:Softmax OP:Softmax Transformer Transformer

Comments

@On-JungWoan
Copy link
Contributor

Issue Type

Others

onnx2tf version number

1.9.12

onnx version number

1.13.1

onnxruntime version number

1.13.1

onnxsim (onnx_simplifier) version number

0.4.17

tensorflow version number

2.12.0

Download URL for ONNX

https://drive.google.com/file/d/1rjqhNfn85we2IG6YwKOJlhKv-fMRFha7/view?usp=sharing

Parameter Replacement JSON

{
    "format_version": 1,
    "operations": [


      {
        "op_name": "/backbone/backbone.1/Unsqueeze",
        "param_target": "outputs",
        "param_name": "/backbone/backbone.1/Unsqueeze_output_0",
        "post_process_transpose_perm": [3,0,1,2]
      },
      {
        "op_name": "/backbone/backbone.1/Unsqueeze_1",
        "param_target": "outputs",
        "param_name": "/backbone/backbone.1/Unsqueeze_1_output_0",
        "post_process_transpose_perm": [3,0,1,2]
      }

    ]      
  }

Description

Hi! @PINTO0309. I'm trying to convert DN-DAB-DETR to TFLite. However, when I multiply two tensors, the output of ONNX is different from the TFLite's output. Below is my result using -cotof.

  • -cotof result

image

  • my onnx graph

image

  • python script
test = pos*pos_scales

Should I avoid tensor's multiply?

@PINTO0309 PINTO0309 changed the title [Mul] The output of ONNX's Mul OP is different from the TFLite's output. [DN-DAB-DETR] The output of ONNX's Mul OP is different from the TFLite's output. Apr 25, 2023
@On-JungWoan
Copy link
Contributor Author

When I test the only single Mul OP in real numpy data, the output of Mul OP is all matched.

import numpy as np
import tensorflow as tf

TFLITE_PATH = 'ScaleMul_float32.tflite'
POS_PATH = 'pos.npy'
POS_SCALE_PATH = 'pos_scales.npy'

# Load the TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=TFLITE_PATH)
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

interpreter.set_tensor(
    input_details[0]['index'],
    np.load(POS_PATH).transpose(0,2,1)
)
interpreter.set_tensor(
    input_details[1]['index'],
    np.load(POS_SCALE_PATH).transpose(0,2,1)
)

interpreter.invoke()

out = interpreter.get_tensor(output_details[0]['index'])

res = np.allclose(
    np.load('/data/ojw/convert/samples/res_mul.npy'),
    out.transpose(0,2,1)
)
print(res)
True

However, When I used the -cotof options, I still get the unmatched message. I set everything to the same as the environment of test, but the same result was not obtained. Is this a bug in converting process?

@PINTO0309
Copy link
Owner

PINTO0309 commented Apr 25, 2023

Thank you.

I will be very busy until the second week of May and will not have enough time to debug and maintain at the same pace as before.

  • onnx
    image

If the accuracy check is Unmatched on this simple model, then the problem appears to be in the accuracy check logic itself. But I don't have time to debug.

Probably related to this issue as well.
[InSPyReNet] Swin Transformer Support question #312

By the way,

onnx2tf \
-i ScaleMul.onnx \
-kat onnx__Mul_0 onnx__Mul_1 \
-cind "onnx::Mul_0" pos.npy \
-cind "onnx::Mul_1" pos_scales.npy \
-cotof

image

  • tflite
    image

@PINTO0309
Copy link
Owner

PINTO0309 commented Apr 26, 2023

  • https://github.com/PINTO0309/onnx2tf/releases/tag/1.9.13

  • onnx2tf.py

    • Added sanitizing of : and / just before first dummy inference
    • ScaleMul.onnx.zip
    • Before
      onnx2tf \
      -i ScaleMul.onnx \
      -kat onnx__Mul_0 onnx__Mul_1 \
      -cind "onnx::Mul_0" pos.npy \
      -cind "onnx::Mul_1" pos_scales.npy \
      -cotof
      
    • After
      onnx2tf \
      -i ScaleMul.onnx \
      -kat onnx__Mul_0 onnx__Mul_1 \
      -cind onnx__Mul_0 pos.npy \
      -cind onnx__Mul_1 pos_scales.npy \
      -cotof
      
  • onnx
    image

  • result
    image

  • tflite
    image

@On-JungWoan
Copy link
Contributor Author

On-JungWoan commented Apr 26, 2023

Thank you for your detail answer. I get the Matched message in small model. However, I still get the Unmatched message in original model. And, when I inference the tflite model, I only get -inf values.

onnx2tf -i dn_dab_detr_480x480.onnx  -prf replace.json \
-kat input.1 -onimc /transformer/encoder/Mul_output_0

image



While I'm trying to find the reasons of the problem, I noticed strange things. The original model(tflite) was only taking one input for the Mul OP instead of two. The other inputs appeared to be simply stored as a binary file. Is this the reason for this problem?

  • Small model

@PINTO0309
Copy link
Owner

PINTO0309 commented Apr 26, 2023

While I'm trying to find the reasons of the problem, I noticed strange things. The original model(tflite) was only taking one input for the Mul OP instead of two.

To me this is not a strange phenomenon at all. One of the two inputs of Mul in this ONNX model is a constant. TensorFlow Lite has by far the most powerful optimizer than ONNX, with the ability to optimize and precompute all meaningless operations.

image

image

image

This Transpose seems to be a specification that always outputs a fixed value.
image

In other words, it is very likely that the very large error in the simple Mul operation that you first pointed out is caused by a problem with a conversion somewhere in the figure below. This is almost certain. There is an error in the transformation by onnx2tf of the subgraph that computes the constant.
image

Usually it is very difficult to identify this problem as it is in this case. We need to trace the conversion logs in order starting with the Resize OP that is generating the constants and identify where the transposition error is occurring.

@PINTO0309
Copy link
Owner

The value is already broken in /backbone/backbone.0/Gather.

INFO: onnx_op_type: Resize onnx_op_name: /backbone/backbone.0/Resize
INFO:  input_name.1: /backbone/backbone.0/Constant_output_0 shape: [1, 1, 480, 480] dtype: <class 'numpy.float32'>
INFO:  input_name.2:  shape: None dtype: None
INFO:  input_name.3:  shape: None dtype: None
INFO:  input_name.4: /backbone/backbone.0/Concat_1_output_0 shape: [4] dtype: <class 'numpy.int64'>
INFO:  output_name.1: /backbone/backbone.0/Resize_output_0 shape: [1, 1, 15, 15] dtype: float32
INFO: tf_op_type: upsampling2d_nearest
INFO:  input.1.images: shape: (1, 480, 480, 1) dtype: float32 
INFO:  input.2.boxes: 
INFO:  input.3.box_indices: 
INFO:  input.4.new_size/crop_size: shape: (2,) dtype: <dtype: 'int32'> 
INFO:  input.5.method: val: nearest 
INFO:  input.6.extrapolation_value: val: 0.0 
INFO:  input.7.align_corners: val: False 
INFO:  output.1.output: shape: (1, 15, 15, 1) dtype: <dtype: 'float32'> 

INFO: onnx_op_type: Gather onnx_op_name: /backbone/backbone.0/Gather
INFO:  input_name.1: /backbone/backbone.0/Resize_output_0 shape: [1, 1, 15, 15] dtype: float32
INFO:  input_name.2: /Constant_output_0 shape: [] dtype: <class 'numpy.int64'>
INFO:  output_name.1: /backbone/backbone.0/Gather_output_0 shape: [1, 15, 15] dtype: float32
INFO: tf_op_type: gather_v2
INFO:  input.1.params: shape: (1, 15, 15, 1) dtype: <dtype: 'float32'> 
INFO:  input.2.indices: shape: () dtype: <dtype: 'int64'> 
INFO:  input.3.axis: val: 0 
INFO:  output.1.output: shape: (15, 15, 1) dtype: <dtype: 'float32'>

INFO: onnx_op_type: CumSum onnx_op_name: /backbone/backbone.1/CumSum
INFO:  input_name.1: /backbone/backbone.1/Cast_output_0 shape: [1, 15, 15] dtype: float32
INFO:  input_name.2: /backbone/backbone.1/Constant_1_output_0 shape: [] dtype: <class 'numpy.int32'>
INFO:  output_name.1: /backbone/backbone.1/CumSum_output_0 shape: [1, 15, 15] dtype: float32
INFO: tf_op_type: cumsum
INFO:  input.1.x: shape: (15, 15, 1) dtype: <dtype: 'float32'> 
INFO:  input.2.axis: shape: () dtype: int32 
INFO:  input.3.exclusive: val: False 
INFO:  input.4.reverse: val: False 
INFO:  output.1.output: shape: (15, 15, 1) dtype: <dtype: 'float32'> 

INFO: onnx_op_type: CumSum onnx_op_name: /backbone/backbone.1/CumSum_1
INFO:  input_name.1: /backbone/backbone.1/Cast_output_0 shape: [1, 15, 15] dtype: float32
INFO:  input_name.2: /backbone/backbone.1/Constant_2_output_0 shape: [] dtype: <class 'numpy.int32'>
INFO:  output_name.1: /backbone/backbone.1/CumSum_1_output_0 shape: [1, 15, 15] dtype: float32
INFO: tf_op_type: cumsum
INFO:  input.1.x: shape: (15, 15, 1) dtype: <dtype: 'float32'> 
INFO:  input.2.axis: shape: () dtype: int32 
INFO:  input.3.exclusive: val: False 
INFO:  input.4.reverse: val: False 
INFO:  output.1.output: shape: (15, 15, 1) dtype: <dtype: 'float32'> 

@On-JungWoan
Copy link
Contributor Author

On-JungWoan commented Apr 26, 2023

Thank you for kindly responding despite your busy schedule. For the rest of the issues, I will try to find the problem myself so as not to bother you.

@PINTO0309
Copy link
Owner

PINTO0309 commented Apr 26, 2023

Note that only the Unsqueeze transposition specification has a special specification. Since the internal logic of Unsqueeze requires very complex considerations, I am providing the ability to replace it with Reshape.

https://github.com/PINTO0309/onnx2tf#parameter-replacement

{
  "format_version": 1,
  "operations": [

    {
      "op_name": "/backbone/backbone.0/Gather",
      "param_target": "attributes",
      "param_name": "axis",
      "values": 3
    },
    {
      "op_name": "/backbone/backbone.1/Unsqueeze",
      "param_target": "op",
      "new_shape": [1,15,15,1]
    },
    {
      "op_name": "/backbone/backbone.1/Unsqueeze_1",
      "param_target": "op",
      "new_shape": [1,15,15,1]
    }

  ]
}

image

@PINTO0309 PINTO0309 added OP:Gather OP:Gather OP:Unsqueeze OP:Unsqueeze labels Apr 26, 2023
@On-JungWoan
Copy link
Contributor Author

On-JungWoan commented Apr 26, 2023

I made modifications to the model as follows and finally got matched message in all the outputs. As you mentioned, there seems to be a bug in the process of the Mul operator handling constants.

  • Before

  • After

  • Result

Thank you for your help!

@On-JungWoan
Copy link
Contributor Author

By the way, there seems to be some error in the Placeholder operator here and there, is it okay to ignore them?

image

@PINTO0309
Copy link
Owner

It is hard to know what is going on without tracing it in detail, but it may be related to the following issue.

[InSPyReNet] Swin Transformer Support question #312

@On-JungWoan
Copy link
Contributor Author

Thank you for your fast response. I will check the issue that you mentioned. Can I close this issue?

@PINTO0309
Copy link
Owner

It is up to you to decide whether to close the issue.

@On-JungWoan
Copy link
Contributor Author

On-JungWoan commented Apr 27, 2023

I am currently having a problem with the Gemm operator. All other outputs are correct, but only the output of the Gemm operator shows an Unmatched message. Is this just an error in the calculation method?

image

image

ONNX Link

replace.json
{
  "format_version": 1,
  "operations": [
    {
      "op_name": "/backbone/backbone.1/Unsqueeze",
      "param_target": "op",
      "new_shape": [1,15,15,1]
    },
    {
      "op_name": "/backbone/backbone.1/Unsqueeze_1",
      "param_target": "op",
      "new_shape": [1,15,15,1]
    }
  ]
}



The strange thing here is that when I checked the -cotof result of the operator using the -kat option, I was able to get a Matched message.

onnx2tf -i new_dn_dab_detr_480x480.onnx \
-prf replace.json \
-onimc /transformer/encoder/layers.0/self_attn/Gemm_output_0 \
-cotof -cotoa 1e-1

image

Should I modify the replace.json file?

@PINTO0309
Copy link
Owner

The ONNX file you shared with me was corrupted and unreadable when I downloaded it.

@On-JungWoan
Copy link
Contributor Author

Oh, I'm sorry. That's my mistake. Here is the new ONNX file link.

@PINTO0309
Copy link
Owner

PINTO0309 commented Apr 27, 2023

The strange thing here is that when I checked the -cotof result of the operator using the -kat option, I was able to get a Matched message.

Please give me as much info as possible.

onnx2tf -i new_dn_dab_detr_480x480.onnx \
-prf replace.json \
-onimc /transformer/encoder/layers.0/self_attn/Gemm_output_0 \
-cotof -cotoa 1e-1

This command you are posting does not include the -kat option and is unable to reproduce normal and abnormal patterns, respectively.

Also, how did you generate the model for this image? If I can't reproduce the model you have on my PC, the man-hours required for testing will be very large, and it will be very difficult for me.
image

@On-JungWoan
Copy link
Contributor Author

On-JungWoan commented Apr 27, 2023

1. Model shown in image

Also, how did you generate the model for this image? If I can't reproduce the model you have on my PC, the man-hours required for testing will be very large, and it will be very difficult for me.

First, the part shown in the image is the MultiHeadAttention process, and below is the link to the corresponding model.


For this part of model, matched messages are displayed for all outputs.

onnx2tf -i multiheadatn.onnx \
-kat key_padding_mask \
-cotof -cotoa 1e-1

image



2. The part I think is weird

2-1. -onimic for the first layer's Gemm_output

The matched message is displayed for all outputs even when -onimic is applied to the Gemm output of the first layer(/transformer/encoder/layers.0/self_attn/Gemm_output_0).

onnx2tf -i new_dn_dab_detr_480x480.onnx \
-prf replace.json \
-onimc /transformer/encoder/layers.0/self_attn/Gemm_output_0 \
-cotof -cotoa 1e-1 \
-kat input.1

image

image


2-2. -onimic for the second layer's Gemm_output

However, when I apply -onimic to the second layer(/transformer/encoder/layers.1/self_attn/Gemm_output_0), the Gemm output of the first layer, which was correct until now, becomes incorrect.

onnx2tf -i new_dn_dab_detr_480x480.onnx \
-prf replace.json \
-onimc /transformer/encoder/layers.1/self_attn/Gemm_output_0 \
-cotof -cotoa 1e-1 \
-kat input.1

image

Is this a bug or my mistake?

@On-JungWoan
Copy link
Contributor Author

On-JungWoan commented Apr 29, 2023

Hi, @PINTO0309! I found a very strange issue while experimenting with nn.multiheadattention. When inputting a value that exceeds a certain threshold into the TFLite converted multiheadattention, the results of ONNX and TFLite differ. I have been struggling with this issue all weekend, and I think this is a clear bug. I'm sorry to bother you when you're busy, but I would really appreciate it if you could take a look.

scale=1e-2

( ... )

print(len(res[res==False]))
0
  • Max value < 10
scale=1e+1

( ... )

print(len(res[res==False]))
150092


Model

class MHAtn(nn.Module):
    def __init__(self):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(256, 8, 0.0)
    def forward(self, q, k, src):
        return self.self_attn(q, k, src)

@PINTO0309
Copy link
Owner

PINTO0309 commented Apr 30, 2023

Thank you. I am allocating some time to implement the research (Private implementation as a research position), so I will be a little slow in responding for a while. I am not ignoring this issue.

@On-JungWoan
Copy link
Contributor Author

On-JungWoan commented May 1, 2023

I thought this issue was related to Matmul, so I read the issue you wrote below.

Frequent accuracy degradation occurs only when the sizes of all dimensions except the batch size are the same, as in [1,256,256] in the figure below.

This means that for tensors of structures other than 4 or 5 dimensional, the tool may make a wrong decision about transposition.



On the other hand, wouldn't it be possible to adjust the dimension of the weight that is passed to the input of Matmul to 4-dimensions for nn.Linear?

image

If the dimension of the weight in nn.Linear is 256,256, is there any solution to this?

@PINTO0309
Copy link
Owner

PINTO0309 commented May 1, 2023

This is an incorrect test pattern because it generates different random values for ONNX and TFLite.

onnx_dummy_input = np.random.rand(625,1,256).astype(np.float32) * scale
tflite_dummy_input = np.random.rand(625,256,1).astype(np.float32) * scale

However, you are correct, increasing the value of scale seems to break the value for some reason.

scale=1e+1 # Error
query = np.random.rand(225,1,256).astype(np.float32)*scale
key = np.random.rand(225,1,256).astype(np.float32)*scale
value = np.random.rand(225,1,256).astype(np.float32)*scale
key_padding_mask = np.random.rand(1,225).astype(np.float32)*scale

# Set ONNX model
ort_session = onnxruntime.InferenceSession(
    onnx_path,
    providers=['CUDAExecutionProvider']
)

# Set TFLite model
interpreter = tf.lite.Interpreter(model_path=tflite_path)
interpreter.allocate_tensors()
tf_lite_model = interpreter.get_signature_runner()

ort_inputs = {}
ort_inputs['query'] = query
ort_inputs['key'] = key
ort_inputs['value'] = value
ort_inputs['key_padding_mask'] = key_padding_mask

# Get output of ONNX
onnx_out_list = ort_session.run(None, ort_inputs)

# Get output of TFLite 
tt_lite_output = tf_lite_model(
    query=tf.constant(query, dtype=tf.float32),
    key=tf.constant(key, dtype=tf.float32),
    value=tf.constant(value, dtype=tf.float32),
    key_padding_mask=tf.constant(key_padding_mask, dtype=tf.float32),
)

image

Despite using the same model, both outputs were nearly identical when small values were entered.

scale=1e-2
query = np.random.rand(225,1,256).astype(np.float32)*scale
key = np.random.rand(225,1,256).astype(np.float32)*scale
value = np.random.rand(225,1,256).astype(np.float32)*scale
key_padding_mask = np.random.rand(1,225).astype(np.float32)*scale

# Set ONNX model
ort_session = onnxruntime.InferenceSession(
    onnx_path,
    providers=['CUDAExecutionProvider']
)

# Set TFLite model
interpreter = tf.lite.Interpreter(model_path=tflite_path)
interpreter.allocate_tensors()
tf_lite_model = interpreter.get_signature_runner()

ort_inputs = {}
ort_inputs['query'] = query
ort_inputs['key'] = key
ort_inputs['value'] = value
ort_inputs['key_padding_mask'] = key_padding_mask

# Get output of ONNX
onnx_out_list = ort_session.run(None, ort_inputs)

# Get output of TFLite 
tt_lite_output = tf_lite_model(
    query=tf.constant(query, dtype=tf.float32),
    key=tf.constant(key, dtype=tf.float32),
    value=tf.constant(value, dtype=tf.float32),
    key_padding_mask=tf.constant(key_padding_mask, dtype=tf.float32),
)

image

@PINTO0309
Copy link
Owner

We already know that the 256x256 MatMul pattern makes transposition errors and causes errors, but before that, there appears to be a fatal problem with TFLite's runtime.

@On-JungWoan
Copy link
Contributor Author

You mean, this issue is related to a third party?

@PINTO0309
Copy link
Owner

PINTO0309 commented May 3, 2023

Test data is not properly transposed from NHWC to NCHW. I am trying to figure out how to deal with this.

Essentially, the Numpy array of test data should be NHWC.
input_nhwc.npy.zip

https://github.com/PINTO0309/onnx2tf#cli-parameter
image

I even identified the problem here.

# -cid
if custom_input_op_name_np_data_path:
for param in custom_input_op_name_np_data_path:
input_op_name = str(param[0])
numpy_file_path = str(param[1])
custom_input_data = np.load(numpy_file_path)
input_datas[input_op_name] = custom_input_data

@PINTO0309
Copy link
Owner

PINTO0309 commented May 4, 2023

I beat the bug. 👍

image

@On-JungWoan
Copy link
Contributor Author

On-JungWoan commented May 4, 2023

Thank you for your effort. I appreciate it.😊

Could you let me know the command you used? I am still seeing the Unmatched message. Below is the version of the tools I used.

  • onnx2tf

    • 1.9.19
  • onnxsim

    • 0.4.28
  • Tensorflow

    • 2.12.0
  • onnxruntime

    • 1.13.1

@PINTO0309
Copy link
Owner

Please wait a little longer. Right now I am in the middle of a regression test with my CI. I will upgrade to v1.10.0 when it is all green.

@On-JungWoan
Copy link
Contributor Author

I understand it. Thank you so much.

@PINTO0309
Copy link
Owner

PINTO0309 commented May 4, 2023

https://github.com/PINTO0309/onnx2tf/releases/tag/1.10.0

The following commands can be used for conversion. Please try it.

Originally, the problem was quite complicated by the combination of multiple problems with the way the -cind option handled the test data and the Softmax and MatMul dimensional correction process.

input_nhwc.npy.zip

onnx2tf -i mod_dn_dab_detr.onnx -cotof

image

$ pip show onnx2tf onnxsim onnx onnxruntime-gpu tensorflow

Name: onnx2tf
Version: 1.10.0
Summary: Self-Created Tools to convert ONNX files (NCHW) to TensorFlow/TFLite/Keras format (NHWC). The purpose of this tool is to solve the massive Transpose extrapolation problem in onnx-tensorflow (onnx-tf).
Home-page: https://github.com/PINTO0309/onnx2tf
Author: Katsuya Hyodo
Author-email: rmsdh122@yahoo.co.jp
License: MIT License
Location: /usr/local/lib/python3.8/dist-packages
Editable project location: /usr/local/lib/python3.8/dist-packages
Requires: 
Required-by: simple-onnx-processing-tools
---
Name: onnxsim
Version: 0.4.17
Summary: Simplify your ONNX model
Home-page: https://github.com/daquexian/onnx-simplifier
Author: ONNX Simplifier Authors
Author-email: daquexian566@gmail.com
License: Apache License v2.0
Location: /usr/local/lib/python3.8/dist-packages
Requires: onnx, rich
Required-by: 
---
Name: onnx
Version: 1.13.1
Summary: Open Neural Network Exchange
Home-page: https://github.com/onnx/onnx
Author: ONNX
Author-email: onnx-technical-discuss@lists.lfaidata.foundation
License: Apache License v2.0
Location: /usr/local/lib/python3.8/dist-packages
Requires: numpy, protobuf, typing-extensions
Required-by: deepsparse, onnigiri, onnx-coreml, onnx-graphsurgeon, onnx-simplifier, onnx-tf, onnxconverter-common, onnxoptimizer, onnxruntime-extensions, onnxsim, onnxsim-no-ort, sclblonnx, softneuro, sparseml, sparsezoo, tf2onnx
---
Name: onnxruntime-gpu
Version: 1.13.1
Summary: ONNX Runtime is a runtime accelerator for Machine Learning models
Home-page: https://onnxruntime.ai
Author: Microsoft Corporation
Author-email: onnxruntime@microsoft.com
License: MIT License
Location: /usr/local/lib/python3.8/dist-packages
Requires: coloredlogs, flatbuffers, numpy, packaging, protobuf, sympy
Required-by: 
---
Name: tensorflow
Version: 2.12.0rc0
Summary: TensorFlow is an open source machine learning framework for everyone.
Home-page: https://www.tensorflow.org/
Author: Google Inc.
Author-email: packages@tensorflow.org
License: Apache 2.0
Location: /usr/local/lib/python3.8/dist-packages
Requires: absl-py, astunparse, flatbuffers, gast, google-pasta, grpcio, h5py, jax, keras, libclang, numpy, opt-einsum, packaging, protobuf, setuptools, six, tensorboard, tensorflow-estimator, tensorflow-io-gcs-filesystem, termcolor, typing-extensions, wrapt
Required-by: dopamine-rl, tensorflowjs

@PINTO0309 PINTO0309 removed the Parameter replacement Use Parameter replacement label May 4, 2023
@On-JungWoan
Copy link
Contributor Author

On-JungWoan commented May 4, 2023

Thank you for your hard work. When I inference the converted tflite model, I still get the -inf values in half of the output. Is this bug in TFLite runtime?


  • Inference script
import numpy as np
import tensorflow as tf

TFLITE_PATH = 'my/tflite/path/final_model_float32.tflite'

# Load the TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=TFLITE_PATH)
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

interpreter.set_tensor(
    input_details[0]['index'],
    np.ones(list(input_details[0]['shape'])).astype(np.float32)
)

interpreter.invoke()

out_list = []
for i in range(len(output_details)):
    out_list.append(interpreter.get_tensor(output_details[i]['index']))

print(out_list[0])
array([[[-inf, -inf, -inf, ..., -inf, -inf, -inf],
         [-inf, -inf, -inf, ..., -inf, -inf, -inf],
         [-inf, -inf, -inf, ..., -inf, -inf, -inf],
         ...,
         [-inf, -inf, -inf, ..., -inf, -inf, -inf],
         [-inf, -inf, -inf, ..., -inf, -inf, -inf],
         [-inf, -inf, -inf, ..., -inf, -inf, -inf]]], dtype=float32)

@On-JungWoan
Copy link
Contributor Author

On-JungWoan commented May 4, 2023

I still get Unmatched message in the result of the -cotof option. What is the problem for me?

  • Command
onnx2tf -i mod_dn_dab_detr.onnx -cotof

image

  • Versions
Name: onnx2tf
Version: 1.10.0
Summary: Self-Created Tools to convert ONNX files (NCHW) to TensorFlow/TFLite/Keras format (NHWC). The purpose of this tool is to solve the massive Transpose extrapolation problem in onnx-tensorflow (onnx-tf).
Home-page: https://github.com/PINTO0309/onnx2tf
Author: Katsuya Hyodo
Author-email: rmsdh122@yahoo.co.jp
License: MIT License
Location: /home/user/anaconda3/envs/tflite/lib/python3.9/site-packages
Requires: 
Required-by: simple-onnx-processing-tools
---
Name: onnxsim
Version: 0.4.17
Summary: Simplify your ONNX model
Home-page: https://github.com/daquexian/onnx-simplifier
Author: ONNX Simplifier Authors
Author-email: daquexian566@gmail.com
License: Apache License v2.0
Location: /home/user/anaconda3/envs/tflite/lib/python3.9/site-packages
Requires: onnx, rich
Required-by: 
---
Name: onnx
Version: 1.13.1
Summary: Open Neural Network Exchange
Home-page: https://github.com/onnx/onnx
Author: ONNX
Author-email: onnx-technical-discuss@lists.lfaidata.foundation
License: Apache License v2.0
Location: /home/user/anaconda3/envs/tflite/lib/python3.9/site-packages
Requires: numpy, protobuf, typing-extensions
Required-by: onnx-graphsurgeon, onnxsim
---
Name: onnxruntime-gpu
Version: 1.13.1
Summary: ONNX Runtime is a runtime accelerator for Machine Learning models
Home-page: https://onnxruntime.ai
Author: Microsoft Corporation
Author-email: onnxruntime@microsoft.com
License: MIT License
Location: /home/user/anaconda3/envs/tflite/lib/python3.9/site-packages
Requires: coloredlogs, flatbuffers, numpy, packaging, protobuf, sympy
Required-by: 
---
Name: tensorflow
Version: 2.12.0rc0
Summary: TensorFlow is an open source machine learning framework for everyone.
Home-page: https://www.tensorflow.org/
Author: Google Inc.
Author-email: packages@tensorflow.org
License: Apache 2.0
Location: /home/user/anaconda3/envs/tflite/lib/python3.9/site-packages
Requires: absl-py, astunparse, flatbuffers, gast, google-pasta, grpcio, h5py, jax, keras, libclang, numpy, opt-einsum, packaging, protobuf, setuptools, six, tensorboard, tensorflow-estimator, tensorflow-io-gcs-filesystem, termcolor, typing-extensions, wrapt
Required-by: 

My library version is exactly same with you. @PINTO0309 Also, in the mod_dn_dab_detr.onnx, I still seeing the unmatched message.

@PINTO0309
Copy link
Owner

PINTO0309 commented May 4, 2023

The ONNX file you shared with me is corrupt and cannot be displayed on Netron.

I still get the -inf values in half of the output. Is this bug in TFLite runtime?

If the accuracy check after conversion is normal, I can only surmise that either a problem on the runtime side or the fact that all input data is 1 could be the cause of the problem.

To begin with, is the mod_dn_dab_detr.onnx file you are using for the conversion the same as the mod_dn_dab_detr.onnx you shared with me the other day? This is a completely different issue from the story of the result being -inf.

I can attest to you that there is nothing wrong with my environment.

docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
ghcr.io/pinto0309/onnx2tf:1.10.0

rm mod_dn_dab_detr.onnx
wget https://s3.ap-northeast-2.wasabisys.com/temp-models/onnx2tf_327/mod_dn_dab_detr.onnx

onnx2tf -i mod_dn_dab_detr.onnx -cotof
INFO: onnx_output_name: /Split_output_0 tf_output_name: tf.strided_slice_142/StridedSlice:0 shape: (1, 1, 300, 60) dtype: float32 validate_result:  Matches 
INFO: onnx_output_name: /Split_output_1 tf_output_name: tf.strided_slice_143/StridedSlice:0 shape: (1, 1, 300, 60) dtype: float32 validate_result:  Matches 
INFO: onnx_output_name: /Split_output_2 tf_output_name: tf.strided_slice_144/StridedSlice:0 shape: (1, 1, 300, 60) dtype: float32 validate_result:  Matches 
INFO: onnx_output_name: /Split_output_3 tf_output_name: tf.strided_slice_145/StridedSlice:0 shape: (1, 1, 300, 60) dtype: float32 validate_result:  Matches 
INFO: onnx_output_name: /Split_output_4 tf_output_name: tf.strided_slice_146/StridedSlice:0 shape: (1, 1, 300, 60) dtype: float32 validate_result:  Matches 
INFO: onnx_output_name: /bbox_embed/layers.1/Add_output_0 tf_output_name: tf.math.add_365/Add:0 shape: (6, 1, 300, 256) dtype: float32 validate_result:  Unmatched  max_abs_error: 0.0001277923583984375
INFO: onnx_output_name: 7389 tf_output_name: tf.compat.v1.squeeze_9610//Squeeze:0 shape: (1, 300, 60) dtype: float32 validate_result:  Matches 
INFO: onnx_output_name: 7391 tf_output_name: tf.compat.v1.squeeze_9611//Squeeze_1:0 shape: (1, 300, 60) dtype: float32 validate_result:  Matches 
INFO: onnx_output_name: 7393 tf_output_name: tf.compat.v1.squeeze_9612//Squeeze_2:0 shape: (1, 300, 60) dtype: float32 validate_result:  Matches 
INFO: onnx_output_name: 7395 tf_output_name: tf.compat.v1.squeeze_9613//Squeeze_3:0 shape: (1, 300, 60) dtype: float32 validate_result:  Matches 
INFO: onnx_output_name: 7397 tf_output_name: tf.compat.v1.squeeze_9614//Squeeze_4:0 shape: (1, 300, 60) dtype: float32 validate_result:  Matches 
INFO: onnx_output_name: /bbox_embed/Relu_1_output_0 tf_output_name: tf.nn.relu_105/Relu:0 shape: (6, 1, 300, 256) dtype: float32 validate_result:  Matches 
INFO: onnx_output_name: /bbox_embed/layers.2/MatMul_output_0 tf_output_name: tf.linalg.matmul_4703/MatMul:0 shape: (6, 1, 300, 4) dtype: float32 validate_result:  Matches 
INFO: onnx_output_name: /bbox_embed/layers.2/Add_output_0 tf_output_name: tf.math.add_366/Add:0 shape: (6, 1, 300, 4) dtype: float32 validate_result:  Matches 
INFO: onnx_output_name: /Slice_output_0 tf_output_name: tf.strided_slice_148/StridedSlice:0 shape: (6, 1, 300, 4) dtype: float32 validate_result:  Matches 
INFO: onnx_output_name: /Add_output_0 tf_output_name: tf.math.add_367/Add:0 shape: (6, 1, 300, 4) dtype: float32 validate_result:  Matches 
INFO: onnx_output_name: /ScatterND_output_0 tf_output_name: tf.tensor_scatter_nd_update_30/TensorScatterUpdate:0 shape: (6, 1, 300, 4) dtype: float32 validate_result:  Matches 
INFO: onnx_output_name: /ScatterND_1_output_0 tf_output_name: tf.tensor_scatter_nd_update_31/TensorScatterUpdate:0 shape: (6, 1, 300, 4) dtype: float32 validate_result:  Matches 
INFO: onnx_output_name: /Sigmoid_output_0 tf_output_name: tf.math.sigmoid_10/Sigmoid:0 shape: (6, 1, 300, 4) dtype: float32 validate_result:  Matches 
INFO: onnx_output_name: 7371 tf_output_name: tf.compat.v1.gather_62/GatherV2:0 shape: (1, 300, 4) dtype: float32 validate_result:  Matches 
INFO: onnx_output_name: /Slice_5_output_0 tf_output_name: tf.strided_slice_150/StridedSlice:0 shape: (5, 1, 300, 4) dtype: float32 validate_result:  Matches 
INFO: onnx_output_name: /Split_1_output_0 tf_output_name: tf.strided_slice_151/StridedSlice:0 shape: (1, 1, 300, 4) dtype: float32 validate_result:  Matches 
INFO: onnx_output_name: /Split_1_output_1 tf_output_name: tf.strided_slice_152/StridedSlice:0 shape: (1, 1, 300, 4) dtype: float32 validate_result:  Matches 
INFO: onnx_output_name: /Split_1_output_2 tf_output_name: tf.strided_slice_153/StridedSlice:0 shape: (1, 1, 300, 4) dtype: float32 validate_result:  Matches 
INFO: onnx_output_name: /Split_1_output_3 tf_output_name: tf.strided_slice_154/StridedSlice:0 shape: (1, 1, 300, 4) dtype: float32 validate_result:  Matches 
INFO: onnx_output_name: /Split_1_output_4 tf_output_name: tf.strided_slice_155/StridedSlice:0 shape: (1, 1, 300, 4) dtype: float32 validate_result:  Matches 
INFO: onnx_output_name: 7405 tf_output_name: tf.compat.v1.squeeze_9620//Squeeze_5:0 shape: (1, 300, 4) dtype: float32 validate_result:  Matches 
INFO: onnx_output_name: 7407 tf_output_name: tf.compat.v1.squeeze_9621//Squeeze_6:0 shape: (1, 300, 4) dtype: float32 validate_result:  Matches 
INFO: onnx_output_name: 7409 tf_output_name: tf.compat.v1.squeeze_9622//Squeeze_7:0 shape: (1, 300, 4) dtype: float32 validate_result:  Matches 
INFO: onnx_output_name: 7411 tf_output_name: tf.compat.v1.squeeze_9623//Squeeze_8:0 shape: (1, 300, 4) dtype: float32 validate_result:  Matches 
INFO: onnx_output_name: 7413 tf_output_name: tf.compat.v1.squeeze_9624//Squeeze_9:0 shape: (1, 300, 4) dtype: float32 validate_result:  Matches 

image

If you see Unmatched, it is either a problem with the ONNX file you are using or a problem with your environment.

@On-JungWoan
Copy link
Contributor Author

I will use your docker envirionments. Thank you for your answer.

@On-JungWoan
Copy link
Contributor Author

On-JungWoan commented May 4, 2023

Always thank you for your hard work. I tested in the docker environment you shared and was able to confirm that the matched message is displayed in the -cotof results(in mod_dn_dab_detr.onnx). However, when I actually perform inference with tflite, the output is different from onnx. In float16, the many output values are different, and even with float32, half of the values come out as -inf. I got these results even when I performed inference in your docker environment. How can I solve this problem?

image

import os
import pickle
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=Warning)
warnings.simplefilter(action='ignore', category=DeprecationWarning)
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=RuntimeWarning)
import random
random.seed(0)
import numpy as np
np.set_printoptions(
    precision=6,
    floatmode='fixed',
    suppress=True,
    edgeitems=3,
    linewidth=100,
)
np.random.seed(0)

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
from tensorflow.lite.python import interpreter as iw

TFLITE_PATH = 'saved_model/mod_dn_dab_detr_float32.tflite'
interpreter = iw.Interpreter(
    model_path=TFLITE_PATH,
    num_threads=4,
)
input_details = interpreter.get_input_details()
input_shape_1 = input_details[0]['shape']
output_details = interpreter.get_output_details()

test_data1 = np.random.randn(*input_shape_1).astype(np.float32)

interpreter.allocate_tensors()
interpreter.set_tensor(
    input_details[0]['index'],
    test_data1,
)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])

print('')
print('tflite ===============================================')
print(output_data)


# with open('output/tflite_out.pkl', 'wb') as f:
#     pickle.dump(out_list, f)

@PINTO0309
Copy link
Owner

I can't say whether the following is really the correct test data to input into this model, because I don't know. Is it correct to enter random values? Since the accuracy check of the model is normal, it is better to suspect first the parts of the model other than its structure.

test_data1 = np.random.randn(*input_shape_1).astype(np.float32)

You may want to find out which operations have divergent values.

@On-JungWoan
Copy link
Contributor Author

On-JungWoan commented May 4, 2023

I have tried various methods other than the test dummy data mentioned below, but I always got a value of -inf. Btw, the strange thing here is that only mod_dn_dab_detr_float32 outputs a value of -inf. In mod_dn_dab_detr_float32, although it is not quite right, it does produce a value in float16.

np.ones(input_shape_1).astype(np.float32) # unmatched
torch.rand(input_shape_1, dtype=torch.float32).numpy() # unmatched
np.random.rand(1,800,800,3).astype(np.float32) # unmatched

# etc ...

Then, does your opinion mean that there is a problem with the tflite runtime? Deformable DETR has been successfully converted, but I have no idea why DN-DAB-DETR cannot be converted.

@PINTO0309
Copy link
Owner

https://s3.ap-northeast-2.wasabisys.com/temp-models/onnx2tf_327/test/mod_dn_dab_detr.onnx
https://s3.ap-northeast-2.wasabisys.com/temp-models/onnx2tf_327/test/mod_dn_dab_detr_float32.tflite
https://s3.ap-northeast-2.wasabisys.com/temp-models/onnx2tf_327/test/mod_dn_dab_detr_float16.tflite

I cannot reproduce it.

onnx2tf -i mod_dn_dab_detr.onnx -coion
import os
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=Warning)
warnings.simplefilter(action='ignore', category=DeprecationWarning)
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=RuntimeWarning)
import random
random.seed(0)
import numpy as np
np.set_printoptions(
    precision=6,
    floatmode='fixed',
    suppress=True,
    edgeitems=3,
    linewidth=100,
)
np.random.seed(0)
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
import onnxruntime

ONNX_PATH = 'mod_dn_dab_detr.onnx'
ort_session = onnxruntime.InferenceSession(ONNX_PATH, providers=['CPUExecutionProvider'])
input_name = ort_session.get_inputs()[0].name
input_shape = ort_session.get_inputs()[0].shape
output_num = len(ort_session.get_outputs())

test_data = np.ones(input_shape, dtype=np.float32)

ort_inputs = {
    ort_session.get_inputs()[0].name: test_data,
}
onnx_out = ort_session.run(
    ['7370'],
    ort_inputs,
)

TFLITE_PATH = 'mod_dn_dab_detr_float32.tflite'
interpreter = tf.lite.Interpreter(model_path=TFLITE_PATH)
tf_lite_model = interpreter.get_signature_runner()
tflite_output_float32 = tf_lite_model(
    input_1=test_data.transpose(0,2,3,1),
)
TFLITE_PATH = 'mod_dn_dab_detr_float16.tflite'
interpreter = tf.lite.Interpreter(model_path=TFLITE_PATH)
tf_lite_model = interpreter.get_signature_runner()
tflite_output_float16 = tf_lite_model(
    input_1=test_data.transpose(0,2,3,1),
)

print('=====================================================')
print(onnx_out[0].shape)
print(tflite_output_float32['7370'].shape)
print(tflite_output_float16['7370'].shape)

print('onnx =================================================')
print(onnx_out[0])
print('')
print('tflite float32 =======================================')
print(tflite_output_float32['7370'])
print('')
print('tflite float16 =======================================')
print(tflite_output_float32['7370'])
=====================================================
(1, 300, 60)
(1, 300, 60)
(1, 300, 60)
onnx =================================================
[[[ -9.889225  -5.058238  -4.400069 ...  -4.518981  -4.422979  -5.619113]
  [ -9.666203  -5.503964  -4.429813 ...  -4.284597  -3.964971  -4.635274]
  [-10.165474  -5.403321  -4.708091 ...  -4.778224  -4.726847  -5.837096]
  ...
  [-10.237301  -5.250232  -4.427892 ...  -4.601456  -4.263485  -5.674369]
  [ -9.821136  -5.886144  -4.720225 ...  -4.816403  -4.191678  -4.233029]
  [ -9.900963  -5.236600  -4.576357 ...  -4.596010  -4.396852  -5.339815]]]

tflite float32 =======================================
[[[ -9.889221  -5.058239  -4.400070 ...  -4.518981  -4.422979  -5.619112]
  [ -9.666204  -5.503964  -4.429811 ...  -4.284597  -3.964969  -4.635273]
  [-10.165472  -5.403322  -4.708091 ...  -4.778225  -4.726850  -5.837097]
  ...
  [-10.237301  -5.250231  -4.427890 ...  -4.601455  -4.263482  -5.674368]
  [ -9.821136  -5.886144  -4.720224 ...  -4.816402  -4.191678  -4.233030]
  [ -9.900966  -5.236598  -4.576355 ...  -4.596011  -4.396853  -5.339818]]]

tflite float16 =======================================
[[[ -9.889221  -5.058239  -4.400070 ...  -4.518981  -4.422979  -5.619112]
  [ -9.666204  -5.503964  -4.429811 ...  -4.284597  -3.964969  -4.635273]
  [-10.165472  -5.403322  -4.708091 ...  -4.778225  -4.726850  -5.837097]
  ...
  [-10.237301  -5.250231  -4.427890 ...  -4.601455  -4.263482  -5.674368]
  [ -9.821136  -5.886144  -4.720224 ...  -4.816402  -4.191678  -4.233030]
  [ -9.900966  -5.236598  -4.576355 ...  -4.596011  -4.396853  -5.339818]]]

@On-JungWoan
Copy link
Contributor Author

On-JungWoan commented May 4, 2023

tflite_output_float32 = tf_lite_model(
input_1=test_data.transpose(0,2,3,1),
)

I'm so sorry to bothering you. But, how can I get the input name?

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/unist/anaconda3/envs/onx_test/lib/python3.9/site-packages/tensorflow/lite/python/interpreter.py", line 237, in __call__
    raise ValueError('Invalid Input name (%s) for SignatureDef' %
ValueError: Invalid Input name (input_1) for SignatureDef

image

@PINTO0309
Copy link
Owner

PINTO0309 commented May 5, 2023

https://s3.ap-northeast-2.wasabisys.com/temp-models/onnx2tf_327/test2/mod_dn_dab_detr.onnx
https://s3.ap-northeast-2.wasabisys.com/temp-models/onnx2tf_327/test2/mod_dn_dab_detr_float32.tflite
https://s3.ap-northeast-2.wasabisys.com/temp-models/onnx2tf_327/test2/mod_dn_dab_detr_float16.tflite

import os
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=Warning)
warnings.simplefilter(action='ignore', category=DeprecationWarning)
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=RuntimeWarning)
import random
random.seed(0)
import numpy as np
np.set_printoptions(
    precision=6,
    floatmode='fixed',
    suppress=True,
    edgeitems=3,
    linewidth=100,
)
np.random.seed(0)
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
import onnxruntime

ONNX_PATH = 'mod_dn_dab_detr.onnx'
ort_session = onnxruntime.InferenceSession(ONNX_PATH, providers=['CPUExecutionProvider'])
input_name = ort_session.get_inputs()[0].name
input_shape = ort_session.get_inputs()[0].shape
output_num = len(ort_session.get_outputs())

test_data = np.ones(input_shape, dtype=np.float32)

ort_inputs = {
    ort_session.get_inputs()[0].name: test_data,
}
onnx_out = ort_session.run(
    ['7370'],
    ort_inputs,
)

TFLITE_PATH = 'mod_dn_dab_detr_float32.tflite'
interpreter = tf.lite.Interpreter(model_path=TFLITE_PATH)
tf_lite_model = interpreter.get_signature_runner()
inputs = {'input.1': test_data.transpose(0,2,3,1)}
tflite_output_float32 = tf_lite_model(
    **inputs
)
TFLITE_PATH = 'mod_dn_dab_detr_float16.tflite'
interpreter = tf.lite.Interpreter(model_path=TFLITE_PATH)
tf_lite_model = interpreter.get_signature_runner()
inputs = {'input.1': test_data.transpose(0,2,3,1)}
tflite_output_float16 = tf_lite_model(
    **inputs
)

print('=====================================================')
print(onnx_out[0].shape)
print(tflite_output_float32['7370'].shape)
print(tflite_output_float16['7370'].shape)

print('onnx =================================================')
print(onnx_out[0])
print('')
print('tflite float32 =======================================')
print(tflite_output_float32['7370'])
print('')
print('tflite float16 =======================================')
print(tflite_output_float32['7370'])
=====================================================
(1, 300, 60)
(1, 300, 60)
(1, 300, 60)
onnx =================================================
[[[ -9.889225  -5.058238  -4.400069 ...  -4.518981  -4.422979  -5.619113]
  [ -9.666203  -5.503964  -4.429813 ...  -4.284597  -3.964971  -4.635274]
  [-10.165474  -5.403321  -4.708091 ...  -4.778224  -4.726847  -5.837096]
  ...
  [-10.237301  -5.250232  -4.427892 ...  -4.601456  -4.263485  -5.674369]
  [ -9.821136  -5.886144  -4.720225 ...  -4.816403  -4.191678  -4.233029]
  [ -9.900963  -5.236600  -4.576357 ...  -4.596010  -4.396852  -5.339815]]]

tflite float32 =======================================
[[[ -9.889221  -5.058239  -4.400070 ...  -4.518981  -4.422979  -5.619112]
  [ -9.666204  -5.503964  -4.429811 ...  -4.284597  -3.964969  -4.635273]
  [-10.165472  -5.403322  -4.708091 ...  -4.778225  -4.726850  -5.837097]
  ...
  [-10.237301  -5.250231  -4.427890 ...  -4.601455  -4.263482  -5.674368]
  [ -9.821136  -5.886144  -4.720224 ...  -4.816402  -4.191678  -4.233030]
  [ -9.900966  -5.236598  -4.576355 ...  -4.596011  -4.396853  -5.339818]]]

tflite float16 =======================================
[[[ -9.889221  -5.058239  -4.400070 ...  -4.518981  -4.422979  -5.619112]
  [ -9.666204  -5.503964  -4.429811 ...  -4.284597  -3.964969  -4.635273]
  [-10.165472  -5.403322  -4.708091 ...  -4.778225  -4.726850  -5.837097]
  ...
  [-10.237301  -5.250231  -4.427890 ...  -4.601455  -4.263482  -5.674368]
  [ -9.821136  -5.886144  -4.720224 ...  -4.816402  -4.191678  -4.233030]
  [ -9.900966  -5.236598  -4.576355 ...  -4.596011  -4.396853  -5.339818]]]

@On-JungWoan
Copy link
Contributor Author

On-JungWoan commented May 5, 2023

Your code works very well. I want to express my gratitude for always helping me. As a token of my appreciation, I have decided to become your sponsor. Although I may not be able to provide a large amount of support as a student, I still want to show my gratitude to you in this way.

I will close this issue. Thank you again @PINTO0309 :)

@PINTO0309
Copy link
Owner

Although I may not be able to provide a large amount of support as a student, I still want to show my gratitude to you in this way.

You don't have to strain yourself. I earned all of my tuition money on my own and attended college. I am well aware that there are times in our lives when we struggle financially.

When you graduate and earn a lot of money, please buy me a beer.

@On-JungWoan
Copy link
Contributor Author

Thank you for saying so. I will definitely do that 😊

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
OP:BatchMatMul OP:BatchMatMul OP:Softmax OP:Softmax Transformer Transformer
Projects
None yet
Development

No branches or pull requests

3 participants