Skip to content

Commit

Permalink
Merge pull request #296 from On-JungWoan/custom_input
Browse files Browse the repository at this point in the history
new parameter `-cind`, `—custom_input_op_name_np_data_path`
  • Loading branch information
PINTO0309 committed Apr 9, 2023
2 parents 107d4ee + 60917f4 commit 67b73ed
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 78 deletions.
108 changes: 78 additions & 30 deletions onnx2tf/onnx2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
from onnx2tf.utils.colors import Color
from sng4onnx import generate as op_name_auto_generate


def convert(
input_onnx_file_path: Optional[str] = '',
onnx_graph: Optional[onnx.ModelProto] = None,
Expand All @@ -65,7 +64,7 @@ def convert(
copy_onnx_input_output_names_to_tflite: Optional[bool] = False,
output_integer_quantized_tflite: Optional[bool] = False,
quant_type: Optional[str] = 'per-channel',
quant_calib_input_op_name_np_data_path: Optional[List] = None,
custom_input_op_name_np_data_path: Optional[List] = None,
input_output_quant_dtype: Optional[str] = 'int8',
not_use_onnxsim: Optional[bool] = False,
not_use_opname_auto_generate: Optional[bool] = False,
Expand Down Expand Up @@ -153,7 +152,7 @@ def convert(
Selects whether "per-channel" or "per-tensor" quantization is used.\n
Default: "per-channel"
quant_calib_input_op_name_np_data_path: Optional[List]
custom_input_op_name_np_data_path: Optional[List]
INPUT Name of OP and path of calibration data file (Numpy) for quantization\n
and mean and std.\n
The specification can be omitted only when the input OP is a single 4D tensor image data.\n
Expand Down Expand Up @@ -423,7 +422,16 @@ def convert(
model: tf.keras.Model
Model
"""

# determination of errors in custom input
if custom_input_op_name_np_data_path is not None:
for param in custom_input_op_name_np_data_path:
if len(param) not in [2, 4]:
error_msg = f'' + \
f'{Color.RED}ERROR:{Color.RESET} ' + \
f"'-cind' option must have INPUT_NAME, NUMPY_FILE_PATH, MEAN(optional), STD(optional)"
print(error_msg)
raise ValueError(error_msg)

# Either designation required
if not input_onnx_file_path and not onnx_graph:
print(
Expand Down Expand Up @@ -670,6 +678,7 @@ def convert(
onnx_outputs_for_validation: List[np.ndarray] = dummy_onnx_inference(
onnx_graph=onnx_graph,
output_names=full_ops_output_names,
custom_input_op_name_np_data_path=custom_input_op_name_np_data_path
)
"""
onnx_tensor_infos_for_validation:
Expand Down Expand Up @@ -737,7 +746,8 @@ def convert(
print('')
print(f'{Color.REVERCE}Model convertion started{Color.RESET}', '=' * 60)

with graph.node_ids():
with graph.node_ids():

onnx_graph_input_names: List[str] = [
inputop.name for inputop in graph.inputs
]
Expand All @@ -762,12 +772,12 @@ def convert(
"""
# AUTO calib 4D check
if output_integer_quantized_tflite \
and quant_calib_input_op_name_np_data_path is None \
and custom_input_op_name_np_data_path is None \
and (graph_input.dtype != np.float32 or len(graph_input.shape) != 4):
print(
f'{Color.RED}ERROR:{Color.RESET} ' +
f'For INT8 quantization, the input data type must be Float32. ' +
f'Also, if --quant_calib_input_op_name_np_data_path is not specified, ' +
f'Also, if --custom_input_op_name_np_data_path is not specified, ' +
f'all input OPs must assume 4D tensor image data. ' +
f'INPUT Name: {graph_input.name} INPUT Shape: {graph_input.shape} INPUT dtype: {graph_input.dtype}'
)
Expand Down Expand Up @@ -1066,7 +1076,7 @@ def convert(
model_input.name for model_input in model.inputs
]
data_count = 0
if quant_calib_input_op_name_np_data_path is None \
if custom_input_op_name_np_data_path is None \
and model.inputs[0].dtype == tf.float32 \
and len(model.inputs[0].shape) == 4:

Expand Down Expand Up @@ -1095,15 +1105,25 @@ def convert(
MEAN,
STD,
]
elif quant_calib_input_op_name_np_data_path is not None:
for param in quant_calib_input_op_name_np_data_path:
elif custom_input_op_name_np_data_path is not None:
for param in custom_input_op_name_np_data_path:
if len(param) != 4:
print(
f"{Color.RED}ERROR:{Color.RESET} " +
"If you want to use custom input with the '-oiqt' option, " +
"{input_op_name}, {numpy_file_path}, {mean}, and {std} must all be entered. " +
f"However, you have only entered {len(param)} options. "
)
sys.exit(1)

input_op_name = str(param[0])
numpy_file_path = str(param[1])
calib_data = np.load(numpy_file_path)
if data_count == 0:
data_count = calib_data.shape[0]
mean = param[2]
std = param[3]

calib_data_dict[input_op_name] = \
[
calib_data.copy(),
Expand Down Expand Up @@ -1363,6 +1383,7 @@ def representative_dataset_gen():
onnx_graph=onnx_graph,
output_names=ops_output_names,
test_data_nhwc=test_data_nhwc,
custom_input_op_name_np_data_path=custom_input_op_name_np_data_path
)
except Exception as ex:
print(
Expand All @@ -1377,6 +1398,7 @@ def representative_dataset_gen():
model=model,
inputs=inputs,
test_data_nhwc=test_data_nhwc,
custom_input_op_name_np_data_path=custom_input_op_name_np_data_path,
)
# Validation
onnx_tensor_infos = {
Expand Down Expand Up @@ -1539,19 +1561,31 @@ def main():
'Default: "per-channel"'
)
parser.add_argument(
'-qcind',
'--quant_calib_input_op_name_np_data_path',
'-cind',
'--custom_input_op_name_np_data_path',
type=str,
action='append',
nargs=4,
nargs='+',
help=\
'Input name of OP and path of data file (Numpy) for custom input for -cotof or -oiqt, \n' +
'and mean (optional) and std (optional). \n' +

'\n<Usage in -cotof> \n' +
'When using -cotof, custom input defined by the user, instead of dummy data, is used. \n' +
'In this case, mean and std are omitted from the input. \n' +
'-cind {input_op_name} {numpy_file_path} \n' +
'ex) -cind onnx::Equal_0 test_cind/x_1.npy -cind onnx::Add_1 test_cind/x_2.npy -cotof \n' +
'The input_op_name must be the same as in ONNX, \n'
'and it may not work if the input format is different between ONNX and TF. \n'

'\n<Usage in -oiqt> \n' +
'INPUT Name of OP and path of calibration data file (Numpy) for quantization \n' +
'and mean and std. \n' +
'The specification can be omitted only when the input OP is a single 4D tensor image data. \n' +
'If omitted, it is automatically calibrated using 20 normalized MS-COCO images. \n' +
'The type of the input OP must be Float32. \n' +
'Data for calibration must be pre-normalized to a range of 0 to 1. \n' +
'-qcind {input_op_name} {numpy_file_path} {mean} {std} \n' +
'-cind {input_op_name} {numpy_file_path} {mean} {std} \n' +
'Numpy file paths must be specified the same number of times as the number of input OPs. \n' +
'Normalize the value of the input OP based on the tensor specified in mean and std. \n' +
'(input_value - mean) / std \n' +
Expand Down Expand Up @@ -1587,9 +1621,16 @@ def main():
' input2: [n,5] \n' +
' mean: [1] -> [0.3] \n' +
' std: [1] -> [0.07] \n' +
'-qcind "input0" "../input0.npy" [[[[0.485, 0.456, 0.406]]]] [[[[0.229, 0.224, 0.225]]]] \n' +
'-qcind "input1" "./input1.npy" [0.1, ..., 0.64] [0.05, ..., 0.08] \n' +
'-qcind "input2" "input2.npy" [0.3] [0.07]'
'-cind "input0" "../input0.npy" [[[[0.485, 0.456, 0.406]]]] [[[[0.229, 0.224, 0.225]]]] \n' +
'-cind "input1" "./input1.npy" [0.1, ..., 0.64] [0.05, ..., 0.08] \n' +
'-cind "input2" "input2.npy" [0.3] [0.07]' +

'\n<Using -cotof and -oiqt at the same time> \n' +
'To use -cotof and -oiqt simultaneously, \n' +
'you need to enter the Input name of OP, path of data file, mean, and std all together. \n' +
'And the data file must be in Float32 format, \n' +
'and {input_op_name}, {numpy_file_path}, {mean}, and {std} must all be entered. \n' +
'Otherwise, an error will occur during the -oiqt stage.'
)
parser.add_argument(
'-ioqd',
Expand Down Expand Up @@ -1963,18 +2004,24 @@ def main():
# [{input_op_name} {numpy_file_path} {mean} {std}],
# [{input_op_name} {numpy_file_path} {mean} {std}],
# ]
calib_params = []
if args.quant_calib_input_op_name_np_data_path is not None:
for param in args.quant_calib_input_op_name_np_data_path:
input_op_name = str(param[0])
numpy_file_path = str(param[1])
mean = np.asarray(ast.literal_eval(param[2]), dtype=np.float32)
std = np.asarray(ast.literal_eval(param[3]), dtype=np.float32)
calib_params.append(
[input_op_name, numpy_file_path, mean, std]
)
if len(calib_params) == 0:
calib_params = None
custom_params = []
if args.custom_input_op_name_np_data_path is not None:
for param in args.custom_input_op_name_np_data_path:
tmp = []
if len(param) == 2:
tmp.append(str(param[0])) # input_op_name
tmp.append(str(param[1])) # numpy_file_path

if len(param) == 4:
tmp.append(np.asarray(ast.literal_eval(param[2]), dtype=np.float32)) # mean
tmp.append(np.asarray(ast.literal_eval(param[3]), dtype=np.float32)) # std

custom_params.append(
tmp
)

if len(custom_params) == 0:
custom_params = None

args.replace_to_pseudo_operators = [
name.lower() for name in args.replace_to_pseudo_operators
Expand All @@ -1992,7 +2039,7 @@ def main():
copy_onnx_input_output_names_to_tflite=args.copy_onnx_input_output_names_to_tflite,
output_integer_quantized_tflite=args.output_integer_quantized_tflite,
quant_type=args.quant_type,
quant_calib_input_op_name_np_data_path=calib_params,
custom_input_op_name_np_data_path=custom_params,
input_output_quant_dtype=args.input_output_quant_dtype,
not_use_onnxsim=args.not_use_onnxsim,
not_use_opname_auto_generate=args.not_use_opname_auto_generate,
Expand Down Expand Up @@ -2032,3 +2079,4 @@ def main():

if __name__ == '__main__':
main()

Loading

0 comments on commit 67b73ed

Please sign in to comment.