Skip to content

Commit

Permalink
use bgemm_naive instead of bconv_naive as fallback due to the 128-ali…
Browse files Browse the repository at this point in the history
…gn weight
  • Loading branch information
daquexian committed Aug 22, 2019
1 parent 7af8be7 commit be606fa
Showing 1 changed file with 30 additions and 27 deletions.
57 changes: 30 additions & 27 deletions dabnn/layers/BinConv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,16 @@ BinConv::BinConv(NetCP net, const std::string &name, css input, css weight,
stride_h(stride_h),
stride_w(stride_w) {
auto &mat_map = net.lock()->mat_map_;
const auto binaized_name = "binaized_for_" + output + "_cal";
if (mat_map.find(binaized_name) == mat_map.end()) {
auto &input_mat = *mat_map[input];
mat_map[binaized_name] =
std::make_shared<Mat>(input_mat.h, input_mat.w, input_mat.elem_c,
DataType::Bit, binaized_name);
if (direct_conv_compatible()) {
const auto binaized_name = "binaized_for_" + output + "_cal";
if (mat_map.find(binaized_name) == mat_map.end()) {
auto &input_mat = *mat_map[input];
mat_map[binaized_name] =
std::make_shared<Mat>(input_mat.h, input_mat.w, input_mat.elem_c,
DataType::Bit, binaized_name);
}
binarized_mat = mat(binaized_name);
}
binarized_mat = mat(binaized_name);

const auto pad_name = "pad_for_" + output + "_cal";
if (mat_map.find(pad_name) == mat_map.end()) {
Expand All @@ -43,18 +45,18 @@ BinConv::BinConv(NetCP net, const std::string &name, css input, css weight,
}
padded_mat = mat(pad_name);

const auto col_mat_name = "col_for_" + output + "_cal";
if (mat_map.find(col_mat_name) == mat_map.end()) {
const auto len =
output_mat->h * output_mat->w *
align_to(weight_mat->h * weight_mat->w * input_mat->elem_c, 128);
mat_map[col_mat_name] =
std::make_shared<Mat>(1, 1, len, bnn::DataType::Bit);
}
col_mat = mat(col_mat_name);
if (net.lock()->optimize && !direct_conv_compatible()) {

const auto col_mat_name = "col_for_" + output + "_cal";
if (mat_map.find(col_mat_name) == mat_map.end()) {
const auto len =
output_mat->h * output_mat->w *
align_to(weight_mat->h * weight_mat->w * input_mat->elem_c, 128);
mat_map[col_mat_name] =
std::make_shared<Mat>(1, 1, len, bnn::DataType::Bit);
}
col_mat = mat(col_mat_name);

if (net.lock()->optimize && !direct_conv_compatible() &&
gemm_compatible()) {
const auto trans_weight_mat_name = "trans_" + weight;
// transpose the weight for bgemm
const int m = weight_mat->n;
Expand Down Expand Up @@ -126,7 +128,7 @@ void BinConv::forward_impl() const {
pack_mat(*input_mat, *binarized_mat);
pad(*binarized_mat, pad_h, pad_w, *padded_mat);
bconv_3x3(*padded_mat, *weight_mat, *output_mat, stride_h);
} else if (gemm_compatible()) {
} else {
output_mat->fill<float>(0.f);

bnn::fused_binarize_im2col(*input_mat, weight_mat->h, weight_mat->w,
Expand All @@ -136,14 +138,15 @@ void BinConv::forward_impl() const {
const int m = weight_mat->n;
const int n = output_mat->h * output_mat->w;
const int k = weight_mat->total() / weight_mat->n;
bgemm(m, n, k, static_cast<uint64_t *>(transposed_weight_mat->data),
m, static_cast<uint64_t *>(col_mat->data), k,
static_cast<float *>(output_mat->data), m);
} else {
pack_mat(*input_mat, *binarized_mat);
baseline_bconv(*binarized_mat, *weight_mat, weight_mat->h,
weight_mat->w, pad_h, pad_w, stride_h, stride_w, 1,
1, output_mat->c, *output_mat);
if (gemm_compatible()) {
bgemm(m, n, k, static_cast<uint64_t *>(transposed_weight_mat->data),
m, static_cast<uint64_t *>(col_mat->data), k,
static_cast<float *>(output_mat->data), m);
} else {
bgemm_naive(m, n, k, static_cast<uint64_t *>(transposed_weight_mat->data),
m, static_cast<uint64_t *>(col_mat->data), k,
static_cast<float *>(output_mat->data), m);
}
}
} else {
pack_mat(*input_mat, *binarized_mat);
Expand Down

0 comments on commit be606fa

Please sign in to comment.