diff --git a/dabnn/layers/BinConv.cpp b/dabnn/layers/BinConv.cpp index ccf7ff1..a818903 100644 --- a/dabnn/layers/BinConv.cpp +++ b/dabnn/layers/BinConv.cpp @@ -83,11 +83,13 @@ BinConv::Method BinConv::method() const { return Method::DIRECT_CONV; } else if (gemm_compatible()) { return Method::BGEMM; - } else { + } else if (input_mat->elem_c == 64) { return Method::BCONV_NAIVE; + } else { + return Method::BGEMM_NAIVE; } } else { - if (weight_mat->c == 1) { + if (input_mat->elem_c == 64) { return Method::BCONV_NAIVE; } else { return Method::BGEMM_NAIVE;