From a169cec3635f4272cfc1ca0a5fa29fe766d8c101 Mon Sep 17 00:00:00 2001 From: nihui Date: Sun, 29 Jul 2018 16:19:01 +0800 Subject: [PATCH] core int8 inference, quantize and dequantize, net using flag, caffe2ncnn reads int8 scale table --- src/CMakeLists.txt | 2 + src/blob.cpp | 1 + src/blob.h | 4 + src/layer.h | 3 + src/layer/arm/convolution_arm.cpp | 6 + src/layer/arm/convolutiondepthwise_arm.cpp | 6 + src/layer/arm/innerproduct_arm.cpp | 6 + src/layer/convolution.cpp | 134 ++++++++++ src/layer/convolution.h | 7 + src/layer/convolutiondepthwise.cpp | 177 +++++++++++++ src/layer/convolutiondepthwise.h | 7 + src/layer/dequantize.cpp | 120 +++++++++ src/layer/dequantize.h | 43 ++++ src/layer/innerproduct.cpp | 117 +++++++++ src/layer/innerproduct.h | 7 + src/layer/quantize.cpp | 113 +++++++++ src/layer/quantize.h | 37 +++ src/layer/x86/convolution_x86.cpp | 6 + src/layer/x86/convolutiondepthwise_x86.cpp | 6 + src/modelbin.cpp | 51 ++++ src/modelbin.h | 2 +- src/net.cpp | 177 ++++++++++++- src/net.h | 5 + src/paramdict.cpp | 4 + src/paramdict.h | 5 + tools/caffe/caffe2ncnn.cpp | 277 +++++++++++++++++---- 26 files changed, 1265 insertions(+), 58 deletions(-) create mode 100644 src/layer/dequantize.cpp create mode 100644 src/layer/dequantize.h create mode 100644 src/layer/quantize.cpp create mode 100644 src/layer/quantize.h diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index e5433c1a486..a391925e156 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -148,6 +148,8 @@ ncnn_add_layer(InstanceNorm) ncnn_add_layer(Clip) ncnn_add_layer(Reorg) ncnn_add_layer(YoloDetectionOutput) +ncnn_add_layer(Quantize) +ncnn_add_layer(Dequantize) add_library(ncnn STATIC ${ncnn_SRCS}) diff --git a/src/blob.cpp b/src/blob.cpp index 8af899fb799..66652191d0a 100644 --- a/src/blob.cpp +++ b/src/blob.cpp @@ -19,6 +19,7 @@ namespace ncnn { Blob::Blob() { producer = -1; + int8_scale = 0.f; } } // namespace ncnn diff --git a/src/blob.h b/src/blob.h index 31f2c1d48d7..4f1ab76f8e9 100644 --- a/src/blob.h +++ b/src/blob.h @@ -36,6 +36,10 @@ class Blob int producer; // layer index which need this blob as input std::vector consumers; + +public: + // int8 quantize scale of this blob + float int8_scale; }; } // namespace ncnn diff --git a/src/layer.h b/src/layer.h index b46bf177d58..d24203a956b 100644 --- a/src/layer.h +++ b/src/layer.h @@ -36,6 +36,9 @@ class Option int num_threads; Allocator* blob_allocator; Allocator* workspace_allocator; + +public: + std::vector int8_scales; }; const Option& get_default_option(); diff --git a/src/layer/arm/convolution_arm.cpp b/src/layer/arm/convolution_arm.cpp index 7e12044e0ec..8951915bf52 100644 --- a/src/layer/arm/convolution_arm.cpp +++ b/src/layer/arm/convolution_arm.cpp @@ -194,6 +194,12 @@ int Convolution_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option // convolv with NxN kernel // value = value + bias + if (use_int8_inference) + { + // TODO + return Convolution::forward(bottom_blob, top_blob, opt); + } + if (bottom_blob.dims != 3) { return Convolution::forward(bottom_blob, top_blob, opt); diff --git a/src/layer/arm/convolutiondepthwise_arm.cpp b/src/layer/arm/convolutiondepthwise_arm.cpp index 29402d0c8a9..2aa547dd2d8 100644 --- a/src/layer/arm/convolutiondepthwise_arm.cpp +++ b/src/layer/arm/convolutiondepthwise_arm.cpp @@ -107,6 +107,12 @@ int ConvolutionDepthWise_arm::forward(const Mat& bottom_blob, Mat& top_blob, con // convolv with NxN kernel // value = value + bias + if (use_int8_inference) + { + // TODO + return ConvolutionDepthWise::forward(bottom_blob, top_blob, opt); + } + int w = bottom_blob.w; int h = bottom_blob.h; int channels = bottom_blob.c; diff --git a/src/layer/arm/innerproduct_arm.cpp b/src/layer/arm/innerproduct_arm.cpp index 5005ea7da54..d7dc429a057 100644 --- a/src/layer/arm/innerproduct_arm.cpp +++ b/src/layer/arm/innerproduct_arm.cpp @@ -24,6 +24,12 @@ DEFINE_LAYER_CREATOR(InnerProduct_arm) int InnerProduct_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { + if (use_int8_inference) + { + // TODO + return InnerProduct::forward(bottom_blob, top_blob, opt); + } + int w = bottom_blob.w; int h = bottom_blob.h; int channels = bottom_blob.c; diff --git a/src/layer/convolution.cpp b/src/layer/convolution.cpp index d4757f1512a..62498da0962 100644 --- a/src/layer/convolution.cpp +++ b/src/layer/convolution.cpp @@ -24,6 +24,15 @@ Convolution::Convolution() { one_blob_only = true; support_inplace = false; + + quantize = 0; + dequantize = 0; +} + +Convolution::~Convolution() +{ + delete quantize; + delete dequantize; } int Convolution::load_param(const ParamDict& pd) @@ -39,6 +48,9 @@ int Convolution::load_param(const ParamDict& pd) pad_h = pd.get(14, pad_w); bias_term = pd.get(5, 0); weight_data_size = pd.get(6, 0); + weight_data_int8_scale = pd.get(8, 0.f); + + use_int8_inference = pd.use_int8_inference; return 0; } @@ -56,6 +68,46 @@ int Convolution::load_model(const ModelBin& mb) return -100; } + bool weight_data_is_int8 = (weight_data.elemsize == (size_t)1u); + bool weight_data_is_float32 = (weight_data.elemsize == (size_t)4u); + + if (weight_data_is_int8 && !use_int8_inference) + { + fprintf(stderr, "quantized int8 weight loaded but use_int8_inference disabled\n"); + return -1; + } + + if (use_int8_inference) + { + quantize = ncnn::create_layer(ncnn::LayerType::Quantize); + dequantize = ncnn::create_layer(ncnn::LayerType::Dequantize); + } + + if (weight_data_is_float32 && use_int8_inference) + { + if (weight_data_int8_scale != 0.f) + { + // quantize weight to int8 + ncnn::ParamDict pd; + pd.set(0, weight_data_int8_scale);// scale + + quantize->load_param(pd); + + Mat int8_weight_data; + quantize->forward(weight_data, int8_weight_data); + + if (int8_weight_data.empty()) + return -100; + + weight_data = int8_weight_data; + } + else + { + // plain float32 weight, fallback to float32 inference + use_int8_inference = false; + } + } + return 0; } @@ -78,6 +130,9 @@ int Convolution::forward(const Mat& bottom_blob, Mat& top_blob, const Option& op pd.set(0, num_output); pd.set(1, bias_term); pd.set(2, weight_data_size); + pd.set(8, weight_data_int8_scale); + + pd.use_int8_inference = use_int8_inference; op->load_param(pd); @@ -160,6 +215,85 @@ int Convolution::forward(const Mat& bottom_blob, Mat& top_blob, const Option& op } } + if (use_int8_inference) + { + Mat bottom_blob_bordered_int8; + bottom_blob_bordered_int8.create(w, h, channels, (size_t)1u, opt.workspace_allocator); + if (bottom_blob_bordered_int8.empty()) + return -100; + + float bottom_scale = opt.int8_scales[0]; +// fprintf(stderr, "bottom_scale = %f\n", bottom_scale); + + // quantize, scale and round to nearest + { + ncnn::ParamDict pd; + pd.set(0, bottom_scale);// scale + + quantize->load_param(pd); + + quantize->forward(bottom_blob_bordered, bottom_blob_bordered_int8, opt); + } + + // num_output + #pragma omp parallel for num_threads(opt.num_threads) + for (int p=0; p(i*stride_h) + j*stride_w; + + for (int k = 0; k < maxk; k++) + { + int val = sptr[ space_ofs[k] ]; + int w = kptr[k]; + sum += val * w; + } + + kptr += maxk; + } + + outptr[j] = sum; + } + + outptr += outw; + } + } + + // dequantize, reverse scale inplace + { + float top_rescale = 1.f / (bottom_scale * weight_data_int8_scale); + + ncnn::ParamDict pd; + pd.set(0, top_rescale);// scale + pd.set(1, bias_term);// bias_term + pd.set(2, num_output);// bias_data_size + + dequantize->load_param(pd); + + ncnn::Mat weights[1]; + weights[0] = bias_data; + + dequantize->load_model(ModelBinFromMatArray(weights)); + + dequantize->forward_inplace(top_blob, opt); + } + + return 0; + } + // num_output #pragma omp parallel for num_threads(opt.num_threads) for (int p=0; pload_param(pd); + + Mat int8_weight_data; + quantize->forward(weight_data, int8_weight_data); + + if (int8_weight_data.empty()) + return -100; + + weight_data = int8_weight_data; + } + else + { + // plain float32 weight, fallback to float32 inference + use_int8_inference = false; + } + } + return 0; } @@ -138,6 +192,129 @@ int ConvolutionDepthWise::forward(const Mat& bottom_blob, Mat& top_blob, const O } } + if (use_int8_inference) + { + Mat bottom_blob_bordered_int8; + bottom_blob_bordered_int8.create(w, h, channels, (size_t)1u, opt.workspace_allocator); + if (bottom_blob_bordered_int8.empty()) + return -100; + + float bottom_scale = opt.int8_scales[0]; + + // quantize, scale and round to nearest + { + ncnn::ParamDict pd; + pd.set(0, bottom_scale);// scale + + quantize->load_param(pd); + + quantize->forward(bottom_blob_bordered, bottom_blob_bordered_int8, opt); + } + + // depth-wise + if (channels == group && group == num_output) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int g=0; g(i*stride_h) + j*stride_w; + + for (int k = 0; k < maxk; k++) + { + signed char val = sptr[ space_ofs[k] ]; + signed char w = kptr[k]; + sum += val * w; + } + + outptr[j] = sum; + } + + outptr += outw; + } + } + } + else + { + const int channels_g = channels / group; + const int num_output_g = num_output / group; + +#ifdef _WIN32 + #pragma omp parallel for num_threads(opt.num_threads) +#else // _WIN32 + #pragma omp parallel for collapse(2) num_threads(opt.num_threads) +#endif // _WIN32 + for (int g=0; g(i*stride_h) + j*stride_w; + + for (int k = 0; k < maxk; k++) + { + signed char val = sptr[ space_ofs[k] ]; + signed char w = kptr[k]; + sum += val * w; + } + + kptr += maxk; + } + + outptr[j] = sum; + } + + outptr += outw; + } + } + } + } + + // dequantize, reverse scale inplace + { + float top_rescale = 1.f / (bottom_scale * weight_data_int8_scale); + + ncnn::ParamDict pd; + pd.set(0, top_rescale);// scale + pd.set(1, bias_term);// bias_term + pd.set(2, num_output);// bias_data_size + + dequantize->load_param(pd); + + ncnn::Mat weights[1]; + weights[0] = bias_data; + + dequantize->load_model(ModelBinFromMatArray(weights)); + + dequantize->forward_inplace(top_blob, opt); + } + + return 0; + } + // depth-wise if (channels == group && group == num_output) { diff --git a/src/layer/convolutiondepthwise.h b/src/layer/convolutiondepthwise.h index 0327eea1078..186bf3c946e 100644 --- a/src/layer/convolutiondepthwise.h +++ b/src/layer/convolutiondepthwise.h @@ -23,6 +23,7 @@ class ConvolutionDepthWise : public Layer { public: ConvolutionDepthWise(); + ~ConvolutionDepthWise(); virtual int load_param(const ParamDict& pd); @@ -45,10 +46,16 @@ class ConvolutionDepthWise : public Layer int weight_data_size; int group; + float weight_data_int8_scale; // model Mat weight_data; Mat bias_data; + + bool use_int8_inference; + + ncnn::Layer* quantize; + ncnn::Layer* dequantize; }; } // namespace ncnn diff --git a/src/layer/dequantize.cpp b/src/layer/dequantize.cpp new file mode 100644 index 00000000000..ecc6bd8a695 --- /dev/null +++ b/src/layer/dequantize.cpp @@ -0,0 +1,120 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "dequantize.h" + +namespace ncnn { + +DEFINE_LAYER_CREATOR(Dequantize) + +Dequantize::Dequantize() +{ + one_blob_only = true; + support_inplace = true; +} + +int Dequantize::load_param(const ParamDict& pd) +{ + scale = pd.get(0, 1.f); + bias_term = pd.get(1, 0); + bias_data_size = pd.get(2, 0); + + return 0; +} + +int Dequantize::load_model(const ModelBin& mb) +{ + if (bias_term) + { + bias_data = mb.load(bias_data_size, 1); + if (bias_data.empty()) + return -100; + } + + return 0; +} + +int Dequantize::forward_inplace(Mat& bottom_top_blob, const Option& opt) const +{ + int dims = bottom_top_blob.dims; + + if (dims == 1) + { + fprintf(stderr, "dim=1 bias_data_size=%d\n", bias_data_size); + int w = bottom_top_blob.w; + + int* intptr = bottom_top_blob; + float* ptr = bottom_top_blob; + + if (bias_term) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int i=0; iload_param(pd); + + Mat int8_weight_data; + quantize->forward(weight_data, int8_weight_data); + + if (int8_weight_data.empty()) + return -100; + + weight_data = int8_weight_data; + } + else + { + // plain float32 weight, fallback to float32 inference + use_int8_inference = false; + } + } + return 0; } @@ -61,6 +115,69 @@ int InnerProduct::forward(const Mat& bottom_blob, Mat& top_blob, const Option& o if (top_blob.empty()) return -100; + if (use_int8_inference) + { + Mat bottom_blob_int8; + bottom_blob_int8.create(w, h, channels, (size_t)1u, opt.workspace_allocator); + if (bottom_blob_int8.empty()) + return -100; + + float bottom_scale = opt.int8_scales[0]; +// fprintf(stderr, "bottom_scale = %f\n", bottom_scale); + + // quantize, scale and round to nearest + { + ncnn::ParamDict pd; + pd.set(0, bottom_scale);// scale + + quantize->load_param(pd); + + quantize->forward(bottom_blob, bottom_blob_int8, opt); + } + + // num_output + #pragma omp parallel for num_threads(opt.num_threads) + for (int p=0; pload_param(pd); + + ncnn::Mat weights[1]; + weights[0] = bias_data; + + dequantize->load_model(ModelBinFromMatArray(weights)); + + dequantize->forward_inplace(top_blob, opt); + } + + return 0; + } + // num_output #pragma omp parallel for num_threads(opt.num_threads) for (int p=0; p + +namespace ncnn { + +DEFINE_LAYER_CREATOR(Quantize) + +Quantize::Quantize() +{ + one_blob_only = true; + support_inplace = false; +} + +int Quantize::load_param(const ParamDict& pd) +{ + scale = pd.get(0, 1.f); + + return 0; +} + +static inline signed char float2int8(float v) +{ + int int32 = round(v); + if (int32 > 127) return 127; + if (int32 < -128) return -128; + return (signed char)int32; +} + +int Quantize::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const +{ + int dims = bottom_blob.dims; + + if (dims == 1) + { + int w = bottom_blob.w; + + top_blob.create(w, (size_t)1u, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + const float* ptr = bottom_blob; + signed char* outptr = top_blob; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i=0; i int8_weights; + int8_weights.resize(align_data_size); + nread = fread(int8_weights.data(), align_data_size, 1, binfp); + if (nread != 1) + { + fprintf(stderr, "ModelBin read int8_weights failed %d\n", nread); + return Mat(); + } + + Mat m(w, (size_t)1u); + if (m.empty()) + return m; + + memcpy(m.data, int8_weights.data(), w); + + return m; + } + else if (flag_struct.tag == 0x0002C056) + { + Mat m(w); + if (m.empty()) + return m; + + // raw data with extra scaling + nread = fread(m, w * sizeof(float), 1, binfp); + if (nread != 1) + { + fprintf(stderr, "ModelBin read weight_data failed %d\n", nread); + return Mat(); + } + + return m; + } Mat m(w); if (m.empty()) @@ -195,6 +232,20 @@ Mat ModelBinFromMemory::load(int w, int type) const mem += alignSize(w * sizeof(unsigned short), 4); return m; } + else if (flag_struct.tag == 0x000D4B38) + { + // int8 data + Mat m = Mat(w, (signed char*)mem, 1u); + mem += alignSize(w, 4); + return m; + } + else if (flag_struct.tag == 0x0002C056) + { + // raw data with extra scaling + Mat m = Mat(w, (float*)mem); + mem += w * sizeof(float); + return m; + } if (flag != 0) { diff --git a/src/modelbin.h b/src/modelbin.h index 8eaa8ae9c49..3237beed16d 100644 --- a/src/modelbin.h +++ b/src/modelbin.h @@ -29,7 +29,7 @@ class ModelBin // 0 = auto // 1 = float32 // 2 = float16 - // 3 = uint8 + // 3 = int8 // load vec virtual Mat load(int w, int type) const = 0; // load image diff --git a/src/net.cpp b/src/net.cpp index e60efe04a45..ff8bf8088f2 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -32,6 +32,9 @@ namespace ncnn { Net::Net() { + use_winograd_convolution = 1; + use_sgemm_convolution = 1; + use_int8_inference = 1; } Net::~Net() @@ -115,10 +118,12 @@ int Net::load_param(FILE* fp) blobs.resize(blob_count); ParamDict pd; + pd.use_winograd_convolution = use_winograd_convolution; + pd.use_sgemm_convolution = use_sgemm_convolution; + pd.use_int8_inference = use_int8_inference; - int layer_index = 0; int blob_index = 0; - while (!feof(fp)) + for (int i=0; itype = std::string(layer_type); layer->name = std::string(layer_name); -// fprintf(stderr, "new layer %d %s\n", layer_index, layer_name); +// fprintf(stderr, "new layer %d %s\n", i, layer_name); layer->bottoms.resize(bottom_count); - for (int i=0; ibottoms[i] = bottom_blob_index; + layer->bottoms[j] = bottom_blob_index; } layer->tops.resize(top_count); - for (int i=0; itops[i] = blob_index; + layer->tops[j] = blob_index; blob_index++; } @@ -215,9 +221,61 @@ int Net::load_param(FILE* fp) continue; } - layers[layer_index] = layer; + layers[i] = layer; + } + + while (!feof(fp)) + { + int nscan = 0; + + char blob_name[257]; + nscan = fscanf(fp, "%256s", blob_name); + if (nscan != 1) + { + continue; + } + + int blob_index = find_blob_index_by_name(blob_name); + if (blob_index == -1) + { + clear(); + return -1; + } + + Blob& blob = blobs[blob_index]; + + // blob specific params + int pdlr = pd.load_param(fp); + if (pdlr != 0) + { + fprintf(stderr, "ParamDict load_param failed\n"); + continue; + } + + // set blob params + blob.int8_scale = pd.get(0, 0.f); + } + + // fill all blob int8_scale + if (use_int8_inference) + { + for (int i=0; ibottoms.size() != 1) + continue; - layer_index++; + const Blob& prev_blob = blobs[layer->bottoms[0]]; + blob.int8_scale = prev_blob.int8_scale; + } } return 0; @@ -260,6 +318,9 @@ int Net::load_param_bin(FILE* fp) blobs.resize(blob_count); ParamDict pd; + pd.use_winograd_convolution = use_winograd_convolution; + pd.use_sgemm_convolution = use_sgemm_convolution; + pd.use_int8_inference = use_int8_inference; for (int i=0; ibottoms.size() != 1) + continue; + + const Blob& prev_blob = blobs[layer->bottoms[0]]; + blob.int8_scale = prev_blob.int8_scale; + } + } + return 0; } @@ -430,6 +533,9 @@ int Net::load_param(const unsigned char* _mem) blobs.resize(blob_count); ParamDict pd; + pd.use_winograd_convolution = use_winograd_convolution; + pd.use_sgemm_convolution = use_sgemm_convolution; + pd.use_int8_inference = use_int8_inference; for (int i=0; ibottoms.size() != 1) + continue; + + const Blob& prev_blob = blobs[layer->bottoms[0]]; + blob.int8_scale = prev_blob.int8_scale; + } + } + return mem - _mem; } @@ -643,6 +792,9 @@ int Net::forward_layer(int layer_index, std::vector& blob_mats, Option& opt Mat bottom_blob = blob_mats[bottom_blob_index]; + opt.int8_scales.resize(1); + opt.int8_scales[0] = blobs[bottom_blob_index].int8_scale; + if (opt.lightmode) { // delete after taken in light mode @@ -696,6 +848,7 @@ int Net::forward_layer(int layer_index, std::vector& blob_mats, Option& opt // load bottom blobs std::vector bottom_blobs; bottom_blobs.resize(layer->bottoms.size()); + opt.int8_scales.resize(layer->bottoms.size()); for (size_t i=0; ibottoms.size(); i++) { int bottom_blob_index = layer->bottoms[i]; @@ -709,6 +862,8 @@ int Net::forward_layer(int layer_index, std::vector& blob_mats, Option& opt bottom_blobs[i] = blob_mats[bottom_blob_index]; + opt.int8_scales[i] = blobs[bottom_blob_index].int8_scale; + if (opt.lightmode) { // delete after taken in light mode diff --git a/src/net.h b/src/net.h index 99b8d6fa036..f1d8629a44d 100644 --- a/src/net.h +++ b/src/net.h @@ -78,6 +78,11 @@ class Net // construct an Extractor from network Extractor create_extractor() const; +public: + int use_winograd_convolution; + int use_sgemm_convolution; + int use_int8_inference; + protected: friend class Extractor; #if NCNN_STRING diff --git a/src/paramdict.cpp b/src/paramdict.cpp index d901dabe0d5..6988a1a8f30 100644 --- a/src/paramdict.cpp +++ b/src/paramdict.cpp @@ -20,6 +20,10 @@ namespace ncnn { ParamDict::ParamDict() { + use_winograd_convolution = 1; + use_sgemm_convolution = 1; + use_int8_inference = 1; + clear(); } diff --git a/src/paramdict.h b/src/paramdict.h index 4271621db9a..ea6f1a0c85e 100644 --- a/src/paramdict.h +++ b/src/paramdict.h @@ -45,6 +45,11 @@ class ParamDict // set array void set(int id, const Mat& v); +public: + int use_winograd_convolution; + int use_sgemm_convolution; + int use_int8_inference; + protected: friend class Net; diff --git a/tools/caffe/caffe2ncnn.cpp b/tools/caffe/caffe2ncnn.cpp index 71c52454bb0..aa8d4faa3dc 100644 --- a/tools/caffe/caffe2ncnn.cpp +++ b/tools/caffe/caffe2ncnn.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -98,6 +99,64 @@ static unsigned short float2half(float value) return fp16; } +// round to nearest +static signed char float2int8(float value) +{ + float tmp; + if (value >= 0.f) tmp = value + 0.5; + else tmp = value - 0.5; + + if (tmp > 127) + return 127; + if (tmp < -128) + return -128; + + return tmp; +} + +static bool read_int8scale_table(const char* filepath, std::map& blob_int8scale_table, std::map& weight_int8scale_table) +{ + blob_int8scale_table.clear(); + weight_int8scale_table.clear(); + + FILE* fp = fopen(filepath, "rb"); + if (!fp) + { + fprintf(stderr, "fopen %s failed\n", filepath); + return false; + } + + char line[1024]; + while (!feof(fp)) + { + char* s = fgets(line, 1024, fp); + if (!s) + break; + + char key[256]; + float scale = 1.f; + int nscan = sscanf(line, "%255s %f", key, &scale); + if (nscan != 2) + continue; + + std::string keystr = key; + + // XYZ_param_N pattern + if (strstr(key, "_param_")) + { + weight_int8scale_table[ keystr ] = scale; + } + else + { + blob_int8scale_table[ keystr ] = scale; + } + } + + fclose(fp); + + return true; +} + static int quantize_weight(float *data, size_t data_length, std::vector& float16_weights) { float16_weights.resize(data_length); @@ -115,6 +174,23 @@ static int quantize_weight(float *data, size_t data_length, std::vector& int8_weights) +{ + int8_weights.resize(data_length); + + for (size_t i = 0; i < data_length; i++) + { + float f = data[i]; + + signed char int8 = float2int8(f * scale); + + int8_weights[i] = int8; + } + + // magic tag for int8 + return 0x000D4B38; +} + static bool quantize_weight(float *data, size_t data_length, int quantize_level, std::vector &quantize_table, std::vector &quantize_index) { assert(quantize_level != 0); @@ -206,9 +282,9 @@ static bool read_proto_from_binary(const char* filepath, google::protobuf::Messa int main(int argc, char** argv) { - if (!(argc == 3 || argc == 5 || argc == 6)) + if (!(argc == 3 || argc == 5 || argc == 6 || argc == 7)) { - fprintf(stderr, "Usage: %s [caffeproto] [caffemodel] [ncnnproto] [ncnnbin] [quantizelevel]\n", argv[0]); + fprintf(stderr, "Usage: %s [caffeproto] [caffemodel] [ncnnproto] [ncnnbin] [quantizelevel] [int8scaletable]\n", argv[0]); return -1; } @@ -216,7 +292,8 @@ int main(int argc, char** argv) const char* caffemodel = argv[2]; const char* ncnn_prototxt = argc >= 5 ? argv[3] : "ncnn.proto"; const char* ncnn_modelbin = argc >= 5 ? argv[4] : "ncnn.bin"; - const char* quantize_param = argc == 6 ? argv[5] : "0"; + const char* quantize_param = argc >= 6 ? argv[5] : "0"; + const char* int8scale_table_path = argc == 7 ? argv[6] : NULL; int quantize_level = atoi(quantize_param); if (quantize_level != 0 && quantize_level != 256 && quantize_level != 65536) { @@ -242,6 +319,18 @@ int main(int argc, char** argv) return -1; } + std::map blob_int8scale_table; + std::map weight_int8scale_table; + if (int8scale_table_path) + { + bool s2 = read_int8scale_table(int8scale_table_path, blob_int8scale_table, weight_int8scale_table); + if (!s2) + { + fprintf(stderr, "read_int8scale_table failed\n"); + return -1; + } + } + FILE* pp = fopen(ncnn_prototxt, "wb"); FILE* bp = fopen(ncnn_modelbin, "wb"); @@ -538,42 +627,87 @@ int main(int argc, char** argv) std::vector quantize_index; std::vector float16_weights; + std::vector int8_weights; + + bool has_int8scale = false; + float int8scale = 1.f; // we will not quantize the bias values - if (j == 0 && quantize_level != 0) + if (j == 0) { - if (quantize_level == 256) + if (int8scale_table_path) { - quantize_tag = quantize_weight((float *)blob.data().data(), blob.data_size(), quantize_level, quantize_table, quantize_index); + char key[256]; + sprintf(key, "%s_param_%d", layer.name().c_str(), j); + if (weight_int8scale_table.find(std::string(key)) != weight_int8scale_table.end()) + { + has_int8scale = true; + int8scale = weight_int8scale_table[std::string(key)]; + } + } + + if (has_int8scale) + { + fprintf(pp, " 8=%.8e", int8scale); + + if (quantize_level == 0) + { + quantize_tag = 0x0002C056; + } + else if (quantize_level == 256) + { + quantize_tag = quantize_weight((float *)blob.data().data(), blob.data_size(), int8scale, int8_weights); + } + } + else if (quantize_level == 256) + { + quantize_tag = quantize_weight((float *)blob.data().data(), blob.data_size(), quantize_level, quantize_table, quantize_index); } else if (quantize_level == 65536) { - quantize_tag = quantize_weight((float *)blob.data().data(), blob.data_size(), float16_weights); + quantize_tag = quantize_weight((float *)blob.data().data(), blob.data_size(), float16_weights); } - } - // write quantize tag first - if (j == 0) + // write quantize tag first fwrite(&quantize_tag, sizeof(int), 1, bp); - if (quantize_tag) - { - int p0 = ftell(bp); - if (quantize_level == 256) + if (quantize_tag) { - // write quantize table and index - fwrite(quantize_table.data(), sizeof(float), quantize_table.size(), bp); - fwrite(quantize_index.data(), sizeof(unsigned char), quantize_index.size(), bp); + int p0 = ftell(bp); + if (has_int8scale) + { + if (quantize_level == 0) + { + // write original data and int8scale + fwrite(blob.data().data(), sizeof(float), blob.data_size(), bp); + } + else if (quantize_level == 256) + { + fwrite(int8_weights.data(), sizeof(signed char), int8_weights.size(), bp); + } + } + else if (quantize_level == 256) + { + // write quantize table and index + fwrite(quantize_table.data(), sizeof(float), quantize_table.size(), bp); + fwrite(quantize_index.data(), sizeof(unsigned char), quantize_index.size(), bp); + } + else if (quantize_level == 65536) + { + fwrite(float16_weights.data(), sizeof(unsigned short), float16_weights.size(), bp); + } + + // padding to 32bit align + int nwrite = ftell(bp) - p0; + int nalign = alignSize(nwrite, 4); + unsigned char padding[4] = {0x00, 0x00, 0x00, 0x00}; + fwrite(padding, sizeof(unsigned char), nalign - nwrite, bp); } - else if (quantize_level == 65536) + else { - fwrite(float16_weights.data(), sizeof(unsigned short), float16_weights.size(), bp); + // write original data + fwrite(blob.data().data(), sizeof(float), blob.data_size(), bp); } - // padding to 32bit align - int nwrite = ftell(bp) - p0; - int nalign = alignSize(nwrite, 4); - unsigned char padding[4] = {0x00, 0x00, 0x00, 0x00}; - fwrite(padding, sizeof(unsigned char), nalign - nwrite, bp); } else { @@ -799,45 +933,90 @@ int main(int argc, char** argv) std::vector quantize_index; std::vector float16_weights; + std::vector int8_weights; + + bool has_int8scale = false; + float int8scale = 1.f; // we will not quantize the bias values - if (j == 0 && quantize_level != 0) + if (j == 0) { - if (quantize_level == 256) + if (int8scale_table_path) + { + char key[256]; + sprintf(key, "%s_param_%d", layer.name().c_str(), j); + if (weight_int8scale_table.find(std::string(key)) != weight_int8scale_table.end()) + { + has_int8scale = true; + int8scale = weight_int8scale_table[std::string(key)]; + } + } + + if (has_int8scale) + { + fprintf(pp, " 8=%.8e", int8scale); + + if (quantize_level == 0) + { + quantize_tag = 0x0002C056; + } + else if (quantize_level == 256) + { + quantize_tag = quantize_weight((float *)blob.data().data(), blob.data_size(), int8scale, int8_weights); + } + } + else if (quantize_level == 256) { - quantize_tag = quantize_weight((float *)blob.data().data(), blob.data_size(), quantize_level, quantize_table, quantize_index); + quantize_tag = quantize_weight((float *)blob.data().data(), blob.data_size(), quantize_level, quantize_table, quantize_index); } else if (quantize_level == 65536) { - quantize_tag = quantize_weight((float *)blob.data().data(), blob.data_size(), float16_weights); + quantize_tag = quantize_weight((float *)blob.data().data(), blob.data_size(), float16_weights); } - } - // write quantize tag first - if (j == 0) + // write quantize tag first fwrite(&quantize_tag, sizeof(int), 1, bp); - if (quantize_tag) - { - int p0 = ftell(bp); - if (quantize_level == 256) + if (quantize_tag) { - // write quantize table and index - fwrite(quantize_table.data(), sizeof(float), quantize_table.size(), bp); - fwrite(quantize_index.data(), sizeof(unsigned char), quantize_index.size(), bp); + int p0 = ftell(bp); + if (has_int8scale) + { + if (quantize_level == 0) + { + // write original data and int8scale + fwrite(blob.data().data(), sizeof(float), blob.data_size(), bp); + } + else if (quantize_level == 256) + { + fwrite(int8_weights.data(), sizeof(signed char), int8_weights.size(), bp); + } + } + else if (quantize_level == 256) + { + // write quantize table and index + fwrite(quantize_table.data(), sizeof(float), quantize_table.size(), bp); + fwrite(quantize_index.data(), sizeof(unsigned char), quantize_index.size(), bp); + } + else if (quantize_level == 65536) + { + fwrite(float16_weights.data(), sizeof(unsigned short), float16_weights.size(), bp); + } + + // padding to 32bit align + int nwrite = ftell(bp) - p0; + int nalign = alignSize(nwrite, 4); + unsigned char padding[4] = {0x00, 0x00, 0x00, 0x00}; + fwrite(padding, sizeof(unsigned char), nalign - nwrite, bp); } - else if (quantize_level == 65536) + else { - fwrite(float16_weights.data(), sizeof(unsigned short), float16_weights.size(), bp); + // write original data + fwrite(blob.data().data(), sizeof(float), blob.data_size(), bp); } - // padding to 32bit align - int nwrite = ftell(bp) - p0; - int nalign = alignSize(nwrite, 4); - unsigned char padding[4] = {0x00, 0x00, 0x00, 0x00}; - fwrite(padding, sizeof(unsigned char), nalign - nwrite, bp); } else - { + { // write original data fwrite(blob.data().data(), sizeof(float), blob.data_size(), bp); } @@ -1379,6 +1558,12 @@ int main(int argc, char** argv) } + // concat blob_int8scale_table + for (std::map::const_iterator it = blob_int8scale_table.begin(); it != blob_int8scale_table.end(); it++) + { + fprintf(pp, "%-16s 0=%.8e\n", it->first.c_str(), it->second); + } + fclose(pp); fclose(bp);