diff --git a/src/operator/nn/upsampling-inl.h b/src/operator/nn/upsampling-inl.h index e1d7d66ac5ae..543cbd664604 100644 --- a/src/operator/nn/upsampling-inl.h +++ b/src/operator/nn/upsampling-inl.h @@ -87,9 +87,9 @@ struct UpSamplingParam : public dmlc::Parameter { template void SpatialUpSamplingNearestUpdateOutput(mshadow::Stream *s, const std::vector &in_data, - std::vector &out_data) { + std::vector *out_data) { Tensor itensor = in_data[0].get(s); - Tensor otensor = out_data[0].get(s); + Tensor otensor = (*out_data)[0].get(s); int outputHeight = otensor.size(2); int outputWidth = otensor.size(3); @@ -149,6 +149,7 @@ void UpSamplingForward(const OpContext &ctx, const UpSamplingParam ¶m, } Stream *s = ctx.get_stream(); Tensor out = out_data[up_enum::kOut].get(s); + std::vector outdata = out_data; if (param.num_args > 1) { int begin = 0; for (int i = 0; i < param.num_args; ++i) { @@ -156,31 +157,27 @@ void UpSamplingForward(const OpContext &ctx, const UpSamplingParam ¶m, int end = begin + data.size(1); if (param.multi_input_mode == up_enum::kSum) { if (i == 0) { - std::vector outdata = out_data; MSHADOW_REAL_TYPE_SWITCH_EX(in_data[0].type_flag_, DTyp, AccReal, { - SpatialUpSamplingNearestUpdateOutput(s, in_data, outdata); + SpatialUpSamplingNearestUpdateOutput(s, in_data, &outdata); out = out_data[up_enum::kOut].get(s); }); } else { - std::vector outdata = out_data; MSHADOW_REAL_TYPE_SWITCH_EX(in_data[0].type_flag_, DTyp, AccReal, { - SpatialUpSamplingNearestUpdateOutput(s, in_data, outdata); + SpatialUpSamplingNearestUpdateOutput(s, in_data, &outdata); out += out_data[up_enum::kOut].get(s); }); } } else { - std::vector outdata = out_data; MSHADOW_REAL_TYPE_SWITCH_EX(in_data[0].type_flag_, DTyp, AccReal, { - SpatialUpSamplingNearestUpdateOutput(s, in_data, outdata); + SpatialUpSamplingNearestUpdateOutput(s, in_data, &outdata); slice<1>(out, begin, end) = out_data[up_enum::kOut].get(s); }); } begin = end; } } else { - std::vector outdata = out_data; MSHADOW_REAL_TYPE_SWITCH_EX(in_data[0].type_flag_, DTyp, AccReal, { - SpatialUpSamplingNearestUpdateOutput(s, in_data, outdata); + SpatialUpSamplingNearestUpdateOutput(s, in_data, &outdata); out = out_data[up_enum::kOut].get(s); }); }