Skip to content

Commit

Permalink
opt
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Apr 29, 2024
1 parent bd66ba1 commit 3ddb97b
Showing 1 changed file with 259 additions and 12 deletions.
271 changes: 259 additions & 12 deletions src/layer/arm/lstm_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,42 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x
signed char* kptr = weight_data_tm_dr.row<signed char>(q);
float* descales_ptr = weight_data_tm_int8_descales_dr.row(q);

for (int i = 0; i < size; i++)
int i = 0;
#if __ARM_FEATURE_DOTPROD
for (; i + 3 < size; i += 4)
{
kptr[0] = weight_xc_I[i];
kptr[1] = weight_xc_I[i + 1];
kptr[2] = weight_xc_I[i + 2];
kptr[3] = weight_xc_I[i + 3];
kptr[4] = weight_xc_F[i];
kptr[5] = weight_xc_F[i + 1];
kptr[6] = weight_xc_F[i + 2];
kptr[7] = weight_xc_F[i + 3];
kptr[8 + 0] = weight_xc_O[i];
kptr[8 + 1] = weight_xc_O[i + 1];
kptr[8 + 2] = weight_xc_O[i + 2];
kptr[8 + 3] = weight_xc_O[i + 3];
kptr[8 + 4] = weight_xc_G[i];
kptr[8 + 5] = weight_xc_G[i + 1];
kptr[8 + 6] = weight_xc_G[i + 2];
kptr[8 + 7] = weight_xc_G[i + 3];
kptr += 16;
}
#endif // __ARM_FEATURE_DOTPROD
for (; i + 1 < size; i += 2)
{
kptr[0] = weight_xc_I[i];
kptr[1] = weight_xc_I[i + 1];
kptr[2] = weight_xc_F[i];
kptr[3] = weight_xc_F[i + 1];
kptr[4] = weight_xc_O[i];
kptr[5] = weight_xc_O[i + 1];
kptr[6] = weight_xc_G[i];
kptr[7] = weight_xc_G[i + 1];
kptr += 8;
}
for (; i < size; i++)
{
kptr[0] = weight_xc_I[i];
kptr[1] = weight_xc_F[i];
Expand All @@ -86,7 +121,42 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x
kptr += 4;
}

