Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QUANT][BUG] fix per tensor quantconcat error #607

Merged
merged 2 commits into from
Dec 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 101 additions & 20 deletions source/tnn/device/arm/acc/arm_concat_layer_acc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "tnn/device/arm/arm_util.h"
#include "tnn/utils/bfp16.h"
#include "tnn/utils/dims_vector_utils.h"
#include "tnn/utils/naive_compute.h"

namespace TNN_NS {

Expand Down Expand Up @@ -77,6 +78,42 @@ int concat_channel(Blob *output, const std::vector<Blob *> &inputs, T *unpack_bu
return 0;
}

/*
checkout per tensor quantization
*/
static bool is_per_tensor_quant(const std::vector<Blob *> &inputs) {
bool int8_per_tensor_flag = true;
for (auto &blob : inputs) {
if (reinterpret_cast<BlobInt8 *>(blob)->GetIntResource()->scale_handle.GetDataCount() > 1) {
int8_per_tensor_flag = false;
break;
}
}

return int8_per_tensor_flag;
}

/*
rescale int8, only per tensor
*/
static void rescale_int8(int8_t *dst, int8_t *src, float *rescale, int len) {
int n = 0;
#ifdef TNN_USE_NEON
for (; n + 7 < len; n += 8) {
int8x8_t v_src = vld1_s8(src + n);
int16x8_t v_src_s16 = vmovl_s8(v_src);
float32x4_t v_src0_f32 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(v_src_s16)));
float32x4_t v_src1_f32 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(v_src_s16)));
int16x4_t v_mul0_s16 = vqmovn_s32(VCVTAQ_S32_F32(vmulq_n_f32(v_src0_f32, rescale[0])));
int16x8_t v_mul_s16 = VQMOVN_HIGH_S32_T(v_mul0_s16, VCVTAQ_S32_F32(vmulq_n_f32(v_src1_f32, rescale[0])));
vst1_s8(dst + n, vqmovn_s16(v_mul_s16));
}
#endif
for (; n < len; n++) {
dst[n] = float2int8(src[n] * rescale[0]);
}
}

