diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 9e2e244cb146..9e88a85e035d 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -3300,6 +3300,7 @@ def convert_transpose_conv(self, op): kernel_zero_point = weights_tensor.qnn_params["zero_point"] input_scale = input_tensor.qnn_params["scale"] kernel_scale = weights_tensor.qnn_params["scale"] + out_dtype = "int64" if output_tensor_type_str == "int16" else "int32" out = _qnn.op.conv2d_transpose( in_expr, weight_expr_iohw, @@ -3313,7 +3314,7 @@ def convert_transpose_conv(self, op): kernel_size=(int(kernel_h), int(kernel_w)), data_layout="NHWC", kernel_layout="IOHW", - out_dtype="int32", + out_dtype=out_dtype, ) else: out = _op.nn.conv2d_transpose( diff --git a/src/relay/qnn/op/convolution_transpose.cc b/src/relay/qnn/op/convolution_transpose.cc index 951c1bdfb051..0b24ae71ca8c 100644 --- a/src/relay/qnn/op/convolution_transpose.cc +++ b/src/relay/qnn/op/convolution_transpose.cc @@ -99,7 +99,7 @@ bool QnnConv2DTransposeRel(const Array& types, int num_inputs, const Attrs ICHECK(weight->dtype == DataType::Int(8) || weight->dtype == DataType::UInt(8)) << "Expected qnn conv2d type(int8, uint8) for weight but was " << weight->dtype; ICHECK(param->out_dtype == DataType::Int(16) || param->out_dtype == DataType::Int(32) || - data->dtype == DataType::Int(64)) + param->out_dtype == DataType::Int(64)) << "Expected qnn conv2d type(int16, int32, int64) for output but was " << param->out_dtype; ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0."; diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 3b3dcc59f057..c65e48b40288 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1648,6 +1648,86 @@ def test_forward_transpose_conv(): ) +def _test_tflite2_quantized_transpose_conv( + input_shape, + kernel_shape, + filters, + padding="valid", + strides=(1, 1), + data_format=None, + int_quant_dtype=tf.int8, +): + """One iteration of TFLite2 quantized tranpose conv with given shapes and attributes""" + data_format = "channels_last" if data_format == "NHWC" else "channels_first" + data = np.random.uniform(0, 1, input_shape).astype("float32") + _ = np.random.uniform(0, 1, kernel_shape).astype("float32") + + data_in = tf.keras.layers.Input(shape=data.shape[1:], batch_size=1) + transpose_conv = tf.keras.layers.Conv2DTranspose( + filters=filters, + kernel_size=(kernel_shape[0], kernel_shape[1]), + padding=padding, + strides=strides, + use_bias=True, + )(data_in) + keras_model = tf.keras.models.Model(data_in, transpose_conv) + + # To create quantized values with dynamic range of activations, needs representative dataset + def representative_data_gen(): + for _ in range(1): + yield [data] + + tflite_model_quant = _quantize_keras_model( + keras_model, + representative_data_gen, + is_float_input=True, + is_float_output=True, + int_quant_dtype=int_quant_dtype, + ) + + # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1 + try: + import tflite.Model + + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_quant, 0) + except AttributeError: + import tflite + + tflite_model = tflite.Model.GetRootAsModel(tflite_model_quant, 0) + except ImportError as exc: + raise ImportError("The tflite package must be installed") from exc + + subgraph = tflite_model.Subgraphs(0) + model_input = subgraph.InputsAsNumpy() + input_node = subgraph.Tensors(model_input).Name().decode("utf-8") + + tflite_output = run_tflite_graph(tflite_model_quant, data) + + if tf.__version__ < LooseVersion("2.9"): + input_node = data_in.name.replace(":0", "") + else: + input_node = "serving_default_" + data_in.name + ":0" + + tvm_output = run_tvm_graph(tflite_model_quant, data, input_node) + tvm.testing.assert_allclose( + np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-2, atol=1e-2 + ) + + +def test_forward_quantized_transpose_conv(): + """Quantized convolution""" + for int_quant_dtype in [tf.int8, tf.int16]: + _test_tflite2_quantized_transpose_conv( + (1, 1, 5, 64), + (3, 3), + 64, + padding="same", + strides=(1, 2), + data_format="NHWC", + int_quant_dtype=int_quant_dtype, + ) + + ####################################################################### # Reshape # -------