for (int i = 0; i < num_output; i++)
i = 0;
#if __ARM_FEATURE_DOTPROD
for (; i + 3 < num_output; i += 4)
{
kptr[0] = weight_hc_I[i];
kptr[1] = weight_hc_I[i + 1];
kptr[2] = weight_hc_I[i + 2];
kptr[3] = weight_hc_I[i + 3];
kptr[4] = weight_hc_F[i];
kptr[5] = weight_hc_F[i + 1];
kptr[6] = weight_hc_F[i + 2];
kptr[7] = weight_hc_F[i + 3];
kptr[8 + 0] = weight_hc_O[i];
kptr[8 + 1] = weight_hc_O[i + 1];
kptr[8 + 2] = weight_hc_O[i + 2];
kptr[8 + 3] = weight_hc_O[i + 3];
kptr[8 + 4] = weight_hc_G[i];
kptr[8 + 5] = weight_hc_G[i + 1];
kptr[8 + 6] = weight_hc_G[i + 2];
kptr[8 + 7] = weight_hc_G[i + 3];
kptr += 16;
}
#endif // __ARM_FEATURE_DOTPROD
for (; i + 1 < num_output; i += 2)
{
kptr[0] = weight_hc_I[i];
kptr[1] = weight_hc_I[i + 1];
kptr[2] = weight_hc_F[i];
kptr[3] = weight_hc_F[i + 1];
kptr[4] = weight_hc_O[i];
kptr[5] = weight_hc_O[i + 1];
kptr[6] = weight_hc_G[i];
kptr[7] = weight_hc_G[i + 1];
kptr += 8;
}
for (; i < num_output; i++)
{
kptr[0] = weight_hc_I[i];
kptr[1] = weight_hc_F[i];
Expand Down Expand Up @@ -183,15 +253,184 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d
const signed char* kptr = weight_data_tm.row<const signed char>(q);
const float* descales_ptr = weight_data_tm_int8_descales.row(q);

const float descale_xc_I = descales_ptr[0];
const float descale_xc_F = descales_ptr[1];
const float descale_xc_O = descales_ptr[2];
const float descale_xc_G = descales_ptr[3];
const float descale_hc_I = descales_ptr[4];
const float descale_hc_F = descales_ptr[5];
const float descale_hc_O = descales_ptr[6];
const float descale_hc_G = descales_ptr[7];
float* gates_data = gates.row(q);

#if __ARM_NEON
int32x4_t _lstm_IFOGx0 = vdupq_n_s32(0);
int i = 0;
#if __ARM_FEATURE_DOTPROD
int32x4_t _sum1 = vdupq_n_s32(0);
int32x4_t _sum2 = vdupq_n_s32(0);
int32x4_t _sum3 = vdupq_n_s32(0);
for (; i + 15 < size; i += 16)
{
int32x4_t _xi01 = vreinterpretq_s32_s8(vld1q_s8(x + i));
int8x16_t _xi0 = vreinterpretq_s8_s32(vdupq_laneq_s32(_xi01, 0));
int8x16_t _xi1 = vreinterpretq_s8_s32(vdupq_laneq_s32(_xi01, 1));
int8x16_t _xi2 = vreinterpretq_s8_s32(vdupq_laneq_s32(_xi01, 2));
int8x16_t _xi3 = vreinterpretq_s8_s32(vdupq_laneq_s32(_xi01, 3));
int8x16_t _w0 = vld1q_s8(kptr);
int8x16_t _w1 = vld1q_s8(kptr + 16);
int8x16_t _w2 = vld1q_s8(kptr + 32);
int8x16_t _w3 = vld1q_s8(kptr + 48);
_lstm_IFOGx0 = vdotq_s32(_lstm_IFOGx0, _w0, _xi0);
_sum1 = vdotq_s32(_sum1, _w1, _xi1);
_sum2 = vdotq_s32(_sum2, _w2, _xi2);
_sum3 = vdotq_s32(_sum3, _w3, _xi3);

kptr += 64;
}
for (; i + 7 < size; i += 8)
{
int32x2_t _xi01 = vreinterpret_s32_s8(vld1_s8(x + i));
int8x16_t _xi0 = vreinterpretq_s8_s32(vdupq_lane_s32(_xi01, 0));
int8x16_t _xi1 = vreinterpretq_s8_s32(vdupq_lane_s32(_xi01, 1));
int8x16_t _w0 = vld1q_s8(kptr);
int8x16_t _w1 = vld1q_s8(kptr + 16);
_lstm_IFOGx0 = vdotq_s32(_lstm_IFOGx0, _w0, _xi0);
_sum1 = vdotq_s32(_sum1, _w1, _xi1);

kptr += 32;
}
_lstm_IFOGx0 = vaddq_s32(_lstm_IFOGx0, _sum1);
_lstm_IFOGx0 = vaddq_s32(_lstm_IFOGx0, _sum2);
_lstm_IFOGx0 = vaddq_s32(_lstm_IFOGx0, _sum3);
#endif // __ARM_FEATURE_DOTPROD
for (; i + 3 < size; i += 4)
{
#if __ARM_FEATURE_DOTPROD
int8x16_t _xi = vreinterpretq_s8_s32(vdupq_lane_s32(vreinterpret_s32_s8(vld1_s8(x + i)), 0));
int8x16_t _w = vld1q_s8(kptr);
_lstm_IFOGx0 = vdotq_s32(_lstm_IFOGx0, _w, _xi);
#else
int16x4_t _xi01 = vreinterpret_s16_s8(vld1_s8(x + i));
int8x8_t _xi0 = vreinterpret_s8_s16(vdup_lane_s16(_xi01, 0));
int8x8_t _xi1 = vreinterpret_s8_s16(vdup_lane_s16(_xi01, 1));
int8x16_t _w01 = vld1q_s8(kptr);

int16x8_t _lstm_IFOGx = vmull_s8(vget_low_s8(_w01), _xi0);
_lstm_IFOGx = vmlal_s8(_lstm_IFOGx, vget_high_s8(_w01), _xi1);
_lstm_IFOGx0 = vpadalq_s16(_lstm_IFOGx0, _lstm_IFOGx);
#endif // __ARM_FEATURE_DOTPROD

kptr += 16;
}
for (; i + 1 < size; i += 2)
{
int8x8_t _xi = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vld1_s8(x + i)), 0));
int8x8_t _w = vld1_s8(kptr);

