Skip to content

Commit 62bc971

Browse files
alankellyxnnpack-bot
authored andcommitted
Add AVX512F rdsum accumulating microkernels
PiperOrigin-RevId: 631354913
1 parent 2ad6f3e commit 62bc971

13 files changed

+1781
-2
lines changed

bench/f32-rdsum.cc

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,36 @@ BENCHMARK_CAPTURE(f32_rsum_discontig, scalar_c4,
112112
#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
113113

114114

115+
#if XNN_ARCH_X86 || XNN_ARCH_X86_64
116+
BENCHMARK_CAPTURE(f32_rsum_discontig, avx512f_c16,
117+
xnn_f32_rdsum_ukernel_7p7x__avx512f_c16,
118+
xnn_init_f32_scale_scalar_params,
119+
benchmark::utils::CheckAVX512F)
120+
->Apply(BenchmarkBatch)
121+
->UseRealTime();
122+
#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
123+
124+
125+
#if XNN_ARCH_X86 || XNN_ARCH_X86_64
126+
BENCHMARK_CAPTURE(f32_rsum_discontig, avx512f_c32,
127+
xnn_f32_rdsum_ukernel_7p7x__avx512f_c32,
128+
xnn_init_f32_scale_scalar_params,
129+
benchmark::utils::CheckAVX512F)
130+
->Apply(BenchmarkBatch)
131+
->UseRealTime();
132+
#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
133+
134+
135+
#if XNN_ARCH_X86 || XNN_ARCH_X86_64
136+
BENCHMARK_CAPTURE(f32_rsum_discontig, avx512f_c64,
137+
xnn_f32_rdsum_ukernel_7p7x__avx512f_c64,
138+
xnn_init_f32_scale_scalar_params,
139+
benchmark::utils::CheckAVX512F)
140+
->Apply(BenchmarkBatch)
141+
->UseRealTime();
142+
#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
143+
144+
115145
#if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
116146
BENCHMARK_CAPTURE(f32_rsum_discontig, wasmsimd_c16,
117147
xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c16,

cmake/microkernels.cmake

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,6 +1324,10 @@ SET(ALL_AVX512F_MICROKERNEL_SRCS
13241324
src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-avx512f-rr1-p5-scalef-u192-acc3.c
13251325
src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-avx512f-rr1-p5-scalef-u192-acc6.c
13261326
src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-avx512f-rr1-p5-scalef-u192.c
1327+
src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c16.c
1328+
src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c32.c
1329+
src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c64.c
1330+
src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c128.c
13271331
src/f32-rminmax/gen/f32-rmax-avx512f-u16.c
13281332
src/f32-rminmax/gen/f32-rmax-avx512f-u32-acc2.c
13291333
src/f32-rminmax/gen/f32-rmax-avx512f-u48-acc3.c

microkernels.bzl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,6 +1324,10 @@ ALL_AVX512F_MICROKERNEL_SRCS = [
13241324
"src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-avx512f-rr1-p5-scalef-u192-acc3.c",
13251325
"src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-avx512f-rr1-p5-scalef-u192-acc6.c",
13261326
"src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-avx512f-rr1-p5-scalef-u192.c",
1327+
"src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c16.c",
1328+
"src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c32.c",
1329+
"src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c64.c",
1330+
"src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c128.c",
13271331
"src/f32-rminmax/gen/f32-rmax-avx512f-u16.c",
13281332
"src/f32-rminmax/gen/f32-rmax-avx512f-u32-acc2.c",
13291333
"src/f32-rminmax/gen/f32-rmax-avx512f-u48-acc3.c",

scripts/generate-f32-rdsum.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ tools/xngen src/f32-rdsum/avx.c.in -D CHANNELS=16 -D ACCUMULATORS=7 -o src/f32-r
2222
tools/xngen src/f32-rdsum/avx.c.in -D CHANNELS=32 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c32.c &
2323
tools/xngen src/f32-rdsum/avx.c.in -D CHANNELS=64 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c64.c &
2424

25+
#################################### AVX512F #######$###########################
26+
tools/xngen src/f32-rdsum/avx512.c.in -D CHANNELS=16 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c16.c &
27+
tools/xngen src/f32-rdsum/avx512.c.in -D CHANNELS=32 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c32.c &
28+
tools/xngen src/f32-rdsum/avx512.c.in -D CHANNELS=64 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c64.c &
29+
tools/xngen src/f32-rdsum/avx512.c.in -D CHANNELS=128 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c128.c &
30+
2531
#################################### WAsm SIMD ################################
2632
tools/xngen src/f32-rdsum/wasm-simd.c.in -D CHANNELS=16 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-wasmsimd-c16.c &
2733
tools/xngen src/f32-rdsum/wasm-simd.c.in -D CHANNELS=32 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-wasmsimd-c32.c &

src/f32-rdsum/avx512.c.in

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// This source code is licensed under the BSD-style license found in the
4+
// LICENSE file in the root directory of this source tree.
5+
6+
$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
7+
#include <assert.h>
8+
9+
#include <immintrin.h>
10+
11+
#include <xnnpack/common.h>
12+
#include <xnnpack/reduce.h>
13+
#include <xnnpack/math.h>
14+
15+
16+
$UNROLL = CHANNELS >> 4
17+
void xnn_f32_rdsum_ukernel_${ACCUMULATORS}p${ACCUMULATORS}x__avx512f_c${CHANNELS}(
18+
size_t rows,
19+
size_t channels,
20+
const float* input,
21+
size_t input_stride,
22+
const float* zero,
23+
float* output,
24+
const union xnn_f32_scale_params params[restrict XNN_MIN_ELEMENTS(1)])
25+
{
26+
assert(rows != 0);
27+
assert(channels != 0);
28+
assert(input != NULL);
29+
assert(output != NULL);
30+
31+
const __m512 vscale = _mm512_set1_ps(params->scalar.scale);
32+
33+
size_t input_increment = ${ACCUMULATORS} * input_stride;
34+
for (; channels >= ${CHANNELS}; channels -= ${CHANNELS}) {
35+
const float* i0 = input;
36+
$for i in range(1, ACCUMULATORS):
37+
const float* i${i} = (const float*) ((uintptr_t) input + ${i} * input_stride);
38+
39+
$for i in range(UNROLL):
40+
__m512 vacc${i} = _mm512_setzero_ps();
41+
42+
for (int r = rows; r > 0; r -= ${ACCUMULATORS}) {
43+
$for N in range(1, ACCUMULATORS, 2):
44+
if XNN_UNPREDICTABLE(r < ${N+1}) {
45+
i${N} = zero;
46+
}
47+
if XNN_UNPREDICTABLE(r <= ${N+1}) {
48+
i${N+1} = zero;
49+
}
50+
$for c in range(UNROLL):
51+
__m512 vin${c};
52+
$for j in range(ACCUMULATORS):
53+
$for c in range(UNROLL):
54+
vin${c} = _mm512_loadu_ps(&i${j}[${c*16}]);
55+
$for c in range(UNROLL):
56+
vacc${c} = _mm512_add_ps(vin${c}, vacc${c});
57+
$for N in range(0, ACCUMULATORS):
58+
i${N} = (const float*) ((uintptr_t) i${N} + input_increment);
59+
}
60+
$for i in range(UNROLL):
61+
vacc${i} = _mm512_mul_ps(vacc${i}, vscale);
62+
63+
const float* o = output;
64+
$for i in range(0, UNROLL):
65+
const __m512 vo${i} = _mm512_loadu_ps(o); o += 16;
66+
$for i in range(0, UNROLL):
67+
vacc${i} = _mm512_add_ps(vo${i}, vacc${i});
68+
$for i in range(0, UNROLL):
69+
_mm512_storeu_ps(output, vacc${i}); output += 16;
70+
71+
input = (const float*) ((uintptr_t) input + ${CHANNELS} * sizeof(float));
72+
}
73+
if (channels != 0) {
74+
input_increment = ${ACCUMULATORS} * input_stride;
75+
const float* i0 = input;
76+
$for i in range(1, ACCUMULATORS):
77+
const float* i${i} = (const float*) ((uintptr_t) input + ${i} * input_stride);
78+
__m512 vacc[${UNROLL}];
79+
$for i in range(UNROLL):
80+
vacc[${i}] = _mm512_setzero_ps();
81+
82+
const size_t num_full_chunks = channels >> 4;
83+
const size_t num_chunks = round_up_po2(channels, 16) >> 4;
84+
const size_t remainder = channels & 0xF;
85+
const size_t batch = channels & 0xF;
86+
__mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1)));
87+
if (remainder) {
88+
assert(batch >= 1);
89+
assert(batch <= 15);
90+
vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1)));
91+
}
92+
for (int r = rows; r > 0; r -= ${ACCUMULATORS}) {
93+
$for N in range(1, ACCUMULATORS, 2):
94+
if XNN_UNPREDICTABLE(r < ${N+1}) {
95+
i${N} = zero;
96+
}
97+
if XNN_UNPREDICTABLE(r <= ${N+1}) {
98+
i${N+1} = zero;
99+
}
100+
for (int i = 0; i < num_full_chunks; ++i) {
101+
$for c in range(ACCUMULATORS):
102+
vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i${c}[i*16]), vacc[i]);
103+
}
104+
105+
if (remainder) {
106+
$for c in range(ACCUMULATORS):
107+
vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i${c}[num_full_chunks*16]));
108+
}
109+
$for N in range(ACCUMULATORS):
110+
i${N} = (const float*) ((uintptr_t) i${N} + input_increment);
111+
}
112+
for (size_t i = 0; i < num_chunks; ++i) {
113+
vacc[i] = _mm512_mul_ps(vacc[i], vscale);
114+
}
115+
116+
__m512 vo[${UNROLL}];
117+
const float* o = output;
118+
for (int i = 0; i < channels >> 4; ++i) {
119+
vo[i] = _mm512_loadu_ps(o); o += 16;
120+
}
121+
for (int i = 0; i < channels >> 4; ++i) {
122+
vacc[i] = _mm512_add_ps(vo[i], vacc[i]);
123+
}
124+
for (int i = 0; i < channels >> 4; ++i) {
125+
_mm512_storeu_ps(output, vacc[i]); output += 16;
126+
}
127+
if (remainder) {
128+
const size_t pos = num_full_chunks;
129+
__m512 vout = vacc[pos];
130+
vout = _mm512_maskz_add_ps(vmask, vout, _mm512_maskz_loadu_ps(vmask, output));
131+
_mm512_mask_storeu_ps(output, vmask, vout);
132+
}
133+
}
134+
}

0 commit comments

Comments
 (0)