Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add input_shape_dict and rm raw_input #899

Merged
merged 2 commits into from
Sep 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ x2paddle --framework=caffe --prototxt=deploy.prototxt --weight=deploy.caffemodel
| --weight | 当framework为caffe时,该参数指定caffe模型的参数文件路径 |
| --save_dir | 指定转换后的模型保存目录路径 |
| --model | 当framework为tensorflow/onnx时,该参数指定tensorflow的pb模型文件或onnx模型路径 |
| --input_shape_dict | **[可选]** For ONNX, 定义ONNX模型输入大小 |
| --caffe_proto | **[可选]** 由caffe.proto编译成caffe_pb2.py文件的存放路径,当存在自定义Layer时使用,默认为None |
| --define_input_shape | **[可选]** For TensorFlow, 当指定该参数时,强制用户输入每个Placeholder的shape,见[文档Q2](./docs/inference_model_convertor/FAQ.md) |
| --enable_code_optim | **[可选]** For PyTorch, 是否对生成代码进行优化,默认为False |
Expand Down
11 changes: 10 additions & 1 deletion x2paddle/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@ def arg_parser():
action="store_true",
default=False,
help="define input shape for tf model")
parser.add_argument(
"--input_shape_dict",
"-isd",
type=_text_type,
default=None,
help="define input shapes, e.g --input_shape_dict=\"{'image':[1, 3, 608, 608]}\" or" \
"--input_shape_dict=\"{'image':[1, 3, 608, 608], 'im_shape': [1, 2], 'scale_factor': [1, 2]}\"")
parser.add_argument(
"--convert_torch_project",
"-tp",
Expand Down Expand Up @@ -265,6 +272,7 @@ def caffe2paddle(proto_file,

def onnx2paddle(model_path,
save_dir,
input_shape_dict=None,
convert_to_lite=False,
lite_valid_places="arm",
lite_model_type="naive_buffer",
Expand Down Expand Up @@ -292,7 +300,7 @@ def onnx2paddle(model_path,

from x2paddle.decoder.onnx_decoder import ONNXDecoder
from x2paddle.op_mapper.onnx2paddle.onnx_op_mapper import ONNXOpMapper
model = ONNXDecoder(model_path, enable_onnx_checker)
model = ONNXDecoder(model_path, input_shape_dict, enable_onnx_checker)
mapper = ONNXOpMapper(model)
mapper.paddle_graph.build()
logging.info("Model optimizing ...")
Expand Down Expand Up @@ -481,6 +489,7 @@ def main():
onnx2paddle(
args.model,
args.save_dir,
input_shape_dict=args.input_shape_dict,
convert_to_lite=args.to_lite,
lite_valid_places=args.lite_valid_places,
lite_model_type=args.lite_model_type,
Expand Down
41 changes: 6 additions & 35 deletions x2paddle/decoder/onnx_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,12 @@ def dtype(self):


class ONNXGraph(Graph):
def __init__(self, onnx_model):
def __init__(self, onnx_model, input_shape_dict):
super(ONNXGraph, self).__init__(onnx_model)
self.fixed_input_shape = {}
if input_shape_dict is not None:
for k, v in eval(input_shape_dict).items():
self.fixed_input_shape["x2paddle_" + k] = v
self.initializer = {}
self.place_holder_nodes = list()
self.value_infos = {}
Expand Down Expand Up @@ -216,45 +219,13 @@ def get_symbolic_shape(self, dims):
shape.append(dim.dim_value)
return shape

def check_input_shape(self, vi):
if vi.type.HasField('tensor_type'):
for dim in vi.type.tensor_type.shape.dim:
if dim.HasField(
'dim_param') and vi.name not in self.fixed_input_shape:
shape = self.get_symbolic_shape(
vi.type.tensor_type.shape.dim)
print(
"Unknown shape for input tensor[tensor name: '{}'] -> shape: {}, Please define shape of input here,\nNote:you can use visualization tools like Netron to check input shape."
.format(vi.name, shape))
right_shape_been_input = False
while not right_shape_been_input:
try:
shape = raw_input(
"Shape of Input(e.g. -1,3,224,224), enter 'N' to skip: "
)
except NameError:
shape = input(
"Shape of Input(e.g. -1,3,224,224), enter 'N' to skip: "
)
if shape.count("-1") > 1:
print("Only 1 dimension can be -1, type again:)")
else:
right_shape_been_input = True
if shape == 'N':
break
shape = [int(dim) for dim in shape.strip().split(',')]
assert shape.count(-1) <= 1, "Only one dimension can be -1"
self.fixed_input_shape[vi.name] = shape
break

def get_place_holder_nodes(self):
"""
generate place_holder node of ONNX model
"""
inner_nodes = self.get_inner_nodes()
for ipt_vi in self.graph.input:
if ipt_vi.name not in inner_nodes:
self.check_input_shape(ipt_vi)
self.place_holder_nodes.append(ipt_vi.name)

def get_output_nodes(self):
Expand Down Expand Up @@ -416,7 +387,7 @@ def allocate_shapes(self):


class ONNXDecoder(object):
def __init__(self, onnx_model, enable_onnx_checker):
def __init__(self, onnx_model, input_shape_dict, enable_onnx_checker):
onnx_model = onnx.load(onnx_model)
print('model ir_version: {}, op version: {}'.format(
onnx_model.ir_version, onnx_model.opset_import[0].version))
Expand All @@ -427,7 +398,7 @@ def __init__(self, onnx_model, enable_onnx_checker):

onnx_model = self.optimize_model_skip_op(onnx_model)
onnx_model = self.optimize_node_name(onnx_model)
self.graph = ONNXGraph(onnx_model)
self.graph = ONNXGraph(onnx_model, input_shape_dict)

def build_value_refs(self, nodes):
"""
Expand Down