int16x8_t _lstm_IFOGx = vmull_s8(_w, _xi);
_lstm_IFOGx0 = vpadalq_s16(_lstm_IFOGx0, _lstm_IFOGx);

kptr += 8;
}
for (; i < size; i++)
{
int8x8_t _xi = vdup_n_s8(x[i]);
int8x8_t _w = vld1_s8(kptr);

int16x8_t _lstm_IFOGx = vmull_s8(_w, _xi);
_lstm_IFOGx0 = vaddw_s16(_lstm_IFOGx0, vget_low_s16(_lstm_IFOGx));

kptr += 4;
}

int32x4_t _lstm_IFOGh0 = vdupq_n_s32(0);
i = 0;
#if __ARM_FEATURE_DOTPROD
_sum1 = vdupq_n_s32(0);
_sum2 = vdupq_n_s32(0);
_sum3 = vdupq_n_s32(0);
for (; i + 15 < num_output; i += 16)
{
int32x4_t _h_cont01 = vreinterpretq_s32_s8(vld1q_s8(hs + i));
int8x16_t _h_cont0 = vreinterpretq_s8_s32(vdupq_laneq_s32(_h_cont01, 0));
int8x16_t _h_cont1 = vreinterpretq_s8_s32(vdupq_laneq_s32(_h_cont01, 1));
int8x16_t _h_cont2 = vreinterpretq_s8_s32(vdupq_laneq_s32(_h_cont01, 2));
int8x16_t _h_cont3 = vreinterpretq_s8_s32(vdupq_laneq_s32(_h_cont01, 3));
int8x16_t _w0 = vld1q_s8(kptr);
int8x16_t _w1 = vld1q_s8(kptr + 16);
int8x16_t _w2 = vld1q_s8(kptr + 32);
int8x16_t _w3 = vld1q_s8(kptr + 48);
_lstm_IFOGh0 = vdotq_s32(_lstm_IFOGh0, _w0, _h_cont0);
_sum1 = vdotq_s32(_sum1, _w1, _h_cont1);
_sum2 = vdotq_s32(_sum2, _w2, _h_cont2);
_sum3 = vdotq_s32(_sum3, _w3, _h_cont3);

kptr += 64;
}
for (; i + 7 < num_output; i += 8)
{
int32x2_t _h_cont01 = vreinterpret_s32_s8(vld1_s8(hs + i));
int8x16_t _h_cont0 = vreinterpretq_s8_s32(vdupq_lane_s32(_h_cont01, 0));
int8x16_t _h_cont1 = vreinterpretq_s8_s32(vdupq_lane_s32(_h_cont01, 1));
int8x16_t _w0 = vld1q_s8(kptr);
int8x16_t _w1 = vld1q_s8(kptr + 16);
_lstm_IFOGh0 = vdotq_s32(_lstm_IFOGh0, _w0, _h_cont0);
_sum1 = vdotq_s32(_sum1, _w1, _h_cont1);

kptr += 32;
}
_lstm_IFOGh0 = vaddq_s32(_lstm_IFOGh0, _sum1);
_lstm_IFOGh0 = vaddq_s32(_lstm_IFOGh0, _sum2);
_lstm_IFOGh0 = vaddq_s32(_lstm_IFOGh0, _sum3);
#endif // __ARM_FEATURE_DOTPROD
for (; i + 3 < num_output; i += 4)
{
#if __ARM_FEATURE_DOTPROD
int8x16_t _h_cont = vreinterpretq_s8_s32(vdupq_lane_s32(vreinterpret_s32_s8(vld1_s8(hs + i)), 0));
int8x16_t _w = vld1q_s8(kptr);
_lstm_IFOGh0 = vdotq_s32(_lstm_IFOGh0, _w, _h_cont);
#else
int16x4_t _h_cont01 = vreinterpret_s16_s8(vld1_s8(hs + i));
int8x8_t _h_cont0 = vreinterpret_s8_s16(vdup_lane_s16(_h_cont01, 0));
int8x8_t _h_cont1 = vreinterpret_s8_s16(vdup_lane_s16(_h_cont01, 1));
int8x16_t _w01 = vld1q_s8(kptr);

int16x8_t _lstm_IFOGh = vmull_s8(vget_low_s8(_w01), _h_cont0);
_lstm_IFOGh = vmlal_s8(_lstm_IFOGh, vget_high_s8(_w01), _h_cont1);
_lstm_IFOGh0 = vpadalq_s16(_lstm_IFOGh0, _lstm_IFOGh);
#endif // __ARM_FEATURE_DOTPROD

kptr += 16;
}
for (; i + 1 < num_output; i += 2)
{
int8x8_t _h_cont = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vld1_s8(hs + i)), 0));
int8x8_t _w = vld1_s8(kptr);

