-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-310] [ONNX-MXNet] API to import ONNX models into Gluon. #10605
Conversation
This is unnecessary. You can use gluon's SymbolBlock to directly load symbol models. |
data_names = [input_tensor[0] for input_tensor in metadata['input_tensor_data']] | ||
data_inputs = [symbol.var(data_name) for data_name in data_names] | ||
|
||
from ....gluon import SymbolBlock |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need this here ? can we move it to top ?
#test_elu_example | ||
#test_leakyrelu_example | ||
|
||
#GLUON_TEST.include('test_elu_example') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove lines
except ImportError: | ||
raise ImportError("Onnx and protobuf need to be installed. " | ||
+ "Instructions to install - https://github.com/onnx/onnx") | ||
model_proto = onnx.load(model_file) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add check for if file exists and throw appropriate error message
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i will add the file check
data_inputs = [symbol.var(data_name) for data_name in data_names] | ||
|
||
from ....gluon import SymbolBlock | ||
net = SymbolBlock(outputs=sym, inputs=data_inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add few comments to explain what is the logic here
except ImportError: | ||
raise ImportError("Onnx and protobuf need to be installed. Instructions to" | ||
+ " install - https://github.com/onnx/onnx#installation") | ||
model_proto = onnx.load(model_file) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
file exists check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not needed.
else: | ||
raise NotImplementedError("Only CPU context is supported for now") | ||
|
||
if node.op_type in ['Conv']: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add comment why we are doing this for 'conv'?
if device == 'CPU': | ||
ctx = mx.cpu() | ||
else: | ||
raise NotImplementedError("Only CPU context is supported for now") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't ONNX be implementation independent and thus not care about the used device type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ONNX is implementation independent. Here we are running a particular ONNX model using gluon, so we need to specify the context. In the CI pipeline these tests are running on a CPU, hence the above assignment. Saying "GPU is not implemented" is pretty misleading. I will correct this in the code.
for param in aux_params: | ||
if param in net_params: | ||
net_params[param].shape = aux_params[param].shape | ||
net_params[param]._load_init(aux_params[param], ctx=cpu()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why defaulting to CPU ? Can we not import the model on GPU straight away? We should let the user pass in a ctx
argument that default to CPU
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, will make this change.
Since we are recompose the network using symbol, why specifically targeting Gluon? |
@zhreshold there is an import API to mxnet - https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/contrib/onnx/_import/import_model.py#L24 This PR is to facilitate directly loading the ONNX model and the parameters into gluon interface. I think @piiswrong also made the same comment above. |
gluonImport # Conflicts: # python/mxnet/contrib/onnx/__init__.py # python/mxnet/contrib/onnx/_import/import_model.py # python/mxnet/contrib/onnx/_import/import_onnx.py
…to gluonImport
…to gluonImport
Please create a folder inside python-pytest/onnx for "import", and move all the import specific files in there. "export" specific backend will be added pretty soon. |
if op_name == 'broadcast_add': | ||
op_sym = symbol.broadcast_add(op_sym, inputs[0]) | ||
op_sym = symbol.broadcast_add(inputs[0], op_sym) | ||
elif op_name == 'broadcast_mul': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Calling "_fix_broadcast()" function for all broadcast_add, broadcast_sub, broadcast_mul, broadcast_div operations.
Here, you are handling only broadcas_add and broadcast_mul
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -16,18 +16,19 @@ | |||
# under the License. | |||
|
|||
# coding: utf-8 | |||
"""backend rep for onnx test infrastructure""" | |||
"""MXNet backend rep for onnx test infrastructure""" | |||
from collections import namedtuple |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not using anywhere, remove?
except ImportError: | ||
raise ImportError("Onnx and protobuf need to be installed. " | ||
+ "Instructions to install - https://github.com/onnx/onnx") | ||
model_proto = onnx.load(model_file) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
"""Gluon backend for ONNX""" | ||
|
||
@staticmethod | ||
def make_graph(node, inputs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this function ?
@@ -74,80 +74,6 @@ def make_graph(node, inputs): | |||
|
|||
return graph_proto | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def make_graph(node, inputs) function not needed any more
|
||
metadata = graph.get_graph_metadata(model_proto.graph) | ||
return metadata | ||
|
||
def get_model_metadata(model_file): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this method is repeated?
data_names = [input_tensor[0] for input_tensor in metadata['input_tensor_data']] | ||
data_inputs = [symbol.var(data_name) for data_name in data_names] | ||
|
||
ctx = gpu() if context == 'GPU' else cpu() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from .... import cpu, gpu
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
ci/docker/runtime_functions.sh
Outdated
@@ -514,8 +514,9 @@ integrationtest_ubuntu_cpu_onnx() { | |||
set -ex | |||
export PYTHONPATH=./python/ | |||
python example/onnx/super_resolution.py | |||
pytest tests/python-pytest/onnx/onnx_backend_test.py | |||
pytest tests/python-pytest/onnx/mxnet_backend_test.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shoudnt this be onnx/import/mxnet_backend_test.py
ci/docker/runtime_functions.sh
Outdated
pytest tests/python-pytest/onnx/onnx_test.py | ||
pytest tests/python-pytest/onnx/gluon_backend_test.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
8c3d763
to
31e3d7e
Compare
Can you rename 'tests/python-pytest/onnx/import/onnx_test.py' to 'tests/python-pytest/onnx/import/onnx_import_test.py'? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
@@ -33,6 +34,9 @@ def __init__(self): | |||
self._params = {} | |||
self._num_input = 0 | |||
self._num_param = 0 | |||
self.auxDict = {} | |||
self.argDict = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make it consistent with snake case elsewhere
return 'elemwise_div', new_attr, inputs | ||
broadcast_axis = attrs['axis'] | ||
op_value = translation_utils._fix_broadcast('broadcast_div', inputs, | ||
broadcast_axis, cls) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you explain what is broadcast_axis here. what will the broadcast_axis be when adding two tensors of shape (4,5) and (1,1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
broadcast_axis comes from the ONNX's axis attribute in operators that support broadcasting - https://github.com/onnx/onnx/blob/master/docs/Changelog.md#attributes-103
With OP_SET version 6 broadcasting (1,1) on (4,5) would not be permissible. If we are broadcasting (5,) on (4,5) the broadcast_axis will be equal to 1. On the other hand if we broadcast (4,) on (4,5) broadcast axis will be equal to 0.
ONNX with their OP_SET version 7 are updating the broadcast rules to be aligned with numpy broadcasting rules. When that gets consistently updated in ONNX repo we will also update the translation code in mxnet.
@@ -43,32 +43,42 @@ def add(attrs, inputs, cls): | |||
"""Adding two tensors""" | |||
new_attr = {} | |||
if 'broadcast' in attrs and attrs['broadcast'] == 1: | |||
op_value = translation_utils._fix_bias_shape('broadcast_add', inputs, cls) | |||
broadcast_axis = attrs['axis'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Won't block here, but add, subtract, multiply divide functions code can be reused.
input0_shape = get_input_shape(inputs[0], cls) | ||
#creating reshape shape | ||
reshape_shape = list(len(input0_shape) * (1,)) | ||
reshape_shape[broadcast_axis] = -1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is the broadcast_axis always going to be a scalar ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, broadcast_axis comes from the ONNX's axis attribute in operators that support broadcasting - https://github.com/onnx/onnx/blob/master/docs/Changelog.md#attributes-103
elif op_name == 'broadcast_sub': | ||
op_sym = symbol.broadcast_sub(inputs[0], op_sym) | ||
elif op_name == 'broadcast_div': | ||
op_sym = symbol.broadcast_div(inputs[0], op_sym) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can change the above if else logic to :
op_sym = getattr(symbol, op_name)(inputs[0], op_sym)
@@ -148,21 +152,29 @@ def _fix_bias(op_name, attrs, num_inputs): | |||
raise ValueError("Unexpected number of inputs for: {}".format(op_name)) | |||
return attrs | |||
|
|||
def _fix_bias_shape(op_name, inputs, cls): | |||
def _fix_broadcast(op_name, inputs, broadcast_axis, cls): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be obj
instead of cls
? Same question for everywhere cls is used in op_translations
and translation_utils
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no I meant to use cls
, its just a convention, but the same way self
is used to access an attribute inside the object (class) itself.cls
is often used to reference class and instance variables outside the object.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using cls
for an instance is misleading. cls
is the preferred variable name for anything that is meant to be class. For your case, since it is an instance of a class that is passed to _fix_broadcast
and elsewhere, it should be something that indicates an instance like obj
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay, i will rename it to proto_obj
BACKEND_TESTS.include(basic_model_test) | ||
|
||
BACKEND_TESTS.exclude('.*broadcast.*') | ||
BACKEND_TESTS.exclude('.*bcast.*') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why excluding broadcast tests ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because there is an issue with the way broadcast operator tests are written in ONNX.
For example, if we try to broadcast (5,) dim array on (3,4,5) dim array then mxnet's forward pass will fail because the mxnet's interface expects the same batch size on the two arrays, i.e. (1,5) and (1,3,4,5)
So
x = mx.nd.array(np.random.rand(3,4,5))
y = mx.nd.array(np.random.rand(5,))
mx.nd.broadcast_add(x,y)
will pass, but the following will fail
xvar = mx.sym.var('x')
yvar = mx.sym.var('y')
bcast_add = mx.sym.broadcast_add(xvar, yvar)
There are broadcast operators in the various models that are being tested and they work fine, as the data in such models come with a valid batch_size.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xvar = mx.sym.var('x')
yvar = mx.sym.var('y')
bcast_add = mx.sym.broadcast_add(xvar, yvar)
This should not fail. Can you give a minimal reproducible script which uses mx.sym.broadcast_add and fails.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import mxnet as mx
import numpy as np
from collections import namedtuple
x = mx.nd.array(np.random.rand(3,4,5))
y = mx.nd.array(np.random.rand(5,))
xvar = mx.sym.var('x')
yvar = mx.sym.var('y')
bcast_add = mx.sym.broadcast_add(xvar, yvar)
data_names = ['x', 'y']
data_shapes = []
data_shapes.append(('x', x.shape))
data_shapes.append(('y', y.shape))
print("data shapes", data_shapes)
mod = mx.mod.Module(symbol=bcast_add, context=mx.cpu(), data_names=data_names, label_names=None)
mod.bind(for_training=False, data_shapes=data_shapes, label_shapes=None)
mod.set_params(arg_params=None, aux_params=None)
mod.init_params()
data_forward = []
data_forward.append(x)
data_forward.append(y)
print("data forward", data_forward[0].shape)
mod.forward(mx.io.DataBatch(data_forward))
result = mod.get_outputs()
print("Model Result", result)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay i didnt know that module api was being used to test individual operators in onnx.
bcast_add.bind
call followed by forward() on the executor should work just fine. As discussed, please write a special test for broadcast if we cannot test it using the backend testing fw.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
densenet121 which is ran successfully on CI, has broadcast multiply and add and tests for broadcasting - https://s3.amazonaws.com/download.onnx/models/opset_6/densenet121.tar.gz
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added.
@with_seed() | ||
def test_broadcast(): | ||
"""Test for broadcasting in onnx operators.""" | ||
input1 = np.random.rand(1, 3, 4, 5).astype("float32") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to prepend the 1 ? Does the test pass with tensors of shape (3, 4, 5) and (5) ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will not, if it did then we could have used onnx tests itself
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
@piiswrong do you have any concerns ? |
…e#10605) * gluon import * gluon tests * shape issues. * remove the dim_change list * onnx backend tests * changes to match onnx op set version 7 * fix * lint fix * add new folder * fix * fix * rename test file * comments * comment fix * check for opset differences. * fix * bcast test
…e#10605) * gluon import * gluon tests * shape issues. * remove the dim_change list * onnx backend tests * changes to match onnx op set version 7 * fix * lint fix * add new folder * fix * fix * rename test file * comments * comment fix * check for opset differences. * fix * bcast test
Description
API and corresponding tests to import ONNX models into Gluon. And changes to match ONNX's op set 7.
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments
@spidydev @Roshrini @anirudh2290