diff --git a/nnvm/python/nnvm/frontend/mxnet.py b/nnvm/python/nnvm/frontend/mxnet.py index e741be6e3744..f0217fc1ec85 100644 --- a/nnvm/python/nnvm/frontend/mxnet.py +++ b/nnvm/python/nnvm/frontend/mxnet.py @@ -263,6 +263,15 @@ def _expand_dims(inputs, attrs): new_attrs['axis'] = _required_attr(attrs, 'axis') return _get_nnvm_op(op_name)(*inputs, **new_attrs) +def _lrn(inputs, attrs): + op_name, new_attrs = "lrn", {} + new_attrs['alpha'] = attrs.get('alpha', 0.0001) + new_attrs['beta'] = attrs.get('beta', 0.75) + new_attrs['bias'] = attrs.get('knorm', 2) + # NCHW format and normalization along channel axis + new_attrs['axis'] = 1 + new_attrs['size'] = _required_attr(attrs, 'nsize') + return _get_nnvm_op(op_name)(*inputs, **new_attrs) _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__', '__div_symbol__', '__mul_scalar__', '__mul_symbol__', @@ -314,7 +323,8 @@ def _expand_dims(inputs, attrs): 'sum_axis' : _rename('sum'), 'UpSampling' : _upsampling, 'clip' : _clip, - 'expand_dims' : _expand_dims + 'expand_dims' : _expand_dims, + 'LRN' : _lrn } def _convert_symbol(op_name, inputs, attrs, diff --git a/nnvm/tests/python/frontend/mxnet/test_forward.py b/nnvm/tests/python/frontend/mxnet/test_forward.py index a3356689b6e4..6c086cb367e8 100644 --- a/nnvm/tests/python/frontend/mxnet/test_forward.py +++ b/nnvm/tests/python/frontend/mxnet/test_forward.py @@ -149,6 +149,11 @@ def test_forward_pooling(): mx_sym = mx.sym.Pooling(data, kernel=(3, 3), pad=(1, 1), pool_type='max') verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8), (1, 20, 8, 8)) +def test_forward_lrn(): + data = mx.sym.var('data') + mx_sym = mx.sym.LRN(data, alpha=2, beta=2, knorm=1, nsize=5) + verify_mxnet_frontend_impl(mx_sym, (1, 10, 24, 24), (1, 10, 24, 24)) + if __name__ == '__main__': test_forward_mlp() test_forward_vgg() @@ -163,3 +168,4 @@ def test_forward_pooling(): test_forward_split_squeeze() test_forward_expand_dims() test_forward_pooling() + test_forward_lrn()