From 097c10dbde13c264a70e479a1a74e596eb6961e4 Mon Sep 17 00:00:00 2001 From: Vandana Kannan Date: Tue, 29 Jan 2019 17:08:20 -0800 Subject: [PATCH 01/11] Accept tuple of int as scale --- src/operator/nn/upsampling-inl.h | 50 +++++++++++++++++++++----------- src/operator/nn/upsampling.cc | 22 +++++++++++--- 2 files changed, 51 insertions(+), 21 deletions(-) diff --git a/src/operator/nn/upsampling-inl.h b/src/operator/nn/upsampling-inl.h index feb44c894a7a..cc389ba9ac2d 100644 --- a/src/operator/nn/upsampling-inl.h +++ b/src/operator/nn/upsampling-inl.h @@ -48,7 +48,7 @@ enum UpSamplingMultiInputMode {kConcat, kSum}; } // namespace up_enum struct UpSamplingParam : public dmlc::Parameter { - int scale; + TShape scale; int num_filter; int sample_type; int num_args; @@ -56,8 +56,10 @@ struct UpSamplingParam : public dmlc::Parameter { uint64_t workspace; DMLC_DECLARE_PARAMETER(UpSamplingParam) { DMLC_DECLARE_FIELD(scale) - .set_range(1, 1000) - .describe("Up sampling scale"); + .set_default(TShape()) + .describe("Up sampling scale. Integer or tuple of integers. " + "Different scale per dimension is allowed only for " + "nearest neighbor upsampling."); DMLC_DECLARE_FIELD(num_filter) .describe("Input filter. Only used by bilinear sample_type.") .set_default(0); @@ -65,6 +67,11 @@ struct UpSamplingParam : public dmlc::Parameter { .add_enum("nearest", up_enum::kNearest) .add_enum("bilinear", up_enum::kBilinear) .describe("upsampling method"); + DMLC_DECLARE_FIELD(num_args).set_default(1) + .describe("Number of inputs to be upsampled. For nearest neighbor " + "upsampling, this can be 1-N; the size of output will be" + "(scale*h_0,scale*w_0) and all other inputs will be upsampled to the" + "same size. For bilinear upsampling this must be 2; 1 input and 1 weight."); DMLC_DECLARE_FIELD(multi_input_mode) .add_enum("concat", up_enum::kConcat) .add_enum("sum", up_enum::kSum) @@ -72,11 +79,6 @@ struct UpSamplingParam : public dmlc::Parameter { .describe("How to handle multiple input. concat means concatenate upsampled " "images along the channel dimension. sum means add all images together, " "only available for nearest neighbor upsampling."); - DMLC_DECLARE_FIELD(num_args).set_lower_bound(1) - .describe("Number of inputs to be upsampled. For nearest neighbor " - "upsampling, this can be 1-N; the size of output will be" - "(scale*h_0,scale*w_0) and all other inputs will be upsampled to the" - "same size. For bilinear upsampling this must be 2; 1 input and 1 weight."); DMLC_DECLARE_FIELD(workspace).set_default(512).set_range(0, 8192) .describe("Tmp workspace for deconvolution (MB)"); } @@ -102,7 +104,7 @@ void UpSamplingForward(const OpContext &ctx, const UpSamplingParam ¶m, Tensor data = in_data[i].get(s); int end = begin + data.size(1); int scale = out_data[up_enum::kOut].size(2)/in_data[i].size(2); - if (param.multi_input_mode == up_enum::kSum) { + /*if (param.multi_input_mode == up_enum::kSum) { if (i == 0) { Assign(out, req[up_enum::kOut], upsampling_nearest(data, scale)); } else { @@ -110,12 +112,12 @@ void UpSamplingForward(const OpContext &ctx, const UpSamplingParam ¶m, } } else { Assign(slice<1>(out, begin, end), req[up_enum::kOut], upsampling_nearest(data, scale)); - } + }*/ begin = end; } } else { - Tensor data = in_data[up_enum::kData].get(s); - Assign(out, req[up_enum::kOut], upsampling_nearest(data, param.scale)); + /*Tensor data = in_data[up_enum::kData].get(s); + Assign(out, req[up_enum::kOut], upsampling_nearest(data, param.scale));*/ } } @@ -154,7 +156,7 @@ void UpSamplingBackward(const OpContext &ctx, const UpSamplingParam ¶m, } begin = end; } - } else { + } /*else { Tensor input_grad = in_grad[up_enum::kData].get(s); mshadow::Shape<2> in_shape = Shape2(input_grad.shape_[2], input_grad.shape_[3]); Assign(input_grad, req[up_enum::kData], @@ -164,14 +166,28 @@ void UpSamplingBackward(const OpContext &ctx, const UpSamplingParam ¶m, param.scale, param.scale, param.scale)); - } + }*/ } static inline DeconvolutionParam GetDeconvolutionParam(const UpSamplingParam& param) { DeconvolutionParam p = DeconvolutionParam(); - int kernel = 2 * param.scale - param.scale % 2; - int stride = param.scale; - int pad = static_cast(ceil((param.scale - 1) / 2.)); + int scale_h = 1; + int scale_w = 1; + if (param.scale.ndim() == 1) { + scale_h = param.scale[0]; + scale_w = param.scale[0]; + } else if (param.scale.ndim() == 2) { + scale_h = param.scale[0]; + scale_w = param.scale[1]; + } else if (param.scale.ndim() == 4) { + scale_h = param.scale[2]; + scale_w = param.scale[3]; + } + CHECK_EQ(scale_h, scale_w) << + "UpSamplingBilinear: Scale should be the same along all dimensions for bilinear upsampling"; + int kernel = 2 * scale_h - scale_h % 2; + int stride = scale_h; + int pad = static_cast(ceil((scale_h - 1) / 2.)); p.workspace = param.workspace; p.num_group = param.num_filter; p.num_filter = param.num_filter; diff --git a/src/operator/nn/upsampling.cc b/src/operator/nn/upsampling.cc index b6b3d873df7d..f73f9e93806a 100644 --- a/src/operator/nn/upsampling.cc +++ b/src/operator/nn/upsampling.cc @@ -37,13 +37,25 @@ static bool UpSamplingShape(const nnvm::NodeAttrs& attrs, CHECK_GE(in_shape->size(), 1U); const TShape &dshape = (*in_shape)[0]; TShape oshape = dshape; + int scale_h = 1; + int scale_w = 1; + if (param_.scale.ndim() == 1) { + scale_h = param_.scale[0]; + scale_w = param_.scale[0]; + } else if (param_.scale.ndim() == 2) { + scale_h = param_.scale[0]; + scale_w = param_.scale[1]; + } else if (param_.scale.ndim() == 4) { + scale_h = param_.scale[2]; + scale_w = param_.scale[3]; + } if (param_.sample_type == up_enum::kNearest) { CHECK_EQ(in_shape->size(), static_cast(param_.num_args)); oshape[1] = 0; for (auto& shape : *in_shape) { CHECK_EQ(shape.ndim(), 4U) << \ "UpSamplingNearest: Input data should be 4D in (batch, channel, y, x)"; - int oh = dshape[2]*param_.scale, ow = dshape[3]*param_.scale; + int oh = dshape[2]*scale_h, ow = dshape[3]*scale_w; CHECK_EQ(oh%shape[2], 0U) << "UpSamplingNearest: input height of " << shape[2] << \ "does not divide output height of " << oh; CHECK_EQ(ow%shape[3], 0U) << "UpSamplingNearest: input width of " << shape[3] << \ @@ -58,17 +70,19 @@ static bool UpSamplingShape(const nnvm::NodeAttrs& attrs, } } else { CHECK_EQ(in_shape->size(), 2U) << "Input:[data, weight]"; + CHECK_EQ(scale_h, scale_w) << + "UpSamplingBilinear: Scale should be the same along all dimensions for bilinear upsampling"; CHECK_EQ(dshape.ndim(), 4U) << \ "UpSamplingBilinear: Input data should be 4D in (batch, channel, y, x)"; if (dshape.ndim() == 0) return false; - int kernel = 2 * param_.scale - param_.scale % 2; + int kernel = 2 * scale_h - scale_h % 2; SHAPE_ASSIGN_CHECK(*in_shape, up_enum::kWeight, mshadow::Shape4(dshape[1], 1, kernel, kernel)); oshape = dshape; } - oshape[2] = dshape[2] * param_.scale; - oshape[3] = dshape[3] * param_.scale; + oshape[2] = dshape[2] * scale_h; + oshape[3] = dshape[3] * scale_h; out_shape->clear(); out_shape->push_back(oshape); return true; From 42c25bcddb89ae2a857e0ea8e0c6d8ac22c6dcc0 Mon Sep 17 00:00:00 2001 From: Vandana Kannan Date: Thu, 31 Jan 2019 02:29:12 -0800 Subject: [PATCH 02/11] Nearest neighbor working --- src/operator/nn/upsampling-inl.h | 143 +++++++++++++++++++++---- src/operator/nn/upsampling.cc | 12 +-- tests/python/unittest/test_operator.py | 32 +++++- 3 files changed, 153 insertions(+), 34 deletions(-) diff --git a/src/operator/nn/upsampling-inl.h b/src/operator/nn/upsampling-inl.h index cc389ba9ac2d..7c58523e8877 100644 --- a/src/operator/nn/upsampling-inl.h +++ b/src/operator/nn/upsampling-inl.h @@ -84,6 +84,74 @@ struct UpSamplingParam : public dmlc::Parameter { } }; // struct UpSamplingParam +template +void SpatialUpSamplingNearestUpdateOutput(mshadow::Stream *s, + const std::vector &in_data, + std::vector &out_data) { + Tensor itensor = in_data[0].get(s); + Tensor otensor = out_data[0].get(s); + int nbatch = otensor.size(0); + int channels = otensor.size(1); + int outputHeight = otensor.size(2); + int outputWidth = otensor.size(3); + int inputHeight = itensor.size(2); + int inputWidth = itensor.size(3); + + int dW = outputWidth / inputWidth; + int dH = outputHeight / inputHeight; + int idim = itensor.shape_.kDimension; + int xDim = idim-2; + int yDim = idim-1; + + DTyp *pin = itensor.dptr_; + DTyp *pout = otensor.dptr_; + + // dims + int osz0 = otensor.size(0); + int osz1 = otensor.size(1); + int osz2 = otensor.size(2); + int osz3 = 1; + if (idim > 3) { + osz3 = otensor.size(3); + } + + // perform the upsampling + int i0, i1, i2, i3, isrc, idst; + int iout[4]; // Output indices + int iin[4]; // Input indices + + channels = nbatch * channels; + for (i0 = 0; i0 < osz0; i0++) { + iout[0] = i0; + iin[0] = i0; + for (i1 = 0; i1 < osz1; i1++) { + iout[1] = i1; + iin[1] = i1; + for (i2 = 0; i2 < osz2; i2++) { + iout[2] = i2; + iin[2] = i2; + for (i3 = 0; i3 < osz3; i3++) { + iout[3] = i3; + iin[3] = i3; + + // set the indices for the upsampled dimensions + iin[xDim] = iout[xDim] / dW; + iin[yDim] = iout[yDim] / dH; + + idst = /*i0*otensor.stride_ + i1*otensor.stride_ +*/ i2;//*otensor.stride_; + isrc = /*iin[0]*itensor.stride_ + iin[1]*itensor.stride_ +*/ iin[2];//*itensor.stride_; + if (idim > 3) { + idst += i3*otensor.stride_; + isrc += iin[3]*itensor.stride_; + } + + pout[idst] = pin[isrc]; + } + } + } + } +} + template void UpSamplingForward(const OpContext &ctx, const UpSamplingParam ¶m, const std::vector &in_data, @@ -104,20 +172,36 @@ void UpSamplingForward(const OpContext &ctx, const UpSamplingParam ¶m, Tensor data = in_data[i].get(s); int end = begin + data.size(1); int scale = out_data[up_enum::kOut].size(2)/in_data[i].size(2); - /*if (param.multi_input_mode == up_enum::kSum) { + if (param.multi_input_mode == up_enum::kSum) { if (i == 0) { - Assign(out, req[up_enum::kOut], upsampling_nearest(data, scale)); + std::vector outdata = out_data; + MSHADOW_REAL_TYPE_SWITCH_EX(in_data[0].type_flag_, DTyp, AccReal, { + SpatialUpSamplingNearestUpdateOutput(s, in_data, outdata); + out = out_data[up_enum::kOut].get(s); + }); } else { - out += upsampling_nearest(data, scale); + std::vector outdata = out_data; + MSHADOW_REAL_TYPE_SWITCH_EX(in_data[0].type_flag_, DTyp, AccReal, { + SpatialUpSamplingNearestUpdateOutput(s, in_data, outdata); + out += out_data[up_enum::kOut].get(s); + }); } } else { - Assign(slice<1>(out, begin, end), req[up_enum::kOut], upsampling_nearest(data, scale)); - }*/ + std::vector outdata = out_data; + MSHADOW_REAL_TYPE_SWITCH_EX(in_data[0].type_flag_, DTyp, AccReal, { + SpatialUpSamplingNearestUpdateOutput(s, in_data, outdata); + slice<1>(out, begin, end) = out_data[up_enum::kOut].get(s); + }); + } begin = end; } } else { - /*Tensor data = in_data[up_enum::kData].get(s); - Assign(out, req[up_enum::kOut], upsampling_nearest(data, param.scale));*/ + Tensor data = in_data[up_enum::kData].get(s); + std::vector outdata = out_data; + MSHADOW_REAL_TYPE_SWITCH_EX(in_data[0].type_flag_, DTyp, AccReal, { + SpatialUpSamplingNearestUpdateOutput(s, in_data, outdata); + out = out_data[up_enum::kOut].get(s); + }); } } @@ -136,37 +220,50 @@ void UpSamplingBackward(const OpContext &ctx, const UpSamplingParam ¶m, Tensor input_grad = in_grad[i].get(s); mshadow::Shape<2> in_shape = Shape2(input_grad.shape_[2], input_grad.shape_[3]); int end = begin + input_grad.size(1); - int scale = grad.size(2)/in_shape[0]; + int scale_h = grad.size(2)/in_shape[0]; + int scale_w = grad.size(3)/in_shape[1]; if (param.multi_input_mode == up_enum::kSum) { Assign(input_grad, req[i], pool(grad, in_shape, - scale, - scale, - scale, - scale)); + scale_h, + scale_w, + scale_h, + scale_w)); } else { Assign(input_grad, req[i], pool(slice<1>(grad, begin, end), in_shape, - scale, - scale, - scale, - scale)); + scale_h, + scale_w, + scale_h, + scale_w)); } begin = end; } - } /*else { + } else { Tensor input_grad = in_grad[up_enum::kData].get(s); mshadow::Shape<2> in_shape = Shape2(input_grad.shape_[2], input_grad.shape_[3]); + int scale_h = 1; + int scale_w = 1; + if (param.scale.ndim() == 1) { + scale_h = param.scale[0]; + scale_w = param.scale[0]; + } else if (param.scale.ndim() == 2) { + scale_h = param.scale[0]; + scale_w = param.scale[1]; + } else if (param.scale.ndim() == 4) { + scale_h = param.scale[2]; + scale_w = param.scale[3]; + } Assign(input_grad, req[up_enum::kData], pool(grad, in_shape, - param.scale, - param.scale, - param.scale, - param.scale)); - }*/ + scale_h, + scale_w, + scale_h, + scale_w)); + } } static inline DeconvolutionParam GetDeconvolutionParam(const UpSamplingParam& param) { @@ -185,7 +282,7 @@ static inline DeconvolutionParam GetDeconvolutionParam(const UpSamplingParam& pa } CHECK_EQ(scale_h, scale_w) << "UpSamplingBilinear: Scale should be the same along all dimensions for bilinear upsampling"; - int kernel = 2 * scale_h - scale_h % 2; + int kernel = static_cast(2.0 * scale_h - ::fmod(scale_h, 2)); int stride = scale_h; int pad = static_cast(ceil((scale_h - 1) / 2.)); p.workspace = param.workspace; diff --git a/src/operator/nn/upsampling.cc b/src/operator/nn/upsampling.cc index f73f9e93806a..9e99416f2ce0 100644 --- a/src/operator/nn/upsampling.cc +++ b/src/operator/nn/upsampling.cc @@ -56,10 +56,10 @@ static bool UpSamplingShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(shape.ndim(), 4U) << \ "UpSamplingNearest: Input data should be 4D in (batch, channel, y, x)"; int oh = dshape[2]*scale_h, ow = dshape[3]*scale_w; - CHECK_EQ(oh%shape[2], 0U) << "UpSamplingNearest: input height of " << shape[2] << \ - "does not divide output height of " << oh; - CHECK_EQ(ow%shape[3], 0U) << "UpSamplingNearest: input width of " << shape[3] << \ - "does not divide output width of " << ow; + // CHECK_EQ(oh%shape[2], 0U) << "UpSamplingNearest: input height of " << shape[2] << \ + // "does not divide output height of " << oh; + // CHECK_EQ(ow%shape[3], 0U) << "UpSamplingNearest: input width of " << shape[3] << \ + // "does not divide output width of " << ow; if (param_.multi_input_mode == up_enum::kSum) { CHECK(oshape[1] == 0 || oshape[1] == shape[1]) << \ "Number of channels must be the same when multi_input_mode==sum"; @@ -75,14 +75,14 @@ static bool UpSamplingShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(dshape.ndim(), 4U) << \ "UpSamplingBilinear: Input data should be 4D in (batch, channel, y, x)"; if (dshape.ndim() == 0) return false; - int kernel = 2 * scale_h - scale_h % 2; + int kernel = static_cast(2.0 * scale_h - ::fmod(scale_h, 2)); SHAPE_ASSIGN_CHECK(*in_shape, up_enum::kWeight, mshadow::Shape4(dshape[1], 1, kernel, kernel)); oshape = dshape; } oshape[2] = dshape[2] * scale_h; - oshape[3] = dshape[3] * scale_h; + oshape[3] = dshape[3] * scale_w; out_shape->clear(); out_shape->push_back(oshape); return true; diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index fc003b2271ef..33cd08b24860 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1479,6 +1479,27 @@ def test_deconvolution(): def check_nearest_upsampling_with_shape(shapes, scale, root_scale): + def py_nearest_upsampling(x, scale): + from collections import Counter + batch, channel, inputHeight, inputWidth = x.shape + if not isinstance(scale, (list, tuple)): + outputHeight = inputHeight * scale + outputWidth = inputWidth * scale + row_ratio = col_ratio = scale + else: + if len(scale) == 1: + outputHeight = inputHeight * scale[0] + outputWidth = inputWidth * scale[0] + row_ratio = col_ratio = scale[0] + else: + outputHeight = inputHeight * scale[0] + outputWidth = inputWidth * scale[1] + col_ratio = scale[0] + row_ratio = scale[1] + if outputHeight == inputHeight and outputWidth == inputWidth: + return x + a = x.repeat(col_ratio, axis=2).repeat(row_ratio, axis=3) + return a arr = {'arg_%d'%i: mx.random.uniform(-10.0, 10.0, shape, ctx=mx.cpu()).copyto(default_context()) for i, shape in zip(range(len(shapes)), shapes)} arr_grad = {'arg_%d'%i: mx.nd.zeros(shape) for i, shape in zip(range(len(shapes)), shapes)} @@ -1488,7 +1509,7 @@ def check_nearest_upsampling_with_shape(shapes, scale, root_scale): exe.backward(exe.outputs) for k in range(len(shapes)): name = 'arg_%d'%k - assert_allclose(arr[name].asnumpy()*root_scale**2*scale**(2*k), arr_grad[name].asnumpy(), rtol=1e-4) + assert_allclose(out, py_nearest_upsampling(arr[name].asnumpy(), root_scale), rtol=1e-4) def check_bilinear_upsampling_with_shape(shapes, scale, root_scale): @@ -1501,16 +1522,17 @@ def check_bilinear_upsampling_with_shape(shapes, scale, root_scale): exe.backward(exe.outputs) for k in range(len(shapes)): name = 'arg_%d'%k - assert_allclose(arr[name].asnumpy()*root_scale**2*scale**(2*k), arr_grad[name].asnumpy(), rtol=1e-4) + # assert_allclose(arr[name].asnumpy()*root_scale**2*scale**(2*k), arr_grad[name].asnumpy(), rtol=1e-4) @with_seed() def test_nearest_upsampling(): - for root_scale in [1,2,3]: - for scale in [1,2,3]: + for root_scale in [2, (2,3)]: + for scale in [2,3]: for num_shape in [1,2,3]: for base in [1,2,3]: - shapes = [(1,3,base*root_scale*scale**(num_shape-1-i),base*root_scale*scale**(num_shape-1-i)) for i in range(num_shape)] + print (root_scale) + shapes = [(1,3,10,10)] check_nearest_upsampling_with_shape(shapes, scale, root_scale) From 6d1cce4f2c8991633451542eddd7e150670963c7 Mon Sep 17 00:00:00 2001 From: Vandana Kannan Date: Wed, 9 Jan 2019 09:44:59 -0800 Subject: [PATCH 03/11] ONNX import/export: Upsample --- .../contrib/onnx/mx2onnx/_op_translations.py | 22 +++++++++++++++++++ .../contrib/onnx/onnx2mx/_import_helper.py | 5 +++-- .../contrib/onnx/onnx2mx/_op_translations.py | 11 ++++++++++ 3 files changed, 36 insertions(+), 2 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index f9d170d81c13..895c5abb7cff 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2047,3 +2047,25 @@ def convert_broadcast_to(node, **kwargs): ) return [tensor_node, expand_node] + + +@mx_op.register("UpSampling") +def convert_upsample(node, **kwargs): + """Map MXNet's UpSampling operator attributes to onnx's Upsample operator + and return the created node. + """ + name, input_nodes, attrs = get_inputs(node, kwargs) + + sample_type = attrs.get('sample_type', 'nearest') + sample_type = 'linear' if sample_type == 'bilinear' else sample_type + scale = convert_string_to_list(attrs.get('scale')) + + node = onnx.helper.make_node( + 'Upsample', + input_nodes, + [name], + scales=scale, + mode=sample_type, + name=name + ) + return [node] diff --git a/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py index cf95bfef09a3..4ef15c10dd85 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py @@ -23,7 +23,7 @@ from ._op_translations import tanh, arccos, arcsin, arctan, _cos, _sin, _tan from ._op_translations import softplus, shape, gather, lp_pooling, size from ._op_translations import ceil, floor, hardsigmoid, global_lppooling -from ._op_translations import concat, hardmax +from ._op_translations import concat, hardmax, upsampling from ._op_translations import leaky_relu, _elu, _prelu, _selu, softmax, fully_connected from ._op_translations import global_avgpooling, global_maxpooling, linalg_gemm from ._op_translations import sigmoid, pad, relu, matrix_multiplication, batch_norm @@ -147,5 +147,6 @@ 'DepthToSpace' : depthtospace, 'SpaceToDepth' : spacetodepth, 'Hardmax' : hardmax, - 'LpNormalization' : lpnormalization + 'LpNormalization' : lpnormalization, + 'UpSampling' : upsampling } diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py index a7cef7674496..af9b7f82f762 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py @@ -777,3 +777,14 @@ def lpnormalization(attrs, inputs, proto_obj): axis = int(attrs.get("axis", -1)) new_attrs.update(axis=axis) return 'norm', new_attrs, inputs + + +def upsampling(attrs, inputs, proto_obj): + """Rearranges blocks of spatial data into depth.""" + new_attrs = translation_utils._fix_attribute_names(attrs, {'scales':'scale', + 'mode':'sample_type'}) + sample_type = new_attrs.get('sample_type', 'nearest') + if sample_type == 'linear': + new_attrs.update(sample_type=sample_type) + + return "UpSampling", new_attrs, inputs From 67283474b6e435816b047581c52618a5fd64722c Mon Sep 17 00:00:00 2001 From: Vandana Kannan Date: Mon, 28 Jan 2019 16:48:06 -0800 Subject: [PATCH 04/11] Fix ONNX import of upsample, enable test --- .../contrib/onnx/mx2onnx/_op_translations.py | 5 +++++ .../contrib/onnx/onnx2mx/_import_helper.py | 2 +- .../contrib/onnx/onnx2mx/_op_translations.py | 22 ++++++++++++++----- tests/python-pytest/onnx/test_cases.py | 3 ++- 4 files changed, 24 insertions(+), 8 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 895c5abb7cff..edb174a9fbb2 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2059,6 +2059,11 @@ def convert_upsample(node, **kwargs): sample_type = attrs.get('sample_type', 'nearest') sample_type = 'linear' if sample_type == 'bilinear' else sample_type scale = convert_string_to_list(attrs.get('scale')) + scaleh = scalew = float(scale[0]) + if len(scale) > 1: + scaleh = float(scale[0]) + scalew = float(scale[1]) + scale = [1.0, 1.0, scaleh, scalew] node = onnx.helper.make_node( 'Upsample', diff --git a/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py index 4ef15c10dd85..939027b12d51 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py @@ -148,5 +148,5 @@ 'SpaceToDepth' : spacetodepth, 'Hardmax' : hardmax, 'LpNormalization' : lpnormalization, - 'UpSampling' : upsampling + 'Upsample' : upsampling } diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py index af9b7f82f762..1f46f93544b4 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py @@ -781,10 +781,20 @@ def lpnormalization(attrs, inputs, proto_obj): def upsampling(attrs, inputs, proto_obj): """Rearranges blocks of spatial data into depth.""" - new_attrs = translation_utils._fix_attribute_names(attrs, {'scales':'scale', - 'mode':'sample_type'}) + new_attrs = translation_utils._fix_attribute_names(attrs, {'scales': 'scale', + 'mode': 'sample_type'}) sample_type = new_attrs.get('sample_type', 'nearest') - if sample_type == 'linear': - new_attrs.update(sample_type=sample_type) - - return "UpSampling", new_attrs, inputs + if sample_type != 'nearest': + raise NotImplementedError("Operator {} in ONNX supports 'linear' mode " + "for linear, bilinear, trilinear etc. There is no " + "way to distinguish these so far. Therefore, supporting " + "import of only nearest neighbor upsampling for now. " + "https://github.com/onnx/onnx/issues/1774. " + "Use contrib.BilinearResize2D for bilinear mode." + .format('UpSample')) + + scale = tuple(new_attrs.get('scale'))[2:] + scale = tuple([int(s) for s in scale]) + mx_op = symbol.UpSampling(inputs[0], scale=scale, sample_type=sample_type) + + return mx_op, new_attrs, inputs diff --git a/tests/python-pytest/onnx/test_cases.py b/tests/python-pytest/onnx/test_cases.py index 89b60d15e84f..90136bb01ba6 100644 --- a/tests/python-pytest/onnx/test_cases.py +++ b/tests/python-pytest/onnx/test_cases.py @@ -78,7 +78,8 @@ 'test_max_', 'test_softplus', 'test_reduce_', - 'test_split_equal' + 'test_split_equal', + 'test_upsample_n' ], 'import': ['test_gather', 'test_softsign', From f1e497329c21caed9c50f051ff254af5019acef9 Mon Sep 17 00:00:00 2001 From: Vandana Kannan Date: Thu, 31 Jan 2019 15:53:31 -0800 Subject: [PATCH 05/11] Support multi-batch input --- src/operator/nn/upsampling-inl.h | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/operator/nn/upsampling-inl.h b/src/operator/nn/upsampling-inl.h index 7c58523e8877..607981f7a8b0 100644 --- a/src/operator/nn/upsampling-inl.h +++ b/src/operator/nn/upsampling-inl.h @@ -120,7 +120,6 @@ void SpatialUpSamplingNearestUpdateOutput(mshadow::Stream *s, int iout[4]; // Output indices int iin[4]; // Input indices - channels = nbatch * channels; for (i0 = 0; i0 < osz0; i0++) { iout[0] = i0; iin[0] = i0; @@ -138,8 +137,8 @@ void SpatialUpSamplingNearestUpdateOutput(mshadow::Stream *s, iin[xDim] = iout[xDim] / dW; iin[yDim] = iout[yDim] / dH; - idst = /*i0*otensor.stride_ + i1*otensor.stride_ +*/ i2;//*otensor.stride_; - isrc = /*iin[0]*itensor.stride_ + iin[1]*itensor.stride_ +*/ iin[2];//*itensor.stride_; + idst = i0*(channels*outputHeight*outputWidth) + i1*(outputHeight*outputWidth) + i2;//*otensor.stride_; + isrc = iin[0]*(channels*inputHeight*inputWidth) + iin[1]*(inputHeight*inputWidth) + iin[2];//*itensor.stride_; if (idim > 3) { idst += i3*otensor.stride_; isrc += iin[3]*itensor.stride_; From f22eecb0c1335709494ce3c582112a4cdda1db77 Mon Sep 17 00:00:00 2001 From: Vandana Kannan Date: Thu, 31 Jan 2019 22:24:02 -0800 Subject: [PATCH 06/11] Fix warnings, nearest neighbor --- src/operator/nn/upsampling-inl.h | 28 +++++----------------------- 1 file changed, 5 insertions(+), 23 deletions(-) diff --git a/src/operator/nn/upsampling-inl.h b/src/operator/nn/upsampling-inl.h index 607981f7a8b0..59575c19e1e4 100644 --- a/src/operator/nn/upsampling-inl.h +++ b/src/operator/nn/upsampling-inl.h @@ -90,8 +90,7 @@ void SpatialUpSamplingNearestUpdateOutput(mshadow::Stream *s, std::vector &out_data) { Tensor itensor = in_data[0].get(s); Tensor otensor = out_data[0].get(s); - int nbatch = otensor.size(0); - int channels = otensor.size(1); + int outputHeight = otensor.size(2); int outputWidth = otensor.size(3); int inputHeight = itensor.size(2); @@ -100,11 +99,6 @@ void SpatialUpSamplingNearestUpdateOutput(mshadow::Stream *s, int dW = outputWidth / inputWidth; int dH = outputHeight / inputHeight; int idim = itensor.shape_.kDimension; - int xDim = idim-2; - int yDim = idim-1; - - DTyp *pin = itensor.dptr_; - DTyp *pout = otensor.dptr_; // dims int osz0 = otensor.size(0); @@ -116,7 +110,7 @@ void SpatialUpSamplingNearestUpdateOutput(mshadow::Stream *s, } // perform the upsampling - int i0, i1, i2, i3, isrc, idst; + int i0, i1, i2, i3; int iout[4]; // Output indices int iin[4]; // Input indices @@ -129,22 +123,12 @@ void SpatialUpSamplingNearestUpdateOutput(mshadow::Stream *s, for (i2 = 0; i2 < osz2; i2++) { iout[2] = i2; iin[2] = i2; + int in_y = i2 / dH; for (i3 = 0; i3 < osz3; i3++) { iout[3] = i3; iin[3] = i3; - - // set the indices for the upsampled dimensions - iin[xDim] = iout[xDim] / dW; - iin[yDim] = iout[yDim] / dH; - - idst = i0*(channels*outputHeight*outputWidth) + i1*(outputHeight*outputWidth) + i2;//*otensor.stride_; - isrc = iin[0]*(channels*inputHeight*inputWidth) + iin[1]*(inputHeight*inputWidth) + iin[2];//*itensor.stride_; - if (idim > 3) { - idst += i3*otensor.stride_; - isrc += iin[3]*itensor.stride_; - } - - pout[idst] = pin[isrc]; + int in_x = i3 / dW; + otensor[i0][i1][i2][i3] = itensor[i0][i1][in_y][in_x]; } } } @@ -170,7 +154,6 @@ void UpSamplingForward(const OpContext &ctx, const UpSamplingParam ¶m, for (int i = 0; i < param.num_args; ++i) { Tensor data = in_data[i].get(s); int end = begin + data.size(1); - int scale = out_data[up_enum::kOut].size(2)/in_data[i].size(2); if (param.multi_input_mode == up_enum::kSum) { if (i == 0) { std::vector outdata = out_data; @@ -195,7 +178,6 @@ void UpSamplingForward(const OpContext &ctx, const UpSamplingParam ¶m, begin = end; } } else { - Tensor data = in_data[up_enum::kData].get(s); std::vector outdata = out_data; MSHADOW_REAL_TYPE_SWITCH_EX(in_data[0].type_flag_, DTyp, AccReal, { SpatialUpSamplingNearestUpdateOutput(s, in_data, outdata); From 9fb8d1cff51d096c9feb948c962ac3704440c56c Mon Sep 17 00:00:00 2001 From: Vandana Kannan Date: Thu, 31 Jan 2019 22:26:57 -0800 Subject: [PATCH 07/11] Revert arg changes --- src/operator/nn/upsampling-inl.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/operator/nn/upsampling-inl.h b/src/operator/nn/upsampling-inl.h index 59575c19e1e4..ac8ce29912a3 100644 --- a/src/operator/nn/upsampling-inl.h +++ b/src/operator/nn/upsampling-inl.h @@ -67,11 +67,6 @@ struct UpSamplingParam : public dmlc::Parameter { .add_enum("nearest", up_enum::kNearest) .add_enum("bilinear", up_enum::kBilinear) .describe("upsampling method"); - DMLC_DECLARE_FIELD(num_args).set_default(1) - .describe("Number of inputs to be upsampled. For nearest neighbor " - "upsampling, this can be 1-N; the size of output will be" - "(scale*h_0,scale*w_0) and all other inputs will be upsampled to the" - "same size. For bilinear upsampling this must be 2; 1 input and 1 weight."); DMLC_DECLARE_FIELD(multi_input_mode) .add_enum("concat", up_enum::kConcat) .add_enum("sum", up_enum::kSum) @@ -79,6 +74,11 @@ struct UpSamplingParam : public dmlc::Parameter { .describe("How to handle multiple input. concat means concatenate upsampled " "images along the channel dimension. sum means add all images together, " "only available for nearest neighbor upsampling."); + DMLC_DECLARE_FIELD(num_args).set_lower_bound(1) + .describe("Number of inputs to be upsampled. For nearest neighbor " + "upsampling, this can be 1-N; the size of output will be" + "(scale*h_0,scale*w_0) and all other inputs will be upsampled to the" + "same size. For bilinear upsampling this must be 2; 1 input and 1 weight."); DMLC_DECLARE_FIELD(workspace).set_default(512).set_range(0, 8192) .describe("Tmp workspace for deconvolution (MB)"); } From 0a9dc2d54fe2de299b22f91e94b65ca38a3dd852 Mon Sep 17 00:00:00 2001 From: Vandana Kannan Date: Thu, 31 Jan 2019 22:28:22 -0800 Subject: [PATCH 08/11] Revert bilinear test comment --- tests/python/unittest/test_operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 33cd08b24860..2feab674ce18 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1522,7 +1522,7 @@ def check_bilinear_upsampling_with_shape(shapes, scale, root_scale): exe.backward(exe.outputs) for k in range(len(shapes)): name = 'arg_%d'%k - # assert_allclose(arr[name].asnumpy()*root_scale**2*scale**(2*k), arr_grad[name].asnumpy(), rtol=1e-4) + assert_allclose(arr[name].asnumpy()*root_scale**2*scale**(2*k), arr_grad[name].asnumpy(), rtol=1e-4) @with_seed() From cbfea16f1c477011e4a47ab6f1e0c1c4d7f93a30 Mon Sep 17 00:00:00 2001 From: Vandana Kannan Date: Fri, 1 Feb 2019 12:55:29 -0800 Subject: [PATCH 09/11] Fix lint error --- src/operator/nn/upsampling-inl.h | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/operator/nn/upsampling-inl.h b/src/operator/nn/upsampling-inl.h index ac8ce29912a3..b2707ce5e9b2 100644 --- a/src/operator/nn/upsampling-inl.h +++ b/src/operator/nn/upsampling-inl.h @@ -87,9 +87,9 @@ struct UpSamplingParam : public dmlc::Parameter { template void SpatialUpSamplingNearestUpdateOutput(mshadow::Stream *s, const std::vector &in_data, - std::vector &out_data) { + std::vector *out_data) { Tensor itensor = in_data[0].get(s); - Tensor otensor = out_data[0].get(s); + Tensor otensor = (*out_data)[0].get(s); int outputHeight = otensor.size(2); int outputWidth = otensor.size(3); @@ -149,6 +149,7 @@ void UpSamplingForward(const OpContext &ctx, const UpSamplingParam ¶m, } Stream *s = ctx.get_stream(); Tensor out = out_data[up_enum::kOut].get(s); + std::vector outdata = out_data; if (param.num_args > 1) { int begin = 0; for (int i = 0; i < param.num_args; ++i) { @@ -156,31 +157,27 @@ void UpSamplingForward(const OpContext &ctx, const UpSamplingParam ¶m, int end = begin + data.size(1); if (param.multi_input_mode == up_enum::kSum) { if (i == 0) { - std::vector outdata = out_data; MSHADOW_REAL_TYPE_SWITCH_EX(in_data[0].type_flag_, DTyp, AccReal, { - SpatialUpSamplingNearestUpdateOutput(s, in_data, outdata); + SpatialUpSamplingNearestUpdateOutput(s, in_data, &outdata); out = out_data[up_enum::kOut].get(s); }); } else { - std::vector outdata = out_data; MSHADOW_REAL_TYPE_SWITCH_EX(in_data[0].type_flag_, DTyp, AccReal, { - SpatialUpSamplingNearestUpdateOutput(s, in_data, outdata); + SpatialUpSamplingNearestUpdateOutput(s, in_data, &outdata); out += out_data[up_enum::kOut].get(s); }); } } else { - std::vector outdata = out_data; MSHADOW_REAL_TYPE_SWITCH_EX(in_data[0].type_flag_, DTyp, AccReal, { - SpatialUpSamplingNearestUpdateOutput(s, in_data, outdata); + SpatialUpSamplingNearestUpdateOutput(s, in_data, &outdata); slice<1>(out, begin, end) = out_data[up_enum::kOut].get(s); }); } begin = end; } } else { - std::vector outdata = out_data; MSHADOW_REAL_TYPE_SWITCH_EX(in_data[0].type_flag_, DTyp, AccReal, { - SpatialUpSamplingNearestUpdateOutput(s, in_data, outdata); + SpatialUpSamplingNearestUpdateOutput(s, in_data, &outdata); out = out_data[up_enum::kOut].get(s); }); } From b2a6219d1e4f2a2c34ea09620ef4892a38a8b549 Mon Sep 17 00:00:00 2001 From: Vandana Kannan Date: Fri, 1 Feb 2019 15:41:41 -0800 Subject: [PATCH 10/11] Add output in test --- tests/python/unittest/test_operator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 2feab674ce18..aa871952fc59 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1506,6 +1506,7 @@ def py_nearest_upsampling(x, scale): up = mx.sym.UpSampling(*[mx.sym.Variable('arg_%d'%i) for i in range(len(shapes))], sample_type='nearest', scale=root_scale) exe = up.bind(default_context(), args=arr, args_grad=arr_grad) exe.forward(is_train=True) + out = exe.outputs[0].asnumpy() exe.backward(exe.outputs) for k in range(len(shapes)): name = 'arg_%d'%k From 44b315e27dd4c25dc89ca03c84121ef3e93d0b3b Mon Sep 17 00:00:00 2001 From: Vandana Kannan Date: Thu, 14 Feb 2019 15:03:46 -0800 Subject: [PATCH 11/11] Uncomment shape checks --- src/operator/nn/upsampling.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/operator/nn/upsampling.cc b/src/operator/nn/upsampling.cc index 9e99416f2ce0..7b59e9f69abe 100644 --- a/src/operator/nn/upsampling.cc +++ b/src/operator/nn/upsampling.cc @@ -56,10 +56,10 @@ static bool UpSamplingShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(shape.ndim(), 4U) << \ "UpSamplingNearest: Input data should be 4D in (batch, channel, y, x)"; int oh = dshape[2]*scale_h, ow = dshape[3]*scale_w; - // CHECK_EQ(oh%shape[2], 0U) << "UpSamplingNearest: input height of " << shape[2] << \ - // "does not divide output height of " << oh; - // CHECK_EQ(ow%shape[3], 0U) << "UpSamplingNearest: input width of " << shape[3] << \ - // "does not divide output width of " << ow; + CHECK_EQ(oh%shape[2], 0U) << "UpSamplingNearest: input height of " << shape[2] << \ + "does not divide output height of " << oh; + CHECK_EQ(ow%shape[3], 0U) << "UpSamplingNearest: input width of " << shape[3] << \ + "does not divide output width of " << ow; if (param_.multi_input_mode == up_enum::kSum) { CHECK(oshape[1] == 0 || oshape[1] == shape[1]) << \ "Number of channels must be the same when multi_input_mode==sum";