Skip to content

Commit

Permalink
Revert "use bgemm_naive instead of bconv_naive as fallback due to the…
Browse files Browse the repository at this point in the history
… 128-align weight"

This reverts commit be606fa.
  • Loading branch information
daquexian committed Aug 22, 2019
1 parent cec3499 commit 337f55f
Showing 1 changed file with 49 additions and 0 deletions.
49 changes: 49 additions & 0 deletions dabnn/layers/BinConv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ BinConv::BinConv(NetCP net, const std::string &name, css input, css weight,
stride_h(stride_h),
stride_w(stride_w) {
auto &mat_map = net.lock()->mat_map_;
<<<<<<< HEAD
if (direct_conv_compatible()) {
const auto binaized_name = "binaized_for_" + output + "_cal";
if (mat_map.find(binaized_name) == mat_map.end()) {
Expand All @@ -34,7 +35,16 @@ BinConv::BinConv(NetCP net, const std::string &name, css input, css weight,
binaized_name);
}
binarized_mat = mat(binaized_name);
=======
const auto binaized_name = "binaized_for_" + output + "_cal";
if (mat_map.find(binaized_name) == mat_map.end()) {
auto &input_mat = *mat_map[input];
mat_map[binaized_name] =
std::make_shared<Mat>(input_mat.h, input_mat.w, input_mat.elem_c,
DataType::Bit, binaized_name);
>>>>>>> parent of be606fa... use bgemm_naive instead of bconv_naive as fallback due to the 128-align weight
}
binarized_mat = mat(binaized_name);

const auto pad_name = "pad_for_" + output + "_cal";
if (mat_map.find(pad_name) == mat_map.end()) {
Expand All @@ -45,6 +55,7 @@ BinConv::BinConv(NetCP net, const std::string &name, css input, css weight,
}
padded_mat = mat(pad_name);

<<<<<<< HEAD
if (net.lock()->optimize && !direct_conv_compatible()) {
const auto col_mat_name = "col_for_" + output + "_cal";
if (mat_map.find(col_mat_name) == mat_map.end()) {
Expand All @@ -56,7 +67,20 @@ BinConv::BinConv(NetCP net, const std::string &name, css input, css weight,
std::make_shared<Mat>(1, 1, len, bnn::DataType::Bit);
}
col_mat = mat(col_mat_name);
=======
const auto col_mat_name = "col_for_" + output + "_cal";
if (mat_map.find(col_mat_name) == mat_map.end()) {
const auto len =
output_mat->h * output_mat->w *
align_to(weight_mat->h * weight_mat->w * input_mat->elem_c, 128);
mat_map[col_mat_name] =
std::make_shared<Mat>(1, 1, len, bnn::DataType::Bit);
}
col_mat = mat(col_mat_name);
>>>>>>> parent of be606fa... use bgemm_naive instead of bconv_naive as fallback due to the 128-align weight

if (net.lock()->optimize && !direct_conv_compatible() &&
gemm_compatible()) {
const auto trans_weight_mat_name = "trans_" + weight;
// transpose the weight for bgemm
const int m = weight_mat->n;
Expand Down Expand Up @@ -123,6 +147,7 @@ bool BinConv::gemm_compatible() const {
}

void BinConv::forward_impl() const {
<<<<<<< HEAD
if (net_.lock()->optimize && direct_conv_compatible()) {
pack_mat(*input_mat, *binarized_mat);
pad(*binarized_mat, pad_h, pad_w, *padded_mat);
Expand All @@ -138,14 +163,38 @@ void BinConv::forward_impl() const {
const int n = output_mat->h * output_mat->w;
const int k = weight_mat->total() / weight_mat->n;
if (net_.lock()->optimize && gemm_compatible()) {
=======
if (net_.lock()->optimize) {
if (direct_conv_compatible()) {
pack_mat(*input_mat, *binarized_mat);
pad(*binarized_mat, pad_h, pad_w, *padded_mat);
bconv_3x3(*padded_mat, *weight_mat, *output_mat, stride_h);
} else if (gemm_compatible()) {
output_mat->fill<float>(0.f);

bnn::fused_binarize_im2col(*input_mat, weight_mat->h, weight_mat->w,
pad_h, pad_w, stride_h, stride_w, 1, 1,
*col_mat);

const int m = weight_mat->n;
const int n = output_mat->h * output_mat->w;
const int k = weight_mat->total() / weight_mat->n;
>>>>>>> parent of be606fa... use bgemm_naive instead of bconv_naive as fallback due to the 128-align weight
bgemm(m, n, k, static_cast<uint64_t *>(transposed_weight_mat->data),
m, static_cast<uint64_t *>(col_mat->data), k,
static_cast<float *>(output_mat->data), m);
} else {
<<<<<<< HEAD
bgemm_naive(m, n, k,
static_cast<uint64_t *>(transposed_weight_mat->data), m,
static_cast<uint64_t *>(col_mat->data), k,
static_cast<float *>(output_mat->data), m);
=======
pack_mat(*input_mat, *binarized_mat);
baseline_bconv(*binarized_mat, *weight_mat, weight_mat->h,
weight_mat->w, pad_h, pad_w, stride_h, stride_w, 1,
1, output_mat->c, *output_mat);
>>>>>>> parent of be606fa... use bgemm_naive instead of bconv_naive as fallback due to the 128-align weight
}
}
}
Expand Down

0 comments on commit 337f55f

Please sign in to comment.