diff --git a/dabnn/layers/BinConv.cpp b/dabnn/layers/BinConv.cpp index e769181..12ec514 100644 --- a/dabnn/layers/BinConv.cpp +++ b/dabnn/layers/BinConv.cpp @@ -26,9 +26,9 @@ BinConv::BinConv(NetCP net, const std::string &name, css input, css weight, const auto binaized_name = "binaized_for_" + output + "_cal"; if (mat_map.find(binaized_name) == mat_map.end()) { auto &input_mat = *mat_map[input]; - mat_map[binaized_name] = std::make_shared( - input_mat.h, input_mat.w, input_mat.elem_c, - DataType::Bit, binaized_name); + mat_map[binaized_name] = + std::make_shared(input_mat.h, input_mat.w, input_mat.elem_c, + DataType::Bit, binaized_name); } binarized_mat = mat(binaized_name); @@ -56,15 +56,16 @@ BinConv::BinConv(NetCP net, const std::string &name, css input, css weight, const int m = weight_mat->n; BNN_ASSERT(weight_mat->total() % m == 0, ""); const int k = weight_mat->total() / m; - transposed_weight_mat = - std::make_shared(m, k * 64, DataType::Bit); + transposed_weight_mat = std::make_shared(m, k * 64, DataType::Bit); auto *trans_data_ptr = static_cast(transposed_weight_mat->data); auto *data_ptr = static_cast(weight_mat->data); FORZ(i, k) { - FORZ(j, m) { - BNN_ASSERT(static_cast(i * m + j) < transposed_weight_mat->total(), i * m + j, " ", transposed_weight_mat->total()); - trans_data_ptr[i * m + j] = data_ptr[j * k + i]; + FORZ(j, m) { + BNN_ASSERT(static_cast(i * m + j) < + transposed_weight_mat->total(), + i * m + j, " ", transposed_weight_mat->total()); + trans_data_ptr[i * m + j] = data_ptr[j * k + i]; } } net_.lock()->add_mat(trans_weight_mat_name, transposed_weight_mat); @@ -115,7 +116,9 @@ void BinConv::forward_impl() const { bconv_3x3(*padded_mat, *weight_mat, *output_mat, stride_h); } else if (gemm_compatible()) { output_mat->fill(0.f); - bnn::fused_binarize_im2col(*input_mat, weight_mat->h, weight_mat->w, pad_h, pad_w, stride_h, stride_w, 1, 1, *col_mat); + bnn::fused_binarize_im2col(*input_mat, weight_mat->h, weight_mat->w, + pad_h, pad_w, stride_h, stride_w, 1, 1, + *col_mat); const int m = weight_mat->n; const int n = output_mat->h * output_mat->w; const int k = weight_mat->h * weight_mat->w * weight_mat->c; @@ -130,9 +133,9 @@ void BinConv::forward_impl() const { } } else { pack_mat(*input_mat, *binarized_mat); - baseline_bconv(*binarized_mat, *weight_mat, weight_mat->h, weight_mat->w, - pad_h, pad_w, stride_h, stride_w, 1, 1, output_mat->c, - *output_mat); + baseline_bconv(*binarized_mat, *weight_mat, weight_mat->h, + weight_mat->w, pad_h, pad_w, stride_h, stride_w, 1, 1, + output_mat->c, *output_mat); } }