Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
daquexian committed Aug 20, 2019
1 parent 976a71c commit fdfe987
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions dabnn/layers/BinConv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ BinConv::BinConv(NetCP net, const std::string &name, css input, css weight,
const auto binaized_name = "binaized_for_" + output + "_cal";
if (mat_map.find(binaized_name) == mat_map.end()) {
auto &input_mat = *mat_map[input];
mat_map[binaized_name] = std::make_shared<Mat>(
input_mat.h, input_mat.w, input_mat.elem_c,
DataType::Bit, binaized_name);
mat_map[binaized_name] =
std::make_shared<Mat>(input_mat.h, input_mat.w, input_mat.elem_c,
DataType::Bit, binaized_name);
}
binarized_mat = mat(binaized_name);

Expand Down Expand Up @@ -56,15 +56,16 @@ BinConv::BinConv(NetCP net, const std::string &name, css input, css weight,
const int m = weight_mat->n;
BNN_ASSERT(weight_mat->total() % m == 0, "");
const int k = weight_mat->total() / m;
transposed_weight_mat =
std::make_shared<Mat>(m, k * 64, DataType::Bit);
transposed_weight_mat = std::make_shared<Mat>(m, k * 64, DataType::Bit);
auto *trans_data_ptr =
static_cast<uint64_t *>(transposed_weight_mat->data);
auto *data_ptr = static_cast<uint64_t *>(weight_mat->data);
FORZ(i, k) {
FORZ(j, m) {
BNN_ASSERT(static_cast<size_t>(i * m + j) < transposed_weight_mat->total(), i * m + j, " ", transposed_weight_mat->total());
trans_data_ptr[i * m + j] = data_ptr[j * k + i];
FORZ(j, m) {
BNN_ASSERT(static_cast<size_t>(i * m + j) <
transposed_weight_mat->total(),
i * m + j, " ", transposed_weight_mat->total());
trans_data_ptr[i * m + j] = data_ptr[j * k + i];
}
}
net_.lock()->add_mat(trans_weight_mat_name, transposed_weight_mat);
Expand Down Expand Up @@ -115,7 +116,9 @@ void BinConv::forward_impl() const {
bconv_3x3(*padded_mat, *weight_mat, *output_mat, stride_h);
} else if (gemm_compatible()) {
output_mat->fill<float>(0.f);
bnn::fused_binarize_im2col(*input_mat, weight_mat->h, weight_mat->w, pad_h, pad_w, stride_h, stride_w, 1, 1, *col_mat);
bnn::fused_binarize_im2col(*input_mat, weight_mat->h, weight_mat->w,
pad_h, pad_w, stride_h, stride_w, 1, 1,
*col_mat);
const int m = weight_mat->n;
const int n = output_mat->h * output_mat->w;
const int k = weight_mat->h * weight_mat->w * weight_mat->c;
Expand All @@ -130,9 +133,9 @@ void BinConv::forward_impl() const {
}
} else {
pack_mat(*input_mat, *binarized_mat);
baseline_bconv(*binarized_mat, *weight_mat, weight_mat->h, weight_mat->w,
pad_h, pad_w, stride_h, stride_w, 1, 1, output_mat->c,
*output_mat);
baseline_bconv(*binarized_mat, *weight_mat, weight_mat->h,
weight_mat->w, pad_h, pad_w, stride_h, stride_w, 1, 1,
output_mat->c, *output_mat);
}
}

Expand Down

0 comments on commit fdfe987

Please sign in to comment.