Skip to content

Commit

Permalink
Enhance the support for reshape
Browse files Browse the repository at this point in the history
  • Loading branch information
daquexian committed Nov 10, 2019
1 parent 9e375c9 commit c766981
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions tools/onnx2bnn/OnnxConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,14 +262,12 @@ std::vector<std::string> OnnxConverter::Convert(
for (const auto &node : model_proto_.graph().node()) {
NodeAttrHelper helper(node);
const auto &op = node.op_type();
if (has_reshape && op != "Gemm") {
throw std::invalid_argument(
"Reshape can only be the last layer or precede a gemm layer "
"for now");
}
has_reshape = false;
VLOG(5) << "Node " << node.name();
if (op == "Conv") {
if (has_reshape) {
throw std::invalid_argument("Reshape before " + op +
" is not supported");
}
VLOG(5) << "Start converting Conv";
auto strides = helper.get("strides", vector<int>{1, 1});
auto pads = helper.get("pads", vector<int>{0, 0, 0, 0});
Expand Down Expand Up @@ -319,6 +317,10 @@ std::vector<std::string> OnnxConverter::Convert(
VLOG(5) << "Converting Conv completed";
} else if (op == "AveragePool" || op == "MaxPool" ||
op == "GlobalAveragePool" || op == "GlobalMaxPool") {
if (has_reshape) {
throw std::invalid_argument("Reshape before " + op +
" is not supported");
}
VLOG(5) << "Start converting Pool";
auto input_name = m(node.input(0));
auto output_name = m(node.output(0));
Expand Down Expand Up @@ -407,6 +409,10 @@ std::vector<std::string> OnnxConverter::Convert(
layers_.push_back(layer);
VLOG(5) << "Converting Relu completed";
} else if (op == "Add") {
if (has_reshape) {
throw std::invalid_argument("Reshape before " + op +
" is not supported");
}
VLOG(5) << "Start converting Add";
auto input1_name = m(node.input(0));
auto input2_name = m(node.input(1));
Expand All @@ -420,6 +426,9 @@ std::vector<std::string> OnnxConverter::Convert(
layers_.push_back(layer);
VLOG(5) << "Converting Add completed";
} else if (op == "Gemm") {
if (has_reshape) {
has_reshape = false;
}
VLOG(5) << "Start converting Gemm";
auto transA = helper.get("transA", 0);
auto transB = helper.get("transB", 0);
Expand Down Expand Up @@ -478,6 +487,10 @@ std::vector<std::string> OnnxConverter::Convert(
layers_.push_back(layer);
VLOG(5) << "Converting Softmax completed";
} else if (op == "Concat") {
if (has_reshape) {
throw std::invalid_argument("Reshape before " + op +
" is not supported");
}
VLOG(5) << "Start converting Concat";
vector<std::string> concat_inputs_str;
for (const auto &onnx_input : node.input()) {
Expand Down

0 comments on commit c766981

Please sign in to comment.