Skip to content

Commit

Permalink
Fix wrong method in some cases
Browse files Browse the repository at this point in the history
  • Loading branch information
daquexian committed Aug 22, 2019
1 parent 3d58980 commit fcddeb4
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 28 deletions.
91 changes: 63 additions & 28 deletions dabnn/layers/BinConv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,16 @@ BinConv::BinConv(NetCP net, const std::string &name, css input, css weight,
stride_h(stride_h),
stride_w(stride_w) {
auto &mat_map = net.lock()->mat_map_;
const auto binaized_name = "binaized_for_" + output + "_cal";
if (mat_map.find(binaized_name) == mat_map.end()) {
auto &input_mat = *mat_map[input];
mat_map[binaized_name] =
std::make_shared<Mat>(input_mat.h, input_mat.w, input_mat.elem_c,
DataType::Bit, binaized_name);
if (method() == Method::DIRECT_CONV || method() == Method::BCONV_NAIVE) {
const auto binaized_name = "binaized_for_" + output + "_cal";
if (mat_map.find(binaized_name) == mat_map.end()) {
auto &input_mat = *mat_map[input];
mat_map[binaized_name] = std::make_shared<Mat>(
input_mat.h, input_mat.w, input_mat.elem_c, DataType::Bit,
binaized_name);
}
binarized_mat = mat(binaized_name);
}
binarized_mat = mat(binaized_name);

const auto pad_name = "pad_for_" + output + "_cal";
if (mat_map.find(pad_name) == mat_map.end()) {
Expand All @@ -43,18 +45,17 @@ BinConv::BinConv(NetCP net, const std::string &name, css input, css weight,
}
padded_mat = mat(pad_name);

const auto col_mat_name = "col_for_" + output + "_cal";
if (mat_map.find(col_mat_name) == mat_map.end()) {
const auto len =
output_mat->h * output_mat->w *
align_to(weight_mat->h * weight_mat->w * input_mat->elem_c, 128);
mat_map[col_mat_name] =
std::make_shared<Mat>(1, 1, len, bnn::DataType::Bit);
}
col_mat = mat(col_mat_name);

if (net.lock()->optimize && !direct_conv_compatible() &&
gemm_compatible()) {
if (method() == Method::BGEMM || method() == Method::BGEMM_NAIVE) {
const auto col_mat_name = "col_for_" + output + "_cal";
if (mat_map.find(col_mat_name) == mat_map.end()) {
const auto len =
output_mat->h * output_mat->w *
align_to(weight_mat->h * weight_mat->w * input_mat->elem_c,
128);
mat_map[col_mat_name] =
std::make_shared<Mat>(1, 1, len, bnn::DataType::Bit);
}
col_mat = mat(col_mat_name);
const auto trans_weight_mat_name = "trans_" + weight;
// transpose the weight for bgemm
const int m = weight_mat->n;
Expand All @@ -76,6 +77,24 @@ BinConv::BinConv(NetCP net, const std::string &name, css input, css weight,
}
}

BinConv::Method BinConv::method() const {
if (net_.lock()->optimize) {
if (direct_conv_compatible()) {
return Method::DIRECT_CONV;
} else if (gemm_compatible()) {
return Method::BGEMM;
} else {
return Method::BCONV_NAIVE;
}
} else {
if (weight_mat->c == 1) {
return Method::BCONV_NAIVE;
} else {
return Method::BGEMM_NAIVE;
}
}
}

bool BinConv::direct_conv_compatible() const {
#ifdef __aarch64__
if (weight_mat->h == 3 && weight_mat->w == 3 && input_mat->elem_c == 64 &&
Expand Down Expand Up @@ -121,12 +140,14 @@ bool BinConv::gemm_compatible() const {
}

void BinConv::forward_impl() const {
if (net_.lock()->optimize) {
if (direct_conv_compatible()) {
switch (method()) {
case Method::DIRECT_CONV: {
pack_mat(*input_mat, *binarized_mat);
pad(*binarized_mat, pad_h, pad_w, *padded_mat);
bconv_3x3(*padded_mat, *weight_mat, *output_mat, stride_h);
} else if (gemm_compatible()) {
break;
}
case Method::BGEMM: {
output_mat->fill<float>(0.f);

bnn::fused_binarize_im2col(*input_mat, weight_mat->h, weight_mat->w,
Expand All @@ -139,17 +160,31 @@ void BinConv::forward_impl() const {
bgemm(m, n, k, static_cast<uint64_t *>(transposed_weight_mat->data),
m, static_cast<uint64_t *>(col_mat->data), k,
static_cast<float *>(output_mat->data), m);
} else {
break;
}
case Method::BGEMM_NAIVE: {
output_mat->fill<float>(0.f);

bnn::fused_binarize_im2col(*input_mat, weight_mat->h, weight_mat->w,
pad_h, pad_w, stride_h, stride_w, 1, 1,
*col_mat);

const int m = weight_mat->n;
const int n = output_mat->h * output_mat->w;
const int k = weight_mat->total() / weight_mat->n;
bgemm_naive(m, n, k,
static_cast<uint64_t *>(transposed_weight_mat->data), m,
static_cast<uint64_t *>(col_mat->data), k,
static_cast<float *>(output_mat->data), m);
break;
}
case Method::BCONV_NAIVE: {
pack_mat(*input_mat, *binarized_mat);
baseline_bconv(*binarized_mat, *weight_mat, weight_mat->h,
weight_mat->w, pad_h, pad_w, stride_h, stride_w, 1,
1, output_mat->c, *output_mat);
break;
}
} else {
pack_mat(*input_mat, *binarized_mat);
baseline_bconv(*binarized_mat, *weight_mat, weight_mat->h,
weight_mat->w, pad_h, pad_w, stride_h, stride_w, 1, 1,
output_mat->c, *output_mat);
}
}

Expand Down
7 changes: 7 additions & 0 deletions dabnn/layers/BinConv.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,15 @@ class BinConv : public Layer {
virtual std::string to_str() const;

private:
enum Method {
DIRECT_CONV = 0,
BGEMM,
BCONV_NAIVE,
BGEMM_NAIVE
};
bool direct_conv_compatible() const;
bool gemm_compatible() const;
Method method() const;
};
} // namespace bnn

Expand Down

0 comments on commit fcddeb4

Please sign in to comment.