diff --git a/bench/f32-rdsum.cc b/bench/f32-rdsum.cc index 50f5b9b522f..d914ea2a7a8 100644 --- a/bench/f32-rdsum.cc +++ b/bench/f32-rdsum.cc @@ -82,6 +82,93 @@ BENCHMARK_CAPTURE(f32_rsum_discontig, scalar_c4, #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 +#if XNN_ARCH_X86 || XNN_ARCH_X86_64 + BENCHMARK_CAPTURE(f32_rsum_discontig, avx_c16, + xnn_f32_rdsum_ukernel_7p7x__avx_c16, + xnn_init_f32_scale_avx_params, + benchmark::utils::CheckAVX) + ->Apply(BenchmarkBatch) + ->UseRealTime(); +#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 + + +#if XNN_ARCH_X86 || XNN_ARCH_X86_64 + BENCHMARK_CAPTURE(f32_rsum_discontig, avx_c32, + xnn_f32_rdsum_ukernel_7p7x__avx_c32, + xnn_init_f32_scale_avx_params, + benchmark::utils::CheckAVX) + ->Apply(BenchmarkBatch) + ->UseRealTime(); +#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 + + +#if XNN_ARCH_X86 || XNN_ARCH_X86_64 + BENCHMARK_CAPTURE(f32_rsum_discontig, avx_c64, + xnn_f32_rdsum_ukernel_7p7x__avx_c64, + xnn_init_f32_scale_avx_params, + benchmark::utils::CheckAVX) + ->Apply(BenchmarkBatch) + ->UseRealTime(); +#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 + + +#if XNN_ARCH_X86 || XNN_ARCH_X86_64 + BENCHMARK_CAPTURE(f32_rsum_discontig, avx512f_c16, + xnn_f32_rdsum_ukernel_7p7x__avx512f_c16, + xnn_init_f32_scale_scalar_params, + benchmark::utils::CheckAVX512F) + ->Apply(BenchmarkBatch) + ->UseRealTime(); +#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 + + +#if XNN_ARCH_X86 || XNN_ARCH_X86_64 + BENCHMARK_CAPTURE(f32_rsum_discontig, avx512f_c32, + xnn_f32_rdsum_ukernel_7p7x__avx512f_c32, + xnn_init_f32_scale_scalar_params, + benchmark::utils::CheckAVX512F) + ->Apply(BenchmarkBatch) + ->UseRealTime(); +#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 + + +#if XNN_ARCH_X86 || XNN_ARCH_X86_64 + BENCHMARK_CAPTURE(f32_rsum_discontig, avx512f_c64, + xnn_f32_rdsum_ukernel_7p7x__avx512f_c64, + xnn_init_f32_scale_scalar_params, + benchmark::utils::CheckAVX512F) + ->Apply(BenchmarkBatch) + ->UseRealTime(); +#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 + + +#if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD + BENCHMARK_CAPTURE(f32_rsum_discontig, wasmsimd_c16, + xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c16, + xnn_init_f32_scale_scalar_params) + ->Apply(BenchmarkBatch) + ->UseRealTime(); +#endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD + + +#if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD + BENCHMARK_CAPTURE(f32_rsum_discontig, wasmsimd_c32, + xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c32, + xnn_init_f32_scale_scalar_params) + ->Apply(BenchmarkBatch) + ->UseRealTime(); +#endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD + + +#if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD + BENCHMARK_CAPTURE(f32_rsum_discontig, wasmsimd_c64, + xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c64, + xnn_init_f32_scale_scalar_params) + ->Apply(BenchmarkBatch) + ->UseRealTime(); +#endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD + + #ifndef XNNPACK_BENCHMARK_NO_MAIN BENCHMARK_MAIN(); #endif diff --git a/cmake/microkernels.cmake b/cmake/microkernels.cmake index 1dcc392c412..0679bb2d167 100644 --- a/cmake/microkernels.cmake +++ b/cmake/microkernels.cmake @@ -132,6 +132,9 @@ SET(ALL_AVX_MICROKERNEL_SRCS src/f32-qu8-vcvt/gen/f32-qu8-vcvt-avx-u16.c src/f32-qu8-vcvt/gen/f32-qu8-vcvt-avx-u24.c src/f32-qu8-vcvt/gen/f32-qu8-vcvt-avx-u32.c + src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c16.c + src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c32.c + src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c64.c src/f32-rminmax/gen/f32-rmax-avx-u8.c src/f32-rminmax/gen/f32-rmax-avx-u16-acc2.c src/f32-rminmax/gen/f32-rmax-avx-u24-acc3.c @@ -1321,6 +1324,10 @@ SET(ALL_AVX512F_MICROKERNEL_SRCS src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-avx512f-rr1-p5-scalef-u192-acc3.c src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-avx512f-rr1-p5-scalef-u192-acc6.c src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-avx512f-rr1-p5-scalef-u192.c + src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c16.c + src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c32.c + src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c64.c + src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c128.c src/f32-rminmax/gen/f32-rmax-avx512f-u16.c src/f32-rminmax/gen/f32-rmax-avx512f-u32-acc2.c src/f32-rminmax/gen/f32-rmax-avx512f-u48-acc3.c @@ -8897,6 +8904,9 @@ SET(ALL_WASMSIMD_MICROKERNEL_SRCS src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-wasmsimd-rr2-p5-u20-acc2.c src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-wasmsimd-rr2-p5-u20-acc5.c src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-wasmsimd-rr2-p5-u20.c + src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-wasmsimd-c16.c + src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-wasmsimd-c32.c + src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-wasmsimd-c64.c src/f32-rminmax/gen/f32-rmax-wasmsimd-minmax-u4.c src/f32-rminmax/gen/f32-rmax-wasmsimd-minmax-u8-acc2.c src/f32-rminmax/gen/f32-rmax-wasmsimd-minmax-u12-acc3.c diff --git a/microkernels.bzl b/microkernels.bzl index f040a7418f4..b2bedd03fde 100644 --- a/microkernels.bzl +++ b/microkernels.bzl @@ -129,6 +129,9 @@ ALL_AVX_MICROKERNEL_SRCS = [ "src/f32-qu8-vcvt/gen/f32-qu8-vcvt-avx-u16.c", "src/f32-qu8-vcvt/gen/f32-qu8-vcvt-avx-u24.c", "src/f32-qu8-vcvt/gen/f32-qu8-vcvt-avx-u32.c", + "src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c16.c", + "src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c32.c", + "src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c64.c", "src/f32-rminmax/gen/f32-rmax-avx-u8.c", "src/f32-rminmax/gen/f32-rmax-avx-u16-acc2.c", "src/f32-rminmax/gen/f32-rmax-avx-u24-acc3.c", @@ -1321,6 +1324,10 @@ ALL_AVX512F_MICROKERNEL_SRCS = [ "src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-avx512f-rr1-p5-scalef-u192-acc3.c", "src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-avx512f-rr1-p5-scalef-u192-acc6.c", "src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-avx512f-rr1-p5-scalef-u192.c", + "src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c16.c", + "src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c32.c", + "src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c64.c", + "src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c128.c", "src/f32-rminmax/gen/f32-rmax-avx512f-u16.c", "src/f32-rminmax/gen/f32-rmax-avx512f-u32-acc2.c", "src/f32-rminmax/gen/f32-rmax-avx512f-u48-acc3.c", @@ -8932,6 +8939,9 @@ ALL_WASMSIMD_MICROKERNEL_SRCS = [ "src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-wasmsimd-rr2-p5-u20-acc2.c", "src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-wasmsimd-rr2-p5-u20-acc5.c", "src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-wasmsimd-rr2-p5-u20.c", + "src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-wasmsimd-c16.c", + "src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-wasmsimd-c32.c", + "src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-wasmsimd-c64.c", "src/f32-rminmax/gen/f32-rmax-wasmsimd-minmax-u4.c", "src/f32-rminmax/gen/f32-rmax-wasmsimd-minmax-u8-acc2.c", "src/f32-rminmax/gen/f32-rmax-wasmsimd-minmax-u12-acc3.c", diff --git a/scripts/generate-f32-rdsum.sh b/scripts/generate-f32-rdsum.sh index b2e87c187a3..1d186c0e815 100755 --- a/scripts/generate-f32-rdsum.sh +++ b/scripts/generate-f32-rdsum.sh @@ -7,14 +7,30 @@ #################################### Scalar ################################### tools/xngen src/f32-rdsum/scalar.c.in -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-scalar.c & -#################################### NEON ################################### +#################################### NEON ##################################### tools/xngen src/f32-rdsum/neon.c.in -D CHANNELS=16 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-neon-c16.c & tools/xngen src/f32-rdsum/neon.c.in -D CHANNELS=32 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-neon-c32.c & tools/xngen src/f32-rdsum/neon.c.in -D CHANNELS=64 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-neon-c64.c & -#################################### SSE #################################### +#################################### SSE ###################################### tools/xngen src/f32-rdsum/sse.c.in -D CHANNELS=16 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-sse-c16.c & tools/xngen src/f32-rdsum/sse.c.in -D CHANNELS=32 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-sse-c32.c & tools/xngen src/f32-rdsum/sse.c.in -D CHANNELS=64 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-sse-c64.c & +#################################### AVX ###################################### +tools/xngen src/f32-rdsum/avx.c.in -D CHANNELS=16 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c16.c & +tools/xngen src/f32-rdsum/avx.c.in -D CHANNELS=32 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c32.c & +tools/xngen src/f32-rdsum/avx.c.in -D CHANNELS=64 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c64.c & + +#################################### AVX512F #######$########################### +tools/xngen src/f32-rdsum/avx512.c.in -D CHANNELS=16 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c16.c & +tools/xngen src/f32-rdsum/avx512.c.in -D CHANNELS=32 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c32.c & +tools/xngen src/f32-rdsum/avx512.c.in -D CHANNELS=64 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c64.c & +tools/xngen src/f32-rdsum/avx512.c.in -D CHANNELS=128 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c128.c & + +#################################### WAsm SIMD ################################ +tools/xngen src/f32-rdsum/wasm-simd.c.in -D CHANNELS=16 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-wasmsimd-c16.c & +tools/xngen src/f32-rdsum/wasm-simd.c.in -D CHANNELS=32 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-wasmsimd-c32.c & +tools/xngen src/f32-rdsum/wasm-simd.c.in -D CHANNELS=64 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-wasmsimd-c64.c & + wait diff --git a/src/f32-rdsum/avx.c.in b/src/f32-rdsum/avx.c.in new file mode 100644 index 00000000000..7d0ed3a9320 --- /dev/null +++ b/src/f32-rdsum/avx.c.in @@ -0,0 +1,143 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" +#include + +#include + +#include +#include +#include + + +$UNROLL = CHANNELS >> 3 +void xnn_f32_rdsum_ukernel_${ACCUMULATORS}p${ACCUMULATORS}x__avx_c${CHANNELS}( + size_t rows, + size_t channels, + const float* input, + size_t input_stride, + const float* zero, + float* output, + const union xnn_f32_scale_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(rows != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const __m256 vscale = _mm256_set1_ps(params->avx.scale); + + size_t input_increment = ${ACCUMULATORS} * input_stride; + for (; channels >= ${CHANNELS}; channels -= ${CHANNELS}) { + const float* i0 = input; + $for i in range(1, ACCUMULATORS): + const float* i${i} = (const float*) ((uintptr_t) input + ${i} * input_stride); + + $for i in range(UNROLL): + __m256 vacc${i} = _mm256_setzero_ps(); + + for (int r = rows; r > 0; r -= ${ACCUMULATORS}) { + $for N in range(1, ACCUMULATORS, 2): + if XNN_UNPREDICTABLE(r < ${N+1}) { + i${N} = zero; + } + if XNN_UNPREDICTABLE(r <= ${N+1}) { + i${N+1} = zero; + } + $for c in range(UNROLL): + __m256 vin${c}; + $for j in range(ACCUMULATORS): + $for c in range(UNROLL): + vin${c} = _mm256_loadu_ps(&i${j}[${c*8}]); + $for c in range(UNROLL): + vacc${c} = _mm256_add_ps(vin${c}, vacc${c}); + $for N in range(0, ACCUMULATORS): + i${N} = (const float*) ((uintptr_t) i${N} + input_increment); + } + $for i in range(UNROLL): + vacc${i} = _mm256_mul_ps(vacc${i}, vscale); + + const float* o = output; + $for i in range(0, UNROLL): + __m256 vo${i} = _mm256_loadu_ps(o); o += 8; + $for i in range(0, UNROLL): + vacc${i} = _mm256_add_ps(vo${i}, vacc${i}); + $for i in range(0, UNROLL): + _mm256_storeu_ps(output, vacc${i}); output += 8; + + input = (const float*) ((uintptr_t) input + ${CHANNELS} * sizeof(float)); + } + __m256i vmask; + if (channels != 0) { + input_increment = ${ACCUMULATORS} * input_stride; + const float* i0 = input; + $for i in range(1, ACCUMULATORS): + const float* i${i} = (const float*) ((uintptr_t) input + ${i} * input_stride); + __m256 vacc[${UNROLL}]; + $for i in range(UNROLL): + vacc[${i}] = _mm256_setzero_ps(); + + const size_t num_full_chunks = channels >> 3; + const size_t num_chunks = round_up_po2(channels, 8) >> 3; + const size_t remainder = channels & 0x7; + for (int r = rows; r > 0; r -= ${ACCUMULATORS}) { + $for N in range(1, ACCUMULATORS, 2): + if XNN_UNPREDICTABLE(r < ${N+1}) { + i${N} = zero; + } + if XNN_UNPREDICTABLE(r <= ${N+1}) { + i${N+1} = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + $for c in range(ACCUMULATORS): + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i${c}[i*8]), vacc[i]); + } + + if (remainder) { + vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) ¶ms->avx.mask_table[7] - (channels & 0x7) * sizeof(float))); + $for c in range(ACCUMULATORS): + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i${c}[num_full_chunks*8], vmask), vacc[num_full_chunks]); + } + $for N in range(ACCUMULATORS): + i${N} = (const float*) ((uintptr_t) i${N} + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = _mm256_mul_ps(vacc[i], vscale); + } + + __m256 vo[${UNROLL}]; + const float* o = output; + for (int i = 0; i < channels >> 3; ++i) { + vo[i] = _mm256_loadu_ps(o); o += 8; + } + for (int i = 0; i < channels >> 3; ++i) { + vacc[i] = _mm256_add_ps(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 3; ++i) { + _mm256_storeu_ps(output, vacc[i]); output += 8; + } + if (remainder) { + const size_t pos = num_full_chunks; + __m256 vout = vacc[pos]; + const __m256 vdata = _mm256_maskload_ps(output, vmask); + vout = _mm256_add_ps(vout, vdata); + __m128 vout_lo = _mm256_castps256_ps128(vout); + if (channels & 4) { + _mm_storeu_ps(output, vout_lo); + vout_lo = _mm256_extractf128_ps(vout, 1); + output += 4; + } + if (channels & 2) { + _mm_storel_pi((__m64*) output, vout_lo); + vout_lo = _mm_movehl_ps(vout_lo, vout_lo); + output += 2; + } + if (channels & 1) { + _mm_store_ss(output, vout_lo); + } + } + } +} diff --git a/src/f32-rdsum/avx512.c.in b/src/f32-rdsum/avx512.c.in new file mode 100644 index 00000000000..161438d1a78 --- /dev/null +++ b/src/f32-rdsum/avx512.c.in @@ -0,0 +1,134 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" +#include + +#include + +#include +#include +#include + + +$UNROLL = CHANNELS >> 4 +void xnn_f32_rdsum_ukernel_${ACCUMULATORS}p${ACCUMULATORS}x__avx512f_c${CHANNELS}( + size_t rows, + size_t channels, + const float* input, + size_t input_stride, + const float* zero, + float* output, + const union xnn_f32_scale_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(rows != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const __m512 vscale = _mm512_set1_ps(params->scalar.scale); + + size_t input_increment = ${ACCUMULATORS} * input_stride; + for (; channels >= ${CHANNELS}; channels -= ${CHANNELS}) { + const float* i0 = input; + $for i in range(1, ACCUMULATORS): + const float* i${i} = (const float*) ((uintptr_t) input + ${i} * input_stride); + + $for i in range(UNROLL): + __m512 vacc${i} = _mm512_setzero_ps(); + + for (int r = rows; r > 0; r -= ${ACCUMULATORS}) { + $for N in range(1, ACCUMULATORS, 2): + if XNN_UNPREDICTABLE(r < ${N+1}) { + i${N} = zero; + } + if XNN_UNPREDICTABLE(r <= ${N+1}) { + i${N+1} = zero; + } + $for c in range(UNROLL): + __m512 vin${c}; + $for j in range(ACCUMULATORS): + $for c in range(UNROLL): + vin${c} = _mm512_loadu_ps(&i${j}[${c*16}]); + $for c in range(UNROLL): + vacc${c} = _mm512_add_ps(vin${c}, vacc${c}); + $for N in range(0, ACCUMULATORS): + i${N} = (const float*) ((uintptr_t) i${N} + input_increment); + } + $for i in range(UNROLL): + vacc${i} = _mm512_mul_ps(vacc${i}, vscale); + + const float* o = output; + $for i in range(0, UNROLL): + const __m512 vo${i} = _mm512_loadu_ps(o); o += 16; + $for i in range(0, UNROLL): + vacc${i} = _mm512_add_ps(vo${i}, vacc${i}); + $for i in range(0, UNROLL): + _mm512_storeu_ps(output, vacc${i}); output += 16; + + input = (const float*) ((uintptr_t) input + ${CHANNELS} * sizeof(float)); + } + if (channels != 0) { + input_increment = ${ACCUMULATORS} * input_stride; + const float* i0 = input; + $for i in range(1, ACCUMULATORS): + const float* i${i} = (const float*) ((uintptr_t) input + ${i} * input_stride); + __m512 vacc[${UNROLL}]; + $for i in range(UNROLL): + vacc[${i}] = _mm512_setzero_ps(); + + const size_t num_full_chunks = channels >> 4; + const size_t num_chunks = round_up_po2(channels, 16) >> 4; + const size_t remainder = channels & 0xF; + const size_t batch = channels & 0xF; + __mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); + if (remainder) { + assert(batch >= 1); + assert(batch <= 15); + vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); + } + for (int r = rows; r > 0; r -= ${ACCUMULATORS}) { + $for N in range(1, ACCUMULATORS, 2): + if XNN_UNPREDICTABLE(r < ${N+1}) { + i${N} = zero; + } + if XNN_UNPREDICTABLE(r <= ${N+1}) { + i${N+1} = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + $for c in range(ACCUMULATORS): + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i${c}[i*16]), vacc[i]); + } + + if (remainder) { + $for c in range(ACCUMULATORS): + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i${c}[num_full_chunks*16])); + } + $for N in range(ACCUMULATORS): + i${N} = (const float*) ((uintptr_t) i${N} + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = _mm512_mul_ps(vacc[i], vscale); + } + + __m512 vo[${UNROLL}]; + const float* o = output; + for (int i = 0; i < channels >> 4; ++i) { + vo[i] = _mm512_loadu_ps(o); o += 16; + } + for (int i = 0; i < channels >> 4; ++i) { + vacc[i] = _mm512_add_ps(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 4; ++i) { + _mm512_storeu_ps(output, vacc[i]); output += 16; + } + if (remainder) { + const size_t pos = num_full_chunks; + __m512 vout = vacc[pos]; + vout = _mm512_maskz_add_ps(vmask, vout, _mm512_maskz_loadu_ps(vmask, output)); + _mm512_mask_storeu_ps(output, vmask, vout); + } + } +} diff --git a/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c16.c b/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c16.c new file mode 100644 index 00000000000..c82ccad6d6f --- /dev/null +++ b/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c16.c @@ -0,0 +1,218 @@ +// Auto-generated file. Do not edit! +// Template: src/f32-rdsum/avx.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include + +#include +#include +#include + + +void xnn_f32_rdsum_ukernel_7p7x__avx_c16( + size_t rows, + size_t channels, + const float* input, + size_t input_stride, + const float* zero, + float* output, + const union xnn_f32_scale_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(rows != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const __m256 vscale = _mm256_set1_ps(params->avx.scale); + + size_t input_increment = 7 * input_stride; + for (; channels >= 16; channels -= 16) { + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + + __m256 vacc0 = _mm256_setzero_ps(); + __m256 vacc1 = _mm256_setzero_ps(); + + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + __m256 vin0; + __m256 vin1; + vin0 = _mm256_loadu_ps(&i0[0]); + vin1 = _mm256_loadu_ps(&i0[8]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vin0 = _mm256_loadu_ps(&i1[0]); + vin1 = _mm256_loadu_ps(&i1[8]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vin0 = _mm256_loadu_ps(&i2[0]); + vin1 = _mm256_loadu_ps(&i2[8]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vin0 = _mm256_loadu_ps(&i3[0]); + vin1 = _mm256_loadu_ps(&i3[8]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vin0 = _mm256_loadu_ps(&i4[0]); + vin1 = _mm256_loadu_ps(&i4[8]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vin0 = _mm256_loadu_ps(&i5[0]); + vin1 = _mm256_loadu_ps(&i5[8]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vin0 = _mm256_loadu_ps(&i6[0]); + vin1 = _mm256_loadu_ps(&i6[8]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = _mm256_mul_ps(vacc0, vscale); + vacc1 = _mm256_mul_ps(vacc1, vscale); + + const float* o = output; + __m256 vo0 = _mm256_loadu_ps(o); o += 8; + __m256 vo1 = _mm256_loadu_ps(o); o += 8; + vacc0 = _mm256_add_ps(vo0, vacc0); + vacc1 = _mm256_add_ps(vo1, vacc1); + _mm256_storeu_ps(output, vacc0); output += 8; + _mm256_storeu_ps(output, vacc1); output += 8; + + input = (const float*) ((uintptr_t) input + 16 * sizeof(float)); + } + __m256i vmask; + if (channels != 0) { + input_increment = 7 * input_stride; + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + __m256 vacc[2]; + vacc[0] = _mm256_setzero_ps(); + vacc[1] = _mm256_setzero_ps(); + + const size_t num_full_chunks = channels >> 3; + const size_t num_chunks = round_up_po2(channels, 8) >> 3; + const size_t remainder = channels & 0x7; + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i0[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i1[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i2[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i3[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i4[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i5[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i6[i*8]), vacc[i]); + } + + if (remainder) { + vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) ¶ms->avx.mask_table[7] - (channels & 0x7) * sizeof(float))); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i0[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i1[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i2[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i3[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i4[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i5[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i6[num_full_chunks*8], vmask), vacc[num_full_chunks]); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = _mm256_mul_ps(vacc[i], vscale); + } + + __m256 vo[2]; + const float* o = output; + for (int i = 0; i < channels >> 3; ++i) { + vo[i] = _mm256_loadu_ps(o); o += 8; + } + for (int i = 0; i < channels >> 3; ++i) { + vacc[i] = _mm256_add_ps(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 3; ++i) { + _mm256_storeu_ps(output, vacc[i]); output += 8; + } + if (remainder) { + const size_t pos = num_full_chunks; + __m256 vout = vacc[pos]; + const __m256 vdata = _mm256_maskload_ps(output, vmask); + vout = _mm256_add_ps(vout, vdata); + __m128 vout_lo = _mm256_castps256_ps128(vout); + if (channels & 4) { + _mm_storeu_ps(output, vout_lo); + vout_lo = _mm256_extractf128_ps(vout, 1); + output += 4; + } + if (channels & 2) { + _mm_storel_pi((__m64*) output, vout_lo); + vout_lo = _mm_movehl_ps(vout_lo, vout_lo); + output += 2; + } + if (channels & 1) { + _mm_store_ss(output, vout_lo); + } + } + } +} diff --git a/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c32.c b/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c32.c new file mode 100644 index 00000000000..9a16e3f3fe9 --- /dev/null +++ b/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c32.c @@ -0,0 +1,260 @@ +// Auto-generated file. Do not edit! +// Template: src/f32-rdsum/avx.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include + +#include +#include +#include + + +void xnn_f32_rdsum_ukernel_7p7x__avx_c32( + size_t rows, + size_t channels, + const float* input, + size_t input_stride, + const float* zero, + float* output, + const union xnn_f32_scale_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(rows != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const __m256 vscale = _mm256_set1_ps(params->avx.scale); + + size_t input_increment = 7 * input_stride; + for (; channels >= 32; channels -= 32) { + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + + __m256 vacc0 = _mm256_setzero_ps(); + __m256 vacc1 = _mm256_setzero_ps(); + __m256 vacc2 = _mm256_setzero_ps(); + __m256 vacc3 = _mm256_setzero_ps(); + + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + __m256 vin0; + __m256 vin1; + __m256 vin2; + __m256 vin3; + vin0 = _mm256_loadu_ps(&i0[0]); + vin1 = _mm256_loadu_ps(&i0[8]); + vin2 = _mm256_loadu_ps(&i0[16]); + vin3 = _mm256_loadu_ps(&i0[24]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vin0 = _mm256_loadu_ps(&i1[0]); + vin1 = _mm256_loadu_ps(&i1[8]); + vin2 = _mm256_loadu_ps(&i1[16]); + vin3 = _mm256_loadu_ps(&i1[24]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vin0 = _mm256_loadu_ps(&i2[0]); + vin1 = _mm256_loadu_ps(&i2[8]); + vin2 = _mm256_loadu_ps(&i2[16]); + vin3 = _mm256_loadu_ps(&i2[24]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vin0 = _mm256_loadu_ps(&i3[0]); + vin1 = _mm256_loadu_ps(&i3[8]); + vin2 = _mm256_loadu_ps(&i3[16]); + vin3 = _mm256_loadu_ps(&i3[24]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vin0 = _mm256_loadu_ps(&i4[0]); + vin1 = _mm256_loadu_ps(&i4[8]); + vin2 = _mm256_loadu_ps(&i4[16]); + vin3 = _mm256_loadu_ps(&i4[24]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vin0 = _mm256_loadu_ps(&i5[0]); + vin1 = _mm256_loadu_ps(&i5[8]); + vin2 = _mm256_loadu_ps(&i5[16]); + vin3 = _mm256_loadu_ps(&i5[24]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vin0 = _mm256_loadu_ps(&i6[0]); + vin1 = _mm256_loadu_ps(&i6[8]); + vin2 = _mm256_loadu_ps(&i6[16]); + vin3 = _mm256_loadu_ps(&i6[24]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = _mm256_mul_ps(vacc0, vscale); + vacc1 = _mm256_mul_ps(vacc1, vscale); + vacc2 = _mm256_mul_ps(vacc2, vscale); + vacc3 = _mm256_mul_ps(vacc3, vscale); + + const float* o = output; + __m256 vo0 = _mm256_loadu_ps(o); o += 8; + __m256 vo1 = _mm256_loadu_ps(o); o += 8; + __m256 vo2 = _mm256_loadu_ps(o); o += 8; + __m256 vo3 = _mm256_loadu_ps(o); o += 8; + vacc0 = _mm256_add_ps(vo0, vacc0); + vacc1 = _mm256_add_ps(vo1, vacc1); + vacc2 = _mm256_add_ps(vo2, vacc2); + vacc3 = _mm256_add_ps(vo3, vacc3); + _mm256_storeu_ps(output, vacc0); output += 8; + _mm256_storeu_ps(output, vacc1); output += 8; + _mm256_storeu_ps(output, vacc2); output += 8; + _mm256_storeu_ps(output, vacc3); output += 8; + + input = (const float*) ((uintptr_t) input + 32 * sizeof(float)); + } + __m256i vmask; + if (channels != 0) { + input_increment = 7 * input_stride; + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + __m256 vacc[4]; + vacc[0] = _mm256_setzero_ps(); + vacc[1] = _mm256_setzero_ps(); + vacc[2] = _mm256_setzero_ps(); + vacc[3] = _mm256_setzero_ps(); + + const size_t num_full_chunks = channels >> 3; + const size_t num_chunks = round_up_po2(channels, 8) >> 3; + const size_t remainder = channels & 0x7; + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i0[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i1[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i2[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i3[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i4[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i5[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i6[i*8]), vacc[i]); + } + + if (remainder) { + vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) ¶ms->avx.mask_table[7] - (channels & 0x7) * sizeof(float))); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i0[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i1[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i2[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i3[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i4[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i5[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i6[num_full_chunks*8], vmask), vacc[num_full_chunks]); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = _mm256_mul_ps(vacc[i], vscale); + } + + __m256 vo[4]; + const float* o = output; + for (int i = 0; i < channels >> 3; ++i) { + vo[i] = _mm256_loadu_ps(o); o += 8; + } + for (int i = 0; i < channels >> 3; ++i) { + vacc[i] = _mm256_add_ps(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 3; ++i) { + _mm256_storeu_ps(output, vacc[i]); output += 8; + } + if (remainder) { + const size_t pos = num_full_chunks; + __m256 vout = vacc[pos]; + const __m256 vdata = _mm256_maskload_ps(output, vmask); + vout = _mm256_add_ps(vout, vdata); + __m128 vout_lo = _mm256_castps256_ps128(vout); + if (channels & 4) { + _mm_storeu_ps(output, vout_lo); + vout_lo = _mm256_extractf128_ps(vout, 1); + output += 4; + } + if (channels & 2) { + _mm_storel_pi((__m64*) output, vout_lo); + vout_lo = _mm_movehl_ps(vout_lo, vout_lo); + output += 2; + } + if (channels & 1) { + _mm_store_ss(output, vout_lo); + } + } + } +} diff --git a/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c64.c b/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c64.c new file mode 100644 index 00000000000..7ff30de01ec --- /dev/null +++ b/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c64.c @@ -0,0 +1,344 @@ +// Auto-generated file. Do not edit! +// Template: src/f32-rdsum/avx.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include + +#include +#include +#include + + +void xnn_f32_rdsum_ukernel_7p7x__avx_c64( + size_t rows, + size_t channels, + const float* input, + size_t input_stride, + const float* zero, + float* output, + const union xnn_f32_scale_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(rows != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const __m256 vscale = _mm256_set1_ps(params->avx.scale); + + size_t input_increment = 7 * input_stride; + for (; channels >= 64; channels -= 64) { + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + + __m256 vacc0 = _mm256_setzero_ps(); + __m256 vacc1 = _mm256_setzero_ps(); + __m256 vacc2 = _mm256_setzero_ps(); + __m256 vacc3 = _mm256_setzero_ps(); + __m256 vacc4 = _mm256_setzero_ps(); + __m256 vacc5 = _mm256_setzero_ps(); + __m256 vacc6 = _mm256_setzero_ps(); + __m256 vacc7 = _mm256_setzero_ps(); + + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + __m256 vin0; + __m256 vin1; + __m256 vin2; + __m256 vin3; + __m256 vin4; + __m256 vin5; + __m256 vin6; + __m256 vin7; + vin0 = _mm256_loadu_ps(&i0[0]); + vin1 = _mm256_loadu_ps(&i0[8]); + vin2 = _mm256_loadu_ps(&i0[16]); + vin3 = _mm256_loadu_ps(&i0[24]); + vin4 = _mm256_loadu_ps(&i0[32]); + vin5 = _mm256_loadu_ps(&i0[40]); + vin6 = _mm256_loadu_ps(&i0[48]); + vin7 = _mm256_loadu_ps(&i0[56]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vacc4 = _mm256_add_ps(vin4, vacc4); + vacc5 = _mm256_add_ps(vin5, vacc5); + vacc6 = _mm256_add_ps(vin6, vacc6); + vacc7 = _mm256_add_ps(vin7, vacc7); + vin0 = _mm256_loadu_ps(&i1[0]); + vin1 = _mm256_loadu_ps(&i1[8]); + vin2 = _mm256_loadu_ps(&i1[16]); + vin3 = _mm256_loadu_ps(&i1[24]); + vin4 = _mm256_loadu_ps(&i1[32]); + vin5 = _mm256_loadu_ps(&i1[40]); + vin6 = _mm256_loadu_ps(&i1[48]); + vin7 = _mm256_loadu_ps(&i1[56]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vacc4 = _mm256_add_ps(vin4, vacc4); + vacc5 = _mm256_add_ps(vin5, vacc5); + vacc6 = _mm256_add_ps(vin6, vacc6); + vacc7 = _mm256_add_ps(vin7, vacc7); + vin0 = _mm256_loadu_ps(&i2[0]); + vin1 = _mm256_loadu_ps(&i2[8]); + vin2 = _mm256_loadu_ps(&i2[16]); + vin3 = _mm256_loadu_ps(&i2[24]); + vin4 = _mm256_loadu_ps(&i2[32]); + vin5 = _mm256_loadu_ps(&i2[40]); + vin6 = _mm256_loadu_ps(&i2[48]); + vin7 = _mm256_loadu_ps(&i2[56]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vacc4 = _mm256_add_ps(vin4, vacc4); + vacc5 = _mm256_add_ps(vin5, vacc5); + vacc6 = _mm256_add_ps(vin6, vacc6); + vacc7 = _mm256_add_ps(vin7, vacc7); + vin0 = _mm256_loadu_ps(&i3[0]); + vin1 = _mm256_loadu_ps(&i3[8]); + vin2 = _mm256_loadu_ps(&i3[16]); + vin3 = _mm256_loadu_ps(&i3[24]); + vin4 = _mm256_loadu_ps(&i3[32]); + vin5 = _mm256_loadu_ps(&i3[40]); + vin6 = _mm256_loadu_ps(&i3[48]); + vin7 = _mm256_loadu_ps(&i3[56]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vacc4 = _mm256_add_ps(vin4, vacc4); + vacc5 = _mm256_add_ps(vin5, vacc5); + vacc6 = _mm256_add_ps(vin6, vacc6); + vacc7 = _mm256_add_ps(vin7, vacc7); + vin0 = _mm256_loadu_ps(&i4[0]); + vin1 = _mm256_loadu_ps(&i4[8]); + vin2 = _mm256_loadu_ps(&i4[16]); + vin3 = _mm256_loadu_ps(&i4[24]); + vin4 = _mm256_loadu_ps(&i4[32]); + vin5 = _mm256_loadu_ps(&i4[40]); + vin6 = _mm256_loadu_ps(&i4[48]); + vin7 = _mm256_loadu_ps(&i4[56]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vacc4 = _mm256_add_ps(vin4, vacc4); + vacc5 = _mm256_add_ps(vin5, vacc5); + vacc6 = _mm256_add_ps(vin6, vacc6); + vacc7 = _mm256_add_ps(vin7, vacc7); + vin0 = _mm256_loadu_ps(&i5[0]); + vin1 = _mm256_loadu_ps(&i5[8]); + vin2 = _mm256_loadu_ps(&i5[16]); + vin3 = _mm256_loadu_ps(&i5[24]); + vin4 = _mm256_loadu_ps(&i5[32]); + vin5 = _mm256_loadu_ps(&i5[40]); + vin6 = _mm256_loadu_ps(&i5[48]); + vin7 = _mm256_loadu_ps(&i5[56]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vacc4 = _mm256_add_ps(vin4, vacc4); + vacc5 = _mm256_add_ps(vin5, vacc5); + vacc6 = _mm256_add_ps(vin6, vacc6); + vacc7 = _mm256_add_ps(vin7, vacc7); + vin0 = _mm256_loadu_ps(&i6[0]); + vin1 = _mm256_loadu_ps(&i6[8]); + vin2 = _mm256_loadu_ps(&i6[16]); + vin3 = _mm256_loadu_ps(&i6[24]); + vin4 = _mm256_loadu_ps(&i6[32]); + vin5 = _mm256_loadu_ps(&i6[40]); + vin6 = _mm256_loadu_ps(&i6[48]); + vin7 = _mm256_loadu_ps(&i6[56]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vacc4 = _mm256_add_ps(vin4, vacc4); + vacc5 = _mm256_add_ps(vin5, vacc5); + vacc6 = _mm256_add_ps(vin6, vacc6); + vacc7 = _mm256_add_ps(vin7, vacc7); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = _mm256_mul_ps(vacc0, vscale); + vacc1 = _mm256_mul_ps(vacc1, vscale); + vacc2 = _mm256_mul_ps(vacc2, vscale); + vacc3 = _mm256_mul_ps(vacc3, vscale); + vacc4 = _mm256_mul_ps(vacc4, vscale); + vacc5 = _mm256_mul_ps(vacc5, vscale); + vacc6 = _mm256_mul_ps(vacc6, vscale); + vacc7 = _mm256_mul_ps(vacc7, vscale); + + const float* o = output; + __m256 vo0 = _mm256_loadu_ps(o); o += 8; + __m256 vo1 = _mm256_loadu_ps(o); o += 8; + __m256 vo2 = _mm256_loadu_ps(o); o += 8; + __m256 vo3 = _mm256_loadu_ps(o); o += 8; + __m256 vo4 = _mm256_loadu_ps(o); o += 8; + __m256 vo5 = _mm256_loadu_ps(o); o += 8; + __m256 vo6 = _mm256_loadu_ps(o); o += 8; + __m256 vo7 = _mm256_loadu_ps(o); o += 8; + vacc0 = _mm256_add_ps(vo0, vacc0); + vacc1 = _mm256_add_ps(vo1, vacc1); + vacc2 = _mm256_add_ps(vo2, vacc2); + vacc3 = _mm256_add_ps(vo3, vacc3); + vacc4 = _mm256_add_ps(vo4, vacc4); + vacc5 = _mm256_add_ps(vo5, vacc5); + vacc6 = _mm256_add_ps(vo6, vacc6); + vacc7 = _mm256_add_ps(vo7, vacc7); + _mm256_storeu_ps(output, vacc0); output += 8; + _mm256_storeu_ps(output, vacc1); output += 8; + _mm256_storeu_ps(output, vacc2); output += 8; + _mm256_storeu_ps(output, vacc3); output += 8; + _mm256_storeu_ps(output, vacc4); output += 8; + _mm256_storeu_ps(output, vacc5); output += 8; + _mm256_storeu_ps(output, vacc6); output += 8; + _mm256_storeu_ps(output, vacc7); output += 8; + + input = (const float*) ((uintptr_t) input + 64 * sizeof(float)); + } + __m256i vmask; + if (channels != 0) { + input_increment = 7 * input_stride; + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + __m256 vacc[8]; + vacc[0] = _mm256_setzero_ps(); + vacc[1] = _mm256_setzero_ps(); + vacc[2] = _mm256_setzero_ps(); + vacc[3] = _mm256_setzero_ps(); + vacc[4] = _mm256_setzero_ps(); + vacc[5] = _mm256_setzero_ps(); + vacc[6] = _mm256_setzero_ps(); + vacc[7] = _mm256_setzero_ps(); + + const size_t num_full_chunks = channels >> 3; + const size_t num_chunks = round_up_po2(channels, 8) >> 3; + const size_t remainder = channels & 0x7; + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i0[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i1[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i2[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i3[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i4[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i5[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i6[i*8]), vacc[i]); + } + + if (remainder) { + vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) ¶ms->avx.mask_table[7] - (channels & 0x7) * sizeof(float))); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i0[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i1[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i2[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i3[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i4[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i5[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i6[num_full_chunks*8], vmask), vacc[num_full_chunks]); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = _mm256_mul_ps(vacc[i], vscale); + } + + __m256 vo[8]; + const float* o = output; + for (int i = 0; i < channels >> 3; ++i) { + vo[i] = _mm256_loadu_ps(o); o += 8; + } + for (int i = 0; i < channels >> 3; ++i) { + vacc[i] = _mm256_add_ps(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 3; ++i) { + _mm256_storeu_ps(output, vacc[i]); output += 8; + } + if (remainder) { + const size_t pos = num_full_chunks; + __m256 vout = vacc[pos]; + const __m256 vdata = _mm256_maskload_ps(output, vmask); + vout = _mm256_add_ps(vout, vdata); + __m128 vout_lo = _mm256_castps256_ps128(vout); + if (channels & 4) { + _mm_storeu_ps(output, vout_lo); + vout_lo = _mm256_extractf128_ps(vout, 1); + output += 4; + } + if (channels & 2) { + _mm_storel_pi((__m64*) output, vout_lo); + vout_lo = _mm_movehl_ps(vout_lo, vout_lo); + output += 2; + } + if (channels & 1) { + _mm_store_ss(output, vout_lo); + } + } + } +} diff --git a/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c128.c b/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c128.c new file mode 100644 index 00000000000..9aa3e46112c --- /dev/null +++ b/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c128.c @@ -0,0 +1,335 @@ +// Auto-generated file. Do not edit! +// Template: src/f32-rdsum/avx512.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include + +#include +#include +#include + + +void xnn_f32_rdsum_ukernel_7p7x__avx512f_c128( + size_t rows, + size_t channels, + const float* input, + size_t input_stride, + const float* zero, + float* output, + const union xnn_f32_scale_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(rows != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const __m512 vscale = _mm512_set1_ps(params->scalar.scale); + + size_t input_increment = 7 * input_stride; + for (; channels >= 128; channels -= 128) { + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + + __m512 vacc0 = _mm512_setzero_ps(); + __m512 vacc1 = _mm512_setzero_ps(); + __m512 vacc2 = _mm512_setzero_ps(); + __m512 vacc3 = _mm512_setzero_ps(); + __m512 vacc4 = _mm512_setzero_ps(); + __m512 vacc5 = _mm512_setzero_ps(); + __m512 vacc6 = _mm512_setzero_ps(); + __m512 vacc7 = _mm512_setzero_ps(); + + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + __m512 vin0; + __m512 vin1; + __m512 vin2; + __m512 vin3; + __m512 vin4; + __m512 vin5; + __m512 vin6; + __m512 vin7; + vin0 = _mm512_loadu_ps(&i0[0]); + vin1 = _mm512_loadu_ps(&i0[16]); + vin2 = _mm512_loadu_ps(&i0[32]); + vin3 = _mm512_loadu_ps(&i0[48]); + vin4 = _mm512_loadu_ps(&i0[64]); + vin5 = _mm512_loadu_ps(&i0[80]); + vin6 = _mm512_loadu_ps(&i0[96]); + vin7 = _mm512_loadu_ps(&i0[112]); + vacc0 = _mm512_add_ps(vin0, vacc0); + vacc1 = _mm512_add_ps(vin1, vacc1); + vacc2 = _mm512_add_ps(vin2, vacc2); + vacc3 = _mm512_add_ps(vin3, vacc3); + vacc4 = _mm512_add_ps(vin4, vacc4); + vacc5 = _mm512_add_ps(vin5, vacc5); + vacc6 = _mm512_add_ps(vin6, vacc6); + vacc7 = _mm512_add_ps(vin7, vacc7); + vin0 = _mm512_loadu_ps(&i1[0]); + vin1 = _mm512_loadu_ps(&i1[16]); + vin2 = _mm512_loadu_ps(&i1[32]); + vin3 = _mm512_loadu_ps(&i1[48]); + vin4 = _mm512_loadu_ps(&i1[64]); + vin5 = _mm512_loadu_ps(&i1[80]); + vin6 = _mm512_loadu_ps(&i1[96]); + vin7 = _mm512_loadu_ps(&i1[112]); + vacc0 = _mm512_add_ps(vin0, vacc0); + vacc1 = _mm512_add_ps(vin1, vacc1); + vacc2 = _mm512_add_ps(vin2, vacc2); + vacc3 = _mm512_add_ps(vin3, vacc3); + vacc4 = _mm512_add_ps(vin4, vacc4); + vacc5 = _mm512_add_ps(vin5, vacc5); + vacc6 = _mm512_add_ps(vin6, vacc6); + vacc7 = _mm512_add_ps(vin7, vacc7); + vin0 = _mm512_loadu_ps(&i2[0]); + vin1 = _mm512_loadu_ps(&i2[16]); + vin2 = _mm512_loadu_ps(&i2[32]); + vin3 = _mm512_loadu_ps(&i2[48]); + vin4 = _mm512_loadu_ps(&i2[64]); + vin5 = _mm512_loadu_ps(&i2[80]); + vin6 = _mm512_loadu_ps(&i2[96]); + vin7 = _mm512_loadu_ps(&i2[112]); + vacc0 = _mm512_add_ps(vin0, vacc0); + vacc1 = _mm512_add_ps(vin1, vacc1); + vacc2 = _mm512_add_ps(vin2, vacc2); + vacc3 = _mm512_add_ps(vin3, vacc3); + vacc4 = _mm512_add_ps(vin4, vacc4); + vacc5 = _mm512_add_ps(vin5, vacc5); + vacc6 = _mm512_add_ps(vin6, vacc6); + vacc7 = _mm512_add_ps(vin7, vacc7); + vin0 = _mm512_loadu_ps(&i3[0]); + vin1 = _mm512_loadu_ps(&i3[16]); + vin2 = _mm512_loadu_ps(&i3[32]); + vin3 = _mm512_loadu_ps(&i3[48]); + vin4 = _mm512_loadu_ps(&i3[64]); + vin5 = _mm512_loadu_ps(&i3[80]); + vin6 = _mm512_loadu_ps(&i3[96]); + vin7 = _mm512_loadu_ps(&i3[112]); + vacc0 = _mm512_add_ps(vin0, vacc0); + vacc1 = _mm512_add_ps(vin1, vacc1); + vacc2 = _mm512_add_ps(vin2, vacc2); + vacc3 = _mm512_add_ps(vin3, vacc3); + vacc4 = _mm512_add_ps(vin4, vacc4); + vacc5 = _mm512_add_ps(vin5, vacc5); + vacc6 = _mm512_add_ps(vin6, vacc6); + vacc7 = _mm512_add_ps(vin7, vacc7); + vin0 = _mm512_loadu_ps(&i4[0]); + vin1 = _mm512_loadu_ps(&i4[16]); + vin2 = _mm512_loadu_ps(&i4[32]); + vin3 = _mm512_loadu_ps(&i4[48]); + vin4 = _mm512_loadu_ps(&i4[64]); + vin5 = _mm512_loadu_ps(&i4[80]); + vin6 = _mm512_loadu_ps(&i4[96]); + vin7 = _mm512_loadu_ps(&i4[112]); + vacc0 = _mm512_add_ps(vin0, vacc0); + vacc1 = _mm512_add_ps(vin1, vacc1); + vacc2 = _mm512_add_ps(vin2, vacc2); + vacc3 = _mm512_add_ps(vin3, vacc3); + vacc4 = _mm512_add_ps(vin4, vacc4); + vacc5 = _mm512_add_ps(vin5, vacc5); + vacc6 = _mm512_add_ps(vin6, vacc6); + vacc7 = _mm512_add_ps(vin7, vacc7); + vin0 = _mm512_loadu_ps(&i5[0]); + vin1 = _mm512_loadu_ps(&i5[16]); + vin2 = _mm512_loadu_ps(&i5[32]); + vin3 = _mm512_loadu_ps(&i5[48]); + vin4 = _mm512_loadu_ps(&i5[64]); + vin5 = _mm512_loadu_ps(&i5[80]); + vin6 = _mm512_loadu_ps(&i5[96]); + vin7 = _mm512_loadu_ps(&i5[112]); + vacc0 = _mm512_add_ps(vin0, vacc0); + vacc1 = _mm512_add_ps(vin1, vacc1); + vacc2 = _mm512_add_ps(vin2, vacc2); + vacc3 = _mm512_add_ps(vin3, vacc3); + vacc4 = _mm512_add_ps(vin4, vacc4); + vacc5 = _mm512_add_ps(vin5, vacc5); + vacc6 = _mm512_add_ps(vin6, vacc6); + vacc7 = _mm512_add_ps(vin7, vacc7); + vin0 = _mm512_loadu_ps(&i6[0]); + vin1 = _mm512_loadu_ps(&i6[16]); + vin2 = _mm512_loadu_ps(&i6[32]); + vin3 = _mm512_loadu_ps(&i6[48]); + vin4 = _mm512_loadu_ps(&i6[64]); + vin5 = _mm512_loadu_ps(&i6[80]); + vin6 = _mm512_loadu_ps(&i6[96]); + vin7 = _mm512_loadu_ps(&i6[112]); + vacc0 = _mm512_add_ps(vin0, vacc0); + vacc1 = _mm512_add_ps(vin1, vacc1); + vacc2 = _mm512_add_ps(vin2, vacc2); + vacc3 = _mm512_add_ps(vin3, vacc3); + vacc4 = _mm512_add_ps(vin4, vacc4); + vacc5 = _mm512_add_ps(vin5, vacc5); + vacc6 = _mm512_add_ps(vin6, vacc6); + vacc7 = _mm512_add_ps(vin7, vacc7); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = _mm512_mul_ps(vacc0, vscale); + vacc1 = _mm512_mul_ps(vacc1, vscale); + vacc2 = _mm512_mul_ps(vacc2, vscale); + vacc3 = _mm512_mul_ps(vacc3, vscale); + vacc4 = _mm512_mul_ps(vacc4, vscale); + vacc5 = _mm512_mul_ps(vacc5, vscale); + vacc6 = _mm512_mul_ps(vacc6, vscale); + vacc7 = _mm512_mul_ps(vacc7, vscale); + + const float* o = output; + const __m512 vo0 = _mm512_loadu_ps(o); o += 16; + const __m512 vo1 = _mm512_loadu_ps(o); o += 16; + const __m512 vo2 = _mm512_loadu_ps(o); o += 16; + const __m512 vo3 = _mm512_loadu_ps(o); o += 16; + const __m512 vo4 = _mm512_loadu_ps(o); o += 16; + const __m512 vo5 = _mm512_loadu_ps(o); o += 16; + const __m512 vo6 = _mm512_loadu_ps(o); o += 16; + const __m512 vo7 = _mm512_loadu_ps(o); o += 16; + vacc0 = _mm512_add_ps(vo0, vacc0); + vacc1 = _mm512_add_ps(vo1, vacc1); + vacc2 = _mm512_add_ps(vo2, vacc2); + vacc3 = _mm512_add_ps(vo3, vacc3); + vacc4 = _mm512_add_ps(vo4, vacc4); + vacc5 = _mm512_add_ps(vo5, vacc5); + vacc6 = _mm512_add_ps(vo6, vacc6); + vacc7 = _mm512_add_ps(vo7, vacc7); + _mm512_storeu_ps(output, vacc0); output += 16; + _mm512_storeu_ps(output, vacc1); output += 16; + _mm512_storeu_ps(output, vacc2); output += 16; + _mm512_storeu_ps(output, vacc3); output += 16; + _mm512_storeu_ps(output, vacc4); output += 16; + _mm512_storeu_ps(output, vacc5); output += 16; + _mm512_storeu_ps(output, vacc6); output += 16; + _mm512_storeu_ps(output, vacc7); output += 16; + + input = (const float*) ((uintptr_t) input + 128 * sizeof(float)); + } + if (channels != 0) { + input_increment = 7 * input_stride; + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + __m512 vacc[8]; + vacc[0] = _mm512_setzero_ps(); + vacc[1] = _mm512_setzero_ps(); + vacc[2] = _mm512_setzero_ps(); + vacc[3] = _mm512_setzero_ps(); + vacc[4] = _mm512_setzero_ps(); + vacc[5] = _mm512_setzero_ps(); + vacc[6] = _mm512_setzero_ps(); + vacc[7] = _mm512_setzero_ps(); + + const size_t num_full_chunks = channels >> 4; + const size_t num_chunks = round_up_po2(channels, 16) >> 4; + const size_t remainder = channels & 0xF; + const size_t batch = channels & 0xF; + __mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); + if (remainder) { + assert(batch >= 1); + assert(batch <= 15); + vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); + } + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i0[i*16]), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i1[i*16]), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i2[i*16]), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i3[i*16]), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i4[i*16]), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i5[i*16]), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i6[i*16]), vacc[i]); + } + + if (remainder) { + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i0[num_full_chunks*16])); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i1[num_full_chunks*16])); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i2[num_full_chunks*16])); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i3[num_full_chunks*16])); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i4[num_full_chunks*16])); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i5[num_full_chunks*16])); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i6[num_full_chunks*16])); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = _mm512_mul_ps(vacc[i], vscale); + } + + __m512 vo[8]; + const float* o = output; + for (int i = 0; i < channels >> 4; ++i) { + vo[i] = _mm512_loadu_ps(o); o += 16; + } + for (int i = 0; i < channels >> 4; ++i) { + vacc[i] = _mm512_add_ps(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 4; ++i) { + _mm512_storeu_ps(output, vacc[i]); output += 16; + } + if (remainder) { + const size_t pos = num_full_chunks; + __m512 vout = vacc[pos]; + vout = _mm512_maskz_add_ps(vmask, vout, _mm512_maskz_loadu_ps(vmask, output)); + _mm512_mask_storeu_ps(output, vmask, vout); + } + } +} diff --git a/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c16.c b/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c16.c new file mode 100644 index 00000000000..b80585598b8 --- /dev/null +++ b/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c16.c @@ -0,0 +1,188 @@ +// Auto-generated file. Do not edit! +// Template: src/f32-rdsum/avx512.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include + +#include +#include +#include + + +void xnn_f32_rdsum_ukernel_7p7x__avx512f_c16( + size_t rows, + size_t channels, + const float* input, + size_t input_stride, + const float* zero, + float* output, + const union xnn_f32_scale_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(rows != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const __m512 vscale = _mm512_set1_ps(params->scalar.scale); + + size_t input_increment = 7 * input_stride; + for (; channels >= 16; channels -= 16) { + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + + __m512 vacc0 = _mm512_setzero_ps(); + + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + __m512 vin0; + vin0 = _mm512_loadu_ps(&i0[0]); + vacc0 = _mm512_add_ps(vin0, vacc0); + vin0 = _mm512_loadu_ps(&i1[0]); + vacc0 = _mm512_add_ps(vin0, vacc0); + vin0 = _mm512_loadu_ps(&i2[0]); + vacc0 = _mm512_add_ps(vin0, vacc0); + vin0 = _mm512_loadu_ps(&i3[0]); + vacc0 = _mm512_add_ps(vin0, vacc0); + vin0 = _mm512_loadu_ps(&i4[0]); + vacc0 = _mm512_add_ps(vin0, vacc0); + vin0 = _mm512_loadu_ps(&i5[0]); + vacc0 = _mm512_add_ps(vin0, vacc0); + vin0 = _mm512_loadu_ps(&i6[0]); + vacc0 = _mm512_add_ps(vin0, vacc0); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = _mm512_mul_ps(vacc0, vscale); + + const float* o = output; + const __m512 vo0 = _mm512_loadu_ps(o); o += 16; + vacc0 = _mm512_add_ps(vo0, vacc0); + _mm512_storeu_ps(output, vacc0); output += 16; + + input = (const float*) ((uintptr_t) input + 16 * sizeof(float)); + } + if (channels != 0) { + input_increment = 7 * input_stride; + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + __m512 vacc[1]; + vacc[0] = _mm512_setzero_ps(); + + const size_t num_full_chunks = channels >> 4; + const size_t num_chunks = round_up_po2(channels, 16) >> 4; + const size_t remainder = channels & 0xF; + const size_t batch = channels & 0xF; + __mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); + if (remainder) { + assert(batch >= 1); + assert(batch <= 15); + vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); + } + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i0[i*16]), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i1[i*16]), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i2[i*16]), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i3[i*16]), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i4[i*16]), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i5[i*16]), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i6[i*16]), vacc[i]); + } + + if (remainder) { + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i0[num_full_chunks*16])); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i1[num_full_chunks*16])); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i2[num_full_chunks*16])); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i3[num_full_chunks*16])); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i4[num_full_chunks*16])); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i5[num_full_chunks*16])); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i6[num_full_chunks*16])); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = _mm512_mul_ps(vacc[i], vscale); + } + + __m512 vo[1]; + const float* o = output; + for (int i = 0; i < channels >> 4; ++i) { + vo[i] = _mm512_loadu_ps(o); o += 16; + } + for (int i = 0; i < channels >> 4; ++i) { + vacc[i] = _mm512_add_ps(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 4; ++i) { + _mm512_storeu_ps(output, vacc[i]); output += 16; + } + if (remainder) { + const size_t pos = num_full_chunks; + __m512 vout = vacc[pos]; + vout = _mm512_maskz_add_ps(vmask, vout, _mm512_maskz_loadu_ps(vmask, output)); + _mm512_mask_storeu_ps(output, vmask, vout); + } + } +} diff --git a/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c32.c b/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c32.c new file mode 100644 index 00000000000..369b14bd14e --- /dev/null +++ b/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c32.c @@ -0,0 +1,209 @@ +// Auto-generated file. Do not edit! +// Template: src/f32-rdsum/avx512.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include + +#include +#include +#include + + +void xnn_f32_rdsum_ukernel_7p7x__avx512f_c32( + size_t rows, + size_t channels, + const float* input, + size_t input_stride, + const float* zero, + float* output, + const union xnn_f32_scale_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(rows != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const __m512 vscale = _mm512_set1_ps(params->scalar.scale); + + size_t input_increment = 7 * input_stride; + for (; channels >= 32; channels -= 32) { + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + + __m512 vacc0 = _mm512_setzero_ps(); + __m512 vacc1 = _mm512_setzero_ps(); + + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + __m512 vin0; + __m512 vin1; + vin0 = _mm512_loadu_ps(&i0[0]); + vin1 = _mm512_loadu_ps(&i0[16]); + vacc0 = _mm512_add_ps(vin0, vacc0); + vacc1 = _mm512_add_ps(vin1, vacc1); + vin0 = _mm512_loadu_ps(&i1[0]); + vin1 = _mm512_loadu_ps(&i1[16]); + vacc0 = _mm512_add_ps(vin0, vacc0); + vacc1 = _mm512_add_ps(vin1, vacc1); + vin0 = _mm512_loadu_ps(&i2[0]); + vin1 = _mm512_loadu_ps(&i2[16]); + vacc0 = _mm512_add_ps(vin0, vacc0); + vacc1 = _mm512_add_ps(vin1, vacc1); + vin0 = _mm512_loadu_ps(&i3[0]); + vin1 = _mm512_loadu_ps(&i3[16]); + vacc0 = _mm512_add_ps(vin0, vacc0); + vacc1 = _mm512_add_ps(vin1, vacc1); + vin0 = _mm512_loadu_ps(&i4[0]); + vin1 = _mm512_loadu_ps(&i4[16]); + vacc0 = _mm512_add_ps(vin0, vacc0); + vacc1 = _mm512_add_ps(vin1, vacc1); + vin0 = _mm512_loadu_ps(&i5[0]); + vin1 = _mm512_loadu_ps(&i5[16]); + vacc0 = _mm512_add_ps(vin0, vacc0); + vacc1 = _mm512_add_ps(vin1, vacc1); + vin0 = _mm512_loadu_ps(&i6[0]); + vin1 = _mm512_loadu_ps(&i6[16]); + vacc0 = _mm512_add_ps(vin0, vacc0); + vacc1 = _mm512_add_ps(vin1, vacc1); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = _mm512_mul_ps(vacc0, vscale); + vacc1 = _mm512_mul_ps(vacc1, vscale); + + const float* o = output; + const __m512 vo0 = _mm512_loadu_ps(o); o += 16; + const __m512 vo1 = _mm512_loadu_ps(o); o += 16; + vacc0 = _mm512_add_ps(vo0, vacc0); + vacc1 = _mm512_add_ps(vo1, vacc1); + _mm512_storeu_ps(output, vacc0); output += 16; + _mm512_storeu_ps(output, vacc1); output += 16; + + input = (const float*) ((uintptr_t) input + 32 * sizeof(float)); + } + if (channels != 0) { + input_increment = 7 * input_stride; + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + __m512 vacc[2]; + vacc[0] = _mm512_setzero_ps(); + vacc[1] = _mm512_setzero_ps(); + + const size_t num_full_chunks = channels >> 4; + const size_t num_chunks = round_up_po2(channels, 16) >> 4; + const size_t remainder = channels & 0xF; + const size_t batch = channels & 0xF; + __mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); + if (remainder) { + assert(batch >= 1); + assert(batch <= 15); + vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); + } + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i0[i*16]), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i1[i*16]), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i2[i*16]), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i3[i*16]), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i4[i*16]), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i5[i*16]), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i6[i*16]), vacc[i]); + } + + if (remainder) { + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i0[num_full_chunks*16])); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i1[num_full_chunks*16])); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i2[num_full_chunks*16])); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i3[num_full_chunks*16])); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i4[num_full_chunks*16])); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i5[num_full_chunks*16])); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i6[num_full_chunks*16])); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = _mm512_mul_ps(vacc[i], vscale); + } + + __m512 vo[2]; + const float* o = output; + for (int i = 0; i < channels >> 4; ++i) { + vo[i] = _mm512_loadu_ps(o); o += 16; + } + for (int i = 0; i < channels >> 4; ++i) { + vacc[i] = _mm512_add_ps(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 4; ++i) { + _mm512_storeu_ps(output, vacc[i]); output += 16; + } + if (remainder) { + const size_t pos = num_full_chunks; + __m512 vout = vacc[pos]; + vout = _mm512_maskz_add_ps(vmask, vout, _mm512_maskz_loadu_ps(vmask, output)); + _mm512_mask_storeu_ps(output, vmask, vout); + } + } +} diff --git a/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c64.c b/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c64.c new file mode 100644 index 00000000000..3736df87615 --- /dev/null +++ b/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c64.c @@ -0,0 +1,251 @@ +// Auto-generated file. Do not edit! +// Template: src/f32-rdsum/avx512.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include + +#include +#include +#include + + +void xnn_f32_rdsum_ukernel_7p7x__avx512f_c64( + size_t rows, + size_t channels, + const float* input, + size_t input_stride, + const float* zero, + float* output, + const union xnn_f32_scale_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(rows != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const __m512 vscale = _mm512_set1_ps(params->scalar.scale); + + size_t input_increment = 7 * input_stride; + for (; channels >= 64; channels -= 64) { + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + + __m512 vacc0 = _mm512_setzero_ps(); + __m512 vacc1 = _mm512_setzero_ps(); + __m512 vacc2 = _mm512_setzero_ps(); + __m512 vacc3 = _mm512_setzero_ps(); + + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + __m512 vin0; + __m512 vin1; + __m512 vin2; + __m512 vin3; + vin0 = _mm512_loadu_ps(&i0[0]); + vin1 = _mm512_loadu_ps(&i0[16]); + vin2 = _mm512_loadu_ps(&i0[32]); + vin3 = _mm512_loadu_ps(&i0[48]); + vacc0 = _mm512_add_ps(vin0, vacc0); + vacc1 = _mm512_add_ps(vin1, vacc1); + vacc2 = _mm512_add_ps(vin2, vacc2); + vacc3 = _mm512_add_ps(vin3, vacc3); + vin0 = _mm512_loadu_ps(&i1[0]); + vin1 = _mm512_loadu_ps(&i1[16]); + vin2 = _mm512_loadu_ps(&i1[32]); + vin3 = _mm512_loadu_ps(&i1[48]); + vacc0 = _mm512_add_ps(vin0, vacc0); + vacc1 = _mm512_add_ps(vin1, vacc1); + vacc2 = _mm512_add_ps(vin2, vacc2); + vacc3 = _mm512_add_ps(vin3, vacc3); + vin0 = _mm512_loadu_ps(&i2[0]); + vin1 = _mm512_loadu_ps(&i2[16]); + vin2 = _mm512_loadu_ps(&i2[32]); + vin3 = _mm512_loadu_ps(&i2[48]); + vacc0 = _mm512_add_ps(vin0, vacc0); + vacc1 = _mm512_add_ps(vin1, vacc1); + vacc2 = _mm512_add_ps(vin2, vacc2); + vacc3 = _mm512_add_ps(vin3, vacc3); + vin0 = _mm512_loadu_ps(&i3[0]); + vin1 = _mm512_loadu_ps(&i3[16]); + vin2 = _mm512_loadu_ps(&i3[32]); + vin3 = _mm512_loadu_ps(&i3[48]); + vacc0 = _mm512_add_ps(vin0, vacc0); + vacc1 = _mm512_add_ps(vin1, vacc1); + vacc2 = _mm512_add_ps(vin2, vacc2); + vacc3 = _mm512_add_ps(vin3, vacc3); + vin0 = _mm512_loadu_ps(&i4[0]); + vin1 = _mm512_loadu_ps(&i4[16]); + vin2 = _mm512_loadu_ps(&i4[32]); + vin3 = _mm512_loadu_ps(&i4[48]); + vacc0 = _mm512_add_ps(vin0, vacc0); + vacc1 = _mm512_add_ps(vin1, vacc1); + vacc2 = _mm512_add_ps(vin2, vacc2); + vacc3 = _mm512_add_ps(vin3, vacc3); + vin0 = _mm512_loadu_ps(&i5[0]); + vin1 = _mm512_loadu_ps(&i5[16]); + vin2 = _mm512_loadu_ps(&i5[32]); + vin3 = _mm512_loadu_ps(&i5[48]); + vacc0 = _mm512_add_ps(vin0, vacc0); + vacc1 = _mm512_add_ps(vin1, vacc1); + vacc2 = _mm512_add_ps(vin2, vacc2); + vacc3 = _mm512_add_ps(vin3, vacc3); + vin0 = _mm512_loadu_ps(&i6[0]); + vin1 = _mm512_loadu_ps(&i6[16]); + vin2 = _mm512_loadu_ps(&i6[32]); + vin3 = _mm512_loadu_ps(&i6[48]); + vacc0 = _mm512_add_ps(vin0, vacc0); + vacc1 = _mm512_add_ps(vin1, vacc1); + vacc2 = _mm512_add_ps(vin2, vacc2); + vacc3 = _mm512_add_ps(vin3, vacc3); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = _mm512_mul_ps(vacc0, vscale); + vacc1 = _mm512_mul_ps(vacc1, vscale); + vacc2 = _mm512_mul_ps(vacc2, vscale); + vacc3 = _mm512_mul_ps(vacc3, vscale); + + const float* o = output; + const __m512 vo0 = _mm512_loadu_ps(o); o += 16; + const __m512 vo1 = _mm512_loadu_ps(o); o += 16; + const __m512 vo2 = _mm512_loadu_ps(o); o += 16; + const __m512 vo3 = _mm512_loadu_ps(o); o += 16; + vacc0 = _mm512_add_ps(vo0, vacc0); + vacc1 = _mm512_add_ps(vo1, vacc1); + vacc2 = _mm512_add_ps(vo2, vacc2); + vacc3 = _mm512_add_ps(vo3, vacc3); + _mm512_storeu_ps(output, vacc0); output += 16; + _mm512_storeu_ps(output, vacc1); output += 16; + _mm512_storeu_ps(output, vacc2); output += 16; + _mm512_storeu_ps(output, vacc3); output += 16; + + input = (const float*) ((uintptr_t) input + 64 * sizeof(float)); + } + if (channels != 0) { + input_increment = 7 * input_stride; + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + __m512 vacc[4]; + vacc[0] = _mm512_setzero_ps(); + vacc[1] = _mm512_setzero_ps(); + vacc[2] = _mm512_setzero_ps(); + vacc[3] = _mm512_setzero_ps(); + + const size_t num_full_chunks = channels >> 4; + const size_t num_chunks = round_up_po2(channels, 16) >> 4; + const size_t remainder = channels & 0xF; + const size_t batch = channels & 0xF; + __mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); + if (remainder) { + assert(batch >= 1); + assert(batch <= 15); + vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); + } + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i0[i*16]), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i1[i*16]), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i2[i*16]), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i3[i*16]), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i4[i*16]), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i5[i*16]), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i6[i*16]), vacc[i]); + } + + if (remainder) { + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i0[num_full_chunks*16])); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i1[num_full_chunks*16])); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i2[num_full_chunks*16])); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i3[num_full_chunks*16])); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i4[num_full_chunks*16])); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i5[num_full_chunks*16])); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i6[num_full_chunks*16])); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = _mm512_mul_ps(vacc[i], vscale); + } + + __m512 vo[4]; + const float* o = output; + for (int i = 0; i < channels >> 4; ++i) { + vo[i] = _mm512_loadu_ps(o); o += 16; + } + for (int i = 0; i < channels >> 4; ++i) { + vacc[i] = _mm512_add_ps(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 4; ++i) { + _mm512_storeu_ps(output, vacc[i]); output += 16; + } + if (remainder) { + const size_t pos = num_full_chunks; + __m512 vout = vacc[pos]; + vout = _mm512_maskz_add_ps(vmask, vout, _mm512_maskz_loadu_ps(vmask, output)); + _mm512_mask_storeu_ps(output, vmask, vout); + } + } +} diff --git a/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-wasmsimd-c16.c b/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-wasmsimd-c16.c new file mode 100644 index 00000000000..b07592edbb6 --- /dev/null +++ b/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-wasmsimd-c16.c @@ -0,0 +1,235 @@ +// Auto-generated file. Do not edit! +// Template: src/f32-rdsum/wasm-simd.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC // +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include + +#include +#include +#include + +void xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c16( + size_t rows, + size_t channels, + const float* input, + size_t input_stride, + const float* zero, + float* output, + const union xnn_f32_scale_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(rows != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const v128_t vscale = wasm_v128_load32_splat(¶ms->scalar.scale); + + size_t input_increment = 7 * input_stride; + for (; channels >= 16; channels -= 16) { + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + + v128_t vacc0 = wasm_i32x4_const_splat(0.f); + v128_t vacc1 = wasm_i32x4_const_splat(0.f); + v128_t vacc2 = wasm_i32x4_const_splat(0.f); + v128_t vacc3 = wasm_i32x4_const_splat(0.f); + + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + v128_t vin0; + v128_t vin1; + v128_t vin2; + v128_t vin3; + vin0 = wasm_v128_load(&i0[0]); + vin1 = wasm_v128_load(&i0[4]); + vin2 = wasm_v128_load(&i0[8]); + vin3 = wasm_v128_load(&i0[12]); + vacc0 = wasm_f32x4_add(vin0, vacc0); + vacc1 = wasm_f32x4_add(vin1, vacc1); + vacc2 = wasm_f32x4_add(vin2, vacc2); + vacc3 = wasm_f32x4_add(vin3, vacc3); + vin0 = wasm_v128_load(&i1[0]); + vin1 = wasm_v128_load(&i1[4]); + vin2 = wasm_v128_load(&i1[8]); + vin3 = wasm_v128_load(&i1[12]); + vacc0 = wasm_f32x4_add(vin0, vacc0); + vacc1 = wasm_f32x4_add(vin1, vacc1); + vacc2 = wasm_f32x4_add(vin2, vacc2); + vacc3 = wasm_f32x4_add(vin3, vacc3); + vin0 = wasm_v128_load(&i2[0]); + vin1 = wasm_v128_load(&i2[4]); + vin2 = wasm_v128_load(&i2[8]); + vin3 = wasm_v128_load(&i2[12]); + vacc0 = wasm_f32x4_add(vin0, vacc0); + vacc1 = wasm_f32x4_add(vin1, vacc1); + vacc2 = wasm_f32x4_add(vin2, vacc2); + vacc3 = wasm_f32x4_add(vin3, vacc3); + vin0 = wasm_v128_load(&i3[0]); + vin1 = wasm_v128_load(&i3[4]); + vin2 = wasm_v128_load(&i3[8]); + vin3 = wasm_v128_load(&i3[12]); + vacc0 = wasm_f32x4_add(vin0, vacc0); + vacc1 = wasm_f32x4_add(vin1, vacc1); + vacc2 = wasm_f32x4_add(vin2, vacc2); + vacc3 = wasm_f32x4_add(vin3, vacc3); + vin0 = wasm_v128_load(&i4[0]); + vin1 = wasm_v128_load(&i4[4]); + vin2 = wasm_v128_load(&i4[8]); + vin3 = wasm_v128_load(&i4[12]); + vacc0 = wasm_f32x4_add(vin0, vacc0); + vacc1 = wasm_f32x4_add(vin1, vacc1); + vacc2 = wasm_f32x4_add(vin2, vacc2); + vacc3 = wasm_f32x4_add(vin3, vacc3); + vin0 = wasm_v128_load(&i5[0]); + vin1 = wasm_v128_load(&i5[4]); + vin2 = wasm_v128_load(&i5[8]); + vin3 = wasm_v128_load(&i5[12]); + vacc0 = wasm_f32x4_add(vin0, vacc0); + vacc1 = wasm_f32x4_add(vin1, vacc1); + vacc2 = wasm_f32x4_add(vin2, vacc2); + vacc3 = wasm_f32x4_add(vin3, vacc3); + vin0 = wasm_v128_load(&i6[0]); + vin1 = wasm_v128_load(&i6[4]); + vin2 = wasm_v128_load(&i6[8]); + vin3 = wasm_v128_load(&i6[12]); + vacc0 = wasm_f32x4_add(vin0, vacc0); + vacc1 = wasm_f32x4_add(vin1, vacc1); + vacc2 = wasm_f32x4_add(vin2, vacc2); + vacc3 = wasm_f32x4_add(vin3, vacc3); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = wasm_f32x4_mul(vacc0, vscale); + vacc1 = wasm_f32x4_mul(vacc1, vscale); + vacc2 = wasm_f32x4_mul(vacc2, vscale); + vacc3 = wasm_f32x4_mul(vacc3, vscale); + + const float* o = output; + v128_t vo0 = wasm_v128_load(o); o += 4; + v128_t vo1 = wasm_v128_load(o); o += 4; + v128_t vo2 = wasm_v128_load(o); o += 4; + v128_t vo3 = wasm_v128_load(o); o += 4; + vacc0 = wasm_f32x4_add(vo0, vacc0); + vacc1 = wasm_f32x4_add(vo1, vacc1); + vacc2 = wasm_f32x4_add(vo2, vacc2); + vacc3 = wasm_f32x4_add(vo3, vacc3); + wasm_v128_store(output, vacc0); output += 4; + wasm_v128_store(output, vacc1); output += 4; + wasm_v128_store(output, vacc2); output += 4; + wasm_v128_store(output, vacc3); output += 4; + + input = (const float*) ((uintptr_t) input + 16 * sizeof(float)); + } + if (channels != 0) { + input_increment = 7 * input_stride; + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + v128_t vacc[4]; + vacc[0] = wasm_i32x4_const_splat(0.f); + vacc[1] = wasm_i32x4_const_splat(0.f); + vacc[2] = wasm_i32x4_const_splat(0.f); + vacc[3] = wasm_i32x4_const_splat(0.f); + + const size_t num_chunks = round_up_po2(channels, 4) >> 2; + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_chunks; ++i) { + vacc[i] = wasm_f32x4_add(wasm_v128_load(&i0[i*4]), vacc[i]); + vacc[i] = wasm_f32x4_add(wasm_v128_load(&i1[i*4]), vacc[i]); + vacc[i] = wasm_f32x4_add(wasm_v128_load(&i2[i*4]), vacc[i]); + vacc[i] = wasm_f32x4_add(wasm_v128_load(&i3[i*4]), vacc[i]); + vacc[i] = wasm_f32x4_add(wasm_v128_load(&i4[i*4]), vacc[i]); + vacc[i] = wasm_f32x4_add(wasm_v128_load(&i5[i*4]), vacc[i]); + vacc[i] = wasm_f32x4_add(wasm_v128_load(&i6[i*4]), vacc[i]); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (int i = 0; i < num_chunks; ++i) { + vacc[i] = wasm_f32x4_mul(vacc[i], vscale); + } + + v128_t vo[4]; + const float* o = output; + for (int i = 0; i < channels >> 2; ++i) { + vo[i] = wasm_v128_load(o); o += 4; + } + for (int i = 0; i < channels >> 2; ++i) { + vacc[i] = wasm_f32x4_add(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 2; ++i) { + wasm_v128_store(output, vacc[i]); output += 4; + } + const size_t pos = channels / 4; + v128_t vout = vacc[pos]; + if (channels & 2) { + v128_t vo = wasm_f32x4_make(output[0], output[1], 0.f, 0.f); + wasm_v128_store64_lane(output, wasm_f32x4_add(vo, vout), 0); + vout = wasm_v64x2_shuffle(vout, vout, 1, 1); + output += 2; + } + if (channels & 1) { + *output += wasm_f32x4_extract_lane(vout, 0); + } + } +} diff --git a/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-wasmsimd-c32.c b/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-wasmsimd-c32.c new file mode 100644 index 00000000000..c1a82067571 --- /dev/null +++ b/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-wasmsimd-c32.c @@ -0,0 +1,319 @@ +// Auto-generated file. Do not edit! +// Template: src/f32-rdsum/wasm-simd.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC // +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include + +#include +#include +#include + +void xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c32( + size_t rows, + size_t channels, + const float* input, + size_t input_stride, + const float* zero, + float* output, + const union xnn_f32_scale_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(rows != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const v128_t vscale = wasm_v128_load32_splat(¶ms->scalar.scale); + + size_t input_increment = 7 * input_stride; + for (; channels >= 32; channels -= 32) { + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + + v128_t vacc0 = wasm_i32x4_const_splat(0.f); + v128_t vacc1 = wasm_i32x4_const_splat(0.f); + v128_t vacc2 = wasm_i32x4_const_splat(0.f); + v128_t vacc3 = wasm_i32x4_const_splat(0.f); + v128_t vacc4 = wasm_i32x4_const_splat(0.f); + v128_t vacc5 = wasm_i32x4_const_splat(0.f); + v128_t vacc6 = wasm_i32x4_const_splat(0.f); + v128_t vacc7 = wasm_i32x4_const_splat(0.f); + + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + v128_t vin0; + v128_t vin1; + v128_t vin2; + v128_t vin3; + v128_t vin4; + v128_t vin5; + v128_t vin6; + v128_t vin7; + vin0 = wasm_v128_load(&i0[0]); + vin1 = wasm_v128_load(&i0[4]); + vin2 = wasm_v128_load(&i0[8]); + vin3 = wasm_v128_load(&i0[12]); + vin4 = wasm_v128_load(&i0[16]); + vin5 = wasm_v128_load(&i0[20]); + vin6 = wasm_v128_load(&i0[24]); + vin7 = wasm_v128_load(&i0[28]); + vacc0 = wasm_f32x4_add(vin0, vacc0); + vacc1 = wasm_f32x4_add(vin1, vacc1); + vacc2 = wasm_f32x4_add(vin2, vacc2); + vacc3 = wasm_f32x4_add(vin3, vacc3); + vacc4 = wasm_f32x4_add(vin4, vacc4); + vacc5 = wasm_f32x4_add(vin5, vacc5); + vacc6 = wasm_f32x4_add(vin6, vacc6); + vacc7 = wasm_f32x4_add(vin7, vacc7); + vin0 = wasm_v128_load(&i1[0]); + vin1 = wasm_v128_load(&i1[4]); + vin2 = wasm_v128_load(&i1[8]); + vin3 = wasm_v128_load(&i1[12]); + vin4 = wasm_v128_load(&i1[16]); + vin5 = wasm_v128_load(&i1[20]); + vin6 = wasm_v128_load(&i1[24]); + vin7 = wasm_v128_load(&i1[28]); + vacc0 = wasm_f32x4_add(vin0, vacc0); + vacc1 = wasm_f32x4_add(vin1, vacc1); + vacc2 = wasm_f32x4_add(vin2, vacc2); + vacc3 = wasm_f32x4_add(vin3, vacc3); + vacc4 = wasm_f32x4_add(vin4, vacc4); + vacc5 = wasm_f32x4_add(vin5, vacc5); + vacc6 = wasm_f32x4_add(vin6, vacc6); + vacc7 = wasm_f32x4_add(vin7, vacc7); + vin0 = wasm_v128_load(&i2[0]); + vin1 = wasm_v128_load(&i2[4]); + vin2 = wasm_v128_load(&i2[8]); + vin3 = wasm_v128_load(&i2[12]); + vin4 = wasm_v128_load(&i2[16]); + vin5 = wasm_v128_load(&i2[20]); + vin6 = wasm_v128_load(&i2[24]); + vin7 = wasm_v128_load(&i2[28]); + vacc0 = wasm_f32x4_add(vin0, vacc0); + vacc1 = wasm_f32x4_add(vin1, vacc1); + vacc2 = wasm_f32x4_add(vin2, vacc2); + vacc3 = wasm_f32x4_add(vin3, vacc3); + vacc4 = wasm_f32x4_add(vin4, vacc4); + vacc5 = wasm_f32x4_add(vin5, vacc5); + vacc6 = wasm_f32x4_add(vin6, vacc6); + vacc7 = wasm_f32x4_add(vin7, vacc7); + vin0 = wasm_v128_load(&i3[0]); + vin1 = wasm_v128_load(&i3[4]); + vin2 = wasm_v128_load(&i3[8]); + vin3 = wasm_v128_load(&i3[12]); + vin4 = wasm_v128_load(&i3[16]); + vin5 = wasm_v128_load(&i3[20]); + vin6 = wasm_v128_load(&i3[24]); + vin7 = wasm_v128_load(&i3[28]); + vacc0 = wasm_f32x4_add(vin0, vacc0); + vacc1 = wasm_f32x4_add(vin1, vacc1); + vacc2 = wasm_f32x4_add(vin2, vacc2); + vacc3 = wasm_f32x4_add(vin3, vacc3); + vacc4 = wasm_f32x4_add(vin4, vacc4); + vacc5 = wasm_f32x4_add(vin5, vacc5); + vacc6 = wasm_f32x4_add(vin6, vacc6); + vacc7 = wasm_f32x4_add(vin7, vacc7); + vin0 = wasm_v128_load(&i4[0]); + vin1 = wasm_v128_load(&i4[4]); + vin2 = wasm_v128_load(&i4[8]); + vin3 = wasm_v128_load(&i4[12]); + vin4 = wasm_v128_load(&i4[16]); + vin5 = wasm_v128_load(&i4[20]); + vin6 = wasm_v128_load(&i4[24]); + vin7 = wasm_v128_load(&i4[28]); + vacc0 = wasm_f32x4_add(vin0, vacc0); + vacc1 = wasm_f32x4_add(vin1, vacc1); + vacc2 = wasm_f32x4_add(vin2, vacc2); + vacc3 = wasm_f32x4_add(vin3, vacc3); + vacc4 = wasm_f32x4_add(vin4, vacc4); + vacc5 = wasm_f32x4_add(vin5, vacc5); + vacc6 = wasm_f32x4_add(vin6, vacc6); + vacc7 = wasm_f32x4_add(vin7, vacc7); + vin0 = wasm_v128_load(&i5[0]); + vin1 = wasm_v128_load(&i5[4]); + vin2 = wasm_v128_load(&i5[8]); + vin3 = wasm_v128_load(&i5[12]); + vin4 = wasm_v128_load(&i5[16]); + vin5 = wasm_v128_load(&i5[20]); + vin6 = wasm_v128_load(&i5[24]); + vin7 = wasm_v128_load(&i5[28]); + vacc0 = wasm_f32x4_add(vin0, vacc0); + vacc1 = wasm_f32x4_add(vin1, vacc1); + vacc2 = wasm_f32x4_add(vin2, vacc2); + vacc3 = wasm_f32x4_add(vin3, vacc3); + vacc4 = wasm_f32x4_add(vin4, vacc4); + vacc5 = wasm_f32x4_add(vin5, vacc5); + vacc6 = wasm_f32x4_add(vin6, vacc6); + vacc7 = wasm_f32x4_add(vin7, vacc7); + vin0 = wasm_v128_load(&i6[0]); + vin1 = wasm_v128_load(&i6[4]); + vin2 = wasm_v128_load(&i6[8]); + vin3 = wasm_v128_load(&i6[12]); + vin4 = wasm_v128_load(&i6[16]); + vin5 = wasm_v128_load(&i6[20]); + vin6 = wasm_v128_load(&i6[24]); + vin7 = wasm_v128_load(&i6[28]); + vacc0 = wasm_f32x4_add(vin0, vacc0); + vacc1 = wasm_f32x4_add(vin1, vacc1); + vacc2 = wasm_f32x4_add(vin2, vacc2); + vacc3 = wasm_f32x4_add(vin3, vacc3); + vacc4 = wasm_f32x4_add(vin4, vacc4); + vacc5 = wasm_f32x4_add(vin5, vacc5); + vacc6 = wasm_f32x4_add(vin6, vacc6); + vacc7 = wasm_f32x4_add(vin7, vacc7); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = wasm_f32x4_mul(vacc0, vscale); + vacc1 = wasm_f32x4_mul(vacc1, vscale); + vacc2 = wasm_f32x4_mul(vacc2, vscale); + vacc3 = wasm_f32x4_mul(vacc3, vscale); + vacc4 = wasm_f32x4_mul(vacc4, vscale); + vacc5 = wasm_f32x4_mul(vacc5, vscale); + vacc6 = wasm_f32x4_mul(vacc6, vscale); + vacc7 = wasm_f32x4_mul(vacc7, vscale); + + const float* o = output; + v128_t vo0 = wasm_v128_load(o); o += 4; + v128_t vo1 = wasm_v128_load(o); o += 4; + v128_t vo2 = wasm_v128_load(o); o += 4; + v128_t vo3 = wasm_v128_load(o); o += 4; + v128_t vo4 = wasm_v128_load(o); o += 4; + v128_t vo5 = wasm_v128_load(o); o += 4; + v128_t vo6 = wasm_v128_load(o); o += 4; + v128_t vo7 = wasm_v128_load(o); o += 4; + vacc0 = wasm_f32x4_add(vo0, vacc0); + vacc1 = wasm_f32x4_add(vo1, vacc1); + vacc2 = wasm_f32x4_add(vo2, vacc2); + vacc3 = wasm_f32x4_add(vo3, vacc3); + vacc4 = wasm_f32x4_add(vo4, vacc4); + vacc5 = wasm_f32x4_add(vo5, vacc5); + vacc6 = wasm_f32x4_add(vo6, vacc6); + vacc7 = wasm_f32x4_add(vo7, vacc7); + wasm_v128_store(output, vacc0); output += 4; + wasm_v128_store(output, vacc1); output += 4; + wasm_v128_store(output, vacc2); output += 4; + wasm_v128_store(output, vacc3); output += 4; + wasm_v128_store(output, vacc4); output += 4; + wasm_v128_store(output, vacc5); output += 4; + wasm_v128_store(output, vacc6); output += 4; + wasm_v128_store(output, vacc7); output += 4; + + input = (const float*) ((uintptr_t) input + 32 * sizeof(float)); + } + if (channels != 0) { + input_increment = 7 * input_stride; + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + v128_t vacc[8]; + vacc[0] = wasm_i32x4_const_splat(0.f); + vacc[1] = wasm_i32x4_const_splat(0.f); + vacc[2] = wasm_i32x4_const_splat(0.f); + vacc[3] = wasm_i32x4_const_splat(0.f); + vacc[4] = wasm_i32x4_const_splat(0.f); + vacc[5] = wasm_i32x4_const_splat(0.f); + vacc[6] = wasm_i32x4_const_splat(0.f); + vacc[7] = wasm_i32x4_const_splat(0.f); + + const size_t num_chunks = round_up_po2(channels, 4) >> 2; + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_chunks; ++i) { + vacc[i] = wasm_f32x4_add(wasm_v128_load(&i0[i*4]), vacc[i]); + vacc[i] = wasm_f32x4_add(wasm_v128_load(&i1[i*4]), vacc[i]); + vacc[i] = wasm_f32x4_add(wasm_v128_load(&i2[i*4]), vacc[i]); + vacc[i] = wasm_f32x4_add(wasm_v128_load(&i3[i*4]), vacc[i]); + vacc[i] = wasm_f32x4_add(wasm_v128_load(&i4[i*4]), vacc[i]); + vacc[i] = wasm_f32x4_add(wasm_v128_load(&i5[i*4]), vacc[i]); + vacc[i] = wasm_f32x4_add(wasm_v128_load(&i6[i*4]), vacc[i]); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (int i = 0; i < num_chunks; ++i) { + vacc[i] = wasm_f32x4_mul(vacc[i], vscale); + } + + v128_t vo[8]; + const float* o = output; + for (int i = 0; i < channels >> 2; ++i) { + vo[i] = wasm_v128_load(o); o += 4; + } + for (int i = 0; i < channels >> 2; ++i) { + vacc[i] = wasm_f32x4_add(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 2; ++i) { + wasm_v128_store(output, vacc[i]); output += 4; + } + const size_t pos = channels / 4; + v128_t vout = vacc[pos]; + if (channels & 2) { + v128_t vo = wasm_f32x4_make(output[0], output[1], 0.f, 0.f); + wasm_v128_store64_lane(output, wasm_f32x4_add(vo, vout), 0); + vout = wasm_v64x2_shuffle(vout, vout, 1, 1); + output += 2; + } + if (channels & 1) { + *output += wasm_f32x4_extract_lane(vout, 0); + } + } +} diff --git a/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-wasmsimd-c64.c b/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-wasmsimd-c64.c new file mode 100644 index 00000000000..f0dad1d10e1 --- /dev/null +++ b/src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-wasmsimd-c64.c @@ -0,0 +1,487 @@ +// Auto-generated file. Do not edit! +// Template: src/f32-rdsum/wasm-simd.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC // +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include + +#include +#include +#include + +void xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c64( + size_t rows, + size_t channels, + const float* input, + size_t input_stride, + const float* zero, + float* output, + const union xnn_f32_scale_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(rows != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const v128_t vscale = wasm_v128_load32_splat(¶ms->scalar.scale); + + size_t input_increment = 7 * input_stride; + for (; channels >= 64; channels -= 64) { + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + + v128_t vacc0 = wasm_i32x4_const_splat(0.f); + v128_t vacc1 = wasm_i32x4_const_splat(0.f); + v128_t vacc2 = wasm_i32x4_const_splat(0.f); + v128_t vacc3 = wasm_i32x4_const_splat(0.f); + v128_t vacc4 = wasm_i32x4_const_splat(0.f); + v128_t vacc5 = wasm_i32x4_const_splat(0.f); + v128_t vacc6 = wasm_i32x4_const_splat(0.f); + v128_t vacc7 = wasm_i32x4_const_splat(0.f); + v128_t vacc8 = wasm_i32x4_const_splat(0.f); + v128_t vacc9 = wasm_i32x4_const_splat(0.f); + v128_t vacc10 = wasm_i32x4_const_splat(0.f); + v128_t vacc11 = wasm_i32x4_const_splat(0.f); + v128_t vacc12 = wasm_i32x4_const_splat(0.f); + v128_t vacc13 = wasm_i32x4_const_splat(0.f); + v128_t vacc14 = wasm_i32x4_const_splat(0.f); + v128_t vacc15 = wasm_i32x4_const_splat(0.f); + + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + v128_t vin0; + v128_t vin1; + v128_t vin2; + v128_t vin3; + v128_t vin4; + v128_t vin5; + v128_t vin6; + v128_t vin7; + v128_t vin8; + v128_t vin9; + v128_t vin10; + v128_t vin11; + v128_t vin12; + v128_t vin13; + v128_t vin14; + v128_t vin15; + vin0 = wasm_v128_load(&i0[0]); + vin1 = wasm_v128_load(&i0[4]); + vin2 = wasm_v128_load(&i0[8]); + vin3 = wasm_v128_load(&i0[12]); + vin4 = wasm_v128_load(&i0[16]); + vin5 = wasm_v128_load(&i0[20]); + vin6 = wasm_v128_load(&i0[24]); + vin7 = wasm_v128_load(&i0[28]); + vin8 = wasm_v128_load(&i0[32]); + vin9 = wasm_v128_load(&i0[36]); + vin10 = wasm_v128_load(&i0[40]); + vin11 = wasm_v128_load(&i0[44]); + vin12 = wasm_v128_load(&i0[48]); + vin13 = wasm_v128_load(&i0[52]); + vin14 = wasm_v128_load(&i0[56]); + vin15 = wasm_v128_load(&i0[60]); + vacc0 = wasm_f32x4_add(vin0, vacc0); + vacc1 = wasm_f32x4_add(vin1, vacc1); + vacc2 = wasm_f32x4_add(vin2, vacc2); + vacc3 = wasm_f32x4_add(vin3, vacc3); + vacc4 = wasm_f32x4_add(vin4, vacc4); + vacc5 = wasm_f32x4_add(vin5, vacc5); + vacc6 = wasm_f32x4_add(vin6, vacc6); + vacc7 = wasm_f32x4_add(vin7, vacc7); + vacc8 = wasm_f32x4_add(vin8, vacc8); + vacc9 = wasm_f32x4_add(vin9, vacc9); + vacc10 = wasm_f32x4_add(vin10, vacc10); + vacc11 = wasm_f32x4_add(vin11, vacc11); + vacc12 = wasm_f32x4_add(vin12, vacc12); + vacc13 = wasm_f32x4_add(vin13, vacc13); + vacc14 = wasm_f32x4_add(vin14, vacc14); + vacc15 = wasm_f32x4_add(vin15, vacc15); + vin0 = wasm_v128_load(&i1[0]); + vin1 = wasm_v128_load(&i1[4]); + vin2 = wasm_v128_load(&i1[8]); + vin3 = wasm_v128_load(&i1[12]); + vin4 = wasm_v128_load(&i1[16]); + vin5 = wasm_v128_load(&i1[20]); + vin6 = wasm_v128_load(&i1[24]); + vin7 = wasm_v128_load(&i1[28]); + vin8 = wasm_v128_load(&i1[32]); + vin9 = wasm_v128_load(&i1[36]); + vin10 = wasm_v128_load(&i1[40]); + vin11 = wasm_v128_load(&i1[44]); + vin12 = wasm_v128_load(&i1[48]); + vin13 = wasm_v128_load(&i1[52]); + vin14 = wasm_v128_load(&i1[56]); + vin15 = wasm_v128_load(&i1[60]); + vacc0 = wasm_f32x4_add(vin0, vacc0); + vacc1 = wasm_f32x4_add(vin1, vacc1); + vacc2 = wasm_f32x4_add(vin2, vacc2); + vacc3 = wasm_f32x4_add(vin3, vacc3); + vacc4 = wasm_f32x4_add(vin4, vacc4); + vacc5 = wasm_f32x4_add(vin5, vacc5); + vacc6 = wasm_f32x4_add(vin6, vacc6); + vacc7 = wasm_f32x4_add(vin7, vacc7); + vacc8 = wasm_f32x4_add(vin8, vacc8); + vacc9 = wasm_f32x4_add(vin9, vacc9); + vacc10 = wasm_f32x4_add(vin10, vacc10); + vacc11 = wasm_f32x4_add(vin11, vacc11); + vacc12 = wasm_f32x4_add(vin12, vacc12); + vacc13 = wasm_f32x4_add(vin13, vacc13); + vacc14 = wasm_f32x4_add(vin14, vacc14); + vacc15 = wasm_f32x4_add(vin15, vacc15); + vin0 = wasm_v128_load(&i2[0]); + vin1 = wasm_v128_load(&i2[4]); + vin2 = wasm_v128_load(&i2[8]); + vin3 = wasm_v128_load(&i2[12]); + vin4 = wasm_v128_load(&i2[16]); + vin5 = wasm_v128_load(&i2[20]); + vin6 = wasm_v128_load(&i2[24]); + vin7 = wasm_v128_load(&i2[28]); + vin8 = wasm_v128_load(&i2[32]); + vin9 = wasm_v128_load(&i2[36]); + vin10 = wasm_v128_load(&i2[40]); + vin11 = wasm_v128_load(&i2[44]); + vin12 = wasm_v128_load(&i2[48]); + vin13 = wasm_v128_load(&i2[52]); + vin14 = wasm_v128_load(&i2[56]); + vin15 = wasm_v128_load(&i2[60]); + vacc0 = wasm_f32x4_add(vin0, vacc0); + vacc1 = wasm_f32x4_add(vin1, vacc1); + vacc2 = wasm_f32x4_add(vin2, vacc2); + vacc3 = wasm_f32x4_add(vin3, vacc3); + vacc4 = wasm_f32x4_add(vin4, vacc4); + vacc5 = wasm_f32x4_add(vin5, vacc5); + vacc6 = wasm_f32x4_add(vin6, vacc6); + vacc7 = wasm_f32x4_add(vin7, vacc7); + vacc8 = wasm_f32x4_add(vin8, vacc8); + vacc9 = wasm_f32x4_add(vin9, vacc9); + vacc10 = wasm_f32x4_add(vin10, vacc10); + vacc11 = wasm_f32x4_add(vin11, vacc11); + vacc12 = wasm_f32x4_add(vin12, vacc12); + vacc13 = wasm_f32x4_add(vin13, vacc13); + vacc14 = wasm_f32x4_add(vin14, vacc14); + vacc15 = wasm_f32x4_add(vin15, vacc15); + vin0 = wasm_v128_load(&i3[0]); + vin1 = wasm_v128_load(&i3[4]); + vin2 = wasm_v128_load(&i3[8]); + vin3 = wasm_v128_load(&i3[12]); + vin4 = wasm_v128_load(&i3[16]); + vin5 = wasm_v128_load(&i3[20]); + vin6 = wasm_v128_load(&i3[24]); + vin7 = wasm_v128_load(&i3[28]); + vin8 = wasm_v128_load(&i3[32]); + vin9 = wasm_v128_load(&i3[36]); + vin10 = wasm_v128_load(&i3[40]); + vin11 = wasm_v128_load(&i3[44]); + vin12 = wasm_v128_load(&i3[48]); + vin13 = wasm_v128_load(&i3[52]); + vin14 = wasm_v128_load(&i3[56]); + vin15 = wasm_v128_load(&i3[60]); + vacc0 = wasm_f32x4_add(vin0, vacc0); + vacc1 = wasm_f32x4_add(vin1, vacc1); + vacc2 = wasm_f32x4_add(vin2, vacc2); + vacc3 = wasm_f32x4_add(vin3, vacc3); + vacc4 = wasm_f32x4_add(vin4, vacc4); + vacc5 = wasm_f32x4_add(vin5, vacc5); + vacc6 = wasm_f32x4_add(vin6, vacc6); + vacc7 = wasm_f32x4_add(vin7, vacc7); + vacc8 = wasm_f32x4_add(vin8, vacc8); + vacc9 = wasm_f32x4_add(vin9, vacc9); + vacc10 = wasm_f32x4_add(vin10, vacc10); + vacc11 = wasm_f32x4_add(vin11, vacc11); + vacc12 = wasm_f32x4_add(vin12, vacc12); + vacc13 = wasm_f32x4_add(vin13, vacc13); + vacc14 = wasm_f32x4_add(vin14, vacc14); + vacc15 = wasm_f32x4_add(vin15, vacc15); + vin0 = wasm_v128_load(&i4[0]); + vin1 = wasm_v128_load(&i4[4]); + vin2 = wasm_v128_load(&i4[8]); + vin3 = wasm_v128_load(&i4[12]); + vin4 = wasm_v128_load(&i4[16]); + vin5 = wasm_v128_load(&i4[20]); + vin6 = wasm_v128_load(&i4[24]); + vin7 = wasm_v128_load(&i4[28]); + vin8 = wasm_v128_load(&i4[32]); + vin9 = wasm_v128_load(&i4[36]); + vin10 = wasm_v128_load(&i4[40]); + vin11 = wasm_v128_load(&i4[44]); + vin12 = wasm_v128_load(&i4[48]); + vin13 = wasm_v128_load(&i4[52]); + vin14 = wasm_v128_load(&i4[56]); + vin15 = wasm_v128_load(&i4[60]); + vacc0 = wasm_f32x4_add(vin0, vacc0); + vacc1 = wasm_f32x4_add(vin1, vacc1); + vacc2 = wasm_f32x4_add(vin2, vacc2); + vacc3 = wasm_f32x4_add(vin3, vacc3); + vacc4 = wasm_f32x4_add(vin4, vacc4); + vacc5 = wasm_f32x4_add(vin5, vacc5); + vacc6 = wasm_f32x4_add(vin6, vacc6); + vacc7 = wasm_f32x4_add(vin7, vacc7); + vacc8 = wasm_f32x4_add(vin8, vacc8); + vacc9 = wasm_f32x4_add(vin9, vacc9); + vacc10 = wasm_f32x4_add(vin10, vacc10); + vacc11 = wasm_f32x4_add(vin11, vacc11); + vacc12 = wasm_f32x4_add(vin12, vacc12); + vacc13 = wasm_f32x4_add(vin13, vacc13); + vacc14 = wasm_f32x4_add(vin14, vacc14); + vacc15 = wasm_f32x4_add(vin15, vacc15); + vin0 = wasm_v128_load(&i5[0]); + vin1 = wasm_v128_load(&i5[4]); + vin2 = wasm_v128_load(&i5[8]); + vin3 = wasm_v128_load(&i5[12]); + vin4 = wasm_v128_load(&i5[16]); + vin5 = wasm_v128_load(&i5[20]); + vin6 = wasm_v128_load(&i5[24]); + vin7 = wasm_v128_load(&i5[28]); + vin8 = wasm_v128_load(&i5[32]); + vin9 = wasm_v128_load(&i5[36]); + vin10 = wasm_v128_load(&i5[40]); + vin11 = wasm_v128_load(&i5[44]); + vin12 = wasm_v128_load(&i5[48]); + vin13 = wasm_v128_load(&i5[52]); + vin14 = wasm_v128_load(&i5[56]); + vin15 = wasm_v128_load(&i5[60]); + vacc0 = wasm_f32x4_add(vin0, vacc0); + vacc1 = wasm_f32x4_add(vin1, vacc1); + vacc2 = wasm_f32x4_add(vin2, vacc2); + vacc3 = wasm_f32x4_add(vin3, vacc3); + vacc4 = wasm_f32x4_add(vin4, vacc4); + vacc5 = wasm_f32x4_add(vin5, vacc5); + vacc6 = wasm_f32x4_add(vin6, vacc6); + vacc7 = wasm_f32x4_add(vin7, vacc7); + vacc8 = wasm_f32x4_add(vin8, vacc8); + vacc9 = wasm_f32x4_add(vin9, vacc9); + vacc10 = wasm_f32x4_add(vin10, vacc10); + vacc11 = wasm_f32x4_add(vin11, vacc11); + vacc12 = wasm_f32x4_add(vin12, vacc12); + vacc13 = wasm_f32x4_add(vin13, vacc13); + vacc14 = wasm_f32x4_add(vin14, vacc14); + vacc15 = wasm_f32x4_add(vin15, vacc15); + vin0 = wasm_v128_load(&i6[0]); + vin1 = wasm_v128_load(&i6[4]); + vin2 = wasm_v128_load(&i6[8]); + vin3 = wasm_v128_load(&i6[12]); + vin4 = wasm_v128_load(&i6[16]); + vin5 = wasm_v128_load(&i6[20]); + vin6 = wasm_v128_load(&i6[24]); + vin7 = wasm_v128_load(&i6[28]); + vin8 = wasm_v128_load(&i6[32]); + vin9 = wasm_v128_load(&i6[36]); + vin10 = wasm_v128_load(&i6[40]); + vin11 = wasm_v128_load(&i6[44]); + vin12 = wasm_v128_load(&i6[48]); + vin13 = wasm_v128_load(&i6[52]); + vin14 = wasm_v128_load(&i6[56]); + vin15 = wasm_v128_load(&i6[60]); + vacc0 = wasm_f32x4_add(vin0, vacc0); + vacc1 = wasm_f32x4_add(vin1, vacc1); + vacc2 = wasm_f32x4_add(vin2, vacc2); + vacc3 = wasm_f32x4_add(vin3, vacc3); + vacc4 = wasm_f32x4_add(vin4, vacc4); + vacc5 = wasm_f32x4_add(vin5, vacc5); + vacc6 = wasm_f32x4_add(vin6, vacc6); + vacc7 = wasm_f32x4_add(vin7, vacc7); + vacc8 = wasm_f32x4_add(vin8, vacc8); + vacc9 = wasm_f32x4_add(vin9, vacc9); + vacc10 = wasm_f32x4_add(vin10, vacc10); + vacc11 = wasm_f32x4_add(vin11, vacc11); + vacc12 = wasm_f32x4_add(vin12, vacc12); + vacc13 = wasm_f32x4_add(vin13, vacc13); + vacc14 = wasm_f32x4_add(vin14, vacc14); + vacc15 = wasm_f32x4_add(vin15, vacc15); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = wasm_f32x4_mul(vacc0, vscale); + vacc1 = wasm_f32x4_mul(vacc1, vscale); + vacc2 = wasm_f32x4_mul(vacc2, vscale); + vacc3 = wasm_f32x4_mul(vacc3, vscale); + vacc4 = wasm_f32x4_mul(vacc4, vscale); + vacc5 = wasm_f32x4_mul(vacc5, vscale); + vacc6 = wasm_f32x4_mul(vacc6, vscale); + vacc7 = wasm_f32x4_mul(vacc7, vscale); + vacc8 = wasm_f32x4_mul(vacc8, vscale); + vacc9 = wasm_f32x4_mul(vacc9, vscale); + vacc10 = wasm_f32x4_mul(vacc10, vscale); + vacc11 = wasm_f32x4_mul(vacc11, vscale); + vacc12 = wasm_f32x4_mul(vacc12, vscale); + vacc13 = wasm_f32x4_mul(vacc13, vscale); + vacc14 = wasm_f32x4_mul(vacc14, vscale); + vacc15 = wasm_f32x4_mul(vacc15, vscale); + + const float* o = output; + v128_t vo0 = wasm_v128_load(o); o += 4; + v128_t vo1 = wasm_v128_load(o); o += 4; + v128_t vo2 = wasm_v128_load(o); o += 4; + v128_t vo3 = wasm_v128_load(o); o += 4; + v128_t vo4 = wasm_v128_load(o); o += 4; + v128_t vo5 = wasm_v128_load(o); o += 4; + v128_t vo6 = wasm_v128_load(o); o += 4; + v128_t vo7 = wasm_v128_load(o); o += 4; + v128_t vo8 = wasm_v128_load(o); o += 4; + v128_t vo9 = wasm_v128_load(o); o += 4; + v128_t vo10 = wasm_v128_load(o); o += 4; + v128_t vo11 = wasm_v128_load(o); o += 4; + v128_t vo12 = wasm_v128_load(o); o += 4; + v128_t vo13 = wasm_v128_load(o); o += 4; + v128_t vo14 = wasm_v128_load(o); o += 4; + v128_t vo15 = wasm_v128_load(o); o += 4; + vacc0 = wasm_f32x4_add(vo0, vacc0); + vacc1 = wasm_f32x4_add(vo1, vacc1); + vacc2 = wasm_f32x4_add(vo2, vacc2); + vacc3 = wasm_f32x4_add(vo3, vacc3); + vacc4 = wasm_f32x4_add(vo4, vacc4); + vacc5 = wasm_f32x4_add(vo5, vacc5); + vacc6 = wasm_f32x4_add(vo6, vacc6); + vacc7 = wasm_f32x4_add(vo7, vacc7); + vacc8 = wasm_f32x4_add(vo8, vacc8); + vacc9 = wasm_f32x4_add(vo9, vacc9); + vacc10 = wasm_f32x4_add(vo10, vacc10); + vacc11 = wasm_f32x4_add(vo11, vacc11); + vacc12 = wasm_f32x4_add(vo12, vacc12); + vacc13 = wasm_f32x4_add(vo13, vacc13); + vacc14 = wasm_f32x4_add(vo14, vacc14); + vacc15 = wasm_f32x4_add(vo15, vacc15); + wasm_v128_store(output, vacc0); output += 4; + wasm_v128_store(output, vacc1); output += 4; + wasm_v128_store(output, vacc2); output += 4; + wasm_v128_store(output, vacc3); output += 4; + wasm_v128_store(output, vacc4); output += 4; + wasm_v128_store(output, vacc5); output += 4; + wasm_v128_store(output, vacc6); output += 4; + wasm_v128_store(output, vacc7); output += 4; + wasm_v128_store(output, vacc8); output += 4; + wasm_v128_store(output, vacc9); output += 4; + wasm_v128_store(output, vacc10); output += 4; + wasm_v128_store(output, vacc11); output += 4; + wasm_v128_store(output, vacc12); output += 4; + wasm_v128_store(output, vacc13); output += 4; + wasm_v128_store(output, vacc14); output += 4; + wasm_v128_store(output, vacc15); output += 4; + + input = (const float*) ((uintptr_t) input + 64 * sizeof(float)); + } + if (channels != 0) { + input_increment = 7 * input_stride; + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + v128_t vacc[16]; + vacc[0] = wasm_i32x4_const_splat(0.f); + vacc[1] = wasm_i32x4_const_splat(0.f); + vacc[2] = wasm_i32x4_const_splat(0.f); + vacc[3] = wasm_i32x4_const_splat(0.f); + vacc[4] = wasm_i32x4_const_splat(0.f); + vacc[5] = wasm_i32x4_const_splat(0.f); + vacc[6] = wasm_i32x4_const_splat(0.f); + vacc[7] = wasm_i32x4_const_splat(0.f); + vacc[8] = wasm_i32x4_const_splat(0.f); + vacc[9] = wasm_i32x4_const_splat(0.f); + vacc[10] = wasm_i32x4_const_splat(0.f); + vacc[11] = wasm_i32x4_const_splat(0.f); + vacc[12] = wasm_i32x4_const_splat(0.f); + vacc[13] = wasm_i32x4_const_splat(0.f); + vacc[14] = wasm_i32x4_const_splat(0.f); + vacc[15] = wasm_i32x4_const_splat(0.f); + + const size_t num_chunks = round_up_po2(channels, 4) >> 2; + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_chunks; ++i) { + vacc[i] = wasm_f32x4_add(wasm_v128_load(&i0[i*4]), vacc[i]); + vacc[i] = wasm_f32x4_add(wasm_v128_load(&i1[i*4]), vacc[i]); + vacc[i] = wasm_f32x4_add(wasm_v128_load(&i2[i*4]), vacc[i]); + vacc[i] = wasm_f32x4_add(wasm_v128_load(&i3[i*4]), vacc[i]); + vacc[i] = wasm_f32x4_add(wasm_v128_load(&i4[i*4]), vacc[i]); + vacc[i] = wasm_f32x4_add(wasm_v128_load(&i5[i*4]), vacc[i]); + vacc[i] = wasm_f32x4_add(wasm_v128_load(&i6[i*4]), vacc[i]); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (int i = 0; i < num_chunks; ++i) { + vacc[i] = wasm_f32x4_mul(vacc[i], vscale); + } + + v128_t vo[16]; + const float* o = output; + for (int i = 0; i < channels >> 2; ++i) { + vo[i] = wasm_v128_load(o); o += 4; + } + for (int i = 0; i < channels >> 2; ++i) { + vacc[i] = wasm_f32x4_add(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 2; ++i) { + wasm_v128_store(output, vacc[i]); output += 4; + } + const size_t pos = channels / 4; + v128_t vout = vacc[pos]; + if (channels & 2) { + v128_t vo = wasm_f32x4_make(output[0], output[1], 0.f, 0.f); + wasm_v128_store64_lane(output, wasm_f32x4_add(vo, vout), 0); + vout = wasm_v64x2_shuffle(vout, vout, 1, 1); + output += 2; + } + if (channels & 1) { + *output += wasm_f32x4_extract_lane(vout, 0); + } + } +} diff --git a/src/f32-rdsum/wasm-simd.c.in b/src/f32-rdsum/wasm-simd.c.in new file mode 100644 index 00000000000..3b14a50e5f8 --- /dev/null +++ b/src/f32-rdsum/wasm-simd.c.in @@ -0,0 +1,123 @@ +// Copyright 2024 Google LLC // +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" +#include + +#include + +#include +#include +#include + +$UNROLL = CHANNELS >> 2 +void xnn_f32_rdsum_ukernel_${ACCUMULATORS}p${ACCUMULATORS}x__wasmsimd_c${CHANNELS}( + size_t rows, + size_t channels, + const float* input, + size_t input_stride, + const float* zero, + float* output, + const union xnn_f32_scale_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(rows != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const v128_t vscale = wasm_v128_load32_splat(¶ms->scalar.scale); + + size_t input_increment = ${ACCUMULATORS} * input_stride; + for (; channels >= ${CHANNELS}; channels -= ${CHANNELS}) { + const float* i0 = input; + $for i in range(1, ACCUMULATORS): + const float* i${i} = (const float*) ((uintptr_t) input + ${i} * input_stride); + + $for i in range(UNROLL): + v128_t vacc${i} = wasm_i32x4_const_splat(0.f); + + for (int r = rows; r > 0; r -= ${ACCUMULATORS}) { + $for N in range(1, ACCUMULATORS, 2): + if XNN_UNPREDICTABLE(r < ${N+1}) { + i${N} = zero; + } + if XNN_UNPREDICTABLE(r <= ${N+1}) { + i${N+1} = zero; + } + $for c in range(UNROLL): + v128_t vin${c}; + $for j in range(ACCUMULATORS): + $for c in range(UNROLL): + vin${c} = wasm_v128_load(&i${j}[${c*4}]); + $for c in range(UNROLL): + vacc${c} = wasm_f32x4_add(vin${c}, vacc${c}); + $for N in range(0, ACCUMULATORS): + i${N} = (const float*) ((uintptr_t) i${N} + input_increment); + } + $for i in range(UNROLL): + vacc${i} = wasm_f32x4_mul(vacc${i}, vscale); + + const float* o = output; + $for i in range(0, UNROLL): + v128_t vo${i} = wasm_v128_load(o); o += 4; + $for i in range(0, UNROLL): + vacc${i} = wasm_f32x4_add(vo${i}, vacc${i}); + $for i in range(0, UNROLL): + wasm_v128_store(output, vacc${i}); output += 4; + + input = (const float*) ((uintptr_t) input + ${CHANNELS} * sizeof(float)); + } + if (channels != 0) { + input_increment = ${ACCUMULATORS} * input_stride; + const float* i0 = input; + $for i in range(1, ACCUMULATORS): + const float* i${i} = (const float*) ((uintptr_t) input + ${i} * input_stride); + v128_t vacc[${UNROLL}]; + $for i in range(UNROLL): + vacc[${i}] = wasm_i32x4_const_splat(0.f); + + const size_t num_chunks = round_up_po2(channels, 4) >> 2; + for (int r = rows; r > 0; r -= ${ACCUMULATORS}) { + $for N in range(1, ACCUMULATORS, 2): + if XNN_UNPREDICTABLE(r < ${N+1}) { + i${N} = zero; + } + if XNN_UNPREDICTABLE(r <= ${N+1}) { + i${N+1} = zero; + } + for (int i = 0; i < num_chunks; ++i) { + $for c in range(ACCUMULATORS): + vacc[i] = wasm_f32x4_add(wasm_v128_load(&i${c}[i*4]), vacc[i]); + } + $for N in range(ACCUMULATORS): + i${N} = (const float*) ((uintptr_t) i${N} + input_increment); + } + for (int i = 0; i < num_chunks; ++i) { + vacc[i] = wasm_f32x4_mul(vacc[i], vscale); + } + + v128_t vo[${UNROLL}]; + const float* o = output; + for (int i = 0; i < channels >> 2; ++i) { + vo[i] = wasm_v128_load(o); o += 4; + } + for (int i = 0; i < channels >> 2; ++i) { + vacc[i] = wasm_f32x4_add(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 2; ++i) { + wasm_v128_store(output, vacc[i]); output += 4; + } + const size_t pos = channels / 4; + v128_t vout = vacc[pos]; + if (channels & 2) { + v128_t vo = wasm_f32x4_make(output[0], output[1], 0.f, 0.f); + wasm_v128_store64_lane(output, wasm_f32x4_add(vo, vout), 0); + vout = wasm_v64x2_shuffle(vout, vout, 1, 1); + output += 2; + } + if (channels & 1) { + *output += wasm_f32x4_extract_lane(vout, 0); + } + } +} diff --git a/src/xnnpack/gavgpool.h b/src/xnnpack/gavgpool.h index e497ed5deff..a7c6e76cc7b 100644 --- a/src/xnnpack/gavgpool.h +++ b/src/xnnpack/gavgpool.h @@ -37,7 +37,6 @@ DECLARE_F32_GAVGPOOL_MINMAX_MULTIPASS_UKERNEL_FUNCTION(xnn_f32_gavgpool_minmax_u DECLARE_F32_GAVGPOOL_MINMAX_MULTIPASS_UKERNEL_FUNCTION(xnn_f32_gavgpool_minmax_ukernel_7p7x__wasmsimd_arm_c4) DECLARE_F32_GAVGPOOL_MINMAX_MULTIPASS_UKERNEL_FUNCTION(xnn_f32_gavgpool_minmax_ukernel_7p7x__wasmsimd_x86_c4) - #define DECLARE_F32_GAVGPOOL_MINMAX_UNIPASS_UKERNEL_FUNCTION(fn_name) \ XNN_INTERNAL void fn_name( \ size_t rows, \ diff --git a/src/xnnpack/reduce.h b/src/xnnpack/reduce.h index bf3668b4686..a113547928d 100644 --- a/src/xnnpack/reduce.h +++ b/src/xnnpack/reduce.h @@ -333,7 +333,16 @@ DECLARE_F32_RDSUM_UKERNEL_FUNCTION(xnn_f32_rdsum_ukernel_7p7x__scalar_c4) DECLARE_F32_RDSUM_UKERNEL_FUNCTION(xnn_f32_rdsum_ukernel_7p7x__sse_c16) DECLARE_F32_RDSUM_UKERNEL_FUNCTION(xnn_f32_rdsum_ukernel_7p7x__sse_c32) DECLARE_F32_RDSUM_UKERNEL_FUNCTION(xnn_f32_rdsum_ukernel_7p7x__sse_c64) -DECLARE_F32_RDSUM_UKERNEL_FUNCTION(xnn_f32_rdsum_ukernel_7p7x__sse_c128) +DECLARE_F32_RDSUM_UKERNEL_FUNCTION(xnn_f32_rdsum_ukernel_7p7x__avx_c16) +DECLARE_F32_RDSUM_UKERNEL_FUNCTION(xnn_f32_rdsum_ukernel_7p7x__avx_c32) +DECLARE_F32_RDSUM_UKERNEL_FUNCTION(xnn_f32_rdsum_ukernel_7p7x__avx_c64) +DECLARE_F32_RDSUM_UKERNEL_FUNCTION(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c16) +DECLARE_F32_RDSUM_UKERNEL_FUNCTION(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c32) +DECLARE_F32_RDSUM_UKERNEL_FUNCTION(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c64) +DECLARE_F32_RDSUM_UKERNEL_FUNCTION(xnn_f32_rdsum_ukernel_7p7x__avx512f_c16) +DECLARE_F32_RDSUM_UKERNEL_FUNCTION(xnn_f32_rdsum_ukernel_7p7x__avx512f_c32) +DECLARE_F32_RDSUM_UKERNEL_FUNCTION(xnn_f32_rdsum_ukernel_7p7x__avx512f_c64) +DECLARE_F32_RDSUM_UKERNEL_FUNCTION(xnn_f32_rdsum_ukernel_7p7x__avx512f_c128) #ifdef __cplusplus } // extern "C" diff --git a/test/f32-gavgpool-minmax.yaml b/test/f32-gavgpool-minmax.yaml index c6f0090d60a..21b42a11d8d 100644 --- a/test/f32-gavgpool-minmax.yaml +++ b/test/f32-gavgpool-minmax.yaml @@ -2,7 +2,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - +# # ARM NEON - name: xnn_f32_gavgpool_minmax_ukernel_7p7x__neon_c4 init: xnn_init_f32_scaleminmax_scalar_params diff --git a/test/f32-rdsum.cc b/test/f32-rdsum.cc index 96119dd6b14..dd7b82d8183 100644 --- a/test/f32-rdsum.cc +++ b/test/f32-rdsum.cc @@ -1415,3 +1415,1776 @@ TEST(F32_RDSUM_7P7X__SCALAR_C4, channels_gt_4_multipass_fulltile_with_input_stri } } #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 + + +#if XNN_ARCH_X86 || XNN_ARCH_X86_64 + TEST(F32_RDSUM_7P7X__AVX_C16, channels_eq_16_2pass_fulltile) { + TEST_REQUIRES_X86_AVX; + RDSumMicrokernelTester() + .rows(14) + .channels(16) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + + TEST(F32_RDSUM_7P7X__AVX_C16, channels_eq_16_2pass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + RDSumMicrokernelTester() + .rows(14) + .channels(16) + .input_stride(19) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + + TEST(F32_RDSUM_7P7X__AVX_C16, channels_eq_16_2pass_subtile) { + TEST_REQUIRES_X86_AVX; + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(16) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C16, channels_eq_16_2pass_subtile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(16) + .input_stride(19) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C16, channels_eq_16_multipass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(16) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C16, channels_eq_16_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(16) + .input_stride(19) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C16, channels_div_16_2pass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 32; channels < 128; channels += 16) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C16, channels_div_16_2pass_subtile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 32; channels < 128; channels += 16) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C16, channels_div_16_multipass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 32; channels < 128; channels += 16) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C16, channels_div_16_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 32; channels < 128; channels += 16) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(263) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C16, channels_lt_16_2pass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 1; channels < 16; channels++) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C16, channels_lt_16_2pass_subtile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 1; channels < 16; channels++) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C16, channels_lt_16_multipass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 1; channels < 16; channels++) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C16, channels_lt_16_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 1; channels < 16; channels++) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(19) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C16, channels_gt_16_2pass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 17; channels < 32; channels++) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C16, channels_gt_16_2pass_subtile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 17; channels < 32; channels++) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C16, channels_gt_16_multipass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 17; channels < 32; channels++) { + for (size_t rows = 1; rows < 35; rows += 14) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C16, channels_gt_16_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 17; channels < 32; channels++) { + for (size_t rows = 1; rows < 35; rows += 14) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(47) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c16, xnn_init_f32_scale_avx_params); + } + } + } +#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 + + +#if XNN_ARCH_X86 || XNN_ARCH_X86_64 + TEST(F32_RDSUM_7P7X__AVX_C32, channels_eq_32_2pass_fulltile) { + TEST_REQUIRES_X86_AVX; + RDSumMicrokernelTester() + .rows(14) + .channels(32) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + + TEST(F32_RDSUM_7P7X__AVX_C32, channels_eq_32_2pass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + RDSumMicrokernelTester() + .rows(14) + .channels(32) + .input_stride(37) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + + TEST(F32_RDSUM_7P7X__AVX_C32, channels_eq_32_2pass_subtile) { + TEST_REQUIRES_X86_AVX; + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(32) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C32, channels_eq_32_2pass_subtile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(32) + .input_stride(37) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C32, channels_eq_32_multipass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(32) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C32, channels_eq_32_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(32) + .input_stride(37) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C32, channels_div_32_2pass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 64; channels < 256; channels += 32) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C32, channels_div_32_2pass_subtile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 64; channels < 256; channels += 32) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C32, channels_div_32_multipass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 64; channels < 256; channels += 32) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C32, channels_div_32_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 64; channels < 256; channels += 32) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(521) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C32, channels_lt_32_2pass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 1; channels < 32; channels++) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C32, channels_lt_32_2pass_subtile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 1; channels < 32; channels++) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C32, channels_lt_32_multipass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 1; channels < 32; channels++) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C32, channels_lt_32_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 1; channels < 32; channels++) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(37) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C32, channels_gt_32_2pass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 33; channels < 64; channels++) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C32, channels_gt_32_2pass_subtile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 33; channels < 64; channels++) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C32, channels_gt_32_multipass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 33; channels < 64; channels++) { + for (size_t rows = 1; rows < 35; rows += 14) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C32, channels_gt_32_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 33; channels < 64; channels++) { + for (size_t rows = 1; rows < 35; rows += 14) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(79) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c32, xnn_init_f32_scale_avx_params); + } + } + } +#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 + + +#if XNN_ARCH_X86 || XNN_ARCH_X86_64 + TEST(F32_RDSUM_7P7X__AVX_C64, channels_eq_64_2pass_fulltile) { + TEST_REQUIRES_X86_AVX; + RDSumMicrokernelTester() + .rows(14) + .channels(64) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + + TEST(F32_RDSUM_7P7X__AVX_C64, channels_eq_64_2pass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + RDSumMicrokernelTester() + .rows(14) + .channels(64) + .input_stride(67) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + + TEST(F32_RDSUM_7P7X__AVX_C64, channels_eq_64_2pass_subtile) { + TEST_REQUIRES_X86_AVX; + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(64) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C64, channels_eq_64_2pass_subtile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(64) + .input_stride(67) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C64, channels_eq_64_multipass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(64) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C64, channels_eq_64_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(64) + .input_stride(67) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C64, channels_div_64_2pass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 128; channels < 512; channels += 64) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C64, channels_div_64_2pass_subtile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 128; channels < 512; channels += 64) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C64, channels_div_64_multipass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 128; channels < 512; channels += 64) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C64, channels_div_64_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 128; channels < 512; channels += 64) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(1031) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C64, channels_lt_64_2pass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 1; channels < 64; channels++) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C64, channels_lt_64_2pass_subtile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 1; channels < 64; channels++) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C64, channels_lt_64_multipass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 1; channels < 64; channels++) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C64, channels_lt_64_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 1; channels < 64; channels++) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(67) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C64, channels_gt_64_2pass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 65; channels < 128; channels++) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX_C64, channels_gt_64_2pass_subtile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 65; channels < 128; channels++) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C64, channels_gt_64_multipass_fulltile) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 65; channels < 128; channels++) { + for (size_t rows = 1; rows < 35; rows += 14) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX_C64, channels_gt_64_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX; + for (size_t channels = 65; channels < 128; channels++) { + for (size_t rows = 1; rows < 35; rows += 14) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(149) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx_c64, xnn_init_f32_scale_avx_params); + } + } + } +#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 + + +#if XNN_ARCH_X86 || XNN_ARCH_X86_64 + TEST(F32_RDSUM_7P7X__AVX512F_C16, channels_eq_16_2pass_fulltile) { + TEST_REQUIRES_X86_AVX512F; + RDSumMicrokernelTester() + .rows(14) + .channels(16) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c16, xnn_init_f32_scale_scalar_params); + } + + TEST(F32_RDSUM_7P7X__AVX512F_C16, channels_eq_16_2pass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX512F; + RDSumMicrokernelTester() + .rows(14) + .channels(16) + .input_stride(19) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c16, xnn_init_f32_scale_scalar_params); + } + + TEST(F32_RDSUM_7P7X__AVX512F_C16, channels_eq_16_2pass_subtile) { + TEST_REQUIRES_X86_AVX512F; + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(16) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c16, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C16, channels_eq_16_2pass_subtile_with_input_stride) { + TEST_REQUIRES_X86_AVX512F; + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(16) + .input_stride(19) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c16, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C16, channels_eq_16_multipass_fulltile) { + TEST_REQUIRES_X86_AVX512F; + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(16) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c16, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C16, channels_eq_16_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX512F; + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(16) + .input_stride(19) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c16, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C16, channels_div_16_2pass_fulltile) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 32; channels < 128; channels += 16) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c16, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C16, channels_div_16_2pass_subtile) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 32; channels < 128; channels += 16) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c16, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C16, channels_div_16_multipass_fulltile) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 32; channels < 128; channels += 16) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c16, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C16, channels_div_16_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 32; channels < 128; channels += 16) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(263) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c16, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C16, channels_lt_16_2pass_fulltile) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 1; channels < 16; channels++) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c16, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C16, channels_lt_16_2pass_subtile) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 1; channels < 16; channels++) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c16, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C16, channels_lt_16_multipass_fulltile) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 1; channels < 16; channels++) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c16, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C16, channels_lt_16_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 1; channels < 16; channels++) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(19) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c16, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C16, channels_gt_16_2pass_fulltile) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 17; channels < 32; channels++) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c16, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C16, channels_gt_16_2pass_subtile) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 17; channels < 32; channels++) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c16, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C16, channels_gt_16_multipass_fulltile) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 17; channels < 32; channels++) { + for (size_t rows = 1; rows < 35; rows += 14) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c16, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C16, channels_gt_16_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 17; channels < 32; channels++) { + for (size_t rows = 1; rows < 35; rows += 14) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(47) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c16, xnn_init_f32_scale_scalar_params); + } + } + } +#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 + + +#if XNN_ARCH_X86 || XNN_ARCH_X86_64 + TEST(F32_RDSUM_7P7X__AVX512F_C32, channels_eq_32_2pass_fulltile) { + TEST_REQUIRES_X86_AVX512F; + RDSumMicrokernelTester() + .rows(14) + .channels(32) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c32, xnn_init_f32_scale_scalar_params); + } + + TEST(F32_RDSUM_7P7X__AVX512F_C32, channels_eq_32_2pass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX512F; + RDSumMicrokernelTester() + .rows(14) + .channels(32) + .input_stride(37) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c32, xnn_init_f32_scale_scalar_params); + } + + TEST(F32_RDSUM_7P7X__AVX512F_C32, channels_eq_32_2pass_subtile) { + TEST_REQUIRES_X86_AVX512F; + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(32) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c32, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C32, channels_eq_32_2pass_subtile_with_input_stride) { + TEST_REQUIRES_X86_AVX512F; + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(32) + .input_stride(37) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c32, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C32, channels_eq_32_multipass_fulltile) { + TEST_REQUIRES_X86_AVX512F; + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(32) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c32, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C32, channels_eq_32_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX512F; + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(32) + .input_stride(37) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c32, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C32, channels_div_32_2pass_fulltile) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 64; channels < 256; channels += 32) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c32, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C32, channels_div_32_2pass_subtile) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 64; channels < 256; channels += 32) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c32, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C32, channels_div_32_multipass_fulltile) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 64; channels < 256; channels += 32) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c32, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C32, channels_div_32_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 64; channels < 256; channels += 32) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(521) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c32, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C32, channels_lt_32_2pass_fulltile) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 1; channels < 32; channels++) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c32, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C32, channels_lt_32_2pass_subtile) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 1; channels < 32; channels++) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c32, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C32, channels_lt_32_multipass_fulltile) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 1; channels < 32; channels++) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c32, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C32, channels_lt_32_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 1; channels < 32; channels++) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(37) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c32, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C32, channels_gt_32_2pass_fulltile) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 33; channels < 64; channels++) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c32, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C32, channels_gt_32_2pass_subtile) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 33; channels < 64; channels++) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c32, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C32, channels_gt_32_multipass_fulltile) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 33; channels < 64; channels++) { + for (size_t rows = 1; rows < 35; rows += 14) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c32, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C32, channels_gt_32_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 33; channels < 64; channels++) { + for (size_t rows = 1; rows < 35; rows += 14) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(79) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c32, xnn_init_f32_scale_scalar_params); + } + } + } +#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 + + +#if XNN_ARCH_X86 || XNN_ARCH_X86_64 + TEST(F32_RDSUM_7P7X__AVX512F_C64, channels_eq_64_2pass_fulltile) { + TEST_REQUIRES_X86_AVX512F; + RDSumMicrokernelTester() + .rows(14) + .channels(64) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c64, xnn_init_f32_scale_scalar_params); + } + + TEST(F32_RDSUM_7P7X__AVX512F_C64, channels_eq_64_2pass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX512F; + RDSumMicrokernelTester() + .rows(14) + .channels(64) + .input_stride(67) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c64, xnn_init_f32_scale_scalar_params); + } + + TEST(F32_RDSUM_7P7X__AVX512F_C64, channels_eq_64_2pass_subtile) { + TEST_REQUIRES_X86_AVX512F; + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(64) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c64, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C64, channels_eq_64_2pass_subtile_with_input_stride) { + TEST_REQUIRES_X86_AVX512F; + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(64) + .input_stride(67) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c64, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C64, channels_eq_64_multipass_fulltile) { + TEST_REQUIRES_X86_AVX512F; + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(64) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c64, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C64, channels_eq_64_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX512F; + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(64) + .input_stride(67) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c64, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C64, channels_div_64_2pass_fulltile) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 128; channels < 512; channels += 64) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c64, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C64, channels_div_64_2pass_subtile) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 128; channels < 512; channels += 64) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c64, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C64, channels_div_64_multipass_fulltile) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 128; channels < 512; channels += 64) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c64, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C64, channels_div_64_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 128; channels < 512; channels += 64) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(1031) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c64, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C64, channels_lt_64_2pass_fulltile) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 1; channels < 64; channels++) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c64, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C64, channels_lt_64_2pass_subtile) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 1; channels < 64; channels++) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c64, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C64, channels_lt_64_multipass_fulltile) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 1; channels < 64; channels++) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c64, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C64, channels_lt_64_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 1; channels < 64; channels++) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(67) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c64, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C64, channels_gt_64_2pass_fulltile) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 65; channels < 128; channels++) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c64, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C64, channels_gt_64_2pass_subtile) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 65; channels < 128; channels++) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c64, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C64, channels_gt_64_multipass_fulltile) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 65; channels < 128; channels++) { + for (size_t rows = 1; rows < 35; rows += 14) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c64, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__AVX512F_C64, channels_gt_64_multipass_fulltile_with_input_stride) { + TEST_REQUIRES_X86_AVX512F; + for (size_t channels = 65; channels < 128; channels++) { + for (size_t rows = 1; rows < 35; rows += 14) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(149) + .Test(xnn_f32_rdsum_ukernel_7p7x__avx512f_c64, xnn_init_f32_scale_scalar_params); + } + } + } +#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 + + +#if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD + TEST(F32_RDSUM_7P7X__WASMSIMD_C16, channels_eq_16_2pass_fulltile) { + RDSumMicrokernelTester() + .rows(14) + .channels(16) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c16, xnn_init_f32_scale_scalar_params); + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C16, channels_eq_16_2pass_fulltile_with_input_stride) { + RDSumMicrokernelTester() + .rows(14) + .channels(16) + .input_stride(19) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c16, xnn_init_f32_scale_scalar_params); + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C16, channels_eq_16_2pass_subtile) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(16) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c16, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C16, channels_eq_16_2pass_subtile_with_input_stride) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(16) + .input_stride(19) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c16, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C16, channels_eq_16_multipass_fulltile) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(16) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c16, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C16, channels_eq_16_multipass_fulltile_with_input_stride) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(16) + .input_stride(19) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c16, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C16, channels_div_16_2pass_fulltile) { + for (size_t channels = 32; channels < 128; channels += 16) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c16, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C16, channels_div_16_2pass_subtile) { + for (size_t channels = 32; channels < 128; channels += 16) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c16, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C16, channels_div_16_multipass_fulltile) { + for (size_t channels = 32; channels < 128; channels += 16) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c16, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C16, channels_div_16_multipass_fulltile_with_input_stride) { + for (size_t channels = 32; channels < 128; channels += 16) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(263) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c16, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C16, channels_lt_16_2pass_fulltile) { + for (size_t channels = 1; channels < 16; channels++) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c16, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C16, channels_lt_16_2pass_subtile) { + for (size_t channels = 1; channels < 16; channels++) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c16, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C16, channels_lt_16_multipass_fulltile) { + for (size_t channels = 1; channels < 16; channels++) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c16, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C16, channels_lt_16_multipass_fulltile_with_input_stride) { + for (size_t channels = 1; channels < 16; channels++) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(19) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c16, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C16, channels_gt_16_2pass_fulltile) { + for (size_t channels = 17; channels < 32; channels++) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c16, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C16, channels_gt_16_2pass_subtile) { + for (size_t channels = 17; channels < 32; channels++) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c16, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C16, channels_gt_16_multipass_fulltile) { + for (size_t channels = 17; channels < 32; channels++) { + for (size_t rows = 1; rows < 35; rows += 14) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c16, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C16, channels_gt_16_multipass_fulltile_with_input_stride) { + for (size_t channels = 17; channels < 32; channels++) { + for (size_t rows = 1; rows < 35; rows += 14) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(47) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c16, xnn_init_f32_scale_scalar_params); + } + } + } +#endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD + + +#if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD + TEST(F32_RDSUM_7P7X__WASMSIMD_C32, channels_eq_32_2pass_fulltile) { + RDSumMicrokernelTester() + .rows(14) + .channels(32) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c32, xnn_init_f32_scale_scalar_params); + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C32, channels_eq_32_2pass_fulltile_with_input_stride) { + RDSumMicrokernelTester() + .rows(14) + .channels(32) + .input_stride(37) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c32, xnn_init_f32_scale_scalar_params); + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C32, channels_eq_32_2pass_subtile) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(32) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c32, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C32, channels_eq_32_2pass_subtile_with_input_stride) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(32) + .input_stride(37) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c32, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C32, channels_eq_32_multipass_fulltile) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(32) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c32, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C32, channels_eq_32_multipass_fulltile_with_input_stride) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(32) + .input_stride(37) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c32, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C32, channels_div_32_2pass_fulltile) { + for (size_t channels = 64; channels < 256; channels += 32) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c32, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C32, channels_div_32_2pass_subtile) { + for (size_t channels = 64; channels < 256; channels += 32) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c32, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C32, channels_div_32_multipass_fulltile) { + for (size_t channels = 64; channels < 256; channels += 32) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c32, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C32, channels_div_32_multipass_fulltile_with_input_stride) { + for (size_t channels = 64; channels < 256; channels += 32) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(521) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c32, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C32, channels_lt_32_2pass_fulltile) { + for (size_t channels = 1; channels < 32; channels++) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c32, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C32, channels_lt_32_2pass_subtile) { + for (size_t channels = 1; channels < 32; channels++) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c32, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C32, channels_lt_32_multipass_fulltile) { + for (size_t channels = 1; channels < 32; channels++) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c32, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C32, channels_lt_32_multipass_fulltile_with_input_stride) { + for (size_t channels = 1; channels < 32; channels++) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(37) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c32, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C32, channels_gt_32_2pass_fulltile) { + for (size_t channels = 33; channels < 64; channels++) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c32, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C32, channels_gt_32_2pass_subtile) { + for (size_t channels = 33; channels < 64; channels++) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c32, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C32, channels_gt_32_multipass_fulltile) { + for (size_t channels = 33; channels < 64; channels++) { + for (size_t rows = 1; rows < 35; rows += 14) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c32, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C32, channels_gt_32_multipass_fulltile_with_input_stride) { + for (size_t channels = 33; channels < 64; channels++) { + for (size_t rows = 1; rows < 35; rows += 14) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(79) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c32, xnn_init_f32_scale_scalar_params); + } + } + } +#endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD + + +#if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD + TEST(F32_RDSUM_7P7X__WASMSIMD_C64, channels_eq_64_2pass_fulltile) { + RDSumMicrokernelTester() + .rows(14) + .channels(64) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c64, xnn_init_f32_scale_scalar_params); + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C64, channels_eq_64_2pass_fulltile_with_input_stride) { + RDSumMicrokernelTester() + .rows(14) + .channels(64) + .input_stride(67) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c64, xnn_init_f32_scale_scalar_params); + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C64, channels_eq_64_2pass_subtile) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(64) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c64, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C64, channels_eq_64_2pass_subtile_with_input_stride) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(64) + .input_stride(67) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c64, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C64, channels_eq_64_multipass_fulltile) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(64) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c64, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C64, channels_eq_64_multipass_fulltile_with_input_stride) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(64) + .input_stride(67) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c64, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C64, channels_div_64_2pass_fulltile) { + for (size_t channels = 128; channels < 512; channels += 64) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c64, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C64, channels_div_64_2pass_subtile) { + for (size_t channels = 128; channels < 512; channels += 64) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c64, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C64, channels_div_64_multipass_fulltile) { + for (size_t channels = 128; channels < 512; channels += 64) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c64, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C64, channels_div_64_multipass_fulltile_with_input_stride) { + for (size_t channels = 128; channels < 512; channels += 64) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(1031) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c64, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C64, channels_lt_64_2pass_fulltile) { + for (size_t channels = 1; channels < 64; channels++) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c64, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C64, channels_lt_64_2pass_subtile) { + for (size_t channels = 1; channels < 64; channels++) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c64, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C64, channels_lt_64_multipass_fulltile) { + for (size_t channels = 1; channels < 64; channels++) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c64, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C64, channels_lt_64_multipass_fulltile_with_input_stride) { + for (size_t channels = 1; channels < 64; channels++) { + for (size_t rows = 1; rows <= 35; rows += 7) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(67) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c64, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C64, channels_gt_64_2pass_fulltile) { + for (size_t channels = 65; channels < 128; channels++) { + RDSumMicrokernelTester() + .rows(14) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c64, xnn_init_f32_scale_scalar_params); + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C64, channels_gt_64_2pass_subtile) { + for (size_t channels = 65; channels < 128; channels++) { + for (size_t rows = 1; rows < 14; rows++) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c64, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C64, channels_gt_64_multipass_fulltile) { + for (size_t channels = 65; channels < 128; channels++) { + for (size_t rows = 1; rows < 35; rows += 14) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c64, xnn_init_f32_scale_scalar_params); + } + } + } + + TEST(F32_RDSUM_7P7X__WASMSIMD_C64, channels_gt_64_multipass_fulltile_with_input_stride) { + for (size_t channels = 65; channels < 128; channels++) { + for (size_t rows = 1; rows < 35; rows += 14) { + RDSumMicrokernelTester() + .rows(rows) + .channels(channels) + .input_stride(149) + .Test(xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c64, xnn_init_f32_scale_scalar_params); + } + } + } +#endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD diff --git a/test/f32-rdsum.yaml b/test/f32-rdsum.yaml index e5abbd9cd8b..547593a54d1 100644 --- a/test/f32-rdsum.yaml +++ b/test/f32-rdsum.yaml @@ -20,3 +20,24 @@ init: xnn_init_f32_scale_sse_params - name: xnn_f32_rdsum_ukernel_7p7x__sse_c64 init: xnn_init_f32_scale_sse_params +# x86 AVX +- name: xnn_f32_rdsum_ukernel_7p7x__avx_c16 + init: xnn_init_f32_scale_avx_params +- name: xnn_f32_rdsum_ukernel_7p7x__avx_c32 + init: xnn_init_f32_scale_avx_params +- name: xnn_f32_rdsum_ukernel_7p7x__avx_c64 + init: xnn_init_f32_scale_avx_params +# x86 AVX512F +- name: xnn_f32_rdsum_ukernel_7p7x__avx512f_c16 + init: xnn_init_f32_scale_scalar_params +- name: xnn_f32_rdsum_ukernel_7p7x__avx512f_c32 + init: xnn_init_f32_scale_scalar_params +- name: xnn_f32_rdsum_ukernel_7p7x__avx512f_c64 + init: xnn_init_f32_scale_scalar_params +# WAsmSIMD +- name: xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c16 + init: xnn_init_f32_scale_scalar_params +- name: xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c32 + init: xnn_init_f32_scale_scalar_params +- name: xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c64 + init: xnn_init_f32_scale_scalar_params diff --git a/test/gavgpool-microkernel-tester.h b/test/gavgpool-microkernel-tester.h index c029f2de06b..f58878437be 100644 --- a/test/gavgpool-microkernel-tester.h +++ b/test/gavgpool-microkernel-tester.h @@ -598,8 +598,7 @@ class GAvgPoolMicrokernelTester { std::vector output_ref(channels()); for (size_t iteration = 0; iteration < iterations(); iteration++) { std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); - std::fill(output.begin(), output.end(), std::nanf("")); - + std::fill(output.begin(), output.end(), 0.f); // kernels accumulate. // Compute reference results, without clamping. for (size_t c = 0; c < channels(); c++) { float acc = 0.0f;