Skip to content

Commit

Permalink
opt++
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Apr 29, 2024
1 parent 79fe31c commit 5cc4072
Showing 1 changed file with 20 additions and 32 deletions.
52 changes: 20 additions & 32 deletions src/layer/arm/rnn_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -292,31 +292,25 @@ static void rnn_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_de
int32x4_t _sum3 = vdupq_n_s32(0);
for (; i + 15 < size; i += 16)
{
int32x4_t _xi01 = vreinterpretq_s32_s8(vld1q_s8(x + i));
int8x16_t _xi0 = vreinterpretq_s8_s32(vdupq_laneq_s32(_xi01, 0));
int8x16_t _xi1 = vreinterpretq_s8_s32(vdupq_laneq_s32(_xi01, 1));
int8x16_t _xi2 = vreinterpretq_s8_s32(vdupq_laneq_s32(_xi01, 2));
int8x16_t _xi3 = vreinterpretq_s8_s32(vdupq_laneq_s32(_xi01, 3));
int8x16_t _xi = vld1q_s8(x + i);
int8x16_t _w0 = vld1q_s8(kptr);
int8x16_t _w1 = vld1q_s8(kptr + 16);
int8x16_t _w2 = vld1q_s8(kptr + 32);
int8x16_t _w3 = vld1q_s8(kptr + 48);
_rnn_Hx0 = vdotq_s32(_rnn_Hx0, _w0, _xi0);
_sum1 = vdotq_s32(_sum1, _w1, _xi1);
_sum2 = vdotq_s32(_sum2, _w2, _xi2);
_sum3 = vdotq_s32(_sum3, _w3, _xi3);
_rnn_Hx0 = vdotq_laneq_s32(_rnn_Hx0, _w0, _xi, 0);
_sum1 = vdotq_laneq_s32(_sum1, _w1, _xi, 1);
_sum2 = vdotq_laneq_s32(_sum2, _w2, _xi, 2);
_sum3 = vdotq_laneq_s32(_sum3, _w3, _xi, 3);

kptr += 64;
}
for (; i + 7 < size; i += 8)
{
int32x2_t _xi01 = vreinterpret_s32_s8(vld1_s8(x + i));
int8x16_t _xi0 = vreinterpretq_s8_s32(vdupq_lane_s32(_xi01, 0));
int8x16_t _xi1 = vreinterpretq_s8_s32(vdupq_lane_s32(_xi01, 1));
int8x8_t _xi = vld1_s8(x + i);
int8x16_t _w0 = vld1q_s8(kptr);
int8x16_t _w1 = vld1q_s8(kptr + 16);
_rnn_Hx0 = vdotq_s32(_rnn_Hx0, _w0, _xi0);
_sum1 = vdotq_s32(_sum1, _w1, _xi1);
_rnn_Hx0 = vdotq_lane_s32(_rnn_Hx0, _w0, _xi, 0);
_sum1 = vdotq_lane_s32(_sum1, _w1, _xi, 1);

kptr += 32;
}
Expand All @@ -327,9 +321,9 @@ static void rnn_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_de
for (; i + 3 < size; i += 4)
{
#if __ARM_FEATURE_DOTPROD
int8x16_t _xi = vreinterpretq_s8_s32(vdupq_lane_s32(vreinterpret_s32_s8(vld1_s8(x + i)), 0));
int8x8_t _xi = vld1_s8(x + i);
int8x16_t _w = vld1q_s8(kptr);
_rnn_Hx0 = vdotq_s32(_rnn_Hx0, _w, _xi);
_rnn_Hx0 = vdotq_lane_s32(_rnn_Hx0, _w, _xi, 0);
#else
int16x4_t _xi01 = vreinterpret_s16_s8(vld1_s8(x + i));
int8x8_t _xi0 = vreinterpret_s8_s16(vdup_lane_s16(_xi01, 0));
Expand Down Expand Up @@ -372,31 +366,25 @@ static void rnn_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_de
_sum3 = vdupq_n_s32(0);
for (; i + 15 < num_output; i += 16)
{
int32x4_t _h_cont01 = vreinterpretq_s32_s8(vld1q_s8(hs + i));
int8x16_t _h_cont0 = vreinterpretq_s8_s32(vdupq_laneq_s32(_h_cont01, 0));
int8x16_t _h_cont1 = vreinterpretq_s8_s32(vdupq_laneq_s32(_h_cont01, 1));
int8x16_t _h_cont2 = vreinterpretq_s8_s32(vdupq_laneq_s32(_h_cont01, 2));
int8x16_t _h_cont3 = vreinterpretq_s8_s32(vdupq_laneq_s32(_h_cont01, 3));
int8x16_t _h_cont = vld1q_s8(hs + i);
int8x16_t _w0 = vld1q_s8(kptr);
int8x16_t _w1 = vld1q_s8(kptr + 16);
int8x16_t _w2 = vld1q_s8(kptr + 32);
int8x16_t _w3 = vld1q_s8(kptr + 48);
_rnn_Hh0 = vdotq_s32(_rnn_Hh0, _w0, _h_cont0);
_sum1 = vdotq_s32(_sum1, _w1, _h_cont1);
_sum2 = vdotq_s32(_sum2, _w2, _h_cont2);
_sum3 = vdotq_s32(_sum3, _w3, _h_cont3);
_rnn_Hh0 = vdotq_laneq_s32(_rnn_Hh0, _w0, _h_cont, 0);
_sum1 = vdotq_laneq_s32(_sum1, _w1, _h_cont, 1);
_sum2 = vdotq_laneq_s32(_sum2, _w2, _h_cont, 2);
_sum3 = vdotq_laneq_s32(_sum3, _w3, _h_cont, 3);

kptr += 64;
}
for (; i + 7 < num_output; i += 8)
{
int32x2_t _h_cont01 = vreinterpret_s32_s8(vld1_s8(hs + i));
int8x16_t _h_cont0 = vreinterpretq_s8_s32(vdupq_lane_s32(_h_cont01, 0));
int8x16_t _h_cont1 = vreinterpretq_s8_s32(vdupq_lane_s32(_h_cont01, 1));
int8x8_t _h_cont = vld1_s8(hs + i);
int8x16_t _w0 = vld1q_s8(kptr);
int8x16_t _w1 = vld1q_s8(kptr + 16);
_rnn_Hh0 = vdotq_s32(_rnn_Hh0, _w0, _h_cont0);
_sum1 = vdotq_s32(_sum1, _w1, _h_cont1);
_rnn_Hh0 = vdotq_lane_s32(_rnn_Hh0, _w0, _h_cont, 0);
_sum1 = vdotq_lane_s32(_sum1, _w1, _h_cont, 1);

kptr += 32;
}
Expand All @@ -407,9 +395,9 @@ static void rnn_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_de
for (; i + 3 < num_output; i += 4)
{
#if __ARM_FEATURE_DOTPROD
int8x16_t _h_cont = vreinterpretq_s8_s32(vdupq_lane_s32(vreinterpret_s32_s8(vld1_s8(hs + i)), 0));
int8x8_t _h_cont = vld1_s8(hs + i);
int8x16_t _w = vld1q_s8(kptr);
_rnn_Hh0 = vdotq_s32(_rnn_Hh0, _w, _h_cont);
_rnn_Hh0 = vdotq_lane_s32(_rnn_Hh0, _w, _h_cont, 0);
#else
int16x4_t _h_cont01 = vreinterpret_s16_s8(vld1_s8(hs + i));
int8x8_t _h_cont0 = vreinterpret_s8_s16(vdup_lane_s16(_h_cont01, 0));
Expand Down

0 comments on commit 5cc4072

Please sign in to comment.