Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-891] Support tuple of scales in upsample operator #14042

Closed
wants to merge 11 commits into from
27 changes: 27 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Expand Up @@ -2047,3 +2047,30 @@ def convert_broadcast_to(node, **kwargs):
)

return [tensor_node, expand_node]


@mx_op.register("UpSampling")
def convert_upsample(node, **kwargs):
"""Map MXNet's UpSampling operator attributes to onnx's Upsample operator
and return the created node.
"""
name, input_nodes, attrs = get_inputs(node, kwargs)

sample_type = attrs.get('sample_type', 'nearest')
sample_type = 'linear' if sample_type == 'bilinear' else sample_type
scale = convert_string_to_list(attrs.get('scale'))
scaleh = scalew = float(scale[0])
if len(scale) > 1:
scaleh = float(scale[0])
scalew = float(scale[1])
scale = [1.0, 1.0, scaleh, scalew]

node = onnx.helper.make_node(
'Upsample',
input_nodes,
[name],
scales=scale,
mode=sample_type,
name=name
)
return [node]
5 changes: 3 additions & 2 deletions python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
Expand Up @@ -23,7 +23,7 @@
from ._op_translations import tanh, arccos, arcsin, arctan, _cos, _sin, _tan
from ._op_translations import softplus, shape, gather, lp_pooling, size
from ._op_translations import ceil, floor, hardsigmoid, global_lppooling
from ._op_translations import concat, hardmax
from ._op_translations import concat, hardmax, upsampling
from ._op_translations import leaky_relu, _elu, _prelu, _selu, softmax, fully_connected
from ._op_translations import global_avgpooling, global_maxpooling, linalg_gemm
from ._op_translations import sigmoid, pad, relu, matrix_multiplication, batch_norm
Expand Down Expand Up @@ -147,5 +147,6 @@
'DepthToSpace' : depthtospace,
'SpaceToDepth' : spacetodepth,
'Hardmax' : hardmax,
'LpNormalization' : lpnormalization
'LpNormalization' : lpnormalization,
'Upsample' : upsampling
}
21 changes: 21 additions & 0 deletions python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
Expand Up @@ -777,3 +777,24 @@ def lpnormalization(attrs, inputs, proto_obj):
axis = int(attrs.get("axis", -1))
new_attrs.update(axis=axis)
return 'norm', new_attrs, inputs


def upsampling(attrs, inputs, proto_obj):
"""Rearranges blocks of spatial data into depth."""
new_attrs = translation_utils._fix_attribute_names(attrs, {'scales': 'scale',
'mode': 'sample_type'})
sample_type = new_attrs.get('sample_type', 'nearest')
if sample_type != 'nearest':
raise NotImplementedError("Operator {} in ONNX supports 'linear' mode "
"for linear, bilinear, trilinear etc. There is no "
"way to distinguish these so far. Therefore, supporting "
"import of only nearest neighbor upsampling for now. "
"https://github.com/onnx/onnx/issues/1774. "
"Use contrib.BilinearResize2D for bilinear mode."
.format('UpSample'))

scale = tuple(new_attrs.get('scale'))[2:]
scale = tuple([int(s) for s in scale])
mx_op = symbol.UpSampling(inputs[0], scale=scale, sample_type=sample_type)

return mx_op, new_attrs, inputs
141 changes: 116 additions & 25 deletions src/operator/nn/upsampling-inl.h
Expand Up @@ -48,16 +48,18 @@ enum UpSamplingMultiInputMode {kConcat, kSum};
} // namespace up_enum

struct UpSamplingParam : public dmlc::Parameter<UpSamplingParam> {
int scale;
TShape scale;
int num_filter;
int sample_type;
int num_args;
int multi_input_mode;
uint64_t workspace;
DMLC_DECLARE_PARAMETER(UpSamplingParam) {
DMLC_DECLARE_FIELD(scale)
.set_range(1, 1000)
.describe("Up sampling scale");
.set_default(TShape())
.describe("Up sampling scale. Integer or tuple of integers. "
"Different scale per dimension is allowed only for "
"nearest neighbor upsampling.");
DMLC_DECLARE_FIELD(num_filter)
.describe("Input filter. Only used by bilinear sample_type.")
.set_default(0);
Expand All @@ -82,6 +84,57 @@ struct UpSamplingParam : public dmlc::Parameter<UpSamplingParam> {
}
}; // struct UpSamplingParam

template<typename xpu, typename DTyp, typename AccReal>
void SpatialUpSamplingNearestUpdateOutput(mshadow::Stream<cpu> *s,
const std::vector<TBlob> &in_data,
std::vector<TBlob> *out_data) {
Tensor<xpu, 4, DTyp> itensor = in_data[0].get<xpu, 4, DTyp>(s);
Tensor<xpu, 4, DTyp> otensor = (*out_data)[0].get<xpu, 4, DTyp>(s);

int outputHeight = otensor.size(2);
int outputWidth = otensor.size(3);
int inputHeight = itensor.size(2);
int inputWidth = itensor.size(3);

int dW = outputWidth / inputWidth;
int dH = outputHeight / inputHeight;
int idim = itensor.shape_.kDimension;

// dims
int osz0 = otensor.size(0);
int osz1 = otensor.size(1);
int osz2 = otensor.size(2);
int osz3 = 1;
if (idim > 3) {
osz3 = otensor.size(3);
}

// perform the upsampling
int i0, i1, i2, i3;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about using index_t as datatype for all of them
Since for large operator support it was found to be useful -#13418

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll make this change

int iout[4]; // Output indices
int iin[4]; // Input indices

for (i0 = 0; i0 < osz0; i0++) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this nested for loop be vectorized?

iout[0] = i0;
iin[0] = i0;
for (i1 = 0; i1 < osz1; i1++) {
iout[1] = i1;
iin[1] = i1;
for (i2 = 0; i2 < osz2; i2++) {
iout[2] = i2;
iin[2] = i2;
int in_y = i2 / dH;
for (i3 = 0; i3 < osz3; i3++) {
iout[3] = i3;
iin[3] = i3;
int in_x = i3 / dW;
otensor[i0][i1][i2][i3] = itensor[i0][i1][in_y][in_x];
}
}
}
}
}

