Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AVX512F rdsum accumulating microkernels #6335

Merged
merged 1 commit into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
30 changes: 30 additions & 0 deletions bench/f32-rdsum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,36 @@ BENCHMARK_CAPTURE(f32_rsum_discontig, scalar_c4,
#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64


#if XNN_ARCH_X86 || XNN_ARCH_X86_64
BENCHMARK_CAPTURE(f32_rsum_discontig, avx512f_c16,
xnn_f32_rdsum_ukernel_7p7x__avx512f_c16,
xnn_init_f32_scale_scalar_params,
benchmark::utils::CheckAVX512F)
->Apply(BenchmarkBatch)
->UseRealTime();
#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64


#if XNN_ARCH_X86 || XNN_ARCH_X86_64
BENCHMARK_CAPTURE(f32_rsum_discontig, avx512f_c32,
xnn_f32_rdsum_ukernel_7p7x__avx512f_c32,
xnn_init_f32_scale_scalar_params,
benchmark::utils::CheckAVX512F)
->Apply(BenchmarkBatch)
->UseRealTime();
#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64


#if XNN_ARCH_X86 || XNN_ARCH_X86_64
BENCHMARK_CAPTURE(f32_rsum_discontig, avx512f_c64,
xnn_f32_rdsum_ukernel_7p7x__avx512f_c64,
xnn_init_f32_scale_scalar_params,
benchmark::utils::CheckAVX512F)
->Apply(BenchmarkBatch)
->UseRealTime();
#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64


#if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
BENCHMARK_CAPTURE(f32_rsum_discontig, wasmsimd_c16,
xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c16,
Expand Down
4 changes: 4 additions & 0 deletions cmake/microkernels.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1324,6 +1324,10 @@ SET(ALL_AVX512F_MICROKERNEL_SRCS
src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-avx512f-rr1-p5-scalef-u192-acc3.c
src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-avx512f-rr1-p5-scalef-u192-acc6.c
src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-avx512f-rr1-p5-scalef-u192.c
src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c16.c
src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c32.c
src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c64.c
src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c128.c
src/f32-rminmax/gen/f32-rmax-avx512f-u16.c
src/f32-rminmax/gen/f32-rmax-avx512f-u32-acc2.c
src/f32-rminmax/gen/f32-rmax-avx512f-u48-acc3.c
Expand Down
4 changes: 4 additions & 0 deletions microkernels.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -1324,6 +1324,10 @@ ALL_AVX512F_MICROKERNEL_SRCS = [
"src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-avx512f-rr1-p5-scalef-u192-acc3.c",
"src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-avx512f-rr1-p5-scalef-u192-acc6.c",
"src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-avx512f-rr1-p5-scalef-u192.c",
"src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c16.c",
"src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c32.c",
"src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c64.c",
"src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c128.c",
"src/f32-rminmax/gen/f32-rmax-avx512f-u16.c",
"src/f32-rminmax/gen/f32-rmax-avx512f-u32-acc2.c",
"src/f32-rminmax/gen/f32-rmax-avx512f-u48-acc3.c",
Expand Down
6 changes: 6 additions & 0 deletions scripts/generate-f32-rdsum.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ tools/xngen src/f32-rdsum/avx.c.in -D CHANNELS=16 -D ACCUMULATORS=7 -o src/f32-r
tools/xngen src/f32-rdsum/avx.c.in -D CHANNELS=32 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c32.c &
tools/xngen src/f32-rdsum/avx.c.in -D CHANNELS=64 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c64.c &

#################################### AVX512F #######$###########################
tools/xngen src/f32-rdsum/avx512.c.in -D CHANNELS=16 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c16.c &
tools/xngen src/f32-rdsum/avx512.c.in -D CHANNELS=32 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c32.c &
tools/xngen src/f32-rdsum/avx512.c.in -D CHANNELS=64 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c64.c &
tools/xngen src/f32-rdsum/avx512.c.in -D CHANNELS=128 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c128.c &

#################################### WAsm SIMD ################################
tools/xngen src/f32-rdsum/wasm-simd.c.in -D CHANNELS=16 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-wasmsimd-c16.c &
tools/xngen src/f32-rdsum/wasm-simd.c.in -D CHANNELS=32 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-wasmsimd-c32.c &
Expand Down
134 changes: 134 additions & 0 deletions src/f32-rdsum/avx512.c.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
// Copyright 2024 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
#include <assert.h>

#include <immintrin.h>

#include <xnnpack/common.h>
#include <xnnpack/reduce.h>
#include <xnnpack/math.h>


$UNROLL = CHANNELS >> 4
void xnn_f32_rdsum_ukernel_${ACCUMULATORS}p${ACCUMULATORS}x__avx512f_c${CHANNELS}(
size_t rows,
size_t channels,
const float* input,
size_t input_stride,
const float* zero,
float* output,
const union xnn_f32_scale_params params[restrict XNN_MIN_ELEMENTS(1)])
{
assert(rows != 0);
assert(channels != 0);
assert(input != NULL);
assert(output != NULL);

const __m512 vscale = _mm512_set1_ps(params->scalar.scale);

size_t input_increment = ${ACCUMULATORS} * input_stride;
for (; channels >= ${CHANNELS}; channels -= ${CHANNELS}) {
const float* i0 = input;
$for i in range(1, ACCUMULATORS):
const float* i${i} = (const float*) ((uintptr_t) input + ${i} * input_stride);

$for i in range(UNROLL):
__m512 vacc${i} = _mm512_setzero_ps();

for (int r = rows; r > 0; r -= ${ACCUMULATORS}) {
$for N in range(1, ACCUMULATORS, 2):
if XNN_UNPREDICTABLE(r < ${N+1}) {
i${N} = zero;
}
if XNN_UNPREDICTABLE(r <= ${N+1}) {
i${N+1} = zero;
}
$for c in range(UNROLL):
__m512 vin${c};
$for j in range(ACCUMULATORS):
$for c in range(UNROLL):
vin${c} = _mm512_loadu_ps(&i${j}[${c*16}]);
$for c in range(UNROLL):
vacc${c} = _mm512_add_ps(vin${c}, vacc${c});
$for N in range(0, ACCUMULATORS):
i${N} = (const float*) ((uintptr_t) i${N} + input_increment);
}
$for i in range(UNROLL):
vacc${i} = _mm512_mul_ps(vacc${i}, vscale);

const float* o = output;
$for i in range(0, UNROLL):
const __m512 vo${i} = _mm512_loadu_ps(o); o += 16;
$for i in range(0, UNROLL):
vacc${i} = _mm512_add_ps(vo${i}, vacc${i});
$for i in range(0, UNROLL):
_mm512_storeu_ps(output, vacc${i}); output += 16;

input = (const float*) ((uintptr_t) input + ${CHANNELS} * sizeof(float));
}
if (channels != 0) {
input_increment = ${ACCUMULATORS} * input_stride;
const float* i0 = input;
$for i in range(1, ACCUMULATORS):
const float* i${i} = (const float*) ((uintptr_t) input + ${i} * input_stride);
__m512 vacc[${UNROLL}];
$for i in range(UNROLL):
vacc[${i}] = _mm512_setzero_ps();

const size_t num_full_chunks = channels >> 4;
const size_t num_chunks = round_up_po2(channels, 16) >> 4;
const size_t remainder = channels & 0xF;
const size_t batch = channels & 0xF;
__mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1)));
if (remainder) {
assert(batch >= 1);
assert(batch <= 15);
vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1)));
}
for (int r = rows; r > 0; r -= ${ACCUMULATORS}) {
$for N in range(1, ACCUMULATORS, 2):
if XNN_UNPREDICTABLE(r < ${N+1}) {
i${N} = zero;
}
if XNN_UNPREDICTABLE(r <= ${N+1}) {
i${N+1} = zero;
}
for (int i = 0; i < num_full_chunks; ++i) {
$for c in range(ACCUMULATORS):
vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i${c}[i*16]), vacc[i]);
}

if (remainder) {
$for c in range(ACCUMULATORS):
vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i${c}[num_full_chunks*16]));
}
$for N in range(ACCUMULATORS):
i${N} = (const float*) ((uintptr_t) i${N} + input_increment);
}
for (size_t i = 0; i < num_chunks; ++i) {
vacc[i] = _mm512_mul_ps(vacc[i], vscale);
}

__m512 vo[${UNROLL}];
const float* o = output;
for (int i = 0; i < channels >> 4; ++i) {
vo[i] = _mm512_loadu_ps(o); o += 16;
}
for (int i = 0; i < channels >> 4; ++i) {
vacc[i] = _mm512_add_ps(vo[i], vacc[i]);
}
for (int i = 0; i < channels >> 4; ++i) {
_mm512_storeu_ps(output, vacc[i]); output += 16;
}
if (remainder) {
const size_t pos = num_full_chunks;
__m512 vout = vacc[pos];
vout = _mm512_maskz_add_ps(vmask, vout, _mm512_maskz_loadu_ps(vmask, output));
_mm512_mask_storeu_ps(output, vmask, vout);
}
}
}