diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index f9d170d81c13..edb174a9fbb2 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2047,3 +2047,30 @@ def convert_broadcast_to(node, **kwargs): ) return [tensor_node, expand_node] + + +@mx_op.register("UpSampling") +def convert_upsample(node, **kwargs): + """Map MXNet's UpSampling operator attributes to onnx's Upsample operator + and return the created node. + """ + name, input_nodes, attrs = get_inputs(node, kwargs) + + sample_type = attrs.get('sample_type', 'nearest') + sample_type = 'linear' if sample_type == 'bilinear' else sample_type + scale = convert_string_to_list(attrs.get('scale')) + scaleh = scalew = float(scale[0]) + if len(scale) > 1: + scaleh = float(scale[0]) + scalew = float(scale[1]) + scale = [1.0, 1.0, scaleh, scalew] + + node = onnx.helper.make_node( + 'Upsample', + input_nodes, + [name], + scales=scale, + mode=sample_type, + name=name + ) + return [node] diff --git a/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py index cf95bfef09a3..939027b12d51 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py @@ -23,7 +23,7 @@ from ._op_translations import tanh, arccos, arcsin, arctan, _cos, _sin, _tan from ._op_translations import softplus, shape, gather, lp_pooling, size from ._op_translations import ceil, floor, hardsigmoid, global_lppooling -from ._op_translations import concat, hardmax +from ._op_translations import concat, hardmax, upsampling from ._op_translations import leaky_relu, _elu, _prelu, _selu, softmax, fully_connected from ._op_translations import global_avgpooling, global_maxpooling, linalg_gemm from ._op_translations import sigmoid, pad, relu, matrix_multiplication, batch_norm @@ -147,5 +147,6 @@ 'DepthToSpace' : depthtospace, 'SpaceToDepth' : spacetodepth, 'Hardmax' : hardmax, - 'LpNormalization' : lpnormalization + 'LpNormalization' : lpnormalization, + 'Upsample' : upsampling } diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py index a7cef7674496..1f46f93544b4 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py @@ -777,3 +777,24 @@ def lpnormalization(attrs, inputs, proto_obj): axis = int(attrs.get("axis", -1)) new_attrs.update(axis=axis) return 'norm', new_attrs, inputs + + +def upsampling(attrs, inputs, proto_obj): + """Rearranges blocks of spatial data into depth.""" + new_attrs = translation_utils._fix_attribute_names(attrs, {'scales': 'scale', + 'mode': 'sample_type'}) + sample_type = new_attrs.get('sample_type', 'nearest') + if sample_type != 'nearest': + raise NotImplementedError("Operator {} in ONNX supports 'linear' mode " + "for linear, bilinear, trilinear etc. There is no " + "way to distinguish these so far. Therefore, supporting " + "import of only nearest neighbor upsampling for now. " + "https://github.com/onnx/onnx/issues/1774. " + "Use contrib.BilinearResize2D for bilinear mode." + .format('UpSample')) + + scale = tuple(new_attrs.get('scale'))[2:] + scale = tuple([int(s) for s in scale]) + mx_op = symbol.UpSampling(inputs[0], scale=scale, sample_type=sample_type) + + return mx_op, new_attrs, inputs diff --git a/src/operator/nn/upsampling-inl.h b/src/operator/nn/upsampling-inl.h index feb44c894a7a..b2707ce5e9b2 100644 --- a/src/operator/nn/upsampling-inl.h +++ b/src/operator/nn/upsampling-inl.h @@ -48,7 +48,7 @@ enum UpSamplingMultiInputMode {kConcat, kSum}; } // namespace up_enum struct UpSamplingParam : public dmlc::Parameter { - int scale; + TShape scale; int num_filter; int sample_type; int num_args; @@ -56,8 +56,10 @@ struct UpSamplingParam : public dmlc::Parameter { uint64_t workspace; DMLC_DECLARE_PARAMETER(UpSamplingParam) { DMLC_DECLARE_FIELD(scale) - .set_range(1, 1000) - .describe("Up sampling scale"); + .set_default(TShape()) + .describe("Up sampling scale. Integer or tuple of integers. " + "Different scale per dimension is allowed only for " + "nearest neighbor upsampling."); DMLC_DECLARE_FIELD(num_filter) .describe("Input filter. Only used by bilinear sample_type.") .set_default(0); @@ -82,6 +84,57 @@ struct UpSamplingParam : public dmlc::Parameter { } }; // struct UpSamplingParam +template +void SpatialUpSamplingNearestUpdateOutput(mshadow::Stream *s, + const std::vector &in_data, + std::vector *out_data) { + Tensor itensor = in_data[0].get(s); + Tensor otensor = (*out_data)[0].get(s); + + int outputHeight = otensor.size(2); + int outputWidth = otensor.size(3); + int inputHeight = itensor.size(2); + int inputWidth = itensor.size(3); + + int dW = outputWidth / inputWidth; + int dH = outputHeight / inputHeight; + int idim = itensor.shape_.kDimension; + + // dims + int osz0 = otensor.size(0); + int osz1 = otensor.size(1); + int osz2 = otensor.size(2); + int osz3 = 1; + if (idim > 3) { + osz3 = otensor.size(3); + } + + // perform the upsampling + int i0, i1, i2, i3; + int iout[4]; // Output indices + int iin[4]; // Input indices + + for (i0 = 0; i0 < osz0; i0++) { + iout[0] = i0; + iin[0] = i0; + for (i1 = 0; i1 < osz1; i1++) { + iout[1] = i1; + iin[1] = i1; + for (i2 = 0; i2 < osz2; i2++) { + iout[2] = i2; + iin[2] = i2; + int in_y = i2 / dH; + for (i3 = 0; i3 < osz3; i3++) { + iout[3] = i3; + iin[3] = i3; + int in_x = i3 / dW; + otensor[i0][i1][i2][i3] = itensor[i0][i1][in_y][in_x]; + } + } + } + } +} + template void UpSamplingForward(const OpContext &ctx, const UpSamplingParam ¶m, const std::vector &in_data, @@ -96,26 +149,37 @@ void UpSamplingForward(const OpContext &ctx, const UpSamplingParam ¶m, } Stream *s = ctx.get_stream(); Tensor out = out_data[up_enum::kOut].get(s); + std::vector outdata = out_data; if (param.num_args > 1) { int begin = 0; for (int i = 0; i < param.num_args; ++i) { Tensor data = in_data[i].get(s); int end = begin + data.size(1); - int scale = out_data[up_enum::kOut].size(2)/in_data[i].size(2); if (param.multi_input_mode == up_enum::kSum) { if (i == 0) { - Assign(out, req[up_enum::kOut], upsampling_nearest(data, scale)); + MSHADOW_REAL_TYPE_SWITCH_EX(in_data[0].type_flag_, DTyp, AccReal, { + SpatialUpSamplingNearestUpdateOutput(s, in_data, &outdata); + out = out_data[up_enum::kOut].get(s); + }); } else { - out += upsampling_nearest(data, scale); + MSHADOW_REAL_TYPE_SWITCH_EX(in_data[0].type_flag_, DTyp, AccReal, { + SpatialUpSamplingNearestUpdateOutput(s, in_data, &outdata); + out += out_data[up_enum::kOut].get(s); + }); } } else { - Assign(slice<1>(out, begin, end), req[up_enum::kOut], upsampling_nearest(data, scale)); + MSHADOW_REAL_TYPE_SWITCH_EX(in_data[0].type_flag_, DTyp, AccReal, { + SpatialUpSamplingNearestUpdateOutput(s, in_data, &outdata); + slice<1>(out, begin, end) = out_data[up_enum::kOut].get(s); + }); } begin = end; } } else { - Tensor data = in_data[up_enum::kData].get(s); - Assign(out, req[up_enum::kOut], upsampling_nearest(data, param.scale)); + MSHADOW_REAL_TYPE_SWITCH_EX(in_data[0].type_flag_, DTyp, AccReal, { + SpatialUpSamplingNearestUpdateOutput(s, in_data, &outdata); + out = out_data[up_enum::kOut].get(s); + }); } } @@ -134,44 +198,71 @@ void UpSamplingBackward(const OpContext &ctx, const UpSamplingParam ¶m, Tensor input_grad = in_grad[i].get(s); mshadow::Shape<2> in_shape = Shape2(input_grad.shape_[2], input_grad.shape_[3]); int end = begin + input_grad.size(1); - int scale = grad.size(2)/in_shape[0]; + int scale_h = grad.size(2)/in_shape[0]; + int scale_w = grad.size(3)/in_shape[1]; if (param.multi_input_mode == up_enum::kSum) { Assign(input_grad, req[i], pool(grad, in_shape, - scale, - scale, - scale, - scale)); + scale_h, + scale_w, + scale_h, + scale_w)); } else { Assign(input_grad, req[i], pool(slice<1>(grad, begin, end), in_shape, - scale, - scale, - scale, - scale)); + scale_h, + scale_w, + scale_h, + scale_w)); } begin = end; } } else { Tensor input_grad = in_grad[up_enum::kData].get(s); mshadow::Shape<2> in_shape = Shape2(input_grad.shape_[2], input_grad.shape_[3]); + int scale_h = 1; + int scale_w = 1; + if (param.scale.ndim() == 1) { + scale_h = param.scale[0]; + scale_w = param.scale[0]; + } else if (param.scale.ndim() == 2) { + scale_h = param.scale[0]; + scale_w = param.scale[1]; + } else if (param.scale.ndim() == 4) { + scale_h = param.scale[2]; + scale_w = param.scale[3]; + } Assign(input_grad, req[up_enum::kData], pool(grad, in_shape, - param.scale, - param.scale, - param.scale, - param.scale)); + scale_h, + scale_w, + scale_h, + scale_w)); } } static inline DeconvolutionParam GetDeconvolutionParam(const UpSamplingParam& param) { DeconvolutionParam p = DeconvolutionParam(); - int kernel = 2 * param.scale - param.scale % 2; - int stride = param.scale; - int pad = static_cast(ceil((param.scale - 1) / 2.)); + int scale_h = 1; + int scale_w = 1; + if (param.scale.ndim() == 1) { + scale_h = param.scale[0]; + scale_w = param.scale[0]; + } else if (param.scale.ndim() == 2) { + scale_h = param.scale[0]; + scale_w = param.scale[1]; + } else if (param.scale.ndim() == 4) { + scale_h = param.scale[2]; + scale_w = param.scale[3]; + } + CHECK_EQ(scale_h, scale_w) << + "UpSamplingBilinear: Scale should be the same along all dimensions for bilinear upsampling"; + int kernel = static_cast(2.0 * scale_h - ::fmod(scale_h, 2)); + int stride = scale_h; + int pad = static_cast(ceil((scale_h - 1) / 2.)); p.workspace = param.workspace; p.num_group = param.num_filter; p.num_filter = param.num_filter; diff --git a/src/operator/nn/upsampling.cc b/src/operator/nn/upsampling.cc index b6b3d873df7d..7b59e9f69abe 100644 --- a/src/operator/nn/upsampling.cc +++ b/src/operator/nn/upsampling.cc @@ -37,13 +37,25 @@ static bool UpSamplingShape(const nnvm::NodeAttrs& attrs, CHECK_GE(in_shape->size(), 1U); const TShape &dshape = (*in_shape)[0]; TShape oshape = dshape; + int scale_h = 1; + int scale_w = 1; + if (param_.scale.ndim() == 1) { + scale_h = param_.scale[0]; + scale_w = param_.scale[0]; + } else if (param_.scale.ndim() == 2) { + scale_h = param_.scale[0]; + scale_w = param_.scale[1]; + } else if (param_.scale.ndim() == 4) { + scale_h = param_.scale[2]; + scale_w = param_.scale[3]; + } if (param_.sample_type == up_enum::kNearest) { CHECK_EQ(in_shape->size(), static_cast(param_.num_args)); oshape[1] = 0; for (auto& shape : *in_shape) { CHECK_EQ(shape.ndim(), 4U) << \ "UpSamplingNearest: Input data should be 4D in (batch, channel, y, x)"; - int oh = dshape[2]*param_.scale, ow = dshape[3]*param_.scale; + int oh = dshape[2]*scale_h, ow = dshape[3]*scale_w; CHECK_EQ(oh%shape[2], 0U) << "UpSamplingNearest: input height of " << shape[2] << \ "does not divide output height of " << oh; CHECK_EQ(ow%shape[3], 0U) << "UpSamplingNearest: input width of " << shape[3] << \ @@ -58,17 +70,19 @@ static bool UpSamplingShape(const nnvm::NodeAttrs& attrs, } } else { CHECK_EQ(in_shape->size(), 2U) << "Input:[data, weight]"; + CHECK_EQ(scale_h, scale_w) << + "UpSamplingBilinear: Scale should be the same along all dimensions for bilinear upsampling"; CHECK_EQ(dshape.ndim(), 4U) << \ "UpSamplingBilinear: Input data should be 4D in (batch, channel, y, x)"; if (dshape.ndim() == 0) return false; - int kernel = 2 * param_.scale - param_.scale % 2; + int kernel = static_cast(2.0 * scale_h - ::fmod(scale_h, 2)); SHAPE_ASSIGN_CHECK(*in_shape, up_enum::kWeight, mshadow::Shape4(dshape[1], 1, kernel, kernel)); oshape = dshape; } - oshape[2] = dshape[2] * param_.scale; - oshape[3] = dshape[3] * param_.scale; + oshape[2] = dshape[2] * scale_h; + oshape[3] = dshape[3] * scale_w; out_shape->clear(); out_shape->push_back(oshape); return true; diff --git a/tests/python-pytest/onnx/test_cases.py b/tests/python-pytest/onnx/test_cases.py index 89b60d15e84f..90136bb01ba6 100644 --- a/tests/python-pytest/onnx/test_cases.py +++ b/tests/python-pytest/onnx/test_cases.py @@ -78,7 +78,8 @@ 'test_max_', 'test_softplus', 'test_reduce_', - 'test_split_equal' + 'test_split_equal', + 'test_upsample_n' ], 'import': ['test_gather', 'test_softsign', diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index fc003b2271ef..aa871952fc59 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1479,16 +1479,38 @@ def test_deconvolution(): def check_nearest_upsampling_with_shape(shapes, scale, root_scale): + def py_nearest_upsampling(x, scale): + from collections import Counter + batch, channel, inputHeight, inputWidth = x.shape + if not isinstance(scale, (list, tuple)): + outputHeight = inputHeight * scale + outputWidth = inputWidth * scale + row_ratio = col_ratio = scale + else: + if len(scale) == 1: + outputHeight = inputHeight * scale[0] + outputWidth = inputWidth * scale[0] + row_ratio = col_ratio = scale[0] + else: + outputHeight = inputHeight * scale[0] + outputWidth = inputWidth * scale[1] + col_ratio = scale[0] + row_ratio = scale[1] + if outputHeight == inputHeight and outputWidth == inputWidth: + return x + a = x.repeat(col_ratio, axis=2).repeat(row_ratio, axis=3) + return a arr = {'arg_%d'%i: mx.random.uniform(-10.0, 10.0, shape, ctx=mx.cpu()).copyto(default_context()) for i, shape in zip(range(len(shapes)), shapes)} arr_grad = {'arg_%d'%i: mx.nd.zeros(shape) for i, shape in zip(range(len(shapes)), shapes)} up = mx.sym.UpSampling(*[mx.sym.Variable('arg_%d'%i) for i in range(len(shapes))], sample_type='nearest', scale=root_scale) exe = up.bind(default_context(), args=arr, args_grad=arr_grad) exe.forward(is_train=True) + out = exe.outputs[0].asnumpy() exe.backward(exe.outputs) for k in range(len(shapes)): name = 'arg_%d'%k - assert_allclose(arr[name].asnumpy()*root_scale**2*scale**(2*k), arr_grad[name].asnumpy(), rtol=1e-4) + assert_allclose(out, py_nearest_upsampling(arr[name].asnumpy(), root_scale), rtol=1e-4) def check_bilinear_upsampling_with_shape(shapes, scale, root_scale): @@ -1506,11 +1528,12 @@ def check_bilinear_upsampling_with_shape(shapes, scale, root_scale): @with_seed() def test_nearest_upsampling(): - for root_scale in [1,2,3]: - for scale in [1,2,3]: + for root_scale in [2, (2,3)]: + for scale in [2,3]: for num_shape in [1,2,3]: for base in [1,2,3]: - shapes = [(1,3,base*root_scale*scale**(num_shape-1-i),base*root_scale*scale**(num_shape-1-i)) for i in range(num_shape)] + print (root_scale) + shapes = [(1,3,10,10)] check_nearest_upsampling_with_shape(shapes, scale, root_scale)