template<typename xpu, typename DType>
void UpSamplingForward(const OpContext &ctx, const UpSamplingParam &param,
const std::vector<TBlob> &in_data,
Expand All @@ -96,26 +149,37 @@ void UpSamplingForward(const OpContext &ctx, const UpSamplingParam &param,
}
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 4, DType> out = out_data[up_enum::kOut].get<xpu, 4, DType>(s);
std::vector<TBlob> outdata = out_data;
if (param.num_args > 1) {
int begin = 0;
for (int i = 0; i < param.num_args; ++i) {
Tensor<xpu, 4, DType> data = in_data[i].get<xpu, 4, DType>(s);
int end = begin + data.size(1);
int scale = out_data[up_enum::kOut].size(2)/in_data[i].size(2);
if (param.multi_input_mode == up_enum::kSum) {
if (i == 0) {
Assign(out, req[up_enum::kOut], upsampling_nearest(data, scale));
MSHADOW_REAL_TYPE_SWITCH_EX(in_data[0].type_flag_, DTyp, AccReal, {
SpatialUpSamplingNearestUpdateOutput<xpu, DTyp, AccReal>(s, in_data, &outdata);
out = out_data[up_enum::kOut].get<xpu, 4, DType>(s);
});
} else {
out += upsampling_nearest(data, scale);
MSHADOW_REAL_TYPE_SWITCH_EX(in_data[0].type_flag_, DTyp, AccReal, {
SpatialUpSamplingNearestUpdateOutput<xpu, DTyp, AccReal>(s, in_data, &outdata);
out += out_data[up_enum::kOut].get<xpu, 4, DType>(s);
});
}
} else {
Assign(slice<1>(out, begin, end), req[up_enum::kOut], upsampling_nearest(data, scale));
MSHADOW_REAL_TYPE_SWITCH_EX(in_data[0].type_flag_, DTyp, AccReal, {
SpatialUpSamplingNearestUpdateOutput<xpu, DTyp, AccReal>(s, in_data, &outdata);
slice<1>(out, begin, end) = out_data[up_enum::kOut].get<xpu, 4, DType>(s);
});
}
begin = end;
}
} else {
Tensor<xpu, 4, DType> data = in_data[up_enum::kData].get<xpu, 4, DType>(s);
Assign(out, req[up_enum::kOut], upsampling_nearest(data, param.scale));
MSHADOW_REAL_TYPE_SWITCH_EX(in_data[0].type_flag_, DTyp, AccReal, {
SpatialUpSamplingNearestUpdateOutput<xpu, DTyp, AccReal>(s, in_data, &outdata);
out = out_data[up_enum::kOut].get<xpu, 4, DType>(s);
});
}
}

Expand All @@ -134,44 +198,71 @@ void UpSamplingBackward(const OpContext &ctx, const UpSamplingParam &param,
Tensor<xpu, 4, DType> input_grad = in_grad[i].get<xpu, 4, DType>(s);
mshadow::Shape<2> in_shape = Shape2(input_grad.shape_[2], input_grad.shape_[3]);
int end = begin + input_grad.size(1);
int scale = grad.size(2)/in_shape[0];
int scale_h = grad.size(2)/in_shape[0];
int scale_w = grad.size(3)/in_shape[1];
if (param.multi_input_mode == up_enum::kSum) {
Assign(input_grad, req[i],
pool<mshadow::red::sum>(grad,
in_shape,
scale,
scale,
scale,
scale));
scale_h,
scale_w,
scale_h,
scale_w));
} else {
Assign(input_grad, req[i],
pool<mshadow::red::sum>(slice<1>(grad, begin, end),
in_shape,
scale,
scale,
scale,
scale));
scale_h,
scale_w,
scale_h,
scale_w));
}
begin = end;
}
} else {
Tensor<xpu, 4, DType> input_grad = in_grad[up_enum::kData].get<xpu, 4, DType>(s);
mshadow::Shape<2> in_shape = Shape2(input_grad.shape_[2], input_grad.shape_[3]);
int scale_h = 1;
int scale_w = 1;
if (param.scale.ndim() == 1) {
scale_h = param.scale[0];
scale_w = param.scale[0];
} else if (param.scale.ndim() == 2) {
scale_h = param.scale[0];
scale_w = param.scale[1];
} else if (param.scale.ndim() == 4) {
scale_h = param.scale[2];
scale_w = param.scale[3];
}
Assign(input_grad, req[up_enum::kData],
pool<mshadow::red::sum>(grad,
in_shape,
param.scale,
param.scale,
param.scale,
param.scale));
scale_h,
scale_w,
scale_h,
scale_w));
}
}

