Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

x86 sse2/avx2 optimization for convolution sgemm/winograd int8 family #4286

Merged
merged 4 commits into from Oct 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
617 changes: 255 additions & 362 deletions src/layer/x86/convolution_3x3_pack8to1_int8.h

Large diffs are not rendered by default.

513 changes: 203 additions & 310 deletions src/layer/x86/convolution_3x3_pack8to4_int8.h

Large diffs are not rendered by default.

176 changes: 20 additions & 156 deletions src/layer/x86/convolution_sgemm_int8.h
Expand Up @@ -338,17 +338,8 @@ static void im2col_sgemm_int8_sse(const Mat& bottom_im2col, Mat& top_blob, const

if (nn4 > 0)
{
#if __AVXVNNI__ || __AVX512VNNI__
__m256i _sum10_02 = _mm256_setzero_si256();
__m256i _sum30_22 = _mm256_setzero_si256();
#else
__m256i _sum10_02 = _mm256_setzero_si256();
__m256i _sum01_13 = _mm256_setzero_si256();
__m256i _sum11_03 = _mm256_setzero_si256();
__m256i _sum30_22 = _mm256_setzero_si256();
__m256i _sum21_33 = _mm256_setzero_si256();
__m256i _sum31_23 = _mm256_setzero_si256();
#endif

int j = 0;
for (; j < nn4; j++)
Expand All @@ -371,72 +362,21 @@ static void im2col_sgemm_int8_sse(const Mat& bottom_im2col, Mat& top_blob, const
_sum20_32 = _mm256_dpwssd_epi32(_sum20_32, _val23_16, _w01_16);
_sum30_22 = _mm256_dpwssd_epi32(_sum30_22, _val32_16, _w01_16);
#else
__m256i _sl00_11 = _mm256_mullo_epi16(_val01_16, _w01_16);
__m256i _sh00_11 = _mm256_mulhi_epi16(_val01_16, _w01_16);
__m256i _sl10_01 = _mm256_mullo_epi16(_val10_16, _w01_16);
__m256i _sh10_01 = _mm256_mulhi_epi16(_val10_16, _w01_16);
__m256i _sl20_31 = _mm256_mullo_epi16(_val23_16, _w01_16);
__m256i _sh20_31 = _mm256_mulhi_epi16(_val23_16, _w01_16);
__m256i _sl30_21 = _mm256_mullo_epi16(_val32_16, _w01_16);
__m256i _sh30_21 = _mm256_mulhi_epi16(_val32_16, _w01_16);

_sum00_12 = _mm256_add_epi32(_sum00_12, _mm256_unpacklo_epi16(_sl00_11, _sh00_11));
_sum10_02 = _mm256_add_epi32(_sum10_02, _mm256_unpacklo_epi16(_sl10_01, _sh10_01));
_sum01_13 = _mm256_add_epi32(_sum01_13, _mm256_unpackhi_epi16(_sl00_11, _sh00_11));
_sum11_03 = _mm256_add_epi32(_sum11_03, _mm256_unpackhi_epi16(_sl10_01, _sh10_01));
_sum20_32 = _mm256_add_epi32(_sum20_32, _mm256_unpacklo_epi16(_sl20_31, _sh20_31));
_sum30_22 = _mm256_add_epi32(_sum30_22, _mm256_unpacklo_epi16(_sl30_21, _sh30_21));
_sum21_33 = _mm256_add_epi32(_sum21_33, _mm256_unpackhi_epi16(_sl20_31, _sh20_31));
_sum31_23 = _mm256_add_epi32(_sum31_23, _mm256_unpackhi_epi16(_sl30_21, _sh30_21));
_sum00_12 = _mm256_add_epi32(_sum00_12, _mm256_madd_epi16(_val01_16, _w01_16));
_sum10_02 = _mm256_add_epi32(_sum10_02, _mm256_madd_epi16(_val10_16, _w01_16));
_sum20_32 = _mm256_add_epi32(_sum20_32, _mm256_madd_epi16(_val23_16, _w01_16));
_sum30_22 = _mm256_add_epi32(_sum30_22, _mm256_madd_epi16(_val32_16, _w01_16));
#endif

tmpptr += 16;
kptr0 += 16;
}

#if __AVXVNNI__ || __AVX512VNNI__
_sum00_12 = _mm256_hadd_epi32(_sum00_12, _sum10_02);
_sum20_32 = _mm256_hadd_epi32(_sum20_32, _sum30_22);

_sum00_12 = _mm256_permute4x64_epi64(_sum00_12, _MM_SHUFFLE(2, 1, 3, 0));
_sum20_32 = _mm256_permute4x64_epi64(_sum20_32, _MM_SHUFFLE(2, 1, 3, 0));
#else
// transpose 4x8
{
__m256i _tmp0, _tmp1, _tmp2, _tmp3;
_tmp0 = _mm256_unpacklo_epi32(_sum00_12, _sum10_02);
_tmp1 = _mm256_unpacklo_epi32(_sum01_13, _sum11_03);
_tmp2 = _mm256_unpackhi_epi32(_sum00_12, _sum10_02);
_tmp3 = _mm256_unpackhi_epi32(_sum01_13, _sum11_03);
_sum00_12 = _mm256_unpacklo_epi64(_tmp0, _tmp1);
_sum10_02 = _mm256_unpackhi_epi64(_tmp0, _tmp1);
_sum01_13 = _mm256_unpacklo_epi64(_tmp2, _tmp3);
_sum11_03 = _mm256_unpackhi_epi64(_tmp2, _tmp3);
}
{
__m256i _tmp0, _tmp1, _tmp2, _tmp3;
_tmp0 = _mm256_unpacklo_epi32(_sum20_32, _sum30_22);
_tmp1 = _mm256_unpacklo_epi32(_sum21_33, _sum31_23);
_tmp2 = _mm256_unpackhi_epi32(_sum20_32, _sum30_22);
_tmp3 = _mm256_unpackhi_epi32(_sum21_33, _sum31_23);
_sum20_32 = _mm256_unpacklo_epi64(_tmp0, _tmp1);
_sum30_22 = _mm256_unpackhi_epi64(_tmp0, _tmp1);
_sum21_33 = _mm256_unpacklo_epi64(_tmp2, _tmp3);
_sum31_23 = _mm256_unpackhi_epi64(_tmp2, _tmp3);
}

_sum00_12 = _mm256_add_epi32(_sum00_12, _sum10_02);
_sum01_13 = _mm256_add_epi32(_sum01_13, _sum11_03);
_sum00_12 = _mm256_add_epi32(_sum00_12, _sum01_13);

_sum20_32 = _mm256_add_epi32(_sum20_32, _sum30_22);
_sum21_33 = _mm256_add_epi32(_sum21_33, _sum31_23);
_sum20_32 = _mm256_add_epi32(_sum20_32, _sum21_33);

__m256i _perm_mask = _mm256_set_epi32(6, 4, 3, 1, 7, 5, 2, 0);
_sum00_12 = _mm256_permutevar8x32_epi32(_sum00_12, _perm_mask);
_sum20_32 = _mm256_permutevar8x32_epi32(_sum20_32, _perm_mask);
#endif
}

__m128i _sum00 = _mm256_extracti128_si256(_sum00_12, 0);
Expand Down Expand Up @@ -532,25 +472,10 @@ static void im2col_sgemm_int8_sse(const Mat& bottom_im2col, Mat& top_blob, const
if (nn4 > 0)
{
#if __AVX2__
#if __AVXVNNI__ || __AVX512VNNI__
__m256i _sum10_02 = _mm256_setzero_si256();
#else
__m256i _sum10_02 = _mm256_setzero_si256();
__m256i _sum01_13 = _mm256_setzero_si256();
__m256i _sum11_03 = _mm256_setzero_si256();
#endif
#else
#if __XOP__
__m128i _sum01 = _mm_setzero_si128();
__m128i _sum11 = _mm_setzero_si128();
#else
__m128i _sum01 = _mm_setzero_si128();
__m128i _sum02 = _mm_setzero_si128();
__m128i _sum03 = _mm_setzero_si128();
__m128i _sum11 = _mm_setzero_si128();
__m128i _sum12 = _mm_setzero_si128();
__m128i _sum13 = _mm_setzero_si128();
#endif
#endif

int j = 0;
Expand All @@ -571,15 +496,8 @@ static void im2col_sgemm_int8_sse(const Mat& bottom_im2col, Mat& top_blob, const
_sum00_12 = _mm256_dpwssd_epi32(_sum00_12, _val01_16, _w01_16);
_sum10_02 = _mm256_dpwssd_epi32(_sum10_02, _val10_16, _w01_16);
#else
__m256i _sl00_11 = _mm256_mullo_epi16(_val01_16, _w01_16);
__m256i _sh00_11 = _mm256_mulhi_epi16(_val01_16, _w01_16);
__m256i _sl10_01 = _mm256_mullo_epi16(_val10_16, _w01_16);
__m256i _sh10_01 = _mm256_mulhi_epi16(_val10_16, _w01_16);

_sum00_12 = _mm256_add_epi32(_sum00_12, _mm256_unpacklo_epi16(_sl00_11, _sh00_11));
_sum10_02 = _mm256_add_epi32(_sum10_02, _mm256_unpacklo_epi16(_sl10_01, _sh10_01));
_sum01_13 = _mm256_add_epi32(_sum01_13, _mm256_unpackhi_epi16(_sl00_11, _sh00_11));
_sum11_03 = _mm256_add_epi32(_sum11_03, _mm256_unpackhi_epi16(_sl10_01, _sh10_01));
_sum00_12 = _mm256_add_epi32(_sum00_12, _mm256_madd_epi16(_val01_16, _w01_16));
_sum10_02 = _mm256_add_epi32(_sum10_02, _mm256_madd_epi16(_val10_16, _w01_16));
#endif
#else
__m128i _val01 = _mm_loadl_epi64((const __m128i*)tmpptr);
Expand All @@ -604,23 +522,10 @@ static void im2col_sgemm_int8_sse(const Mat& bottom_im2col, Mat& top_blob, const
_sum10 = _mm_maddd_epi16(_val1, _w0, _sum10);
_sum11 = _mm_maddd_epi16(_val1, _w1, _sum11);
#else
__m128i _sl00 = _mm_mullo_epi16(_val0, _w0);
__m128i _sh00 = _mm_mulhi_epi16(_val0, _w0);
__m128i _sl01 = _mm_mullo_epi16(_val0, _w1);
__m128i _sh01 = _mm_mulhi_epi16(_val0, _w1);
__m128i _sl10 = _mm_mullo_epi16(_val1, _w0);
__m128i _sh10 = _mm_mulhi_epi16(_val1, _w0);
__m128i _sl11 = _mm_mullo_epi16(_val1, _w1);
__m128i _sh11 = _mm_mulhi_epi16(_val1, _w1);

_sum00 = _mm_add_epi32(_sum00, _mm_unpacklo_epi16(_sl00, _sh00));
_sum01 = _mm_add_epi32(_sum01, _mm_unpackhi_epi16(_sl00, _sh00));
_sum02 = _mm_add_epi32(_sum02, _mm_unpacklo_epi16(_sl01, _sh01));
_sum03 = _mm_add_epi32(_sum03, _mm_unpackhi_epi16(_sl01, _sh01));
_sum10 = _mm_add_epi32(_sum10, _mm_unpacklo_epi16(_sl10, _sh10));
_sum11 = _mm_add_epi32(_sum11, _mm_unpackhi_epi16(_sl10, _sh10));
_sum12 = _mm_add_epi32(_sum12, _mm_unpacklo_epi16(_sl11, _sh11));
_sum13 = _mm_add_epi32(_sum13, _mm_unpackhi_epi16(_sl11, _sh11));
_sum00 = _mm_add_epi32(_mm_madd_epi16(_val0, _w0), _sum00);
_sum01 = _mm_add_epi32(_mm_madd_epi16(_val0, _w1), _sum01);
_sum10 = _mm_add_epi32(_mm_madd_epi16(_val1, _w0), _sum10);
_sum11 = _mm_add_epi32(_mm_madd_epi16(_val1, _w1), _sum11);
#endif
#endif

Expand All @@ -629,67 +534,26 @@ static void im2col_sgemm_int8_sse(const Mat& bottom_im2col, Mat& top_blob, const
}

#if __AVX2__
#if __AVXVNNI__ || __AVX512VNNI__
_sum00_12 = _mm256_hadd_epi32(_sum00_12, _sum10_02);

_sum00_12 = _mm256_permute4x64_epi64(_sum00_12, _MM_SHUFFLE(2, 1, 3, 0));
#else
// transpose 4x8
{
__m256i _tmp0, _tmp1, _tmp2, _tmp3;
_tmp0 = _mm256_unpacklo_epi32(_sum00_12, _sum10_02);
_tmp1 = _mm256_unpacklo_epi32(_sum01_13, _sum11_03);
_tmp2 = _mm256_unpackhi_epi32(_sum00_12, _sum10_02);
_tmp3 = _mm256_unpackhi_epi32(_sum01_13, _sum11_03);
_sum00_12 = _mm256_unpacklo_epi64(_tmp0, _tmp1);
_sum10_02 = _mm256_unpackhi_epi64(_tmp0, _tmp1);
_sum01_13 = _mm256_unpacklo_epi64(_tmp2, _tmp3);
_sum11_03 = _mm256_unpackhi_epi64(_tmp2, _tmp3);
}

_sum00_12 = _mm256_add_epi32(_sum00_12, _sum10_02);
_sum01_13 = _mm256_add_epi32(_sum01_13, _sum11_03);
_sum00_12 = _mm256_add_epi32(_sum00_12, _sum01_13);

__m256i _perm_mask = _mm256_set_epi32(6, 4, 3, 1, 7, 5, 2, 0);
_sum00_12 = _mm256_permutevar8x32_epi32(_sum00_12, _perm_mask);
#endif
#else
#if __XOP__
#if __SSSE3__
_sum00 = _mm_hadd_epi32(_sum00, _sum01);
_sum10 = _mm_hadd_epi32(_sum10, _sum11);
#else
// transpose 4x4
{
__m128i _tmp0, _tmp1, _tmp2, _tmp3;
_tmp0 = _mm_unpacklo_epi32(_sum00, _sum01);
_tmp1 = _mm_unpacklo_epi32(_sum02, _sum03);
_tmp2 = _mm_unpackhi_epi32(_sum00, _sum01);
_tmp3 = _mm_unpackhi_epi32(_sum02, _sum03);
_sum00 = _mm_unpacklo_epi64(_tmp0, _tmp1);
_sum01 = _mm_unpackhi_epi64(_tmp0, _tmp1);
_sum02 = _mm_unpacklo_epi64(_tmp2, _tmp3);
_sum03 = _mm_unpackhi_epi64(_tmp2, _tmp3);
}
{
__m128i _tmp0, _tmp1, _tmp2, _tmp3;
_tmp0 = _mm_unpacklo_epi32(_sum10, _sum11);
_tmp1 = _mm_unpacklo_epi32(_sum12, _sum13);
_tmp2 = _mm_unpackhi_epi32(_sum10, _sum11);
_tmp3 = _mm_unpackhi_epi32(_sum12, _sum13);
_sum10 = _mm_unpacklo_epi64(_tmp0, _tmp1);
_sum11 = _mm_unpackhi_epi64(_tmp0, _tmp1);
_sum12 = _mm_unpacklo_epi64(_tmp2, _tmp3);
_sum13 = _mm_unpackhi_epi64(_tmp2, _tmp3);
}
__m128i _sum00_sh = _mm_shuffle_epi32(_sum00, 216);
__m128i _sum01_sh = _mm_shuffle_epi32(_sum01, 216);
__m128i _sum10_sh = _mm_shuffle_epi32(_sum10, 216);
__m128i _sum11_sh = _mm_shuffle_epi32(_sum11, 216);

_sum00 = _mm_unpacklo_epi64(_sum00_sh, _sum01_sh);
_sum01 = _mm_unpackhi_epi64(_sum00_sh, _sum01_sh);
_sum10 = _mm_unpacklo_epi64(_sum10_sh, _sum11_sh);
_sum11 = _mm_unpackhi_epi64(_sum10_sh, _sum11_sh);

_sum00 = _mm_add_epi32(_sum00, _sum01);
_sum02 = _mm_add_epi32(_sum02, _sum03);
_sum10 = _mm_add_epi32(_sum10, _sum11);
_sum12 = _mm_add_epi32(_sum12, _sum13);

_sum00 = _mm_add_epi32(_sum00, _sum02);
_sum10 = _mm_add_epi32(_sum10, _sum12);
#endif
#endif
}
Expand Down