From 0994a40c64a897533e3ae1b920034c3b35227d7b Mon Sep 17 00:00:00 2001 From: Alan Kelly Date: Mon, 6 May 2024 07:01:23 -0700 Subject: [PATCH] Accumulating AVX rdsum microkernels PiperOrigin-RevId: 631032144 --- bench/f32-rdsum.cc | 30 + cmake/microkernels.cmake | 3 + microkernels.bzl | 3 + scripts/generate-f32-rdsum.sh | 5 + src/f32-rdsum/avx.c.in | 143 ++++ .../gen/f32-rdsum-7p7x-minmax-avx-c16.c | 218 +++++++ .../gen/f32-rdsum-7p7x-minmax-avx-c32.c | 260 ++++++++ .../gen/f32-rdsum-7p7x-minmax-avx-c64.c | 344 ++++++++++ src/xnnpack/gavgpool.h | 3 + src/xnnpack/reduce.h | 5 +- test/f32-rdsum.cc | 609 ++++++++++++++++++ test/f32-rdsum.yaml | 7 + test/gavgpool-microkernel-tester.h | 3 +- 13 files changed, 1630 insertions(+), 3 deletions(-) create mode 100644 src/f32-rdsum/avx.c.in create mode 100644 src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c16.c create mode 100644 src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c32.c create mode 100644 src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c64.c diff --git a/bench/f32-rdsum.cc b/bench/f32-rdsum.cc index 50f5b9b522f..fbf45838f75 100644 --- a/bench/f32-rdsum.cc +++ b/bench/f32-rdsum.cc @@ -82,6 +82,36 @@ BENCHMARK_CAPTURE(f32_rsum_discontig, scalar_c4, #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 +#if XNN_ARCH_X86 || XNN_ARCH_X86_64 + BENCHMARK_CAPTURE(f32_rsum_discontig, avx_c16, + xnn_f32_rdsum_ukernel_7p7x__avx_c16, + xnn_init_f32_scale_avx_params, + benchmark::utils::CheckAVX) + ->Apply(BenchmarkBatch) + ->UseRealTime(); +#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 + + +#if XNN_ARCH_X86 || XNN_ARCH_X86_64 + BENCHMARK_CAPTURE(f32_rsum_discontig, avx_c32, + xnn_f32_rdsum_ukernel_7p7x__avx_c32, + xnn_init_f32_scale_avx_params, + benchmark::utils::CheckAVX) + ->Apply(BenchmarkBatch) + ->UseRealTime(); +#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 + + +#if XNN_ARCH_X86 || XNN_ARCH_X86_64 + BENCHMARK_CAPTURE(f32_rsum_discontig, avx_c64, + xnn_f32_rdsum_ukernel_7p7x__avx_c64, + xnn_init_f32_scale_avx_params, + benchmark::utils::CheckAVX) + ->Apply(BenchmarkBatch) + ->UseRealTime(); +#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 + + #ifndef XNNPACK_BENCHMARK_NO_MAIN BENCHMARK_MAIN(); #endif diff --git a/cmake/microkernels.cmake b/cmake/microkernels.cmake index 1dcc392c412..74365d24070 100644 --- a/cmake/microkernels.cmake +++ b/cmake/microkernels.cmake @@ -132,6 +132,9 @@ SET(ALL_AVX_MICROKERNEL_SRCS src/f32-qu8-vcvt/gen/f32-qu8-vcvt-avx-u16.c src/f32-qu8-vcvt/gen/f32-qu8-vcvt-avx-u24.c src/f32-qu8-vcvt/gen/f32-qu8-vcvt-avx-u32.c + src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c16.c + src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c32.c + src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c64.c src/f32-rminmax/gen/f32-rmax-avx-u8.c src/f32-rminmax/gen/f32-rmax-avx-u16-acc2.c src/f32-rminmax/gen/f32-rmax-avx-u24-acc3.c diff --git a/microkernels.bzl b/microkernels.bzl index f040a7418f4..7a123b50bfb 100644 --- a/microkernels.bzl +++ b/microkernels.bzl @@ -129,6 +129,9 @@ ALL_AVX_MICROKERNEL_SRCS = [ "src/f32-qu8-vcvt/gen/f32-qu8-vcvt-avx-u16.c", "src/f32-qu8-vcvt/gen/f32-qu8-vcvt-avx-u24.c", "src/f32-qu8-vcvt/gen/f32-qu8-vcvt-avx-u32.c", + "src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c16.c", + "src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c32.c", + "src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c64.c", "src/f32-rminmax/gen/f32-rmax-avx-u8.c", "src/f32-rminmax/gen/f32-rmax-avx-u16-acc2.c", "src/f32-rminmax/gen/f32-rmax-avx-u24-acc3.c", diff --git a/scripts/generate-f32-rdsum.sh b/scripts/generate-f32-rdsum.sh index b2e87c187a3..8af194d0c30 100755 --- a/scripts/generate-f32-rdsum.sh +++ b/scripts/generate-f32-rdsum.sh @@ -17,4 +17,9 @@ tools/xngen src/f32-rdsum/sse.c.in -D CHANNELS=16 -D ACCUMULATORS=7 -o src/f32-r tools/xngen src/f32-rdsum/sse.c.in -D CHANNELS=32 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-sse-c32.c & tools/xngen src/f32-rdsum/sse.c.in -D CHANNELS=64 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-sse-c64.c & +#################################### AVX #################################### +tools/xngen src/f32-rdsum/avx.c.in -D CHANNELS=16 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c16.c & +tools/xngen src/f32-rdsum/avx.c.in -D CHANNELS=32 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c32.c & +tools/xngen src/f32-rdsum/avx.c.in -D CHANNELS=64 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c64.c & + wait diff --git a/src/f32-rdsum/avx.c.in b/src/f32-rdsum/avx.c.in new file mode 100644 index 00000000000..7d0ed3a9320 --- /dev/null +++ b/src/f32-rdsum/avx.c.in @@ -0,0 +1,143 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" +#include + +#include + +#include +#include +#include + + +$UNROLL = CHANNELS >> 3 +void xnn_f32_rdsum_ukernel_${ACCUMULATORS}p${ACCUMULATORS}x__avx_c${CHANNELS}( + size_t rows, + size_t channels, + const float* input, + size_t input_stride, + const float* zero, + float* output, + const union xnn_f32_scale_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(rows != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const __m256 vscale = _mm256_set1_ps(params->avx.scale); + + size_t input_increment = ${ACCUMULATORS} * input_stride; + for (; channels >= ${CHANNELS}; channels -= ${CHANNELS}) { + const float* i0 = input; + $for i in range(1, ACCUMULATORS): + const float* i${i} = (const float*) ((uintptr_t) input + ${i} * input_stride); + + $for i in range(UNROLL): + __m256 vacc${i} = _mm256_setzero_ps(); + + for (int r = rows; r > 0; r -= ${ACCUMULATORS}) { + $for N in range(1, ACCUMULATORS, 2): + if XNN_UNPREDICTABLE(r < ${N+1}) { + i${N} = zero; + } + if XNN_UNPREDICTABLE(r <= ${N+1}) { + i${N+1} = zero; + } + $for c in range(UNROLL): + __m256 vin${c}; + $for j in range(ACCUMULATORS): + $for c in range(UNROLL): + vin${c} = _mm256_loadu_ps(&i${j}[${c*8}]); + $for c in range(UNROLL): + vacc${c} = _mm256_add_ps(vin${c}, vacc${c}); + $for N in range(0, ACCUMULATORS): + i${N} = (const float*) ((uintptr_t) i${N} + input_increment); + } + $for i in range(UNROLL): + vacc${i} = _mm256_mul_ps(vacc${i}, vscale); + + const float* o = output; + $for i in range(0, UNROLL): + __m256 vo${i} = _mm256_loadu_ps(o); o += 8; + $for i in range(0, UNROLL): + vacc${i} = _mm256_add_ps(vo${i}, vacc${i}); + $for i in range(0, UNROLL): + _mm256_storeu_ps(output, vacc${i}); output += 8; + + input = (const float*) ((uintptr_t) input + ${CHANNELS} * sizeof(float)); + } + __m256i vmask; + if (channels != 0) { + input_increment = ${ACCUMULATORS} * input_stride; + const float* i0 = input; + $for i in range(1, ACCUMULATORS): + const float* i${i} = (const float*) ((uintptr_t) input + ${i} * input_stride); + __m256 vacc[${UNROLL}]; + $for i in range(UNROLL): + vacc[${i}] = _mm256_setzero_ps(); + + const size_t num_full_chunks = channels >> 3; + const size_t num_chunks = round_up_po2(channels, 8) >> 3; + const size_t remainder = channels & 0x7; + for (int r = rows; r > 0; r -= ${ACCUMULATORS}) { + $for N in range(1, ACCUMULATORS, 2): + if XNN_UNPREDICTABLE(r < ${N+1}) { + i${N} = zero; + } + if XNN_UNPREDICTABLE(r <= ${N+1}) { + i${N+1} = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + $for c in range(ACCUMULATORS): + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i${c}[i*8]), vacc[i]); + } + + if (remainder) { + vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) ¶ms->avx.mask_table[7] - (channels & 0x7) * sizeof(float))); + $for c in range(ACCUMULATORS): + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i${c}[num_full_chunks*8], vmask), vacc[num_full_chunks]); + } + $for N in range(ACCUMULATORS): + i${N} = (const float*) ((uintptr_t) i${N} + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = _mm256_mul_ps(vacc[i], vscale); + } + + __m256 vo[${UNROLL}]; + const float* o = output; + for (int i = 0; i < channels >> 3; ++i) { + vo[i] = _mm256_loadu_ps(o); o += 8; + } + for (int i = 0; i < channels >> 3; ++i) { + vacc[i] = _mm256_add_ps(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 3; ++i) { + _mm256_storeu_ps(output, vacc[i]); output += 8; + } + if (remainder) { + const size_t pos = num_full_chunks; + __m256 vout = vacc[pos]; + const __m256 vdata = _mm256_maskload_ps(output, vmask); + vout = _mm256_add_ps(vout, vdata); + __m128 vout_lo = _mm256_castps256_ps128(vout); + if (channels & 4) { + _mm_storeu_ps(output, vout_lo); + vout_lo = _mm256_extractf128_ps(vout, 1); + output += 4; + } + if (channels & 2) { + _mm_storel_pi((__m64*) output, vout_lo); + vout_lo = _mm_movehl_ps(vout_lo, vout_lo); + output += 2; + } + if (channels & 1) { + _mm_store_ss(output, vout_lo); + } + } + } +} diff --git a/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c16.c b/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c16.c new file mode 100644 index 00000000000..c82ccad6d6f --- /dev/null +++ b/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c16.c @@ -0,0 +1,218 @@ +// Auto-generated file. Do not edit! +// Template: src/f32-rdsum/avx.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include + +#include +#include +#include + + +void xnn_f32_rdsum_ukernel_7p7x__avx_c16( + size_t rows, + size_t channels, + const float* input, + size_t input_stride, + const float* zero, + float* output, + const union xnn_f32_scale_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(rows != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const __m256 vscale = _mm256_set1_ps(params->avx.scale); + + size_t input_increment = 7 * input_stride; + for (; channels >= 16; channels -= 16) { + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + + __m256 vacc0 = _mm256_setzero_ps(); + __m256 vacc1 = _mm256_setzero_ps(); + + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + __m256 vin0; + __m256 vin1; + vin0 = _mm256_loadu_ps(&i0[0]); + vin1 = _mm256_loadu_ps(&i0[8]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vin0 = _mm256_loadu_ps(&i1[0]); + vin1 = _mm256_loadu_ps(&i1[8]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vin0 = _mm256_loadu_ps(&i2[0]); + vin1 = _mm256_loadu_ps(&i2[8]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vin0 = _mm256_loadu_ps(&i3[0]); + vin1 = _mm256_loadu_ps(&i3[8]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vin0 = _mm256_loadu_ps(&i4[0]); + vin1 = _mm256_loadu_ps(&i4[8]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vin0 = _mm256_loadu_ps(&i5[0]); + vin1 = _mm256_loadu_ps(&i5[8]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vin0 = _mm256_loadu_ps(&i6[0]); + vin1 = _mm256_loadu_ps(&i6[8]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = _mm256_mul_ps(vacc0, vscale); + vacc1 = _mm256_mul_ps(vacc1, vscale); + + const float* o = output; + __m256 vo0 = _mm256_loadu_ps(o); o += 8; + __m256 vo1 = _mm256_loadu_ps(o); o += 8; + vacc0 = _mm256_add_ps(vo0, vacc0); + vacc1 = _mm256_add_ps(vo1, vacc1); + _mm256_storeu_ps(output, vacc0); output += 8; + _mm256_storeu_ps(output, vacc1); output += 8; + + input = (const float*) ((uintptr_t) input + 16 * sizeof(float)); + } + __m256i vmask; + if (channels != 0) { + input_increment = 7 * input_stride; + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + __m256 vacc[2]; + vacc[0] = _mm256_setzero_ps(); + vacc[1] = _mm256_setzero_ps(); + + const size_t num_full_chunks = channels >> 3; + const size_t num_chunks = round_up_po2(channels, 8) >> 3; + const size_t remainder = channels & 0x7; + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i0[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i1[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i2[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i3[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i4[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i5[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i6[i*8]), vacc[i]); + } + + if (remainder) { + vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) ¶ms->avx.mask_table[7] - (channels & 0x7) * sizeof(float))); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i0[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i1[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i2[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i3[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i4[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i5[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i6[num_full_chunks*8], vmask), vacc[num_full_chunks]); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = _mm256_mul_ps(vacc[i], vscale); + } + + __m256 vo[2]; + const float* o = output; + for (int i = 0; i < channels >> 3; ++i) { + vo[i] = _mm256_loadu_ps(o); o += 8; + } + for (int i = 0; i < channels >> 3; ++i) { + vacc[i] = _mm256_add_ps(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 3; ++i) { + _mm256_storeu_ps(output, vacc[i]); output += 8; + } + if (remainder) { + const size_t pos = num_full_chunks; + __m256 vout = vacc[pos]; + const __m256 vdata = _mm256_maskload_ps(output, vmask); + vout = _mm256_add_ps(vout, vdata); + __m128 vout_lo = _mm256_castps256_ps128(vout); + if (channels & 4) { + _mm_storeu_ps(output, vout_lo); + vout_lo = _mm256_extractf128_ps(vout, 1); + output += 4; + } + if (channels & 2) { + _mm_storel_pi((__m64*) output, vout_lo); + vout_lo = _mm_movehl_ps(vout_lo, vout_lo); + output += 2; + } + if (channels & 1) { + _mm_store_ss(output, vout_lo); + } + } + } +} diff --git a/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c32.c b/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c32.c new file mode 100644 index 00000000000..9a16e3f3fe9 --- /dev/null +++ b/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c32.c @@ -0,0 +1,260 @@ +// Auto-generated file. Do not edit! +// Template: src/f32-rdsum/avx.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include + +#include +#include +#include + + +void xnn_f32_rdsum_ukernel_7p7x__avx_c32( + size_t rows, + size_t channels, + const float* input, + size_t input_stride, + const float* zero, + float* output, + const union xnn_f32_scale_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(rows != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const __m256 vscale = _mm256_set1_ps(params->avx.scale); + + size_t input_increment = 7 * input_stride; + for (; channels >= 32; channels -= 32) { + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + + __m256 vacc0 = _mm256_setzero_ps(); + __m256 vacc1 = _mm256_setzero_ps(); + __m256 vacc2 = _mm256_setzero_ps(); + __m256 vacc3 = _mm256_setzero_ps(); + + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + __m256 vin0; + __m256 vin1; + __m256 vin2; + __m256 vin3; + vin0 = _mm256_loadu_ps(&i0[0]); + vin1 = _mm256_loadu_ps(&i0[8]); + vin2 = _mm256_loadu_ps(&i0[16]); + vin3 = _mm256_loadu_ps(&i0[24]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vin0 = _mm256_loadu_ps(&i1[0]); + vin1 = _mm256_loadu_ps(&i1[8]); + vin2 = _mm256_loadu_ps(&i1[16]); + vin3 = _mm256_loadu_ps(&i1[24]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vin0 = _mm256_loadu_ps(&i2[0]); + vin1 = _mm256_loadu_ps(&i2[8]); + vin2 = _mm256_loadu_ps(&i2[16]); + vin3 = _mm256_loadu_ps(&i2[24]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vin0 = _mm256_loadu_ps(&i3[0]); + vin1 = _mm256_loadu_ps(&i3[8]); + vin2 = _mm256_loadu_ps(&i3[16]); + vin3 = _mm256_loadu_ps(&i3[24]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vin0 = _mm256_loadu_ps(&i4[0]); + vin1 = _mm256_loadu_ps(&i4[8]); + vin2 = _mm256_loadu_ps(&i4[16]); + vin3 = _mm256_loadu_ps(&i4[24]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vin0 = _mm256_loadu_ps(&i5[0]); + vin1 = _mm256_loadu_ps(&i5[8]); + vin2 = _mm256_loadu_ps(&i5[16]); + vin3 = _mm256_loadu_ps(&i5[24]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vin0 = _mm256_loadu_ps(&i6[0]); + vin1 = _mm256_loadu_ps(&i6[8]); + vin2 = _mm256_loadu_ps(&i6[16]); + vin3 = _mm256_loadu_ps(&i6[24]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = _mm256_mul_ps(vacc0, vscale); + vacc1 = _mm256_mul_ps(vacc1, vscale); + vacc2 = _mm256_mul_ps(vacc2, vscale); + vacc3 = _mm256_mul_ps(vacc3, vscale); + + const float* o = output; + __m256 vo0 = _mm256_loadu_ps(o); o += 8; + __m256 vo1 = _mm256_loadu_ps(o); o += 8; + __m256 vo2 = _mm256_loadu_ps(o); o += 8; + __m256 vo3 = _mm256_loadu_ps(o); o += 8; + vacc0 = _mm256_add_ps(vo0, vacc0); + vacc1 = _mm256_add_ps(vo1, vacc1); + vacc2 = _mm256_add_ps(vo2, vacc2); + vacc3 = _mm256_add_ps(vo3, vacc3); + _mm256_storeu_ps(output, vacc0); output += 8; + _mm256_storeu_ps(output, vacc1); output += 8; + _mm256_storeu_ps(output, vacc2); output += 8; + _mm256_storeu_ps(output, vacc3); output += 8; + + input = (const float*) ((uintptr_t) input + 32 * sizeof(float)); + } + __m256i vmask; + if (channels != 0) { + input_increment = 7 * input_stride; + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + __m256 vacc[4]; + vacc[0] = _mm256_setzero_ps(); + vacc[1] = _mm256_setzero_ps(); + vacc[2] = _mm256_setzero_ps(); + vacc[3] = _mm256_setzero_ps(); + + const size_t num_full_chunks = channels >> 3; + const size_t num_chunks = round_up_po2(channels, 8) >> 3; + const size_t remainder = channels & 0x7; + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i0[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i1[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i2[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i3[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i4[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i5[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i6[i*8]), vacc[i]); + } + + if (remainder) { + vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) ¶ms->avx.mask_table[7] - (channels & 0x7) * sizeof(float))); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i0[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i1[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i2[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i3[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i4[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i5[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i6[num_full_chunks*8], vmask), vacc[num_full_chunks]); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = _mm256_mul_ps(vacc[i], vscale); + } + + __m256 vo[4]; + const float* o = output; + for (int i = 0; i < channels >> 3; ++i) { + vo[i] = _mm256_loadu_ps(o); o += 8; + } + for (int i = 0; i < channels >> 3; ++i) { + vacc[i] = _mm256_add_ps(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 3; ++i) { + _mm256_storeu_ps(output, vacc[i]); output += 8; + } + if (remainder) { + const size_t pos = num_full_chunks; + __m256 vout = vacc[pos]; + const __m256 vdata = _mm256_maskload_ps(output, vmask); + vout = _mm256_add_ps(vout, vdata); + __m128 vout_lo = _mm256_castps256_ps128(vout); + if (channels & 4) { + _mm_storeu_ps(output, vout_lo); + vout_lo = _mm256_extractf128_ps(vout, 1); + output += 4; + } + if (channels & 2) { + _mm_storel_pi((__m64*) output, vout_lo); + vout_lo = _mm_movehl_ps(vout_lo, vout_lo); + output += 2; + } + if (channels & 1) { + _mm_store_ss(output, vout_lo); + } + } + } +} diff --git a/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c64.c b/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c64.c new file mode 100644 index 00000000000..7ff30de01ec --- /dev/null +++ b/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c64.c @@ -0,0 +1,344 @@ +// Auto-generated file. Do not edit! +// Template: src/f32-rdsum/avx.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include + +#include +#include +#include + + +void xnn_f32_rdsum_ukernel_7p7x__avx_c64( + size_t rows, + size_t channels, + const float* input, + size_t input_stride, + const float* zero, + float* output, + const union xnn_f32_scale_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(rows != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const __m256 vscale = _mm256_set1_ps(params->avx.scale); + + size_t input_increment = 7 * input_stride; + for (; channels >= 64; channels -= 64) { + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + + __m256 vacc0 = _mm256_setzero_ps(); + __m256 vacc1 = _mm256_setzero_ps(); + __m256 vacc2 = _mm256_setzero_ps(); + __m256 vacc3 = _mm256_setzero_ps(); + __m256 vacc4 = _mm256_setzero_ps(); + __m256 vacc5 = _mm256_setzero_ps(); + __m256 vacc6 = _mm256_setzero_ps(); + __m256 vacc7 = _mm256_setzero_ps(); + + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + __m256 vin0; + __m256 vin1; + __m256 vin2; + __m256 vin3; + __m256 vin4; + __m256 vin5; + __m256 vin6; + __m256 vin7; + vin0 = _mm256_loadu_ps(&i0[0]); + vin1 = _mm256_loadu_ps(&i0[8]); + vin2 = _mm256_loadu_ps(&i0[16]); + vin3 = _mm256_loadu_ps(&i0[24]); + vin4 = _mm256_loadu_ps(&i0[32]); + vin5 = _mm256_loadu_ps(&i0[40]); + vin6 = _mm256_loadu_ps(&i0[48]); + vin7 = _mm256_loadu_ps(&i0[56]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vacc4 = _mm256_add_ps(vin4, vacc4); + vacc5 = _mm256_add_ps(vin5, vacc5); + vacc6 = _mm256_add_ps(vin6, vacc6); + vacc7 = _mm256_add_ps(vin7, vacc7); + vin0 = _mm256_loadu_ps(&i1[0]); + vin1 = _mm256_loadu_ps(&i1[8]); + vin2 = _mm256_loadu_ps(&i1[16]); + vin3 = _mm256_loadu_ps(&i1[24]); + vin4 = _mm256_loadu_ps(&i1[32]); + vin5 = _mm256_loadu_ps(&i1[40]); + vin6 = _mm256_loadu_ps(&i1[48]); + vin7 = _mm256_loadu_ps(&i1[56]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vacc4 = _mm256_add_ps(vin4, vacc4); + vacc5 = _mm256_add_ps(vin5, vacc5); + vacc6 = _mm256_add_ps(vin6, vacc6); + vacc7 = _mm256_add_ps(vin7, vacc7); + vin0 = _mm256_loadu_ps(&i2[0]); + vin1 = _mm256_loadu_ps(&i2[8]); + vin2 = _mm256_loadu_ps(&i2[16]); + vin3 = _mm256_loadu_ps(&i2[24]); + vin4 = _mm256_loadu_ps(&i2[32]); + vin5 = _mm256_loadu_ps(&i2[40]); + vin6 = _mm256_loadu_ps(&i2[48]); + vin7 = _mm256_loadu_ps(&i2[56]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vacc4 = _mm256_add_ps(vin4, vacc4); + vacc5 = _mm256_add_ps(vin5, vacc5); + vacc6 = _mm256_add_ps(vin6, vacc6); + vacc7 = _mm256_add_ps(vin7, vacc7); + vin0 = _mm256_loadu_ps(&i3[0]); + vin1 = _mm256_loadu_ps(&i3[8]); + vin2 = _mm256_loadu_ps(&i3[16]); + vin3 = _mm256_loadu_ps(&i3[24]); + vin4 = _mm256_loadu_ps(&i3[32]); + vin5 = _mm256_loadu_ps(&i3[40]); + vin6 = _mm256_loadu_ps(&i3[48]); + vin7 = _mm256_loadu_ps(&i3[56]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vacc4 = _mm256_add_ps(vin4, vacc4); + vacc5 = _mm256_add_ps(vin5, vacc5); + vacc6 = _mm256_add_ps(vin6, vacc6); + vacc7 = _mm256_add_ps(vin7, vacc7); + vin0 = _mm256_loadu_ps(&i4[0]); + vin1 = _mm256_loadu_ps(&i4[8]); + vin2 = _mm256_loadu_ps(&i4[16]); + vin3 = _mm256_loadu_ps(&i4[24]); + vin4 = _mm256_loadu_ps(&i4[32]); + vin5 = _mm256_loadu_ps(&i4[40]); + vin6 = _mm256_loadu_ps(&i4[48]); + vin7 = _mm256_loadu_ps(&i4[56]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vacc4 = _mm256_add_ps(vin4, vacc4); + vacc5 = _mm256_add_ps(vin5, vacc5); + vacc6 = _mm256_add_ps(vin6, vacc6); + vacc7 = _mm256_add_ps(vin7, vacc7); + vin0 = _mm256_loadu_ps(&i5[0]); + vin1 = _mm256_loadu_ps(&i5[8]); + vin2 = _mm256_loadu_ps(&i5[16]); + vin3 = _mm256_loadu_ps(&i5[24]); + vin4 = _mm256_loadu_ps(&i5[32]); + vin5 = _mm256_loadu_ps(&i5[40]); + vin6 = _mm256_loadu_ps(&i5[48]); + vin7 = _mm256_loadu_ps(&i5[56]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vacc4 = _mm256_add_ps(vin4, vacc4); + vacc5 = _mm256_add_ps(vin5, vacc5); + vacc6 = _mm256_add_ps(vin6, vacc6); + vacc7 = _mm256_add_ps(vin7, vacc7); + vin0 = _mm256_loadu_ps(&i6[0]); + vin1 = _mm256_loadu_ps(&i6[8]); + vin2 = _mm256_loadu_ps(&i6[16]); + vin3 = _mm256_loadu_ps(&i6[24]); + vin4 = _mm256_loadu_ps(&i6[32]); + vin5 = _mm256_loadu_ps(&i6[40]); + vin6 = _mm256_loadu_ps(&i6[48]); + vin7 = _mm256_loadu_ps(&i6[56]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vacc4 = _mm256_add_ps(vin4, vacc4); + vacc5 = _mm256_add_ps(vin5, vacc5); + vacc6 = _mm256_add_ps(vin6, vacc6); + vacc7 = _mm256_add_ps(vin7, vacc7); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = _mm256_mul_ps(vacc0, vscale); + vacc1 = _mm256_mul_ps(vacc1, vscale); + vacc2 = _mm256_mul_ps(vacc2, vscale); + vacc3 = _mm256_mul_ps(vacc3, vscale); + vacc4 = _mm256_mul_ps(vacc4, vscale); + vacc5 = _mm256_mul_ps(vacc5, vscale); + vacc6 = _mm256_mul_ps(vacc6, vscale); + vacc7 = _mm256_mul_ps(vacc7, vscale); + + const float* o = output; + __m256 vo0 = _mm256_loadu_ps(o); o += 8; + __m256 vo1 = _mm256_loadu_ps(o); o += 8; + __m256 vo2 = _mm256_loadu_ps(o); o += 8; + __m256 vo3 = _mm256_loadu_ps(o); o += 8; + __m256 vo4 = _mm256_loadu_ps(o); o += 8; + __m256 vo5 = _mm256_loadu_ps(o); o += 8; + __m256 vo6 = _mm256_loadu_ps(o); o += 8; + __m256 vo7 = _mm256_loadu_ps(o); o += 8; + vacc0 = _mm256_add_ps(vo0, vacc0); + vacc1 = _mm256_add_ps(vo1, vacc1); + vacc2 = _mm256_add_ps(vo2, vacc2); + vacc3 = _mm256_add_ps(vo3, vacc3); + vacc4 = _mm256_add_ps(vo4, vacc4); + vacc5 = _mm256_add_ps(vo5, vacc5); + vacc6 = _mm256_add_ps(vo6, vacc6); + vacc7 = _mm256_add_ps(vo7, vacc7); + _mm256_storeu_ps(output, vacc0); output += 8; + _mm256_storeu_ps(output, vacc1); output += 8; + _mm256_storeu_ps(output, vacc2); output += 8; + _mm256_storeu_ps(output, vacc3); output += 8; + _mm256_storeu_ps(output, vacc4); output += 8; + _mm256_storeu_ps(output, vacc5); output += 8; + _mm256_storeu_ps(output, vacc6); output += 8; + _mm256_storeu_ps(output, vacc7); output += 8; + + input = (const float*) ((uintptr_t) input + 64 * sizeof(float)); + } + __m256i vmask; + if (channels != 0) { + input_increment = 7 * input_stride; + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + __m256 vacc[8]; + vacc[0] = _mm256_setzero_ps(); + vacc[1] = _mm256_setzero_ps(); + vacc[2] = _mm256_setzero_ps(); + vacc[3] = _mm256_setzero_ps(); + vacc[4] = _mm256_setzero_ps(); + vacc[5] = _mm256_setzero_ps(); + vacc[6] = _mm256_setzero_ps(); + vacc[7] = _mm256_setzero_ps(); + + const size_t num_full_chunks = channels >> 3; + const size_t num_chunks = round_up_po2(channels, 8) >> 3; + const size_t remainder = channels & 0x7; + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i0[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i1[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i2[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i3[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i4[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i5[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i6[i*8]), vacc[i]); + } + + if (remainder) { + vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) ¶ms->avx.mask_table[7] - (channels & 0x7) * sizeof(float))); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i0[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i1[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i2[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i3[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i4[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i5[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i6[num_full_chunks*8], vmask), vacc[num_full_chunks]); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = _mm256_mul_ps(vacc[i], vscale); + } + + __m256 vo[8]; + const float* o = output; + for (int i = 0; i < channels >> 3; ++i) { + vo[i] = _mm256_loadu_ps(o); o += 8; + } + for (int i = 0; i < channels >> 3; ++i) { + vacc[i] = _mm256_add_ps(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 3; ++i) { + _mm256_storeu_ps(output, vacc[i]); output += 8; + } + if (remainder) { + const size_t pos = num_full_chunks; + __m256 vout = vacc[pos]; + const __m256 vdata = _mm256_maskload_ps(output, vmask); + vout = _mm256_add_ps(vout, vdata); + __m128 vout_lo = _mm256_castps256_ps128(vout); + if (channels & 4) { + _mm_storeu_ps(output, vout_lo); + vout_lo = _mm256_extractf128_ps(vout, 1); + output += 4; + } + if (channels & 2) { + _mm_storel_pi((__m64*) output, vout_lo); + vout_lo = _mm_movehl_ps(vout_lo, vout_lo); + output += 2; + } + if (channels & 1) { + _mm_store_ss(output, vout_lo); + } + } + } +} diff --git a/src/xnnpack/gavgpool.h b/src/xnnpack/gavgpool.h index e497ed5deff..bc0786590a6 100644 --- a/src/xnnpack/gavgpool.h +++ b/src/xnnpack/gavgpool.h @@ -36,6 +36,9 @@ DECLARE_F32_GAVGPOOL_MINMAX_MULTIPASS_UKERNEL_FUNCTION(xnn_f32_gavgpool_minmax_u DECLARE_F32_GAVGPOOL_MINMAX_MULTIPASS_UKERNEL_FUNCTION(xnn_f32_gavgpool_minmax_ukernel_7p7x__wasm_c1) DECLARE_F32_GAVGPOOL_MINMAX_MULTIPASS_UKERNEL_FUNCTION(xnn_f32_gavgpool_minmax_ukernel_7p7x__wasmsimd_arm_c4) DECLARE_F32_GAVGPOOL_MINMAX_MULTIPASS_UKERNEL_FUNCTION(xnn_f32_gavgpool_minmax_ukernel_7p7x__wasmsimd_x86_c4) +DECLARE_F32_GAVGPOOL_MINMAX_MULTIPASS_UKERNEL_FUNCTION(xnn_f32_rdsum_minmax_ukernel_7p7x__avx_c16) +DECLARE_F32_GAVGPOOL_MINMAX_MULTIPASS_UKERNEL_FUNCTION(xnn_f32_rdsum_minmax_ukernel_7p7x__avx_c32) +DECLARE_F32_GAVGPOOL_MINMAX_MULTIPASS_UKERNEL_FUNCTION(xnn_f32_rdsum_minmax_ukernel_7p7x__avx_c64) #define DECLARE_F32_GAVGPOOL_MINMAX_UNIPASS_UKERNEL_FUNCTION(fn_name) \ diff --git a/src/xnnpack/reduce.h b/src/xnnpack/reduce.h index bf3668b4686..ad1a0b3ad49 100644 --- a/src/xnnpack/reduce.h +++ b/src/xnnpack/reduce.h @@ -333,7 +333,10 @@ DECLARE_F32_RDSUM_UKERNEL_FUNCTION(xnn_f32_rdsum_ukernel_7p7x__scalar_c4) DECLARE_F32_RDSUM_UKERNEL_FUNCTION(xnn_f32_rdsum_ukernel_7p7x__sse_c16) DECLARE_F32_RDSUM_UKERNEL_FUNCTION(xnn_f32_rdsum_ukernel_7p7x__sse_c32) DECLARE_F32_RDSUM_UKERNEL_FUNCTION(xnn_f32_rdsum_ukernel_7p7x__sse_c64) -DECLARE_F32_RDSUM_UKERNEL_FUNCTION(xnn_f32_rdsum_ukernel_7p7x__sse_c128) +DECLARE_F32_RDSUM_UKERNEL_FUNCTION(xnn_f32_rdsum_ukernel_7p7x__avx_c16) +DECLARE_F32_RDSUM_UKERNEL_FUNCTION(xnn_f32_rdsum_ukernel_7p7x__avx_c32) +DECLARE_F32_RDSUM_UKERNEL_FUNCTION(xnn_f32_rdsum_ukernel_7p7x__avx_c64) +DECLARE_F32_RDSUM_UKERNEL_FUNCTION(xnn_f32_rdsum_ukernel_7p7x__avx_c128) #ifdef __cplusplus } // extern "C" diff --git a/test/f32-rdsum.cc b/test/f32-rdsum.cc index 96119dd6b14..b1331244857 100644 --- a/test/f32-rdsum.cc +++ b/test/f32-rdsum.cc @@ -1415,3 +1415,612 @@ TEST(F32_RDSUM_7P7X__SCALAR_C4, channels_gt_4_multipass_fulltile_with_input_stri } } #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 + + +#if XNN_ARCH_X86 || XNN_ARCH_X86_64 + TEST(F32_RDSUM_7P7X__AVX_C16, channels_eq_16_2pass_fulltile) { + TEST_REQUIRES_X86_AVX; + RDSumMicrokernelTester() + .rows(14) + .channels(16) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + + TEST(F32_RDSUM_7P7X__AVX_C16, channels_eq_16_2pass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + RDSumMicrokernelTester() + .rows(14) + .channels(16) + .input_stride(19) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + + TEST(F32_RDSUM_7P7X__AVX_C16, channels_eq_16_2pass_subtile) { + TEST_REQUIRES_X86_AVX; + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(16) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C16, channels_eq_16_2pass_subtile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(16) + .input_stride(19) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C16, channels_eq_16_multipass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(16) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C16, channels_eq_16_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(16) + .input_stride(19) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C16, channels_div_16_2pass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 32; channels < 128; channels += 16) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C16, channels_div_16_2pass_subtile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 32; channels < 128; channels += 16) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C16, channels_div_16_multipass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 32; channels < 128; channels += 16) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C16, channels_div_16_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 32; channels < 128; channels += 16) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(263) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C16, channels_lt_16_2pass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 1; channels < 16; channels++) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C16, channels_lt_16_2pass_subtile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 1; channels < 16; channels++) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C16, channels_lt_16_multipass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 1; channels < 16; channels++) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C16, channels_lt_16_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 1; channels < 16; channels++) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(19) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C16, channels_gt_16_2pass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 17; channels < 32; channels++) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C16, channels_gt_16_2pass_subtile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 17; channels < 32; channels++) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C16, channels_gt_16_multipass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 17; channels < 32; channels++) { + for (size_t rows = 1; rows < 35; rows += 14) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C16, channels_gt_16_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 17; channels < 32; channels++) { + for (size_t rows = 1; rows < 35; rows += 14) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(47) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + } + } +#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 + + +#if XNN_ARCH_X86 || XNN_ARCH_X86_64 + TEST(F32_RDSUM_7P7X__AVX_C32, channels_eq_32_2pass_fulltile) { + TEST_REQUIRES_X86_AVX; + RDSumMicrokernelTester() + .rows(14) + .channels(32) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + + TEST(F32_RDSUM_7P7X__AVX_C32, channels_eq_32_2pass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + RDSumMicrokernelTester() + .rows(14) + .channels(32) + .input_stride(37) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + + TEST(F32_RDSUM_7P7X__AVX_C32, channels_eq_32_2pass_subtile) { + TEST_REQUIRES_X86_AVX; + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(32) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C32, channels_eq_32_2pass_subtile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(32) + .input_stride(37) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C32, channels_eq_32_multipass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(32) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C32, channels_eq_32_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(32) + .input_stride(37) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C32, channels_div_32_2pass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 64; channels < 256; channels += 32) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C32, channels_div_32_2pass_subtile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 64; channels < 256; channels += 32) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C32, channels_div_32_multipass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 64; channels < 256; channels += 32) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C32, channels_div_32_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 64; channels < 256; channels += 32) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(521) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C32, channels_lt_32_2pass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 1; channels < 32; channels++) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C32, channels_lt_32_2pass_subtile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 1; channels < 32; channels++) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C32, channels_lt_32_multipass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 1; channels < 32; channels++) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C32, channels_lt_32_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 1; channels < 32; channels++) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(37) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C32, channels_gt_32_2pass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 33; channels < 64; channels++) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C32, channels_gt_32_2pass_subtile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 33; channels < 64; channels++) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C32, channels_gt_32_multipass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 33; channels < 64; channels++) { + for (size_t rows = 1; rows < 35; rows += 14) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C32, channels_gt_32_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 33; channels < 64; channels++) { + for (size_t rows = 1; rows < 35; rows += 14) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(79) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + } + } +#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 + + +#if XNN_ARCH_X86 || XNN_ARCH_X86_64 + TEST(F32_RDSUM_7P7X__AVX_C64, channels_eq_64_2pass_fulltile) { + TEST_REQUIRES_X86_AVX; + RDSumMicrokernelTester() + .rows(14) + .channels(64) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + + TEST(F32_RDSUM_7P7X__AVX_C64, channels_eq_64_2pass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + RDSumMicrokernelTester() + .rows(14) + .channels(64) + .input_stride(67) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + + TEST(F32_RDSUM_7P7X__AVX_C64, channels_eq_64_2pass_subtile) { + TEST_REQUIRES_X86_AVX; + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(64) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C64, channels_eq_64_2pass_subtile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(64) + .input_stride(67) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C64, channels_eq_64_multipass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(64) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C64, channels_eq_64_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(64) + .input_stride(67) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C64, channels_div_64_2pass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 128; channels < 512; channels += 64) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C64, channels_div_64_2pass_subtile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 128; channels < 512; channels += 64) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C64, channels_div_64_multipass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 128; channels < 512; channels += 64) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C64, channels_div_64_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 128; channels < 512; channels += 64) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(1031) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C64, channels_lt_64_2pass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 1; channels < 64; channels++) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C64, channels_lt_64_2pass_subtile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 1; channels < 64; channels++) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C64, channels_lt_64_multipass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 1; channels < 64; channels++) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C64, channels_lt_64_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 1; channels < 64; channels++) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(67) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C64, channels_gt_64_2pass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 65; channels < 128; channels++) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C64, channels_gt_64_2pass_subtile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 65; channels < 128; channels++) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C64, channels_gt_64_multipass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 65; channels < 128; channels++) { + for (size_t rows = 1; rows < 35; rows += 14) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C64, channels_gt_64_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 65; channels < 128; channels++) { + for (size_t rows = 1; rows < 35; rows += 14) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(149) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + } + } +#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 diff --git a/test/f32-rdsum.yaml b/test/f32-rdsum.yaml index e5abbd9cd8b..f3228096674 100644 --- a/test/f32-rdsum.yaml +++ b/test/f32-rdsum.yaml @@ -20,3 +20,10 @@ init: xnn_init_f32_scale_sse_params - name: xnn_f32_rdsum_ukernel_7p7x__sse_c64 init: xnn_init_f32_scale_sse_params +# x86 AVX +- name: xnn_f32_rdsum_ukernel_7p7x__avx_c16 + init: xnn_init_f32_scale_avx_params +- name: xnn_f32_rdsum_ukernel_7p7x__avx_c32 + init: xnn_init_f32_scale_avx_params +- name: xnn_f32_rdsum_ukernel_7p7x__avx_c64 + init: xnn_init_f32_scale_avx_params diff --git a/test/gavgpool-microkernel-tester.h b/test/gavgpool-microkernel-tester.h index c029f2de06b..f58878437be 100644 --- a/test/gavgpool-microkernel-tester.h +++ b/test/gavgpool-microkernel-tester.h @@ -598,8 +598,7 @@ class GAvgPoolMicrokernelTester { std::vector output_ref(channels()); for (size_t iteration = 0; iteration < iterations(); iteration++) { std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); - std::fill(output.begin(), output.end(), std::nanf("")); - + std::fill(output.begin(), output.end(), 0.f); // kernels accumulate. // Compute reference results, without clamping. for (size_t c = 0; c < channels(); c++) { float acc = 0.0f;