Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-310] [ONNX-MXNet] API to import ONNX models into Gluon. #10605

Merged
merged 22 commits into from
Jun 4, 2018

Conversation

anirudhacharya
Copy link
Member

@anirudhacharya anirudhacharya commented Apr 18, 2018

Description

API and corresponding tests to import ONNX models into Gluon. And changes to match ONNX's op set 7.

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • Check the API doc at http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • New import to gluon API
  • changes to match ONNX's op set 7.

Comments

@spidydev @Roshrini @anirudh2290

@piiswrong
Copy link
Contributor

This is unnecessary. You can use gluon's SymbolBlock to directly load symbol models.

data_names = [input_tensor[0] for input_tensor in metadata['input_tensor_data']]
data_inputs = [symbol.var(data_name) for data_name in data_names]

from ....gluon import SymbolBlock
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this here ? can we move it to top ?

#test_elu_example
#test_leakyrelu_example

#GLUON_TEST.include('test_elu_example')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove lines

except ImportError:
raise ImportError("Onnx and protobuf need to be installed. "
+ "Instructions to install - https://github.com/onnx/onnx")
model_proto = onnx.load(model_file)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add check for if file exists and throw appropriate error message

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i will add the file check

data_inputs = [symbol.var(data_name) for data_name in data_names]