int16x8_t _lstm_IFOGh = vmull_s8(_w, _h_cont);
_lstm_IFOGh0 = vpadalq_s16(_lstm_IFOGh0, _lstm_IFOGh);

kptr += 8;
}
for (; i < num_output; i++)
{
int8x8_t _h_cont = vdup_n_s8(hs[i]);
int8x8_t _w = vld1_s8(kptr);

int16x8_t _lstm_IFOGh = vmull_s8(_w, _h_cont);
_lstm_IFOGh0 = vaddw_s16(_lstm_IFOGh0, vget_low_s16(_lstm_IFOGh));

kptr += 4;
}

float32x4_t _descale_x = vdupq_n_f32(descale_x);
float32x4_t _descale_h = vdupq_n_f32(descale_h);

float32x4_t _lstm_IFOG0 = vld1q_f32(bias_c_IFOG);

float32x4_t _descale_xc_IFOG = vld1q_f32(descales_ptr);

_lstm_IFOG0 = vmlaq_f32(_lstm_IFOG0, vcvtq_f32_s32(_lstm_IFOGx0), vmulq_f32(_descale_x, _descale_xc_IFOG));

float32x4_t _descale_hc_IFOG = vld1q_f32(descales_ptr + 4);

_lstm_IFOG0 = vmlaq_f32(_lstm_IFOG0, vcvtq_f32_s32(_lstm_IFOGh0), vmulq_f32(_descale_h, _descale_hc_IFOG));

vst1q_f32(gates_data, _lstm_IFOG0);
#else
int Ix = 0;
int Fx = 0;
int Ox = 0;
Expand Down Expand Up @@ -224,17 +463,25 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d
kptr += 4;
}

const float descale_xc_I = descales_ptr[0];
const float descale_xc_F = descales_ptr[1];
const float descale_xc_O = descales_ptr[2];
const float descale_xc_G = descales_ptr[3];
const float descale_hc_I = descales_ptr[4];
const float descale_hc_F = descales_ptr[5];
const float descale_hc_O = descales_ptr[6];
const float descale_hc_G = descales_ptr[7];

float I = bias_c_IFOG[0] + Ix * (descale_x * descale_xc_I) + Ih * (descale_h * descale_hc_I);
float F = bias_c_IFOG[1] + Fx * (descale_x * descale_xc_F) + Fh * (descale_h * descale_hc_F);
float O = bias_c_IFOG[2] + Ox * (descale_x * descale_xc_O) + Oh * (descale_h * descale_hc_O);
float G = bias_c_IFOG[3] + Gx * (descale_x * descale_xc_G) + Gh * (descale_h * descale_hc_G);

float* gates_data = gates.row(q);

gates_data[0] = I;
gates_data[1] = F;
gates_data[2] = O;
gates_data[3] = G;
#endif // __ARM_NEON
}

// lstm unit
Expand Down

0 comments on commit 3ddb97b

Please sign in to comment.