/*
concat channel int8, nhwc format
*/
Expand All @@ -86,17 +123,41 @@ static int concat_channel_i8(Blob *output, const std::vector<Blob *> &inputs) {
auto oc_c4 = ROUND_UP(dims_output[1], 4);

int8_t *output_origin = reinterpret_cast<int8_t *>(GetBlobHandlePtr(output->GetHandle()));
for (int n = 0; n < dims_output[0]; n++) {
int c_offset = 0;
for (int b = 0; b < inputs.size(); b++) {
auto input_channel = inputs[b]->GetBlobDesc().dims[1];
auto ic_c4 = ROUND_UP(input_channel, 4);
auto input_ptr = reinterpret_cast<int8_t *>(GetBlobHandlePtr(inputs[b]->GetHandle())) + n * ic_c4 * full_hw;
auto output_ptr = output_origin + n * full_hw * oc_c4 + c_offset;
for (int cur_hw = 0; cur_hw < full_hw; cur_hw++) {
memcpy(output_ptr + cur_hw * oc_c4, input_ptr + cur_hw * ic_c4, input_channel);

if (!is_per_tensor_quant(inputs)) {
for (int n = 0; n < dims_output[0]; n++) {
int c_offset = 0;
for (int b = 0; b < inputs.size(); b++) {
auto input_channel = inputs[b]->GetBlobDesc().dims[1];
auto ic_c4 = ROUND_UP(input_channel, 4);
auto input_ptr =
reinterpret_cast<int8_t *>(GetBlobHandlePtr(inputs[b]->GetHandle())) + n * ic_c4 * full_hw;
auto output_ptr = output_origin + n * full_hw * oc_c4 + c_offset;
for (int cur_hw = 0; cur_hw < full_hw; cur_hw++) {
memcpy(output_ptr + cur_hw * oc_c4, input_ptr + cur_hw * ic_c4, input_channel);
}
c_offset += input_channel;
}
}
} else {
float *output_scale = reinterpret_cast<BlobInt8 *>(output)->GetIntResource()->scale_handle.force_to<float *>();
for (int n = 0; n < dims_output[0]; n++) {
int c_offset = 0;
for (int b = 0; b < inputs.size(); b++) {
float *input_scale =
reinterpret_cast<BlobInt8 *>(inputs[b])->GetIntResource()->scale_handle.force_to<float *>();
float rescale = input_scale[0] / output_scale[0];
auto input_channel = inputs[b]->GetBlobDesc().dims[1];
auto ic_c4 = ROUND_UP(input_channel, 4);
auto input_ptr =
reinterpret_cast<int8_t *>(GetBlobHandlePtr(inputs[b]->GetHandle())) + n * ic_c4 * full_hw;
auto output_ptr = output_origin + n * full_hw * oc_c4 + c_offset;
for (int cur_hw = 0; cur_hw < full_hw; cur_hw++) {
// memcpy(output_ptr + cur_hw * oc_c4, input_ptr + cur_hw * ic_c4, input_channel);
rescale_int8(output_ptr + cur_hw * oc_c4, input_ptr + cur_hw * ic_c4, &rescale, input_channel);
}
c_offset += input_channel;
}
c_offset += input_channel;
}
}

Expand All @@ -113,16 +174,36 @@ static int concat_common_i8(Blob *output, const std::vector<Blob *> &inputs, int
auto output_stride = DimsVectorUtils::Count(round_output_dims, axis - 1);
auto *output_origin = reinterpret_cast<int8_t *>(GetBlobHandlePtr(output->GetHandle()));

for (int n = 0; n < slice_count; n++) {
auto output_ptr = output_origin + n * output_stride;
for (int b = 0; b < inputs.size(); b++) {
auto input = inputs[b];
auto input_dims = input->GetBlobDesc().dims;
DimsVector round_input_dims = {input_dims[0], input_dims[2], input_dims[3], ROUND_UP(input_dims[1], 4)};
auto input_stride = DimsVectorUtils::Count(round_input_dims, axis - 1);
auto input_ptr = reinterpret_cast<int8_t *>(GetBlobHandlePtr(input->GetHandle())) + n * input_stride;
memcpy(output_ptr, input_ptr, input_stride * sizeof(int8_t));
output_ptr += input_stride;
if (!is_per_tensor_quant(inputs)) {
for (int n = 0; n < slice_count; n++) {
auto output_ptr = output_origin + n * output_stride;
for (int b = 0; b < inputs.size(); b++) {
auto input = inputs[b];
auto input_dims = input->GetBlobDesc().dims;
DimsVector round_input_dims = {input_dims[0], input_dims[2], input_dims[3], ROUND_UP(input_dims[1], 4)};
auto input_stride = DimsVectorUtils::Count(round_input_dims, axis - 1);
auto input_ptr = reinterpret_cast<int8_t *>(GetBlobHandlePtr(input->GetHandle())) + n * input_stride;
memcpy(output_ptr, input_ptr, input_stride * sizeof(int8_t));
output_ptr += input_stride;
}
}
} else {
float *output_scale = reinterpret_cast<BlobInt8 *>(output)->GetIntResource()->scale_handle.force_to<float *>();
for (int n = 0; n < slice_count; n++) {
auto output_ptr = output_origin + n * output_stride;
for (int b = 0; b < inputs.size(); b++) {
float *input_scale =
reinterpret_cast<BlobInt8 *>(inputs[b])->GetIntResource()->scale_handle.force_to<float *>();
float rescale = input_scale[0] / output_scale[0];
auto input = inputs[b];
auto input_dims = input->GetBlobDesc().dims;
DimsVector round_input_dims = {input_dims[0], input_dims[2], input_dims[3], ROUND_UP(input_dims[1], 4)};
auto input_stride = DimsVectorUtils::Count(round_input_dims, axis - 1);
auto input_ptr = reinterpret_cast<int8_t *>(GetBlobHandlePtr(input->GetHandle())) + n * input_stride;
// memcpy(output_ptr, input_ptr, input_stride * sizeof(int8_t));
rescale_int8(output_ptr, input_ptr, &rescale, input_stride);
output_ptr += input_stride;
}
}
}

Expand Down
53 changes: 43 additions & 10 deletions source/tnn/device/cpu/acc/cpu_concat_layer_acc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "tnn/utils/naive_compute.h"
#include "tnn/core/blob_int8.h"
#include "tnn/device/cpu/acc/cpu_layer_acc.h"
#include "tnn/device/cpu/cpu_context.h"
#include "tnn/utils/data_type_utils.h"
#include "tnn/utils/dims_vector_utils.h"
#include "tnn/utils/naive_compute.h"

namespace TNN_NS {

Expand All @@ -40,6 +41,18 @@ Status CpuConcatLayerAcc::Forward(const std::vector<Blob *> &inputs, const std::
auto output = outputs[0];
auto dims = input->GetBlobDesc().dims;

bool int8_per_tensor_flag = false;
if (output->GetBlobDesc().data_type == DATA_TYPE_INT8) {
int8_per_tensor_flag = true;
// if one blob is per channel quant, concat with the normal way
for (auto &blob : inputs) {
if (reinterpret_cast<BlobInt8 *>(blob)->GetIntResource()->scale_handle.GetDataCount() > 1) {
int8_per_tensor_flag = false;
break;
}
}
}

const int axis = param->axis;
if (axis > dims.size() || axis < 0) {
LOGE("Error: Concat layer param invalid\n");
Expand All @@ -60,16 +73,36 @@ Status CpuConcatLayerAcc::Forward(const std::vector<Blob *> &inputs, const std::
int8_t *output_data = static_cast<int8_t *>(output->GetHandle().base);
int output_concat_axis = output->GetBlobDesc().dims[axis];
int output_concat_axis_offset = 0;
for (size_t i = 0; i < inputs.size(); ++i) {
// use int8_t for all types
int8_t *input_data = static_cast<int8_t *>(inputs[i]->GetHandle().base);
const int input_concat_axis = inputs[i]->GetBlobDesc().dims[axis];
for (int n = 0; n < num_concats; ++n) {
memcpy(output_data + (n * output_concat_axis + output_concat_axis_offset) * concate_size * datasize,
input_data + n * input_concat_axis * concate_size * datasize,
input_concat_axis * concate_size * datasize);

if (!int8_per_tensor_flag) {
for (size_t i = 0; i < inputs.size(); ++i) {
// use int8_t for all types
int8_t *input_data = static_cast<int8_t *>(inputs[i]->GetHandle().base);
const int input_concat_axis = inputs[i]->GetBlobDesc().dims[axis];
for (int n = 0; n < num_concats; ++n) {
memcpy(output_data + (n * output_concat_axis + output_concat_axis_offset) * concate_size * datasize,
input_data + n * input_concat_axis * concate_size * datasize,
input_concat_axis * concate_size * datasize);
}
output_concat_axis_offset += input_concat_axis;
}
} else {
float *output_scale = reinterpret_cast<BlobInt8 *>(output)->GetIntResource()->scale_handle.force_to<float *>();
for (size_t i = 0; i < inputs.size(); ++i) {
float *input_scale =
reinterpret_cast<BlobInt8 *>(inputs[i])->GetIntResource()->scale_handle.force_to<float *>();
int8_t *input_data = static_cast<int8_t *>(inputs[i]->GetHandle().base);
const int input_concat_axis = inputs[i]->GetBlobDesc().dims[axis];
for (int n = 0; n < num_concats; ++n) {
int8_t *concat_dst = output_data + (n * output_concat_axis + output_concat_axis_offset) * concate_size;
int8_t *concat_src = input_data + n * input_concat_axis * concate_size;
// per tensor need dequant and requant
for (int i = 0; i < input_concat_axis * concate_size; i++) {
concat_dst[i] = float2int8(concat_src[i] * input_scale[0] / output_scale[0]);
}
}
output_concat_axis_offset += input_concat_axis;
}
output_concat_axis_offset += input_concat_axis;
}

return TNN_OK;
Expand Down