Skip to content

Commit

Permalink
[QUANT][BUG] fix per tensor quantconcat error (#607)
Browse files Browse the repository at this point in the history
Co-authored-by: lucasktian <lucasktian@tencent.com>
  • Loading branch information
seanxcwang and gttiankai authored Dec 3, 2020
1 parent f932d1e commit 04b783c
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 30 deletions.
121 changes: 101 additions & 20 deletions source/tnn/device/arm/acc/arm_concat_layer_acc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "tnn/device/arm/arm_util.h"
#include "tnn/utils/bfp16.h"
#include "tnn/utils/dims_vector_utils.h"
#include "tnn/utils/naive_compute.h"

namespace TNN_NS {

Expand Down Expand Up @@ -77,6 +78,42 @@ int concat_channel(Blob *output, const std::vector<Blob *> &inputs, T *unpack_bu
return 0;
}

/*
checkout per tensor quantization
*/
static bool is_per_tensor_quant(const std::vector<Blob *> &inputs) {
bool int8_per_tensor_flag = true;
for (auto &blob : inputs) {
if (reinterpret_cast<BlobInt8 *>(blob)->GetIntResource()->scale_handle.GetDataCount() > 1) {
int8_per_tensor_flag = false;
break;
}
}

return int8_per_tensor_flag;
}

/*
rescale int8, only per tensor
*/
static void rescale_int8(int8_t *dst, int8_t *src, float *rescale, int len) {
int n = 0;
#ifdef TNN_USE_NEON
for (; n + 7 < len; n += 8) {
int8x8_t v_src = vld1_s8(src + n);
int16x8_t v_src_s16 = vmovl_s8(v_src);
float32x4_t v_src0_f32 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(v_src_s16)));
float32x4_t v_src1_f32 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(v_src_s16)));
int16x4_t v_mul0_s16 = vqmovn_s32(VCVTAQ_S32_F32(vmulq_n_f32(v_src0_f32, rescale[0])));
int16x8_t v_mul_s16 = VQMOVN_HIGH_S32_T(v_mul0_s16, VCVTAQ_S32_F32(vmulq_n_f32(v_src1_f32, rescale[0])));
vst1_s8(dst + n, vqmovn_s16(v_mul_s16));
}
#endif
for (; n < len; n++) {
dst[n] = float2int8(src[n] * rescale[0]);
}
}

