Skip to content

Commit

Permalink
feat(tools): add batch-size option for onnx conversion process (#582)
Browse files Browse the repository at this point in the history
feat(tools): add batch-size option for onnx conversion process
  • Loading branch information
developer0hye committed Aug 27, 2021
1 parent 57ec302 commit a1a2958
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions tools/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ def make_parser():
parser.add_argument(
"-o", "--opset", default=11, type=int, help="onnx opset version"
)
parser.add_argument("--batch-size", type=int, default=1, help="batch size")
parser.add_argument(
"--dynamic", action="store_true", help="whether the input shape should be dynamic or not"
)
parser.add_argument("--no-onnxsim", action="store_true", help="use onnxsim or not")
parser.add_argument(
"-f",
Expand Down Expand Up @@ -77,13 +81,16 @@ def main():
model.head.decode_in_inference = False

logger.info("loading checkpoint done.")
dummy_input = torch.randn(1, 3, exp.test_size[0], exp.test_size[1])
dummy_input = torch.randn(args.batch_size, 3, exp.test_size[0], exp.test_size[1])

torch.onnx._export(
model,
dummy_input,
args.output_name,
input_names=[args.input],
output_names=[args.output],
dynamic_axes={args.input: {0: 'batch'},
args.output: {0: 'batch'}} if args.dynamic else None,
opset_version=args.opset,
)
logger.info("generated onnx model named {}".format(args.output_name))
Expand All @@ -93,9 +100,13 @@ def main():

from onnxsim import simplify

input_shapes = {args.input: list(dummy_input.shape)} if args.dynamic else None

# use onnxsimplify to reduce reduent model.
onnx_model = onnx.load(args.output_name)
model_simp, check = simplify(onnx_model)
model_simp, check = simplify(onnx_model,
dynamic_input_shape=args.dynamic,
input_shapes=input_shapes)
assert check, "Simplified ONNX model could not be validated"
onnx.save(model_simp, args.output_name)
logger.info("generated simplified onnx model named {}".format(args.output_name))
Expand Down

0 comments on commit a1a2958

Please sign in to comment.