Skip to content

Commit

Permalink
use bgemm_naive instead of bconv_naive when non-optimize
Browse files Browse the repository at this point in the history
  • Loading branch information
daquexian committed Aug 22, 2019
1 parent 418c8d7 commit cec3499
Showing 1 changed file with 27 additions and 33 deletions.
60 changes: 27 additions & 33 deletions dabnn/layers/BinConv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ BinConv::BinConv(NetCP net, const std::string &name, css input, css weight,
const auto binaized_name = "binaized_for_" + output + "_cal";
if (mat_map.find(binaized_name) == mat_map.end()) {
auto &input_mat = *mat_map[input];
mat_map[binaized_name] =
std::make_shared<Mat>(input_mat.h, input_mat.w, input_mat.elem_c,
DataType::Bit, binaized_name);
mat_map[binaized_name] = std::make_shared<Mat>(
input_mat.h, input_mat.w, input_mat.elem_c, DataType::Bit,
binaized_name);
}
binarized_mat = mat(binaized_name);
}
Expand All @@ -46,12 +46,12 @@ BinConv::BinConv(NetCP net, const std::string &name, css input, css weight,
padded_mat = mat(pad_name);

if (net.lock()->optimize && !direct_conv_compatible()) {

const auto col_mat_name = "col_for_" + output + "_cal";
if (mat_map.find(col_mat_name) == mat_map.end()) {
const auto len =
output_mat->h * output_mat->w *
align_to(weight_mat->h * weight_mat->w * input_mat->elem_c, 128);
align_to(weight_mat->h * weight_mat->w * input_mat->elem_c,
128);
mat_map[col_mat_name] =
std::make_shared<Mat>(1, 1, len, bnn::DataType::Bit);
}
Expand Down Expand Up @@ -123,36 +123,30 @@ bool BinConv::gemm_compatible() const {
}

void BinConv::forward_impl() const {
if (net_.lock()->optimize) {
if (direct_conv_compatible()) {
pack_mat(*input_mat, *binarized_mat);
pad(*binarized_mat, pad_h, pad_w, *padded_mat);
bconv_3x3(*padded_mat, *weight_mat, *output_mat, stride_h);
if (net_.lock()->optimize && direct_conv_compatible()) {
pack_mat(*input_mat, *binarized_mat);
pad(*binarized_mat, pad_h, pad_w, *padded_mat);
bconv_3x3(*padded_mat, *weight_mat, *output_mat, stride_h);
} else {
output_mat->fill<float>(0.f);

bnn::fused_binarize_im2col(*input_mat, weight_mat->h, weight_mat->w,
pad_h, pad_w, stride_h, stride_w, 1, 1,
*col_mat);

const int m = weight_mat->n;
const int n = output_mat->h * output_mat->w;
const int k = weight_mat->total() / weight_mat->n;
if (net_.lock()->optimize && gemm_compatible()) {
bgemm(m, n, k, static_cast<uint64_t *>(transposed_weight_mat->data),
m, static_cast<uint64_t *>(col_mat->data), k,
static_cast<float *>(output_mat->data), m);
} else {
output_mat->fill<float>(0.f);

bnn::fused_binarize_im2col(*input_mat, weight_mat->h, weight_mat->w,
pad_h, pad_w, stride_h, stride_w, 1, 1,
*col_mat);

const int m = weight_mat->n;
const int n = output_mat->h * output_mat->w;
const int k = weight_mat->total() / weight_mat->n;
if (gemm_compatible()) {
bgemm(m, n, k, static_cast<uint64_t *>(transposed_weight_mat->data),
m, static_cast<uint64_t *>(col_mat->data), k,
static_cast<float *>(output_mat->data), m);
} else {
bgemm_naive(m, n, k, static_cast<uint64_t *>(transposed_weight_mat->data),
m, static_cast<uint64_t *>(col_mat->data), k,
static_cast<float *>(output_mat->data), m);
}
bgemm_naive(m, n, k,
static_cast<uint64_t *>(transposed_weight_mat->data), m,
static_cast<uint64_t *>(col_mat->data), k,
static_cast<float *>(output_mat->data), m);
}
} else {
pack_mat(*input_mat, *binarized_mat);
baseline_bconv(*binarized_mat, *weight_mat, weight_mat->h,
weight_mat->w, pad_h, pad_w, stride_h, stride_w, 1, 1,
output_mat->c, *output_mat);
}
}

Expand Down

0 comments on commit cec3499

Please sign in to comment.