/*
concat channel int8, nhwc format
*/
Expand All @@ -86,17 +123,41 @@ static int concat_channel_i8(Blob *output, const std::vector<Blob *> &inputs) {
auto oc_c4 = ROUND_UP(dims_output[1], 4);

int8_t *output_origin = reinterpret_cast<int8_t *>(GetBlobHandlePtr(output->GetHandle()));
for (int n = 0; n < dims_output[0]; n++) {
int c_offset = 0;
for (int b = 0; b < inputs.size(); b++) {
auto input_channel = inputs[b]->GetBlobDesc().dims[1];
auto ic_c4 = ROUND_UP(input_channel, 4);
auto input_ptr = reinterpret_cast<int8_t *>(GetBlobHandlePtr(inputs[b]->GetHandle())) + n * ic_c4 * full_hw;
auto output_ptr = output_origin + n * full_hw * oc_c4 + c_offset;
for (int cur_hw = 0; cur_hw < full_hw; cur_hw++) {
memcpy(output_ptr + cur_hw * oc_c4, input_ptr + cur_hw * ic_c4, input_channel);

if (!is_per_tensor_quant(inputs)) {
for (int n = 0; n < dims_output[0]; n++) {
int c_offset = 0;
for (int b = 0; b < inputs.size(); b++) {
auto input_channel = inputs[b]->GetBlobDesc().dims[1];
auto ic_c4 = ROUND_UP(input_channel, 4);
auto input_ptr =
reinterpret_cast<int8_t *>(GetBlobHandlePtr(inputs[b]->GetHandle())) + n * ic_c4 * full_hw;
auto output_ptr = output_origin + n * full_hw * oc_c4 + c_offset;
for (int cur_hw = 0; cur_hw < full_hw; cur_hw++) {
memcpy(output_ptr + cur_hw * oc_c4, input_ptr + cur_hw * ic_c4, input_channel);
}
c_offset += input_channel;
}
}
} else {
float *output_scale = reinterpret_cast<BlobInt8 *>(output)->GetIntResource()->scale_handle.force_to<float *>();
for (int n = 0; n < dims_output[0]; n++) {
int c_offset = 0;
for (int b = 0; b < inputs.size(); b++) {
float *input_scale =
reinterpret_cast<BlobInt8 *>(inputs[b])->GetIntResource()->scale_handle.force_to<float *>();
float rescale = input_scale[0] / output_scale[0];
auto input_channel = inputs[b]->GetBlobDesc().dims[1];
auto ic_c4 = ROUND_UP(input_channel, 4);
auto input_ptr =
reinterpret_cast<int8_t *>(GetBlobHandlePtr(inputs[b]->GetHandle())) + n * ic_c4 * full_hw;
auto output_ptr = output_origin + n * full_hw * oc_c4 + c_offset;
for (int cur_hw = 0; cur_hw < full_hw; cur_hw++) {
// memcpy(output_ptr + cur_hw * oc_c4, input_ptr + cur_hw * ic_c4, input_channel);
rescale_int8(output_ptr + cur_hw * oc_c4, input_ptr + cur_hw * ic_c4, &rescale, input_channel);
}
c_offset += input_channel;
}
c_offset += input_channel;
}
}

Expand All @@ -113,16 +174,36 @@ static int concat_common_i8(Blob *output, const std::vector<Blob *> &inputs, int
auto output_stride = DimsVectorUtils::Count(round_output_dims, axis - 1);
auto *output_origin = reinterpret_cast<int8_t *>(GetBlobHandlePtr(output->GetHandle()));

for (int n = 0; n < slice_count; n++) {
auto output_ptr = output_origin + n * output_stride;
for (int b = 0; b < inputs.size(); b++) {
auto input = inputs[b];
auto input_dims = input->GetBlobDesc().dims;
DimsVector round_input_dims = {input_dims[0], input_dims[2], input_dims[3], ROUND_UP(input_dims[1], 4)};
auto input_stride = DimsVectorUtils::Count(round_input_dims, axis - 1);
auto input_ptr = reinterpret_cast<int8_t *>(GetBlobHandlePtr(input->GetHandle())) + n * input_stride;
memcpy(output_ptr, input_ptr, input_stride * sizeof(int8_t));
output_ptr += input_stride;
if (!is_per_tensor_quant(inputs)) {
for (int n = 0; n < slice_count; n++) {
auto output_ptr = output_origin + n * output_stride;
for (int b = 0; b < inputs.size(); b++) {
auto input = inputs[b];
auto input_dims = input->GetBlobDesc().dims;
DimsVector round_input_dims = {input_dims[0], input_dims[2], input_dims[3], ROUND_UP(input_dims[1], 4)};
auto input_stride = DimsVectorUtils::Count(round_input_dims, axis - 1);
auto input_ptr = reinterpret_cast<int8_t *>(GetBlobHandlePtr(input->GetHandle())) + n * input_stride;
memcpy(output_ptr, input_ptr, input_stride * sizeof(int8_t));
output_ptr += input_stride;
}
}
} else {
float *output_scale = reinterpret_cast<BlobInt8 *>(output)->GetIntResource()->scale_handle.force_to<float *>();
for (int n = 0; n < slice_count; n++) {
auto output_ptr = output_origin + n * output_stride;
for (int b = 0; b < inputs.size(); b++) {
float *input_scale =
reinterpret_cast<BlobInt8 *>(inputs[b])->GetIntResource()->scale_handle.force_to<float *>();
float rescale = input_scale[0] / output_scale[0];
auto input = inputs[b];
auto input_dims = input->GetBlobDesc().dims;
DimsVector round_input_dims = {input_dims[0], input_dims[2], input_dims[3], ROUND_UP(input_dims[1], 4)};
auto input_stride = DimsVectorUtils::Count(round_input_dims, axis - 1);
auto input_ptr = reinterpret_cast<int8_t *>(GetBlobHandlePtr(input->GetHandle())) + n * input_stride;
// memcpy(output_ptr, input_ptr, input_stride * sizeof(int8_t));
rescale_int8(output_ptr, input_ptr, &rescale, input_stride);
output_ptr += input_stride;
}
}
}

Expand Down
53 changes: 43 additions & 10 deletions source/tnn/device/cpu/acc/cpu_concat_layer_acc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "tnn/utils/naive_compute.h"
#include "tnn/core/blob_int8.h"
#include "tnn/device/cpu/acc/cpu_layer_acc.h"
#include "tnn/device/cpu/cpu_context.h"
#include "tnn/utils/data_type_utils.h"
#include "tnn/utils/dims_vector_utils.h"
#include "tnn/utils/naive_compute.h"

namespace TNN_NS {

Expand All @@ -40,6 +41,18 @@ Status CpuConcatLayerAcc::Forward(const std::vector<Blob *> &inputs, const std::
auto output = outputs[0];
auto dims = input->GetBlobDesc().dims;

bool int8_per_tensor_flag = false;
if (output->GetBlobDesc().data_type == DATA_TYPE_INT8) {
int8_per_tensor_flag = true;
// if one blob is per channel quant, concat with the normal way
for (auto &blob : inputs) {
if (reinterpret_cast<BlobInt8 *>(blob)->GetIntResource()->scale_handle.GetDataCount() > 1) {
int8_per_tensor_flag = false;
break;
}
}
}

const int axis = param->axis;
if (axis > dims.size() || axis < 0) {
LOGE("Error: Concat layer param invalid\n");
Expand All @@ -60,16 +73,36 @@ Status CpuConcatLayerAcc::Forward(const std::vector<Blob *> &inputs, const std::
int8_t *output_data = static_cast<int8_t *>(output->GetHandle().base);
int output_concat_axis = output->GetBlobDesc().dims[axis];
int output_concat_axis_offset = 0;
for (size_t i = 0; i < inputs.size(); ++i) {
// use int8_t for all types
int8_t *input_data = static_cast<int8_t *>(inputs[i]->GetHandle().base);
const int input_concat_axis = inputs[i]->GetBlobDesc().dims[axis];
for (int n = 0; n < num_concats; ++n) {
memcpy(output_data + (n * output_concat_axis + output_concat_axis_offset) * concate_size * datasize,
input_data + n * input_concat_axis * concate_size * datasize,
input_concat_axis * concate_size * datasize);

if (!int8_per_tensor_flag) {
for (size_t i = 0; i < inputs.size(); ++i) {
// use int8_t for all types
int8_t *input_data = static_cast<int8_t *>(inputs[i]->GetHandle().base);
const int input_concat_axis = inputs[i]->GetBlobDesc().dims[axis];
for (int n = 0; n < num_concats; ++n) {
memcpy(output_data + (n * output_concat_axis + output_concat_axis_offset) * concate_size * datasize,
input_data + n * input_concat_axis * concate_size * datasize,
input_concat_axis * concate_size * datasize);
}
output_concat_axis_offset += input_concat_axis;
}
} else {
float *output_scale = reinterpret_cast<BlobInt8 *>(output)->GetIntResource()->scale_handle.force_to<float *>();
for (size_t i = 0; i < inputs.size(); ++i) {
float *input_scale =
reinterpret_cast<BlobInt8 *>(inputs[i])->GetIntResource()->scale_handle.force_to<float *>();
int8_t *input_data = static_cast<int8_t *>(inputs[i]->GetHandle().base);
const int input_concat_axis = inputs[i]->GetBlobDesc().dims[axis];
for (int n = 0; n < num_concats; ++n) {
int8_t *concat_dst = output_data + (n * output_concat_axis + output_concat_axis_offset) * concate_size;
int8_t *concat_src = input_data + n * input_concat_axis * concate_size;
// per tensor need dequant and requant
for (int i = 0; i < input_concat_axis * concate_size; i++) {
concat_dst[i] = float2int8(concat_src[i] * input_scale[0] / output_scale[0]);
}
}
output_concat_axis_offset += input_concat_axis;
}
output_concat_axis_offset += input_concat_axis;
}

return TNN_OK;
Expand Down

0 comments on commit 04b783c

Please sign in to comment.