Skip to content

Commit

Permalink
Modify save_quant_model to support different input and output filenam…
Browse files Browse the repository at this point in the history
…es (#40542)

* Modify save_quant_model.py to support differnet input and output filenames

* Correct wrong order of arguments
  • Loading branch information
wozna committed Mar 16, 2022
1 parent 23c036d commit dec2b1c
Showing 1 changed file with 52 additions and 9 deletions.
61 changes: 52 additions & 9 deletions python/paddle/fluid/contrib/slim/tests/save_quant_model.py
Expand Up @@ -52,6 +52,30 @@ def parse_args():
'--debug',
action='store_true',
help='If used, the graph of Quant model is drawn.')
parser.add_argument(
'--quant_model_filename',
type=str,
default="",
help='The input model`s file name. If empty, search default `__model__` and separate parameter files and use them or in case if not found, attempt loading `model` and `params` files.'
)
parser.add_argument(
'--quant_params_filename',
type=str,
default="",
help='If quant_model_filename is empty, this field is ignored. The input model`s all parameters file name. If empty load parameters from separate files.'
)
parser.add_argument(
'--save_model_filename',
type=str,
default="__model__",
help='The name of file to save the inference program itself. If is set None, a default filename __model__ will be used.'
)
parser.add_argument(
'--save_params_filename',
type=str,
default=None,
help='The name of file to save all related parameters. If it is set None, parameters will be saved in separate files'
)

test_args, args = parser.parse_known_args(namespace=unittest)
return test_args, sys.argv[:1] + args
Expand All @@ -61,18 +85,29 @@ def transform_and_save_int8_model(original_path,
save_path,
ops_to_quantize='',
op_ids_to_skip='',
debug=False):
debug=False,
quant_model_filename='',
quant_params_filename='',
save_model_filename='',
save_params_filename=''):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
inference_scope = fluid.executor.global_scope()
with fluid.scope_guard(inference_scope):
if os.path.exists(os.path.join(original_path, '__model__')):
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(original_path, exe)
if not quant_model_filename:
if os.path.exists(os.path.join(original_path, '__model__')):
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(original_path,
exe)
else:
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(
original_path, exe, 'model', 'params')
else:
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(original_path, exe,
'model', 'params')
fetch_targets] = fluid.io.load_inference_model(
original_path, exe, quant_model_filename,
quant_params_filename)

ops_to_quantize_set = set()
print(ops_to_quantize)
Expand All @@ -97,8 +132,14 @@ def transform_and_save_int8_model(original_path,
graph = transform_to_mkldnn_int8_pass.apply(graph)
inference_program = graph.to_program()
with fluid.scope_guard(inference_scope):
fluid.io.save_inference_model(save_path, feed_target_names,
fetch_targets, exe, inference_program)
fluid.io.save_inference_model(
save_path,
feed_target_names,
fetch_targets,
exe,
inference_program,
model_filename=save_model_filename,
params_filename=save_params_filename)
print(
"Success! INT8 model obtained from the Quant model can be found at {}\n"
.format(save_path))
Expand All @@ -109,4 +150,6 @@ def transform_and_save_int8_model(original_path,
test_args, remaining_args = parse_args()
transform_and_save_int8_model(
test_args.quant_model_path, test_args.int8_model_save_path,
test_args.ops_to_quantize, test_args.op_ids_to_skip, test_args.debug)
test_args.ops_to_quantize, test_args.op_ids_to_skip, test_args.debug,
test_args.quant_model_filename, test_args.quant_params_filename,
test_args.save_model_filename, test_args.save_params_filename)

0 comments on commit dec2b1c

Please sign in to comment.