from ....gluon import SymbolBlock
net = SymbolBlock(outputs=sym, inputs=data_inputs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add few comments to explain what is the logic here

except ImportError:
raise ImportError("Onnx and protobuf need to be installed. Instructions to"
+ " install - https://github.com/onnx/onnx#installation")
model_proto = onnx.load(model_file)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

file exists check

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed.

else:
raise NotImplementedError("Only CPU context is supported for now")

if node.op_type in ['Conv']:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment why we are doing this for 'conv'?

if device == 'CPU':
ctx = mx.cpu()
else:
raise NotImplementedError("Only CPU context is supported for now")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't ONNX be implementation independent and thus not care about the used device type?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ONNX is implementation independent. Here we are running a particular ONNX model using gluon, so we need to specify the context. In the CI pipeline these tests are running on a CPU, hence the above assignment. Saying "GPU is not implemented" is pretty misleading. I will correct this in the code.

for param in aux_params:
if param in net_params:
net_params[param].shape = aux_params[param].shape
net_params[param]._load_init(aux_params[param], ctx=cpu())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why defaulting to CPU ? Can we not import the model on GPU straight away? We should let the user pass in a ctx argument that default to CPU

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, will make this change.

@szha szha requested review from yzhliu and zhreshold and removed request for szha May 21, 2018 22:31
@zhreshold
Copy link
Member

Since we are recompose the network using symbol, why specifically targeting Gluon?
It make more sense to me if it is API to import ONNX model to mxnet.
It is always simple to convert a symbol to a SymbolBlock to use with Gluon.

@anirudhacharya
Copy link
Member Author

@zhreshold there is an import API to mxnet - https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/contrib/onnx/_import/import_model.py#L24 This PR is to facilitate directly loading the ONNX model and the parameters into gluon interface. I think @piiswrong also made the same comment above.

Anirudh Acharya added 6 commits May 23, 2018 14:35
gluonImport

# Conflicts:
#	python/mxnet/contrib/onnx/__init__.py
#	python/mxnet/contrib/onnx/_import/import_model.py
#	python/mxnet/contrib/onnx/_import/import_onnx.py
@rajanksin
Copy link
Contributor

Please create a folder inside python-pytest/onnx for "import", and move all the import specific files in there. "export" specific backend will be added pretty soon.

if op_name == 'broadcast_add':
op_sym = symbol.broadcast_add(op_sym, inputs[0])
op_sym = symbol.broadcast_add(inputs[0], op_sym)
elif op_name == 'broadcast_mul':
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling "_fix_broadcast()" function for all broadcast_add, broadcast_sub, broadcast_mul, broadcast_div operations.
Here, you are handling only broadcas_add and broadcast_mul

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -16,18 +16,19 @@
# under the License.

# coding: utf-8
"""backend rep for onnx test infrastructure"""
"""MXNet backend rep for onnx test infrastructure"""
from collections import namedtuple
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not using anywhere, remove?

except ImportError:
raise ImportError("Onnx and protobuf need to be installed. "
+ "Instructions to install - https://github.com/onnx/onnx")
model_proto = onnx.load(model_file)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

"""Gluon backend for ONNX"""

@staticmethod
def make_graph(node, inputs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this function ?

@@ -74,80 +74,6 @@ def make_graph(node, inputs):

return graph_proto

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def make_graph(node, inputs) function not needed any more


metadata = graph.get_graph_metadata(model_proto.graph)
return metadata

def get_model_metadata(model_file):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this method is repeated?

data_names = [input_tensor[0] for input_tensor in metadata['input_tensor_data']]
data_inputs = [symbol.var(data_name) for data_name in data_names]

ctx = gpu() if context == 'GPU' else cpu()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from .... import cpu, gpu

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -514,8 +514,9 @@ integrationtest_ubuntu_cpu_onnx() {
set -ex
export PYTHONPATH=./python/
python example/onnx/super_resolution.py
pytest tests/python-pytest/onnx/onnx_backend_test.py
pytest tests/python-pytest/onnx/mxnet_backend_test.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shoudnt this be onnx/import/mxnet_backend_test.py

pytest tests/python-pytest/onnx/onnx_test.py
pytest tests/python-pytest/onnx/gluon_backend_test.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

@Roshrini
Copy link
Member

Can you rename 'tests/python-pytest/onnx/import/onnx_test.py' to 'tests/python-pytest/onnx/import/onnx_import_test.py'?

Copy link
Contributor

@rajanksin rajanksin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@@ -33,6 +34,9 @@ def __init__(self):
self._params = {}
self._num_input = 0
self._num_param = 0
self.auxDict = {}
self.argDict = {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make it consistent with snake case elsewhere

return 'elemwise_div', new_attr, inputs
broadcast_axis = attrs['axis']
op_value = translation_utils._fix_broadcast('broadcast_div', inputs,
broadcast_axis, cls)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain what is broadcast_axis here. what will the broadcast_axis be when adding two tensors of shape (4,5) and (1,1)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

broadcast_axis comes from the ONNX's axis attribute in operators that support broadcasting - https://github.com/onnx/onnx/blob/master/docs/Changelog.md#attributes-103

With OP_SET version 6 broadcasting (1,1) on (4,5) would not be permissible. If we are broadcasting (5,) on (4,5) the broadcast_axis will be equal to 1. On the other hand if we broadcast (4,) on (4,5) broadcast axis will be equal to 0.

ONNX with their OP_SET version 7 are updating the broadcast rules to be aligned with numpy broadcasting rules. When that gets consistently updated in ONNX repo we will also update the translation code in mxnet.

@@ -43,32 +43,42 @@ def add(attrs, inputs, cls):
"""Adding two tensors"""
new_attr = {}
if 'broadcast' in attrs and attrs['broadcast'] == 1:
op_value = translation_utils._fix_bias_shape('broadcast_add', inputs, cls)
broadcast_axis = attrs['axis']
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't block here, but add, subtract, multiply divide functions code can be reused.

input0_shape = get_input_shape(inputs[0], cls)
#creating reshape shape
reshape_shape = list(len(input0_shape) * (1,))
reshape_shape[broadcast_axis] = -1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the broadcast_axis always going to be a scalar ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, broadcast_axis comes from the ONNX's axis attribute in operators that support broadcasting - https://github.com/onnx/onnx/blob/master/docs/Changelog.md#attributes-103

elif op_name == 'broadcast_sub':
op_sym = symbol.broadcast_sub(inputs[0], op_sym)
elif op_name == 'broadcast_div':
op_sym = symbol.broadcast_div(inputs[0], op_sym)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can change the above if else logic to :

op_sym = getattr(symbol, op_name)(inputs[0], op_sym)

@@ -148,21 +152,29 @@ def _fix_bias(op_name, attrs, num_inputs):
raise ValueError("Unexpected number of inputs for: {}".format(op_name))
return attrs

def _fix_bias_shape(op_name, inputs, cls):
def _fix_broadcast(op_name, inputs, broadcast_axis, cls):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be obj instead of cls? Same question for everywhere cls is used in op_translations and translation_utils

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no I meant to use cls, its just a convention, but the same way self is used to access an attribute inside the object (class) itself.cls is often used to reference class and instance variables outside the object.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using cls for an instance is misleading. cls is the preferred variable name for anything that is meant to be class. For your case, since it is an instance of a class that is passed to _fix_broadcast and elsewhere, it should be something that indicates an instance like obj.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, i will rename it to proto_obj

BACKEND_TESTS.include(basic_model_test)

BACKEND_TESTS.exclude('.*broadcast.*')
BACKEND_TESTS.exclude('.*bcast.*')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why excluding broadcast tests ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because there is an issue with the way broadcast operator tests are written in ONNX.

For example, if we try to broadcast (5,) dim array on (3,4,5) dim array then mxnet's forward pass will fail because the mxnet's interface expects the same batch size on the two arrays, i.e. (1,5) and (1,3,4,5)

So

x = mx.nd.array(np.random.rand(3,4,5))
y = mx.nd.array(np.random.rand(5,))
mx.nd.broadcast_add(x,y)

will pass, but the following will fail

xvar = mx.sym.var('x')
yvar = mx.sym.var('y')
bcast_add = mx.sym.broadcast_add(xvar, yvar)

There are broadcast operators in the various models that are being tested and they work fine, as the data in such models come with a valid batch_size.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xvar = mx.sym.var('x')
yvar = mx.sym.var('y')
bcast_add = mx.sym.broadcast_add(xvar, yvar)

This should not fail. Can you give a minimal reproducible script which uses mx.sym.broadcast_add and fails.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import mxnet as mx
import numpy as np
from collections import namedtuple
x = mx.nd.array(np.random.rand(3,4,5))
y = mx.nd.array(np.random.rand(5,))
xvar = mx.sym.var('x')
yvar = mx.sym.var('y')
bcast_add = mx.sym.broadcast_add(xvar, yvar)

data_names = ['x', 'y']
data_shapes = []
data_shapes.append(('x', x.shape))
data_shapes.append(('y', y.shape))
print("data shapes", data_shapes)

mod = mx.mod.Module(symbol=bcast_add, context=mx.cpu(), data_names=data_names, label_names=None)
mod.bind(for_training=False, data_shapes=data_shapes, label_shapes=None)
mod.set_params(arg_params=None, aux_params=None)
mod.init_params()

data_forward = []
data_forward.append(x)
data_forward.append(y)
print("data forward", data_forward[0].shape)

mod.forward(mx.io.DataBatch(data_forward))
result = mod.get_outputs()
print("Model Result", result)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay i didnt know that module api was being used to test individual operators in onnx.
bcast_add.bind call followed by forward() on the executor should work just fine. As discussed, please write a special test for broadcast if we cannot test it using the backend testing fw.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

densenet121 which is ran successfully on CI, has broadcast multiply and add and tests for broadcasting - https://s3.amazonaws.com/download.onnx/models/opset_6/densenet121.tar.gz

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added.

@with_seed()
def test_broadcast():
"""Test for broadcasting in onnx operators."""
input1 = np.random.rand(1, 3, 4, 5).astype("float32")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to prepend the 1 ? Does the test pass with tensors of shape (3, 4, 5) and (5) ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will not, if it did then we could have used onnx tests itself

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

@anirudh2290
Copy link
Member

@piiswrong do you have any concerns ?

@anirudh2290 anirudh2290 merged commit f754498 into apache:master Jun 4, 2018
@anirudhacharya anirudhacharya deleted the gluonImport branch June 4, 2018 23:27
zheng-da pushed a commit to zheng-da/incubator-mxnet that referenced this pull request Jun 28, 2018
…e#10605)

* gluon import

* gluon tests

* shape issues.

* remove the dim_change list

* onnx backend tests

* changes to match onnx op set version 7

* fix

* lint fix

* add new folder

* fix

* fix

* rename test file

* comments

* comment fix

* check for opset differences.

* fix

* bcast test
XinYao1994 pushed a commit to XinYao1994/incubator-mxnet that referenced this pull request Aug 29, 2018
…e#10605)

* gluon import

* gluon tests

* shape issues.

* remove the dim_change list

* onnx backend tests

* changes to match onnx op set version 7

* fix

* lint fix

* add new folder

* fix

* fix

* rename test file

* comments

* comment fix

* check for opset differences.

* fix

* bcast test
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants