Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FRONTEND][TFLITE][BugFix] Fix int16 transpose conv loading #15173

Merged
merged 1 commit into from
Jun 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3300,6 +3300,7 @@ def convert_transpose_conv(self, op):
kernel_zero_point = weights_tensor.qnn_params["zero_point"]
input_scale = input_tensor.qnn_params["scale"]
kernel_scale = weights_tensor.qnn_params["scale"]
out_dtype = "int64" if output_tensor_type_str == "int16" else "int32"
out = _qnn.op.conv2d_transpose(
in_expr,
weight_expr_iohw,
Expand All @@ -3313,7 +3314,7 @@ def convert_transpose_conv(self, op):
kernel_size=(int(kernel_h), int(kernel_w)),
data_layout="NHWC",
kernel_layout="IOHW",
out_dtype="int32",
out_dtype=out_dtype,
)
else:
out = _op.nn.conv2d_transpose(
Expand Down
2 changes: 1 addition & 1 deletion src/relay/qnn/op/convolution_transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ bool QnnConv2DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs
ICHECK(weight->dtype == DataType::Int(8) || weight->dtype == DataType::UInt(8))
<< "Expected qnn conv2d type(int8, uint8) for weight but was " << weight->dtype;
ICHECK(param->out_dtype == DataType::Int(16) || param->out_dtype == DataType::Int(32) ||
data->dtype == DataType::Int(64))
param->out_dtype == DataType::Int(64))
<< "Expected qnn conv2d type(int16, int32, int64) for output but was " << param->out_dtype;
ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0.";

Expand Down
80 changes: 80 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1648,6 +1648,86 @@ def test_forward_transpose_conv():
)


def _test_tflite2_quantized_transpose_conv(
input_shape,
kernel_shape,
filters,
padding="valid",
strides=(1, 1),
data_format=None,
int_quant_dtype=tf.int8,
):
"""One iteration of TFLite2 quantized tranpose conv with given shapes and attributes"""
data_format = "channels_last" if data_format == "NHWC" else "channels_first"
data = np.random.uniform(0, 1, input_shape).astype("float32")
_ = np.random.uniform(0, 1, kernel_shape).astype("float32")

data_in = tf.keras.layers.Input(shape=data.shape[1:], batch_size=1)
transpose_conv = tf.keras.layers.Conv2DTranspose(
filters=filters,
kernel_size=(kernel_shape[0], kernel_shape[1]),
padding=padding,
strides=strides,
use_bias=True,
)(data_in)
keras_model = tf.keras.models.Model(data_in, transpose_conv)

# To create quantized values with dynamic range of activations, needs representative dataset
def representative_data_gen():
for _ in range(1):
yield [data]

tflite_model_quant = _quantize_keras_model(
keras_model,
representative_data_gen,
is_float_input=True,
is_float_output=True,
int_quant_dtype=int_quant_dtype,
)

# TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1
try:
import tflite.Model

tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_quant, 0)
except AttributeError:
import tflite

tflite_model = tflite.Model.GetRootAsModel(tflite_model_quant, 0)
except ImportError as exc:
raise ImportError("The tflite package must be installed") from exc

subgraph = tflite_model.Subgraphs(0)
model_input = subgraph.InputsAsNumpy()
input_node = subgraph.Tensors(model_input).Name().decode("utf-8")

tflite_output = run_tflite_graph(tflite_model_quant, data)

if tf.__version__ < LooseVersion("2.9"):
input_node = data_in.name.replace(":0", "")
else:
input_node = "serving_default_" + data_in.name + ":0"

tvm_output = run_tvm_graph(tflite_model_quant, data, input_node)
tvm.testing.assert_allclose(
np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-2, atol=1e-2
)


def test_forward_quantized_transpose_conv():
"""Quantized convolution"""
for int_quant_dtype in [tf.int8, tf.int16]:
_test_tflite2_quantized_transpose_conv(
(1, 1, 5, 64),
(3, 3),
64,
padding="same",
strides=(1, 2),
data_format="NHWC",
int_quant_dtype=int_quant_dtype,
)


#######################################################################
# Reshape
# -------
Expand Down