diff --git a/third_party/onnx b/third_party/onnx index 20835f8..9f70c11 160000 --- a/third_party/onnx +++ b/third_party/onnx @@ -1 +1 @@ -Subproject commit 20835f8f238a232de2f1540fc109de62759e2495 +Subproject commit 9f70c118910f7174f65dac84c90d46d608d9cfd2 diff --git a/tools/onnx2bnn/OnnxConverter.cpp b/tools/onnx2bnn/OnnxConverter.cpp index 5da62e3..9794f43 100644 --- a/tools/onnx2bnn/OnnxConverter.cpp +++ b/tools/onnx2bnn/OnnxConverter.cpp @@ -191,10 +191,11 @@ std::vector OnnxConverter::Convert( // Please check out "dabnn_*" pases in // https://github.com/daquexian/onnx/blob/optimizer_for_bnn/onnx/optimizer/passes // for details. - vector optimizers{"eliminate_nop_pad", - "extract_constant_to_initializer", - "dabnn_convert_gemm_with_reshape_or_flatten_to_conv_and_reshape", - "dabnn_bconv_strict"}; + vector optimizers{ + "eliminate_nop_pad", "extract_constant_to_initializer", + "dabnn_eliminate_dropout", + "dabnn_convert_gemm_with_reshape_or_flatten_to_conv_and_reshape", + "dabnn_bconv_strict"}; if (level == Level::kModerate || level == Level::kAggressive) { optimizers.push_back("dabnn_bconv_moderate"); } @@ -231,13 +232,23 @@ std::vector OnnxConverter::Convert( } Shape shape; - for (const auto &dim : input.type().tensor_type().shape().dim()) { + const auto &dims = input.type().tensor_type().shape().dim(); + FORZ(i, dims.size()) { + if (i == 0) { + // We ignore the value of batch dimension since dabnn doesn't + // support batch input + shape.push_back(1); + continue; + } + const auto &dim = dims.Get(i); if (dim.value_case() == ONNX_NAMESPACE::TensorShapeProto_Dimension::kDimValue) { shape.push_back(static_cast(dim.dim_value())); } else { throw std::invalid_argument( - "The input of graph doesn't have dim_value"); + "Dim " + std::to_string(i) + " of input \"" + input.name() + + "\" is not static, please re-export your ONNX model with " + "static input shape"); } } Shape nhwc_shape{shape[0], shape[2], shape[3], shape[1]}; @@ -248,17 +259,16 @@ std::vector OnnxConverter::Convert( } vector binary_conv_outputs; - vector skipped_act; bool has_reshape = false; for (const auto &node : model_proto_.graph().node()) { - if (has_reshape) { - throw std::invalid_argument( - "Reshape can only be the last layer for now"); - } NodeAttrHelper helper(node); const auto &op = node.op_type(); VLOG(5) << "Node " << node.name(); if (op == "Conv") { + if (has_reshape) { + throw std::invalid_argument("Reshape before " + op + + " is not supported"); + } VLOG(5) << "Start converting Conv"; auto strides = helper.get("strides", vector{1, 1}); auto pads = helper.get("pads", vector{0, 0, 0, 0}); @@ -308,6 +318,10 @@ std::vector OnnxConverter::Convert( VLOG(5) << "Converting Conv completed"; } else if (op == "AveragePool" || op == "MaxPool" || op == "GlobalAveragePool" || op == "GlobalMaxPool") { + if (has_reshape) { + throw std::invalid_argument("Reshape before " + op + + " is not supported"); + } VLOG(5) << "Start converting Pool"; auto input_name = m(node.input(0)); auto output_name = m(node.output(0)); @@ -396,6 +410,10 @@ std::vector OnnxConverter::Convert( layers_.push_back(layer); VLOG(5) << "Converting Relu completed"; } else if (op == "Add") { + if (has_reshape) { + throw std::invalid_argument("Reshape before " + op + + " is not supported"); + } VLOG(5) << "Start converting Add"; auto input1_name = m(node.input(0)); auto input2_name = m(node.input(1)); @@ -409,6 +427,9 @@ std::vector OnnxConverter::Convert( layers_.push_back(layer); VLOG(5) << "Converting Add completed"; } else if (op == "Gemm") { + if (has_reshape) { + has_reshape = false; + } VLOG(5) << "Start converting Gemm"; auto transA = helper.get("transA", 0); auto transB = helper.get("transB", 0); @@ -467,6 +488,10 @@ std::vector OnnxConverter::Convert( layers_.push_back(layer); VLOG(5) << "Converting Softmax completed"; } else if (op == "Concat") { + if (has_reshape) { + throw std::invalid_argument("Reshape before " + op + + " is not supported"); + } VLOG(5) << "Start converting Concat"; vector concat_inputs_str; for (const auto &onnx_input : node.input()) { @@ -486,11 +511,6 @@ std::vector OnnxConverter::Convert( 0, 0, 0, 0, 0, 0, param); layers_.push_back(layer); VLOG(5) << "Converting Concat completed"; - } else if (op == "Dropout") { - VLOG(5) << "Start converting Dropout"; - // Dropout does nothing, so the output is the same as the input - name_map_[node.output(0)] = m(node.input(0)); - VLOG(5) << "Converting Dropout completed"; } else if (op == "Reshape") { VLOG(5) << "Start converting Reshape"; has_reshape = true;