Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev roi align #1511

Merged
merged 8 commits into from
Feb 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions platforms/ios/tnn.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,7 @@
E4D05BF9259F161000921502 /* arm_prelu_fp16_layer.cc in Sources */ = {isa = PBXBuildFile; fileRef = E4D05BF6259F161000921502 /* arm_prelu_fp16_layer.cc */; };
E4D05BFA259F161000921502 /* arm_relu_fp16_layer.cc in Sources */ = {isa = PBXBuildFile; fileRef = E4D05BF7259F161000921502 /* arm_relu_fp16_layer.cc */; };
E4D05C03259F1BA700921502 /* arm_add_layer_acc.cc in Sources */ = {isa = PBXBuildFile; fileRef = E4D05C02259F1BA700921502 /* arm_add_layer_acc.cc */; };
E4F8F98E278BD01C005F7B63 /* arm_roialign_layer_acc.cc in Sources */ = {isa = PBXBuildFile; fileRef = E4F8F98D278BD01C005F7B63 /* arm_roialign_layer_acc.cc */; };
EC0BE13725144B5E009BD69A /* detection_post_process_utils.h in Headers */ = {isa = PBXBuildFile; fileRef = EC0BE13425144B5D009BD69A /* detection_post_process_utils.h */; };
EC0BE13825144B5E009BD69A /* detection_post_process_utils.cc in Sources */ = {isa = PBXBuildFile; fileRef = EC0BE13525144B5D009BD69A /* detection_post_process_utils.cc */; };
EC0BE13925144B5E009BD69A /* string_utils.cc in Sources */ = {isa = PBXBuildFile; fileRef = EC0BE13625144B5D009BD69A /* string_utils.cc */; };
Expand Down Expand Up @@ -1925,6 +1926,7 @@
E4D05BF6259F161000921502 /* arm_prelu_fp16_layer.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = arm_prelu_fp16_layer.cc; sourceTree = "<group>"; };
E4D05BF7259F161000921502 /* arm_relu_fp16_layer.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = arm_relu_fp16_layer.cc; sourceTree = "<group>"; };
E4D05C02259F1BA700921502 /* arm_add_layer_acc.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = arm_add_layer_acc.cc; sourceTree = "<group>"; };
E4F8F98D278BD01C005F7B63 /* arm_roialign_layer_acc.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = arm_roialign_layer_acc.cc; sourceTree = "<group>"; };
EC0BE13425144B5D009BD69A /* detection_post_process_utils.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = detection_post_process_utils.h; sourceTree = "<group>"; };
EC0BE13525144B5D009BD69A /* detection_post_process_utils.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = detection_post_process_utils.cc; sourceTree = "<group>"; };
EC0BE13625144B5D009BD69A /* string_utils.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = string_utils.cc; sourceTree = "<group>"; };
Expand Down Expand Up @@ -3065,6 +3067,7 @@
3620403D26E8D4E000935711 /* arm_layer_norm_layer_acc.cc */,
3620403E26E8D4E000935711 /* arm_layer_norm_layer_acc.h */,
4E187D262672030500804FDF /* arm_concat_layer_acc.h */,
E4F8F98D278BD01C005F7B63 /* arm_roialign_layer_acc.cc */,
4E187D252672030500804FDF /* arm_grid_sample_layer_acc.cc */,
4E187D222672030500804FDF /* arm_inverse_layer_acc.cc */,
4E187D232672030500804FDF /* arm_padv2_layer_acc.cc */,
Expand Down Expand Up @@ -4558,6 +4561,7 @@
9D32FCF524557EEC002DCDAB /* reduce_l2_layer.cc in Sources */,
9DF543F9258B1366006CEC97 /* arm_detection_output_layer_acc.cc in Sources */,
9D32FF0E24557EED002DCDAB /* flatten_layer_interpreter.cc in Sources */,
E4F8F98E278BD01C005F7B63 /* arm_roialign_layer_acc.cc in Sources */,
EC0BE15525144BB8009BD69A /* rsqrt_layer.cc in Sources */,
E4D05BB8259DCB2E00921502 /* arm_sigmoid_fp16_layer.cc in Sources */,
9DD1FB7D247CE9BE00800139 /* metal_relu6_layer_acc.metal in Sources */,
Expand Down
384 changes: 384 additions & 0 deletions source/tnn/device/arm/acc/arm_roialign_layer_acc.cc

