Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix non 128 bug #51

Merged
merged 14 commits into from
Aug 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
6 changes: 6 additions & 0 deletions .daq_pm/configs/run_net_x86
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# It is a configuration file for [project_manager.vim](https://github.com/daquexian/project_manager.vim)
name binary-nn
type cpp
build_dir build_main_x86
cmake_options -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DCMAKE_BUILD_TYPE=Debug -DBNN_BUILD_TEST=OFF -DBNN_BUILD_BENCHMARK=OFF -DBNN_BUILD_MAIN_LIB=ON
target run
5 changes: 5 additions & 0 deletions dabnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ target_include_directories(dabnn
${CMAKE_CURRENT_BINARY_DIR}
${PROJECT_SOURCE_DIR}
)
target_include_directories(dabnn
SYSTEM
PUBLIC
${PROJECT_SOURCE_DIR}/third_party/eigen
)
target_link_libraries(dabnn
glog::glog
flatbuffers
Expand Down
91 changes: 63 additions & 28 deletions dabnn/layers/BinConv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,16 @@ BinConv::BinConv(NetCP net, const std::string &name, css input, css weight,
stride_h(stride_h),
stride_w(stride_w) {
auto &mat_map = net.lock()->mat_map_;
const auto binaized_name = "binaized_for_" + output + "_cal";
if (mat_map.find(binaized_name) == mat_map.end()) {
auto &input_mat = *mat_map[input];
mat_map[binaized_name] =
std::make_shared<Mat>(input_mat.h, input_mat.w, input_mat.elem_c,
DataType::Bit, binaized_name);
if (method() == Method::DIRECT_CONV || method() == Method::BCONV_NAIVE) {
const auto binaized_name = "binaized_for_" + output + "_cal";
if (mat_map.find(binaized_name) == mat_map.end()) {
auto &input_mat = *mat_map[input];
mat_map[binaized_name] = std::make_shared<Mat>(
input_mat.h, input_mat.w, input_mat.elem_c, DataType::Bit,
binaized_name);
}
binarized_mat = mat(binaized_name);
}
binarized_mat = mat(binaized_name);

const auto pad_name = "pad_for_" + output + "_cal";
if (mat_map.find(pad_name) == mat_map.end()) {
Expand All @@ -43,18 +45,17 @@ BinConv::BinConv(NetCP net, const std::string &name, css input, css weight,
}
padded_mat = mat(pad_name);

const auto col_mat_name = "col_for_" + output + "_cal";
if (mat_map.find(col_mat_name) == mat_map.end()) {
const auto len =
output_mat->h * output_mat->w *
align_to(weight_mat->h * weight_mat->w * input_mat->elem_c, 128);
mat_map[col_mat_name] =
std::make_shared<Mat>(1, 1, len, bnn::DataType::Bit);
}
col_mat = mat(col_mat_name);

if (net.lock()->optimize && !direct_conv_compatible() &&
gemm_compatible()) {
if (method() == Method::BGEMM || method() == Method::BGEMM_NAIVE) {
const auto col_mat_name = "col_for_" + output + "_cal";
if (mat_map.find(col_mat_name) == mat_map.end()) {
const auto len =
output_mat->h * output_mat->w *
align_to(weight_mat->h * weight_mat->w * input_mat->elem_c,
128);
mat_map[col_mat_name] =
std::make_shared<Mat>(1, 1, len, bnn::DataType::Bit);
}
col_mat = mat(col_mat_name);
const auto trans_weight_mat_name = "trans_" + weight;
// transpose the weight for bgemm
const int m = weight_mat->n;
Expand All @@ -76,6 +77,24 @@ BinConv::BinConv(NetCP net, const std::string &name, css input, css weight,
}
}

BinConv::Method BinConv::method() const {
if (net_.lock()->optimize) {
if (direct_conv_compatible()) {
return Method::DIRECT_CONV;
} else if (gemm_compatible()) {
return Method::BGEMM;
} else {
return Method::BCONV_NAIVE;
}
} else {
if (weight_mat->c == 1) {
return Method::BCONV_NAIVE;
} else {
return Method::BGEMM_NAIVE;
}
}
}

bool BinConv::direct_conv_compatible() const {
#ifdef __aarch64__
if (weight_mat->h == 3 && weight_mat->w == 3 && input_mat->elem_c == 64 &&
Expand Down Expand Up @@ -121,12 +140,14 @@ bool BinConv::gemm_compatible() const {
}

void BinConv::forward_impl() const {
if (net_.lock()->optimize) {
if (direct_conv_compatible()) {
switch (method()) {
case Method::DIRECT_CONV: {
pack_mat(*input_mat, *binarized_mat);
pad(*binarized_mat, pad_h, pad_w, *padded_mat);
bconv_3x3(*padded_mat, *weight_mat, *output_mat, stride_h);
} else if (gemm_compatible()) {
break;
}
case Method::BGEMM: {
output_mat->fill<float>(0.f);

bnn::fused_binarize_im2col(*input_mat, weight_mat->h, weight_mat->w,
Expand All @@ -139,17 +160,31 @@ void BinConv::forward_impl() const {
bgemm(m, n, k, static_cast<uint64_t *>(transposed_weight_mat->data),
m, static_cast<uint64_t *>(col_mat->data), k,
static_cast<float *>(output_mat->data), m);
} else {
break;
}
case Method::BGEMM_NAIVE: {
output_mat->fill<float>(0.f);

bnn::fused_binarize_im2col(*input_mat, weight_mat->h, weight_mat->w,
pad_h, pad_w, stride_h, stride_w, 1, 1,
*col_mat);

const int m = weight_mat->n;
const int n = output_mat->h * output_mat->w;
const int k = weight_mat->total() / weight_mat->n;
bgemm_naive(m, n, k,
static_cast<uint64_t *>(transposed_weight_mat->data), m,
static_cast<uint64_t *>(col_mat->data), k,
static_cast<float *>(output_mat->data), m);
break;
}
case Method::BCONV_NAIVE: {
pack_mat(*input_mat, *binarized_mat);
baseline_bconv(*binarized_mat, *weight_mat, weight_mat->h,
weight_mat->w, pad_h, pad_w, stride_h, stride_w, 1,
1, output_mat->c, *output_mat);
break;
}
} else {
pack_mat(*input_mat, *binarized_mat);
baseline_bconv(*binarized_mat, *weight_mat, weight_mat->h,
weight_mat->w, pad_h, pad_w, stride_h, stride_w, 1, 1,
output_mat->c, *output_mat);
}
}

Expand Down
7 changes: 7 additions & 0 deletions dabnn/layers/BinConv.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,15 @@ class BinConv : public Layer {
virtual std::string to_str() const;

private:
enum Method {
DIRECT_CONV = 0,
BGEMM,
BCONV_NAIVE,
BGEMM_NAIVE
};
bool direct_conv_compatible() const;
bool gemm_compatible() const;
Method method() const;
};
} // namespace bnn

Expand Down
20 changes: 8 additions & 12 deletions dabnn/mat.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,6 @@ inline Mat::Mat(int _n, int _w, int _h, int _c, void *_data, DataType data_type,
", ", h, ", ", c);
}
elemsize = data_type == DataType::Float ? sizeof(float) : sizeof(uint64_t);
BNN_ASSERT(c > 0, c);
std::stringstream ss;
ss << "Not align, w: " << w << ", c: " << c << ", elemsize: " << elemsize;
BNN_ASSERT(!require_align || w * c == 1 || w * c * elemsize % 16 == 0,
Expand All @@ -283,7 +282,10 @@ inline Mat::Mat(int _n, int _w, int _h, int _c, void *_data, DataType data_type,
} else {
hstep = w * c;
}
BNN_ASSERT(hstep > 0, hstep);
if (data_num == 0) {
BNN_ASSERT(c > 0, c);
BNN_ASSERT(hstep > 0, hstep);
}

external_memory = true;
}
Expand Down Expand Up @@ -529,11 +531,6 @@ inline void Mat::create(int _w, int _h, int _c, DataType _data_type) {
h = _h;
c = _c;

if (w * c != 1 && w * c * elemsize % 16 != 0) {
LOG(FATAL) << "Not align, w: " << w << ", c: " << c
<< ", elemsize: " << elemsize;
throw std::invalid_argument("Not align!");
}
hstep = ncnn::alignSize(w * c * elemsize, 16) / elemsize;

if (total() > 0) {
Expand Down Expand Up @@ -563,11 +560,6 @@ inline void Mat::create(int _n, int _w, int _h, int _c, DataType _data_type,
if (h != 0) dims++;
if (c != 0) dims++;

if (require_align && w * c != 1 && w * c * elemsize % 16 != 0) {
LOG(FATAL) << "Not align, w: " << w << ", c: " << c
<< ", elemsize: " << elemsize;
throw std::invalid_argument("Not align!");
}
if (require_align) {
hstep = ncnn::alignSize(w * c * elemsize, 16) / elemsize;
} else {
Expand Down Expand Up @@ -612,24 +604,28 @@ inline size_t Mat::total() const {

template <typename T>
inline const T *Mat::point(int _n, int _h, int _w) const {
BNN_ASSERT(w * c == 1 || w * c * elemsize % 16 == 0, "");
BNN_ASSERT((_n == 0 && _h == 0 && _w == 0) || hstep > 0, hstep);
return (T *)data + _n * h * hstep + _h * hstep + _w * c;
}

template <typename T>
inline const T *Mat::point(int _h, int _w) const {
BNN_ASSERT(w * c == 1 || w * c * elemsize % 16 == 0, "");
BNN_ASSERT((_h == 0 && _w == 0) || hstep > 0, hstep);
return (T *)data + _h * hstep + _w * c;
}

template <typename T>
inline T *Mat::point(int _n, int _h, int _w) {
BNN_ASSERT(w * c == 1 || w * c * elemsize % 16 == 0, "");
BNN_ASSERT((_n == 0 && _h == 0 && _w == 0) || hstep > 0, hstep);
return (T *)data + _n * h * hstep + _h * hstep + _w * c;
}

template <typename T>
inline T *Mat::point(int _h, int _w) {
BNN_ASSERT(w * c == 1 || w * c * elemsize % 16 == 0, "");
BNN_ASSERT((_h == 0 && _w == 0) || hstep > 0, hstep);
return (T *)data + _h * hstep + _w * c;
}
Expand Down
12 changes: 6 additions & 6 deletions tests/net_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ TEST(net, bireal18imagenet_comparison) {
std::shared_ptr<bnn::Mat> blob1, blob2;
{
auto net = bnn::Net::create();
net->read("/data/local/tmp/model_imagenet.dab");
net->optimize = false;
net->read("/data/local/tmp/model_imagenet.dab");
net->run(input);
blob1 = net->get_blob(blob_name);
}
{
auto net = bnn::Net::create();
net->read("/data/local/tmp/model_imagenet.dab");
net->optimize = true;
net->read("/data/local/tmp/model_imagenet.dab");
net->run(input);
blob2 = net->get_blob(blob_name);
}
Expand All @@ -56,8 +56,8 @@ TEST(net, bireal18imagenet) {
const std::string blob_name = "188";
{
auto net = bnn::Net::create();
net->read("/data/local/tmp/model_imagenet.dab");
net->optimize = true;
net->read("/data/local/tmp/model_imagenet.dab");
net->run(input);
const auto blob = net->get_blob(blob_name);
ASSERT_NEAR((*blob)[0], -0.9431, 1e-4);
Expand All @@ -74,15 +74,15 @@ TEST(net, bireal18imagenetstem_comparison) {
std::shared_ptr<bnn::Mat> blob1, blob2;
{
auto net = bnn::Net::create();
net->read("/data/local/tmp/model_imagenet_stem.dab");
net->optimize = false;
net->read("/data/local/tmp/model_imagenet_stem.dab");
net->run(input);
blob1 = net->get_blob(blob_name);
}
{
auto net = bnn::Net::create();
net->read("/data/local/tmp/model_imagenet_stem.dab");
net->optimize = true;
net->read("/data/local/tmp/model_imagenet_stem.dab");
net->run(input);
blob2 = net->get_blob(blob_name);
}
Expand All @@ -96,8 +96,8 @@ TEST(net, bireal18imagenetstem) {
const std::string blob_name = "216";
{
auto net = bnn::Net::create();
net->read("/data/local/tmp/model_imagenet_stem.dab");
net->optimize = true;
net->read("/data/local/tmp/model_imagenet_stem.dab");
net->run(input);
const auto &blob = net->get_blob(blob_name);
ASSERT_NEAR((*blob)[0], 1.9842, 1e-4);
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading