Skip to content

Commit

Permalink
[ARM] use fp32 for avg pooling to avoid overflow (#867)
Browse files Browse the repository at this point in the history
  • Loading branch information
seanxcwang committed Mar 26, 2021
1 parent 30dd74b commit b2244a5
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 4 deletions.
23 changes: 23 additions & 0 deletions source/tnn/device/arm/acc/Half8.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "tnn/utils/half.hpp"
#include "tnn/utils/half_utils_inner.h"
#include "tnn/device/arm/acc/TNNVector.h"
#include "tnn/device/arm/acc/Float4.h"
#ifdef TNN_USE_NEON
#include <arm_neon.h>
#include "tnn/device/arm/acc/neon_mathfun.h"
Expand Down Expand Up @@ -47,6 +48,9 @@ struct Half4 {
Half4(const Half4&& lr) {
value = std::move(lr.value);
}
Half4(const Float4& lr) {
value = vcvt_f16_f32(lr.value);
}
static Half4 load(const __fp16* addr) {
Half4 v;
v.value = vld1_f16(addr);
Expand All @@ -60,6 +64,9 @@ struct Half4 {
v1.value = v.val[0];
v2.value = v.val[1];
}
static void add_to_f32(Half4& v1, Float4& v2) {
v2.value = vaddq_f32(v2.value, vcvt_f32_f16(v1.value));
}
Half4& operator=(const Half4& lr) {
value = lr.value;
return *this;
Expand Down Expand Up @@ -456,6 +463,9 @@ struct Half4 {
Half4(const Half4&& lr) {
value = std::move(lr.value);
}
Half4(const Float4& lr) {
value = vreinterpret_s16_f16(vcvt_f16_f32(lr.value));
}
static Half4 load(const fp16_t* addr) {
Half4 v;
asm volatile(
Expand All @@ -479,6 +489,9 @@ struct Half4 {
v1.value = v.val[0];
v2.value = v.val[1];
}
static void add_to_f32(Half4& v1, Float4& v2) {
v2.value = vaddq_f32(v2.value, vcvt_f32_f16(vreinterpret_f16_s16(v1.value)));
}
Half4& operator=(const Half4& lr) {
value = lr.value;
return *this;
Expand Down Expand Up @@ -1113,11 +1126,21 @@ struct Half4 : TNNVector<fp16_t, 4> {
value[i] = lr.value[i];
}
}
Half4(const Float4& lr) {
for (int i = 0; i < 4; ++i) {
value[i] = (fp16_t)lr.value[i];
}
}
Half4(const TNNVector<fp16_t, 4>& lr) {
for (int i = 0; i < 4; ++i) {
value[i] = lr.value[i];
}
}
static void add_to_f32(Half4& v1, Float4& v2) {
for (int i = 0; i < 4; ++i) {
v2.value[i] = v2.value[i] + (float)v1.value[i];
}
}
};

struct Half8 : TNNVector<fp16_t, 8> {
Expand Down
16 changes: 12 additions & 4 deletions source/tnn/device/arm/acc/compute_arm82/compute_half.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,16 +214,24 @@ void AvgPoolingHalf(const fp16_t* src, long iw, long ih, fp16_t* dst, long ow, l
const auto src_ptr = src + (srcOriginY * iw + srcOriginX) * 8;
auto dst_ptr = dst + (oy * ow + ox) * 8;

Half8 vavg = Half8(fp16_t(0.f));

Float4 vavg_low = Float4(0.f);
Float4 vavg_high = Float4(0.f);
for (long ky = kys; ky < kye; ++ky) {
const auto src_ptr_h = src_ptr + (ky * iw) * 8;
Half8 vavg = Half8((fp16_t)0.f);
for (long kx = kxs; kx < kxe; kx++) {
vavg = vavg + Half8::load(src_ptr_h + kx * 8);
}
Half4 v0, v1;
Half8::get_low(vavg, v0);
Half8::get_high(vavg, v1);
Half4::add_to_f32(v0, vavg_low);
Half4::add_to_f32(v1, vavg_high);
}

Half8::save(dst_ptr, vavg * Half8(fp16_t(kernel_count)));
vavg_low = vavg_low * Float4(kernel_count);
vavg_high = vavg_high * Float4(kernel_count);
Half4::save(dst_ptr, Half4(vavg_low));
Half4::save(dst_ptr + 4, Half4(vavg_high));
}
}
}
Expand Down

0 comments on commit b2244a5

Please sign in to comment.