diff --git a/tools/onnx2bnn/OnnxConverter.cpp b/tools/onnx2bnn/OnnxConverter.cpp index 72b48aa..842be24 100644 --- a/tools/onnx2bnn/OnnxConverter.cpp +++ b/tools/onnx2bnn/OnnxConverter.cpp @@ -289,6 +289,17 @@ std::vector OnnxConverter::Convert( expected_binary_conv_outputs.end()); if (binary_conv) { binary_conv_outputs.push_back(node.output(0)); + bool precede_bn = false; + for (const auto &node2 : model_proto_.graph().node()) { + if (node2.op_type() == "BatchNormalization" && + node2.input(0) == node.output(0)) { + precede_bn = true; + break; + } + } + if (!precede_bn) { + throw std::invalid_argument("Binary convolutions should precede BatchNorm"); + } } AddConv(m(node.input(0)), strides, pads, dilations, group, ori_weight_name, bias_name, m(node.output(0)), binary_conv); @@ -556,6 +567,13 @@ void OnnxConverter::CalculateCoeff(const ONNX_NAMESPACE::NodeProto &node, height * coeff_a_data[i]; } + if (node2.input_size() == 2) { + const auto &bias = onnx_float_tensors_[node2.input(2)]; + + FORZ(i, coeff_b_data.size()) { + coeff_b_data[i] += coeff_a_data[i] * bias.data[i]; + } + } } { FORZ(i, coeff_a_data.size()) { coeff_a_data[i] *= -2; }