Skip to content

Commit

Permalink
beta 0.2.0.2
Browse files Browse the repository at this point in the history
- CPU
  - add padding support
  - fix bug in permute when channel % 4 != 0
  - fix bug in exp with extreme value
- OpenCL
  - add protecting logics
- OpenGL
  - add protecting logics
  - support NCHW format in Squeeze and Reshape
- Converter
  - add ShuffleChannel support for Caffe
  - add Clip/Transpose/Unary/Pad supports for ONNX
  • Loading branch information
liqing committed Jul 2, 2019
1 parent ad759eb commit db155b4
Show file tree
Hide file tree
Showing 87 changed files with 2,005 additions and 765 deletions.
10 changes: 5 additions & 5 deletions demo/exec/pictureRecognition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,15 @@ int main(int argc, const char* argv[]) {
}
MNN_PRINT("origin size: %d, %d\n", width, height);
Matrix trans;
// Dst -> [0, 1]
trans.postScale(1.0 / size_w, 1.0 / size_h);
//[0, 1] -> Src
trans.postScale(width, height);
// Set scale, from dst scale to src
trans.setScale((float)(width-1) / (size_w-1), (float)(height-1) / (size_h-1));
ImageProcess::Config config;
config.filterType = BILINEAR;
float mean[3] = {103.94f, 116.78f, 123.68f};
::memcpy(config.mean, mean, sizeof(mean));
float normals[3] = {0.017f, 0.017f, 0.017f};
// float mean[3] = {127.5f, 127.5f, 127.5f};
// float normals[3] = {0.00785f, 0.00785f, 0.00785f};
::memcpy(config.mean, mean, sizeof(mean));
::memcpy(config.normal, normals, sizeof(normals));
config.sourceFormat = RGBA;
config.destFormat = BGR;
Expand Down
2 changes: 1 addition & 1 deletion project/android/build_32.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ cmake ../../../ \
-DANDROID_ABI="armeabi-v7a" \
-DANDROID_STL=c++_static \
-DCMAKE_BUILD_TYPE=Release \
-DANDROID_NATIVE_API_LEVEL=android-21 \
-DANDROID_NATIVE_API_LEVEL=android-19 \
-DANDROID_TOOLCHAIN=gcc \
-DMNN_BUILD_FOR_ANDROID_COMMAND=true \
-DMNN_DEBUG=false \
Expand Down
12 changes: 12 additions & 0 deletions project/ios/MNN.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
4843AA5922A7E9AB00889A63 /* CPUConv2DBackPropFilter.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 4843AA5122A7E9AB00889A63 /* CPUConv2DBackPropFilter.cpp */; };
4843AA5A22A7E9AB00889A63 /* CPUSoftmaxGrad.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 4843AA5222A7E9AB00889A63 /* CPUSoftmaxGrad.cpp */; };
4843AA5B22A7E9AB00889A63 /* CPUSoftmaxGrad.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 4843AA5322A7E9AB00889A63 /* CPUSoftmaxGrad.hpp */; };
4847D41D22C0739A0049F3CA /* ShapePadding.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 4847D41C22C0739A0049F3CA /* ShapePadding.cpp */; };
4847D42022C07E850049F3CA /* CPUPadding.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 4847D41E22C07E850049F3CA /* CPUPadding.cpp */; };
4847D42122C07E850049F3CA /* CPUPadding.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 4847D41F22C07E850049F3CA /* CPUPadding.hpp */; };
4851BE102122C1BC009BB0AC /* Tensor.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 4851BE0F2122C1BC009BB0AC /* Tensor.hpp */; settings = {ATTRIBUTES = (Public, ); }; };
485DD411217F495500129159 /* CPUQuantizedAdd.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 485DD40B217F495400129159 /* CPUQuantizedAdd.hpp */; };
485DD412217F495500129159 /* CPUQuantizedSoftmax.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 485DD40C217F495500129159 /* CPUQuantizedSoftmax.cpp */; };
Expand Down Expand Up @@ -752,6 +755,9 @@
4843AA5122A7E9AB00889A63 /* CPUConv2DBackPropFilter.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = CPUConv2DBackPropFilter.cpp; sourceTree = "<group>"; };
4843AA5222A7E9AB00889A63 /* CPUSoftmaxGrad.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = CPUSoftmaxGrad.cpp; sourceTree = "<group>"; };
4843AA5322A7E9AB00889A63 /* CPUSoftmaxGrad.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = CPUSoftmaxGrad.hpp; sourceTree = "<group>"; };
4847D41C22C0739A0049F3CA /* ShapePadding.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = ShapePadding.cpp; sourceTree = "<group>"; };
4847D41E22C07E850049F3CA /* CPUPadding.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = CPUPadding.cpp; sourceTree = "<group>"; };
4847D41F22C07E850049F3CA /* CPUPadding.hpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.h; path = CPUPadding.hpp; sourceTree = "<group>"; };
4851BE0F2122C1BC009BB0AC /* Tensor.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = Tensor.hpp; sourceTree = "<group>"; };
485DD40B217F495400129159 /* CPUQuantizedAdd.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = CPUQuantizedAdd.hpp; sourceTree = "<group>"; };
485DD40C217F495500129159 /* CPUQuantizedSoftmax.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = CPUQuantizedSoftmax.cpp; sourceTree = "<group>"; };
Expand Down Expand Up @@ -1908,6 +1914,8 @@
48B904A522953E0F003116BB /* CPUZeroLike.hpp */,
4829D54E22AF5C340093E3BE /* CPUSetDiff1D.cpp */,
4829D54F22AF5C340093E3BE /* CPUSetDiff1D.hpp */,
4847D41E22C07E850049F3CA /* CPUPadding.cpp */,
4847D41F22C07E850049F3CA /* CPUPadding.hpp */,
);
name = cpu;
path = backend/cpu;
Expand Down Expand Up @@ -2341,6 +2349,7 @@
EBB38EF421E748B9005F76D7 /* ShapeUnpack.cpp */,
EBB38EDC21E748B9005F76D7 /* ShapeWhere.cpp */,
48B904A8229550CF003116BB /* ShapeSelect.cpp */,
4847D41C22C0739A0049F3CA /* ShapePadding.cpp */,
);
path = shape;
sourceTree = "<group>";
Expand Down Expand Up @@ -2475,6 +2484,7 @@
488875D5215B639F0079B12E /* MetalInterp.hpp in Headers */,
488875B4215B639F0079B12E /* MetalReLU.hpp in Headers */,
48887644215B639F0079B12E /* ConvolutionTiledExecutor.hpp in Headers */,
4847D42122C07E850049F3CA /* CPUPadding.hpp in Headers */,
488875B7215B639F0079B12E /* MetalSlice.hpp in Headers */,
92EEFEB2217F0CBB00F89377 /* CPUCrop.hpp in Headers */,
921722F021DDF63A004583BF /* GpuLibrary_generated.h in Headers */,
Expand Down Expand Up @@ -2827,6 +2837,7 @@
92C674F922549A1600011D33 /* MetalReLU6.mm in Sources */,
488875D3215B639F0079B12E /* MetalSpatialProduct.metal in Sources */,
48887630215B639F0079B12E /* CPUTopKV2.cpp in Sources */,
4847D42022C07E850049F3CA /* CPUPadding.cpp in Sources */,
48BF218621A4257500AFF78E /* MNNSamplerC1BilinearOpt.S in Sources */,
CE96FE8121707D58004AB400 /* MetalMatMul.metal in Sources */,
48887689215B639F0079B12E /* MNNCubicLineC4.S in Sources */,
Expand All @@ -2844,6 +2855,7 @@
488875FF215B639F0079B12E /* CPUSize.cpp in Sources */,
EB4925C3224A147E00C512BB /* CPUMoments.cpp in Sources */,
92256950219D6E0200F251E2 /* MetalRange.mm in Sources */,
4847D41D22C0739A0049F3CA /* ShapePadding.cpp in Sources */,
924F132521ABD47F006D46A4 /* MetalQuantizedSoftmax.metal in Sources */,
EBB38F1521E748B9005F76D7 /* ShapeWhere.cpp in Sources */,
488876D9215B639F0079B12E /* CPUTanh.cpp in Sources */,
Expand Down
3 changes: 1 addition & 2 deletions schema/default/MNN.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,7 @@ enum OpType : int {

table Plugin {
type: string;
sizeCompute: Net;
buffer: [byte];
buffer: [Blob];
}

union OpParameter {
Expand Down
2 changes: 1 addition & 1 deletion source/backend/cpu/CPUDequantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ void dequantizeMinFirst(uint8_t* input, float* output, float* rangeScale, float*
namespace MNN {

template <typename T>
CPUDequantize<T>::CPUDequantize(Backend* backend, QuantizeMode mode, const Op* op) : mMode(mode), Execution(backend) {
CPUDequantize<T>::CPUDequantize(Backend* backend, QuantizeMode mode, const Op* op) : Execution(backend), mMode(mode) {
mHalfRange = !std::is_signed<T>::value ? 0.0f
: (static_cast<double>(std::numeric_limits<T>::max()) -
static_cast<double>(std::numeric_limits<T>::min()) + 1) /
Expand Down
6 changes: 3 additions & 3 deletions source/backend/cpu/CPUGatherV2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ class CPUGatherV2Creator : public CPUBackend::Creator {
public:
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
const MNN::Op *op, Backend *backend) const override {
switch (op->main_as_GatherV2()->Tparams()) {
case DataType_DT_INT32:
switch (inputs[0]->getType().code) {
case halide_type_int:
return new CPUGatherV2<int32_t>(backend, op);
case DataType_DT_FLOAT:
case halide_type_float:
return new CPUGatherV2<float>(backend, op);
default:
MNN_ASSERT(false); // unsupported type
Expand Down
2 changes: 2 additions & 0 deletions source/backend/cpu/CPUOPRegister.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ extern void ___CPUMomentsCreator__OpType_Moments__();
extern void ___CPUNonMaxSuppressionV2Creator__OpType_NonMaxSuppressionV2__();
extern void ___CPUNormalizeCreator__OpType_Normalize__();
extern void ___CPUPackCreator__OpType_Pack__();
extern void ___CPUPaddingCreator__OpType_Padding__();
extern void ___CPUPermuteCreator__OpType_Permute__();
extern void ___CPUPoolCreator__OpType_Pooling__();
extern void ___CPUPoolGradCreator__OpType_PoolGrad__();
Expand Down Expand Up @@ -119,6 +120,7 @@ ___CPUMomentsCreator__OpType_Moments__();
___CPUNonMaxSuppressionV2Creator__OpType_NonMaxSuppressionV2__();
___CPUNormalizeCreator__OpType_Normalize__();
___CPUPackCreator__OpType_Pack__();
___CPUPaddingCreator__OpType_Padding__();
___CPUPermuteCreator__OpType_Permute__();
___CPUPoolCreator__OpType_Pooling__();
___CPUPoolGradCreator__OpType_PoolGrad__();
Expand Down
93 changes: 93 additions & 0 deletions source/backend/cpu/CPUPadding.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
//
// CPUPadding.cpp
// MNN
//
// Created by MNN on 2019/6/24.
// Copyright © 2018 Alibaba. All rights reserved.
//

#include "CPUPadding.hpp"
#include "Macro.h"
#include "TensorUtils.hpp"
namespace MNN {
ErrorCode CPUPadding::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
auto input = inputs[0];
auto output = outputs[0];
auto padding = inputs[1]->host<int32_t>();
::memset(output->host<char>(), 0, output->size());
auto bytes = input->getType().bytes();
auto unit = input->length(3) * bytes;
for (int b = 0; b < input->length(0); ++b) {
auto outputB = output->host<char>() + output->stride(0) * (b + padding[2 * 0]) * bytes;
auto inputB = input->host<char>() + input->stride(0) * b * bytes;
for (int h = 0; h < input->length(1); ++h) {
auto outputH = outputB + output->stride(1) * (h + padding[2 * 1]) * bytes;
auto inputH = inputB + input->stride(1) * h * bytes;
for (int w = 0; w < input->length(2); ++w) {
auto outputW = outputH + output->stride(2) * (w + padding[2 * 2]) * bytes;
auto inputW = inputH + input->stride(2) * w * bytes;
::memcpy(outputW + padding[3 * 2] * bytes, inputW, unit);
}
}
}
return NO_ERROR;
}

ErrorCode CPUPaddingPacked::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
auto input = inputs[0];
auto output = outputs[0];
auto iw = input->width();
auto ih = input->height();
auto ic = input->channel();
auto ib = input->batch();

auto ow = output->width();
auto oh = output->height();
auto icC4 = UP_DIV(ic, 4);
auto padding = inputs[1]->host<int32_t>();
::memset(output->host<float>(), 0, output->size());
for (int n = 0; n < ib; ++n) {
auto inputN = input->host<float>() + input->stride(0) * n;
auto outputN = output->host<float>() + output->stride(0) * (padding[2 * 0] + n);
for (int c = 0; c < icC4; ++c) {
auto inputC = inputN + c * iw * ih * 4;
auto outputC = outputN + c * ow * oh * 4;

for (int h = 0; h < ih; ++h) {
auto inputH = inputC + h * iw * 4;
auto outputH = outputC + (h + padding[2 * 2]) * ow * 4;

::memcpy(outputH + padding[2 * 3] * 4, inputH, iw * 4 * sizeof(float));
}
}
}

return NO_ERROR;
}
class CPUPaddingCreator : public CPUBackend::Creator {
public:
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
const MNN::Op *op, Backend *backend) const {
if (inputs[0]->dimensions() != 4) {
MNN_ERROR("Currently padding only support NHWC or NC4HW4\n");
return nullptr;
}
auto padding = inputs[1];
auto paddingPtr = padding->host<int32_t>();
if (TensorUtils::getDescribe(inputs[0])->dimensionFormat == MNN_DATA_FORMAT_NHWC) {
return new CPUPadding(backend);
}
if (paddingPtr[2] != 0 || paddingPtr[3] != 0) {
MNN_ERROR("Currently padding NC4HW4 don't support channel padding\n");
return nullptr;
}
if (inputs[0]->buffer().type.code != halide_type_float) {
MNN_ERROR("Currently padding NC4HW4 only support float padding\n");
return nullptr;
}
return new CPUPaddingPacked(backend);
}
};

REGISTER_CPU_OP_CREATOR(CPUPaddingCreator, OpType_Padding);
}; // namespace MNN
31 changes: 31 additions & 0 deletions source/backend/cpu/CPUPadding.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//
// CPUPadding.hpp
// MNN
//
// Created by MNN on 2019/6/24.
// Copyright © 2018 Alibaba. All rights reserved.
//

#ifndef CPUPadding_hpp
#define CPUPadding_hpp

#include <stdio.h>
#include "CPUBackend.hpp"
namespace MNN {
class CPUPaddingPacked : public Execution {
public:
CPUPaddingPacked(Backend *bn) : Execution(bn) {
// Do nothing
}
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
};
class CPUPadding : public Execution {
public:
CPUPadding(Backend *bn) : Execution(bn) {
// Do nothing
}
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
};
}; // namespace MNN

#endif /* CPUPadding_hpp */
22 changes: 19 additions & 3 deletions source/backend/cpu/CPUPermute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ ErrorCode CPUPermute::onResize(const std::vector<Tensor *> &inputs, const std::v
ErrorCode CPUPermute::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
MNN_ASSERT(1 == inputs.size());
MNN_ASSERT(1 == outputs.size());

auto &input = inputs[0]->buffer();
auto &output = outputs[0]->buffer();

Expand Down Expand Up @@ -73,7 +73,7 @@ ErrorCode CPUPermute::onExecute(const std::vector<Tensor *> &inputs, const std::
if (output.dimensions > 3) {
outputWidth = output.dim[3].extent;
}
const int outputChannelAlign4 = ALIGN_UP4(output.dim[1].extent);
const int outputChannel = output.dim[1].extent;

int strides[4][4]; // map from change of output index to change of input index on N, C4, H and W

Expand All @@ -99,7 +99,7 @@ ErrorCode CPUPermute::onExecute(const std::vector<Tensor *> &inputs, const std::

for (int ob = 0, outputIndex = 0, inputIndex = 0; ob < output.dim[0].extent; ++ob) {
const int inputIndex1 = inputIndex;
for (int oz = 0; oz < outputChannelAlign4; oz += 4) {
for (int oz = 0; oz <= outputChannel - 4; oz += 4) {
const int inputIndex2 = inputIndex;
for (int oy = 0; oy < outputHeight; ++oy) {
const int inputIndex3 = inputIndex;
Expand All @@ -114,6 +114,22 @@ ErrorCode CPUPermute::onExecute(const std::vector<Tensor *> &inputs, const std::
}
inputIndex = inputIndex2 + ocTotalStride;
}
if (outputChannel % 4 != 0) {
for (int oy = 0; oy < outputHeight; ++oy) {
const int inputIndex3 = inputIndex;
for (int ox = 0; ox < outputWidth; ++ox) {
originOutput[outputIndex++] = originInput[inputIndex];
for (int oz = 0; oz < outputChannel % 4 - 1; ++oz) {
originOutput[outputIndex++] = originInput[inputIndex + strides[1][oz]];
}
for (int oz = outputChannel % 4; oz < 4; ++oz) {
originOutput[outputIndex++] = 0.0f;
}
inputIndex += strides[3][ox % 4];
}
inputIndex = inputIndex3 + strides[2][oy % 4];
}
}
inputIndex = inputIndex1 + strides[0][ob % 4];
}

Expand Down
2 changes: 1 addition & 1 deletion source/backend/cpu/CPUReshape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ class CPUReshape : public Execution {
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;

protected:
MNN_DATA_FORMAT mDimType;
Tensor mStorage;
MNN_DATA_FORMAT mDimType;
Tensor mWrapTensorForInput;
Tensor mWrapTensorForOutput;
};
Expand Down
1 change: 0 additions & 1 deletion source/backend/cpu/CPUSlice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ ErrorCode CPUSlice::onExecute(const std::vector<Tensor*>& inputs, const std::vec
auto input = inputs[0];
const auto tensorFormat = input->getDimensionType();
if (Tensor::CAFFE == tensorFormat) {
MNN_ASSERT(inputs[0]->buffer().dim[1].flags == MNN::Tensor::REORDER_4);
if (mAxis == 1) {
_sliceChannel(inputs[0], outputs, mTempInput.get());
return NO_ERROR;
Expand Down
16 changes: 8 additions & 8 deletions source/backend/cpu/arm/arm32/MNNExpC8.S
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,22 @@ vpush {q5, q6}

vld1.32 {q0, q1}, [r2]

vmov.i32 q5, #24
vneg.s32 q6, q5
vmov.i32 q2, #87
vcvt.f32.s32 q2, q2
vneg.f32 q3, q2

Loop:

vld1.32 {q8, q9}, [r1]!

vmin.f32 q8, q8, q2
vmin.f32 q9, q9, q2
vmax.f32 q8, q8, q3
vmax.f32 q9, q9, q3

vneg.f32 q10, q8
vneg.f32 q11, q9


vmul.f32 q8, q10, d0[1]
vmul.f32 q9, q11, d0[1]
vcvt.s32.f32 q8, q8
Expand All @@ -40,11 +45,6 @@ vcvt.s32.f32 q9, q9
vcvt.f32.s32 q12, q8
vcvt.f32.s32 q13, q9

vmin.s32 q8, q8, q5
vmin.s32 q9, q9, q5
vmax.s32 q8, q8, q6
vmax.s32 q9, q9, q6

//q10, q11: t
vmls.f32 q10, q12, d0[0]
vmls.f32 q11, q13, d0[0]
Expand Down
Loading

0 comments on commit db155b4

Please sign in to comment.