static inline DeconvolutionParam GetDeconvolutionParam(const UpSamplingParam& param) {
DeconvolutionParam p = DeconvolutionParam();
int kernel = 2 * param.scale - param.scale % 2;
int stride = param.scale;
int pad = static_cast<int>(ceil((param.scale - 1) / 2.));
int scale_h = 1;
int scale_w = 1;
if (param.scale.ndim() == 1) {
scale_h = param.scale[0];
scale_w = param.scale[0];
} else if (param.scale.ndim() == 2) {
scale_h = param.scale[0];
scale_w = param.scale[1];
} else if (param.scale.ndim() == 4) {
scale_h = param.scale[2];
scale_w = param.scale[3];
}
CHECK_EQ(scale_h, scale_w) <<
"UpSamplingBilinear: Scale should be the same along all dimensions for bilinear upsampling";
int kernel = static_cast<int>(2.0 * scale_h - ::fmod(scale_h, 2));
int stride = scale_h;
int pad = static_cast<int>(ceil((scale_h - 1) / 2.));
p.workspace = param.workspace;
p.num_group = param.num_filter;
p.num_filter = param.num_filter;
Expand Down
22 changes: 18 additions & 4 deletions src/operator/nn/upsampling.cc
Expand Up @@ -37,13 +37,25 @@ static bool UpSamplingShape(const nnvm::NodeAttrs& attrs,
CHECK_GE(in_shape->size(), 1U);
const TShape &dshape = (*in_shape)[0];
TShape oshape = dshape;
int scale_h = 1;
int scale_w = 1;
if (param_.scale.ndim() == 1) {
scale_h = param_.scale[0];
scale_w = param_.scale[0];
} else if (param_.scale.ndim() == 2) {
scale_h = param_.scale[0];
scale_w = param_.scale[1];
} else if (param_.scale.ndim() == 4) {
scale_h = param_.scale[2];
scale_w = param_.scale[3];
}
if (param_.sample_type == up_enum::kNearest) {
CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args));
oshape[1] = 0;
for (auto& shape : *in_shape) {
CHECK_EQ(shape.ndim(), 4U) << \
"UpSamplingNearest: Input data should be 4D in (batch, channel, y, x)";
int oh = dshape[2]*param_.scale, ow = dshape[3]*param_.scale;
int oh = dshape[2]*scale_h, ow = dshape[3]*scale_w;
CHECK_EQ(oh%shape[2], 0U) << "UpSamplingNearest: input height of " << shape[2] << \
"does not divide output height of " << oh;
CHECK_EQ(ow%shape[3], 0U) << "UpSamplingNearest: input width of " << shape[3] << \
Expand All @@ -58,17 +70,19 @@ static bool UpSamplingShape(const nnvm::NodeAttrs& attrs,
}
} else {
CHECK_EQ(in_shape->size(), 2U) << "Input:[data, weight]";
CHECK_EQ(scale_h, scale_w) <<
"UpSamplingBilinear: Scale should be the same along all dimensions for bilinear upsampling";
CHECK_EQ(dshape.ndim(), 4U) << \
"UpSamplingBilinear: Input data should be 4D in (batch, channel, y, x)";
if (dshape.ndim() == 0) return false;
int kernel = 2 * param_.scale - param_.scale % 2;
int kernel = static_cast<int>(2.0 * scale_h - ::fmod(scale_h, 2));
SHAPE_ASSIGN_CHECK(*in_shape,
up_enum::kWeight,
mshadow::Shape4(dshape[1], 1, kernel, kernel));
oshape = dshape;
}
oshape[2] = dshape[2] * param_.scale;
oshape[3] = dshape[3] * param_.scale;
oshape[2] = dshape[2] * scale_h;
oshape[3] = dshape[3] * scale_w;
out_shape->clear();
out_shape->push_back(oshape);
return true;
Expand Down
3 changes: 2 additions & 1 deletion tests/python-pytest/onnx/test_cases.py
Expand Up @@ -78,7 +78,8 @@
'test_max_',
'test_softplus',
'test_reduce_',
'test_split_equal'
'test_split_equal',
'test_upsample_n'
],
'import': ['test_gather',
'test_softsign',
Expand Down