Skip to content

Commit

Permalink
Merge pull request #19 from kuke/convert_conv
Browse files Browse the repository at this point in the history
Support conversion of recognize_digits models
  • Loading branch information
Yibing Liu committed Apr 16, 2018
2 parents c5e5a36 + 557e61b commit 57c845a
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 53 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ TBD

## Supported models

We aim to at least support all the models from our model bank. During our preliminary stage, we plan to support the models generated from:
We aim to at least support all the models from our model bank. During our preliminary stage, we have validated the inference model's conversion on following models:

- [fit_a_line](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/tests/book/test_fit_a_line.py)
- [machine_translation](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/tests/book/test_machine_translation.py)
- [recognize_digits](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/tests/book/test_recognize_digits.py)

## License
Provided under the [Apache-2.0 license](LICENSE).
26 changes: 9 additions & 17 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import fluid_onnx.ops as ops
from fluid_onnx.variables import paddle_variable_to_onnx_tensor
from fluid_onnx.variables import PADDLE_TO_ONNX_DTYPE


def parse_args():
Expand Down Expand Up @@ -59,16 +58,8 @@ def convert(args):
for var_name in global_block.vars:
var = global_block.var(var_name)
if var_name not in ['feed', 'fetch'] and var.persistable:
param = fluid.executor.fetch_var(var_name, inference_scope)
param_node = helper.make_node(
'Constant',
inputs=[],
outputs=[var_name],
value=helper.make_tensor(
name=var_name,
dims=var.shape,
data_type=PADDLE_TO_ONNX_DTYPE[var.dtype],
vals=param.flatten().tolist()))
param_node = ops.node_maker['constant'](var=var,
scope=inference_scope)
onnx_nodes.append(param_node)

# Create inputs
Expand All @@ -92,12 +83,13 @@ def convert(args):
if op.type in ops.node_maker:
# TODO(kuke): deal with the corner case that vars in
# different blocks have the same name
node_proto = ops.node_maker[op.type](
inputs=op.input_arg_names,
attrs=op.attr_names,
outputs=op.output_arg_names)
node_proto = ops.node_maker[op.type](operator=op,
scope=inference_scope)

onnx_nodes.append(node_proto)
if isinstance(node_proto, tuple):
onnx_nodes.extend(list(node_proto))
else:
onnx_nodes.append(node_proto)
else:
if op.type not in ['feed', 'fetch']:
raise NotImplementedError("OP[%s] is not supported in "
Expand All @@ -113,7 +105,7 @@ def convert(args):
# Model check
checker.check_model(onnx_model)

# Output readable model
# Print model
print("The converted model is:\n{}".format(onnx_model))

# Save converted model
Expand Down
13 changes: 13 additions & 0 deletions fluid/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
23 changes: 23 additions & 0 deletions fluid/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


def get_op_io_info(op):
inputs = dict([(name, op.input(name)) for name in op.input_names])
attrs = dict(
[(name, op.attr(name))
for name in op.attr_names]) if op.attr_names is not None else None
outputs = dict([(name, op.output(name)) for name in op.output_names])

return inputs, attrs, outputs

0 comments on commit 57c845a

Please sign in to comment.