Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
CoremlBugFixes:Input variable name can be something other than data; … (
Browse files Browse the repository at this point in the history
#7746)

* CoremlBugFixes:Input variable name can be something other than data; making pad and stride optional.

Earlier, we were not providing data_names argument while creating the module which meant that the input data variable name was assumed to be "data". This is fixed. Also, added a unit test for it (due to which utils.load_model(..) had to be refactored).

The second bug was we missed assuming pad and stride parameters for convolutional layers are optional arguments. Added a unit test for this too.

Also, tested with mnist model from the tutorial (by changing the input variable name to something other than data)

* Minor rewording of a unit test.
  • Loading branch information
pracheer authored and nswamy committed Sep 6, 2017
1 parent 9a021c7 commit 92ee930
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 30 deletions.
17 changes: 13 additions & 4 deletions tools/coreml/converter/_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def convert_convolution(net, node, module, builder):
else:
has_bias = True

if literal_eval(param['pad']) != (0, 0):
if 'pad' in param.keys() and literal_eval(param['pad']) != (0, 0):
pad = literal_eval(param['pad'])
builder.add_padding(
name=name+"_pad",
Expand All @@ -314,7 +314,12 @@ def convert_convolution(net, node, module, builder):
Wb = None

channels = W.shape[1]
stride_height, stride_width = literal_eval(param['stride'])

stride_height = 1
stride_width = 1
if 'stride' in param.keys():
stride_height, stride_width = literal_eval(param['stride'])

kernel_height, kernel_width = literal_eval(param['kernel'])

W = W.transpose((2, 3, 1, 0))
Expand Down Expand Up @@ -367,7 +372,7 @@ def convert_pooling(net, node, module, builder):
raise TypeError("Pooling type %s not supported" % layer_type_mx)

# Add padding if there is any
if literal_eval(param['pad']) != (0, 0):
if 'pad' in param.keys() and literal_eval(param['pad']) != (0, 0):
pad = literal_eval(param['pad'])
builder.add_padding(
name=name+"_pad",
Expand All @@ -380,7 +385,11 @@ def convert_pooling(net, node, module, builder):
output_name=name+"_pad_output")
input_name = name+"_pad_output"

stride_height, stride_width = literal_eval(param['stride'])
stride_height = 1
stride_width = 1
if 'stride' in param.keys():
stride_height, stride_width = literal_eval(param['stride'])

kernel_width, kernel_height = literal_eval(param['kernel'])

type_map = {'valid': 'VALID', 'full': 'INCLUDE_LAST_PIXEL'}
Expand Down
55 changes: 48 additions & 7 deletions tools/coreml/converter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


def load_model(model_name, epoch_num, data_shapes, label_shapes, label_names, gpus=''):
"""Loads and returns a given MXNet model.
"""Returns a module loaded with the provided model.
Parameters
----------
Expand Down Expand Up @@ -53,12 +53,59 @@ def load_model(model_name, epoch_num, data_shapes, label_shapes, label_names, gp
MXNet module
"""
sym, arg_params, aux_params = mx.model.load_checkpoint(model_name, epoch_num)

mod = create_module(sym, data_shapes, label_shapes, label_names, gpus)

mod.set_params(
arg_params=arg_params,
aux_params=aux_params,
allow_missing=True
)

return mod


def create_module(sym, data_shapes, label_shapes, label_names, gpus=''):
"""Creates a new MXNet module.
Parameters
----------
sym : Symbol
An MXNet symbol.
input_shape: tuple
The shape of the input data in the form of (batch_size, channels, height, width)
files: list of strings
List of URLs pertaining to files that need to be downloaded in order to use the model.
data_shapes: list of tuples.
List of tuples where each tuple is a pair of input variable name and its shape.
label_shapes: list of (str, tuple)
Typically is ``data_iter.provide_label``.
label_names: list of str
Name of the output labels in the MXNet symbolic graph.
gpus: str
Comma separated string of gpu ids on which inferences are executed. E.g. 3,5,6 would refer to GPUs 3, 5 and 6.
If empty, we use CPU.
Returns
-------
MXNet module
"""
if gpus == '':
devices = mx.cpu()
else:
devices = [mx.gpu(int(i)) for i in gpus.split(',')]

data_names = [data_shape[0] for data_shape in data_shapes]

mod = mx.mod.Module(
symbol=sym,
data_names=data_names,
context=devices,
label_names=label_names
)
Expand All @@ -67,11 +114,5 @@ def load_model(model_name, epoch_num, data_shapes, label_shapes, label_names, gp
data_shapes=data_shapes,
label_shapes=label_shapes
)
mod.set_params(
arg_params=arg_params,
aux_params=aux_params,
allow_missing=True
)
return mod


2 changes: 1 addition & 1 deletion tools/coreml/mxnet_coreml_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
input_shape = yaml.safe_load(args.input_shape)
data_shapes = []
for key in input_shape:
# We prepend 1 because the coreml model only accept 1 input data at a time.
# We prepend 1 because the coreml model only accept 1 input data at a time (=batch-size).
shape = (1,)+literal_eval(input_shape[key])
input_shape[key] = shape
data_shapes.append((key, shape))
Expand Down
63 changes: 45 additions & 18 deletions tools/coreml/test/test_mxnet_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,30 +25,22 @@
sys.path.append(current_working_directory + "/../converter/")
import _mxnet_converter as mxnet_converter
from collections import namedtuple

from converter import utils

def _mxnet_remove_batch(input_data):
for blob in input_data:
input_data[blob] = np.reshape(input_data[blob], input_data[blob].shape[1:])
return input_data


def _get_mxnet_module(net, input_shape, mode, label_names, input_names=None):
def _get_mxnet_module(net, data_shapes, mode, label_names, input_names=None):
""" Given a symbolic graph, input shape and the initialization mode,
returns an MXNet module.
"""
mx.random.seed(1993)

mod = mx.mod.Module(
symbol=net,
context=mx.cpu(),
label_names=label_names
)
mod.bind(
for_training=False,
data_shapes=[('data', input_shape)],
label_shapes=input_names
)
mod = utils.create_module(sym=net, data_shapes=data_shapes, label_shapes=input_names, label_names=label_names)

if mode == 'random':
mod.init_params(
initializer=mx.init.Uniform(scale=.1)
Expand All @@ -73,7 +65,7 @@ class SingleLayerTest(unittest.TestCase):
In order to do so, it converts model and generates preds on both CoreML and MXNet and check they are the same.
"""
def _test_mxnet_model(self, net, input_shape, mode, class_labels=None, coreml_mode=None, label_names=None, delta=1e-3,
pre_processing_args=None):
pre_processing_args=None, input_name='data'):
""" Helper method that convert the CoreML model into CoreML and compares the predictions over random data.
Parameters
Expand All @@ -92,21 +84,27 @@ def _test_mxnet_model(self, net, input_shape, mode, class_labels=None, coreml_mo
delta: float
The maximum difference b/w predictions of MXNet and CoreML that is tolerable.
input_name: str
The name of the input variable to the symbolic graph.
"""
mod = _get_mxnet_module(net, input_shape, mode, label_names)

data_shapes=[(input_name, input_shape)]

mod = _get_mxnet_module(net, data_shapes, mode, label_names)

# Generate some dummy data
input_data = {'data': np.random.uniform(-10., 10., input_shape)}
input_data = {input_name: np.random.uniform(-10., 10., input_shape)}
Batch = namedtuple('Batch', ['data'])
mod.forward(Batch([mx.nd.array(input_data['data'])]))
mod.forward(Batch([mx.nd.array(input_data[input_name])]))
mxnet_preds = mod.get_outputs()[0].asnumpy().flatten()

# Get predictions from coreml
coreml_model = mxnet_converter.convert(
model=mod,
class_labels=class_labels,
mode=coreml_mode,
input_shape={'data': input_shape},
input_shape={input_name: input_shape},
preprocessor_args=pre_processing_args
)
coreml_preds = coreml_model.predict(_mxnet_remove_batch(input_data)).values()[0].flatten()
Expand Down Expand Up @@ -512,7 +510,7 @@ def test_tiny_synset_random_input(self):
net = mx.sym.FullyConnected(data=net, name='fc1', num_hidden=5)
net = mx.sym.SoftmaxOutput(net, name='softmax')
mod = _get_mxnet_module(net,
input_shape=input_shape,
data_shapes=[('data', input_shape)],
mode='random',
label_names=['softmax_label'])

Expand Down Expand Up @@ -941,6 +939,35 @@ def test_pre_processing_args(self):
self._test_mxnet_model(net, input_shape=input_shape, mode='random', label_names=['softmax_label'],
pre_processing_args={'red_bias':0, 'blue_bias':0, 'green_bias':0, 'image_scale':1})

def test_different_input_variables(self):
"""
Verifying the behavior when input variable name is different than the standard name - 'data'.
"""
np.random.seed(1988)
input_shape = (1, 10)
net = mx.sym.Variable('data1')
net = mx.sym.FullyConnected(data=net, name='fc1', num_hidden=5)
self._test_mxnet_model(net, input_shape=input_shape, mode='zeros', input_name='data1')

def test_really_tiny_conv_optional_params(self):
"""
Verifying the behavior of a convolutional layer when stride and pad are not provided.
"""
np.random.seed(1988)
input_shape = (1, 1, 10, 10)
num_filter = 1
kernel = (1 ,1)

# Define a model
net = mx.sym.Variable('data')
net = mx.symbol.Convolution(
data=net,
num_filter=num_filter,
kernel=kernel,
name='conv_1'
)
self._test_mxnet_model(net, input_shape=input_shape, mode='random')

# TODO test_concat


Expand Down

0 comments on commit 92ee930

Please sign in to comment.