Large diffs are not rendered by default.

30 changes: 22 additions & 8 deletions source/tnn/device/arm/arm_blob_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ Status ArmBlobConverterAcc::ConvertToMatAsync(Mat& image, MatConvertParam param,
fused_int8_scale[i] = param.scale[i] * scale_data[scale_idx];
fused_int8_bias[i] = param.bias[i];
}
} else if (desc.data_type == DATA_TYPE_INT32) {
} else if (desc.data_type == DATA_TYPE_INT32 && desc.data_format == DATA_FORMAT_NCHW) {
int count = DimsVectorUtils::Count(blob_->GetBlobDesc().dims);
int ele_size = DataTypeUtils::GetBytesSize(desc.data_type);
if (image.GetMatType() == NC_INT32) {
Expand Down Expand Up @@ -277,7 +277,7 @@ REGISTER_BLOB_CONVERTER(Arm, DEVICE_ARM);
convert data type from Tin to Tout, data format from nc4hw4 2 nchw
*/
template <typename Tin, typename Tout>
void FloatBlobToNCHW(const Tin *src, Tout *dst, int channel, int hw) {
void BlobToNCHW(const Tin *src, Tout *dst, int channel, int hw) {
if (channel % 4 == 0 && hw == 1 && sizeof(Tin) == sizeof(Tout)) {
memcpy(dst, src, channel * sizeof(Tin));
return;
Expand All @@ -286,10 +286,11 @@ void FloatBlobToNCHW(const Tin *src, Tout *dst, int channel, int hw) {
return;
}

template void FloatBlobToNCHW(const float *src, bfp16_t *dst, int channel, int hw);
template void FloatBlobToNCHW(const float *src, float *dst, int channel, int hw);
template void FloatBlobToNCHW(const bfp16_t *src, float *dst, int channel, int hw);
template void FloatBlobToNCHW(const bfp16_t *src, bfp16_t *dst, int channel, int hw);
template void BlobToNCHW(const float *src, bfp16_t *dst, int channel, int hw);
template void BlobToNCHW(const float *src, float *dst, int channel, int hw);
template void BlobToNCHW(const bfp16_t *src, float *dst, int channel, int hw);
template void BlobToNCHW(const bfp16_t *src, bfp16_t *dst, int channel, int hw);
template void BlobToNCHW(const int32_t *src, int32_t *dst, int channel, int hw);

template <typename Tin, typename Tout>
void HalfBlobToNCHW(const Tin *src, Tout *dst, int channel, int hw) {
Expand Down Expand Up @@ -1348,12 +1349,12 @@ static Status ConvertFloatBlobToFloatMat(Mat& image, char* handle_ptr, const Mat
RawBuffer scale_biased(c_r4 * hw * sizeof(float));
ScaleBias(reinterpret_cast<T_blob*>(handle_ptr) + n * c_r4 * hw, channel, hw, param.scale.data(),
param.bias.data(), scale_biased.force_to<T_blob*>());
FloatBlobToNCHW(scale_biased.force_to<T_blob*>(),
BlobToNCHW(scale_biased.force_to<T_blob*>(),
reinterpret_cast<T_mat*>(image.GetData()) + n * channel * hw, channel, hw);
}
} else {
for (int n = 0; n < batch; n++) {
FloatBlobToNCHW(reinterpret_cast<T_blob*>(handle_ptr) + n * c_r4 * hw,
BlobToNCHW(reinterpret_cast<T_blob*>(handle_ptr) + n * c_r4 * hw,
reinterpret_cast<T_mat*>(image.GetData()) + n * channel * hw, channel, hw);
}
}
Expand All @@ -1369,6 +1370,18 @@ static Status ConvertInt8BlobToInt8Mat(Mat& image, char* handle_ptr, const MatCo
reinterpret_cast<int8_t*>(handle_ptr), reinterpret_cast<int8_t*>(image.GetData()), batch, channel, hw);
}

static Status ConvertInt32BlobToInt32Mat(Mat& image, char* handle_ptr, const MatConvertParam& param,
const DimsVector& dims, const int hw, const int c_r4,
std::vector<float>& fused_int8_scale, std::vector<float>& fused_int8_bias) {
auto batch = DimsFunctionUtils::GetDim(dims, 0);
auto channel = DimsFunctionUtils::GetDim(dims, 1);
for (int n = 0; n < batch; n++) {
BlobToNCHW(reinterpret_cast<int32_t*>(handle_ptr) + n * c_r4 * hw,
reinterpret_cast<int32_t*>(image.GetData()) + n * channel * hw, channel, hw);
}
return TNN_OK;
}

// convert from blob to mat
REGISTER_ARM_BLOB_CONVERT_FUNC(N8UC4, DATA_TYPE_INT8, CVT_DIR_BLOB2MAT, ConvertInt8BlobToN8UC4)
REGISTER_ARM_BLOB_CONVERT_FUNC(N8UC4, DATA_TYPE_FLOAT, CVT_DIR_BLOB2MAT, ConvertFloatBlobToN8UC4)
Expand All @@ -1377,6 +1390,7 @@ REGISTER_ARM_BLOB_CONVERT_FUNC(N8UC3, DATA_TYPE_FLOAT, CVT_DIR_BLO
REGISTER_ARM_BLOB_CONVERT_FUNC(NCHW_FLOAT, DATA_TYPE_INT8, CVT_DIR_BLOB2MAT, ConvertInt8BlobToNCHWFloat)
REGISTER_ARM_BLOB_CONVERT_FUNC(NCHW_FLOAT, DATA_TYPE_FLOAT, CVT_DIR_BLOB2MAT, (ConvertFloatBlobToFloatMat<float,float>))
REGISTER_ARM_BLOB_CONVERT_FUNC(NCHW_FLOAT, DATA_TYPE_BFP16, CVT_DIR_BLOB2MAT, (ConvertFloatBlobToFloatMat<float, bfp16_t>))
REGISTER_ARM_BLOB_CONVERT_FUNC(NC_INT32, DATA_TYPE_INT32, CVT_DIR_BLOB2MAT, ConvertInt32BlobToInt32Mat)
REGISTER_ARM_BLOB_CONVERT_FUNC(RESERVED_BFP16_TEST, DATA_TYPE_BFP16, CVT_DIR_BLOB2MAT, (ConvertFloatBlobToFloatMat<bfp16_t, bfp16_t>))
REGISTER_ARM_BLOB_CONVERT_FUNC(RESERVED_INT8_TEST, DATA_TYPE_INT8, CVT_DIR_BLOB2MAT, ConvertInt8BlobToInt8Mat)

Expand Down
4 changes: 4 additions & 0 deletions test/unit_test/layer_test/layer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,8 @@ Status LayerTest::GenerateRandomBlob(Blob* cpu_blob, Blob* device_blob, void* co
} else if (blob_desc_device.data_type == DATA_TYPE_HALF && device_blob->GetBlobDesc().device_type == DEVICE_ARM) {
// the value is initialized as half
mat_type = RESERVED_FP16_TEST;
} else if (blob_desc_device.data_type == DATA_TYPE_INT32) {
mat_type = NC_INT32;
}
TNN_NS::Mat input_mat_cpu(DEVICE_NAIVE, mat_type, blob_desc.dims);
void* input_data = input_mat_cpu.GetData();
Expand Down Expand Up @@ -302,6 +304,8 @@ Status LayerTest::GenerateRandomBlob(Blob* cpu_blob, Blob* device_blob, void* co
} else {
InitRandom(static_cast<bfp16_t*>(input_data), blob_count, bfp16_t(1.0f + magic_num));
}
} else if (mat_type == NC_INT32) {
InitRandom(static_cast<int32_t*>(input_data), blob_count, integer_input_min_, integer_input_max_);
}

// default param for the blob_converter
Expand Down
2 changes: 2 additions & 0 deletions test/unit_test/layer_test/layer_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class LayerTest : public ::testing::Test {

protected:
int ensure_input_positive_ = 0;
int integer_input_min_ = 0;
int integer_input_max_ = 1;

static std::shared_ptr<Instance> instance_cpu_;
static std::shared_ptr<Instance> instance_device_;
Expand Down
80 changes: 80 additions & 0 deletions test/unit_test/layer_test/test_roialign_layer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Tencent is pleased to support the open source community by making TNN available.
//
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "test/unit_test/layer_test/layer_test.h"
#include "test/unit_test/unit_test_common.h"
#include "test/unit_test/utils/network_helpers.h"
#include "tnn/utils/dims_utils.h"

namespace TNN_NS {

class RoiAlignLayerTest
: public LayerTest,
public ::testing::WithParamInterface<std::tuple<int, int, int, int, int, int, int, int, float>> {};

INSTANTIATE_TEST_SUITE_P(LayerTest, RoiAlignLayerTest,
::testing::Combine(BASIC_BATCH_CHANNEL_SIZE,
// num_rois
testing::Values(1, 4, 7, 16),
// pool type 0:max 1:avg
testing::Values(0, 1),
// output_height
testing::Values(1, 3, 8),
// output_width
testing::Values(1, 3, 8),
// sampling_ratio
testing::Values(0, 1, 2),
// spatial_scale
testing::Values(0.125, 0.25, 0.05)));

TEST_P(RoiAlignLayerTest, RoiAlignLayer) {
// get param
int batch = std::get<0>(GetParam());
int channel = std::get<1>(GetParam());
int input_size = std::get<2>(GetParam());
int num_rois = std::get<3>(GetParam());
int mode = std::get<4>(GetParam());
int output_height = std::get<5>(GetParam());
int output_width = std::get<6>(GetParam());
int sampling_ratio = std::get<7>(GetParam());
float spatial_scale = std::get<8>(GetParam());

integer_input_max_ = batch;

DeviceType dev = ConvertDeviceType(FLAGS_dt);

if (DEVICE_ARM != dev) {
GTEST_SKIP();
}

// param
std::shared_ptr<RoiAlignLayerParam> param(new RoiAlignLayerParam());
param->name = "RoiAlign";
param->mode = mode;
param->output_height = output_height;
param->output_width = output_width;
param->sampling_ratio = sampling_ratio;
param->spatial_scale = spatial_scale;

// generate interpreter
std::vector<int> input_dims = {batch, channel, input_size, input_size};
std::vector<int> rois_dims = {num_rois, 4};
std::vector<int> batch_indices_dims = {num_rois};
std::vector<DataType> input_dtype = {DATA_TYPE_FLOAT, DATA_TYPE_FLOAT, DATA_TYPE_INT32};
auto interpreter =
GenerateInterpreter("RoiAlign", {input_dims, rois_dims, batch_indices_dims}, param, nullptr, 1, input_dtype);
Run(interpreter);
}

} // namespace TNN_NS
16 changes: 14 additions & 2 deletions test/unit_test/unit_test_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,22 @@ InputShapesMap GenerateInputShapeMap(std::vector<std::vector<int>>& input_vec) {
return shape_map;
}

InputDataTypeMap GenerateInputDataTypeMap(const std::vector<DataType>& input_dtype) {
InputDataTypeMap dtype_map;
for (int i = 0; i < input_dtype.size(); ++i) {
std::ostringstream ostr;
ostr << "input" << i;
dtype_map[ostr.str()] = input_dtype[i];
}
return dtype_map;
}

std::shared_ptr<AbstractModelInterpreter> GenerateInterpreter(std::string layer_type_str,
std::vector<std::vector<int>> input_vec,
std::shared_ptr<LayerParam> param,
std::shared_ptr<LayerResource> resource,
int output_count) {
int output_count,
std::vector<DataType> input_dtype) {
auto interpreter = CreateModelInterpreter(MODEL_TYPE_TNN);
if (!interpreter) {
return nullptr;
Expand All @@ -111,7 +122,8 @@ std::shared_ptr<AbstractModelInterpreter> GenerateInterpreter(std::string layer_
NetResource* net_resource = default_interpreter->GetNetResource();

// generate net structure
net_structure->inputs_shape_map = GenerateInputShapeMap(input_vec);
net_structure->inputs_shape_map = GenerateInputShapeMap(input_vec);
net_structure->input_data_type_map = GenerateInputDataTypeMap(input_dtype);

std::shared_ptr<LayerInfo> layer_info = std::make_shared<LayerInfo>();
layer_info->type = GlobalConvertLayerType(layer_type_str);
Expand Down
3 changes: 2 additions & 1 deletion test/unit_test/unit_test_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ std::shared_ptr<AbstractModelInterpreter> GenerateInterpreter(std::string layer_
std::vector<std::vector<int>> input_vec,
std::shared_ptr<LayerParam> param,
std::shared_ptr<LayerResource> resource = nullptr,
int output_count = 1);
int output_count = 1,
std::vector<DataType> input_dtype = {});

} // namespace TNN_NS

Expand Down