forked from deeplearning4j/deeplearning4j
/
resize_neighbor.cpp
115 lines (102 loc) · 5.29 KB
/
resize_neighbor.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author sgazeos@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_resize_bilinear)
//#include <ops/declarable/headers/parity_ops.h>
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/image_resize.h>
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(resize_nearest_neighbor, 1, 1, false, 0, -2) {
auto image = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
int width;
int height;
bool center = false; // - default value
if (block.width() > 1) {
auto newImageSize = INPUT_VARIABLE(1);
REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_nearest_neighbor: Resize params is a pair of values, not %i.", newImageSize->lengthOf());
REQUIRE_TRUE(block.numI() <= 1, 0, "resize_nearest_neighbor: Resize params already given by the second param. Int params are expensive.");
width = newImageSize->e<int>(0);
height = newImageSize->e<int>(1);
if (block.numI() == 1) {
center = 0 != INT_ARG(0);
}
}
else {
REQUIRE_TRUE(block.numI() <= 3, 0, "resize_nearest_neighbor: Neither resize width nor height are provided.");
width = INT_ARG(0);
height = INT_ARG(1);
if (block.numI() == 3)
center = 0 != INT_ARG(2);
}
auto inRank = image->rankOf();
REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_nearest_neighbor: Input should be 4D tensor, but rank %i occured");
REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_nearest_neighbor: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf());
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
return helpers::resizeNeighborFunctor(block.launchContext(), &source, width, height, center, &target);
}
DECLARE_SHAPE_FN(resize_nearest_neighbor) {
auto shapeList = SHAPELIST();
auto in = inputShape->at(0);
auto inRank = shape::rank(in);
Nd4jLong* outputShape;
REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_bilinear: input image should be 4D "
"tensor, but input has rank %i",
inRank);
int width;
int height;
if (block.width() > 1) {
auto newImageSize = INPUT_VARIABLE(1);
REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_nearest_neighbor: Resize params is a pair of values, not %i.", newImageSize->lengthOf());
REQUIRE_TRUE(block.numI() <= 1, 0, "resize_nearest_neighbor: Resize params already given by the second param. Int params are expensive.");
width = newImageSize->e<int>(0);
height = newImageSize->e<int>(1);
}
else {
REQUIRE_TRUE(block.numI() <= 3, 0, "resize_nearest_neighbor: Neither resize width nor height are provided.");
width = INT_ARG(0);
height = INT_ARG(1);
}
ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong);
outputShape[0] = inRank;
if (inRank == 4) {
outputShape[1] = in[1];
outputShape[2] = width;
outputShape[3] = height;
outputShape[4] = in[4];
}
else { // input shape is 3D, so result also should be 3D
outputShape[1] = width;
outputShape[2] = height;
outputShape[3] = in[3];
}
ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in));
shapeList->push_back(CONSTANT(outputShape));
return shapeList;
}
DECLARE_TYPES(resize_nearest_neighbor) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}
}
}
#endif