Skip to content

Commit 2ad6f3e

Browse files
alankellyxnnpack-bot
authored andcommitted
Add WAsmSIMD rdsum accumulating microkernels
PiperOrigin-RevId: 631333817
1 parent ec3d2b9 commit 2ad6f3e

13 files changed

+1771
-7
lines changed

bench/f32-rdsum.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,33 @@ BENCHMARK_CAPTURE(f32_rsum_discontig, scalar_c4,
112112
#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
113113

114114

115+
#if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
116+
BENCHMARK_CAPTURE(f32_rsum_discontig, wasmsimd_c16,
117+
xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c16,
118+
xnn_init_f32_scale_scalar_params)
119+
->Apply(BenchmarkBatch)
120+
->UseRealTime();
121+
#endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
122+
123+
124+
#if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
125+
BENCHMARK_CAPTURE(f32_rsum_discontig, wasmsimd_c32,
126+
xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c32,
127+
xnn_init_f32_scale_scalar_params)
128+
->Apply(BenchmarkBatch)
129+
->UseRealTime();
130+
#endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
131+
132+
133+
#if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
134+
BENCHMARK_CAPTURE(f32_rsum_discontig, wasmsimd_c64,
135+
xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c64,
136+
xnn_init_f32_scale_scalar_params)
137+
->Apply(BenchmarkBatch)
138+
->UseRealTime();
139+
#endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
140+
141+
115142
#ifndef XNNPACK_BENCHMARK_NO_MAIN
116143
BENCHMARK_MAIN();
117144
#endif

cmake/microkernels.cmake

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8916,6 +8916,9 @@ SET(ALL_WASMSIMD_MICROKERNEL_SRCS
89168916
src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-wasmsimd-rr2-p5-u20-acc2.c
89178917
src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-wasmsimd-rr2-p5-u20-acc5.c
89188918
src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-wasmsimd-rr2-p5-u20.c
8919+
src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-wasmsimd-c16.c
8920+
src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-wasmsimd-c32.c
8921+
src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-wasmsimd-c64.c
89198922
src/f32-rminmax/gen/f32-rmax-wasmsimd-minmax-u4.c
89208923
src/f32-rminmax/gen/f32-rmax-wasmsimd-minmax-u8-acc2.c
89218924
src/f32-rminmax/gen/f32-rmax-wasmsimd-minmax-u12-acc3.c

microkernels.bzl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8951,6 +8951,9 @@ ALL_WASMSIMD_MICROKERNEL_SRCS = [
89518951
"src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-wasmsimd-rr2-p5-u20-acc2.c",
89528952
"src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-wasmsimd-rr2-p5-u20-acc5.c",
89538953
"src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-wasmsimd-rr2-p5-u20.c",
8954+
"src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-wasmsimd-c16.c",
8955+
"src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-wasmsimd-c32.c",
8956+
"src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-wasmsimd-c64.c",
89548957
"src/f32-rminmax/gen/f32-rmax-wasmsimd-minmax-u4.c",
89558958
"src/f32-rminmax/gen/f32-rmax-wasmsimd-minmax-u8-acc2.c",
89568959
"src/f32-rminmax/gen/f32-rmax-wasmsimd-minmax-u12-acc3.c",

scripts/generate-f32-rdsum.sh

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,24 @@
77
#################################### Scalar ###################################
88
tools/xngen src/f32-rdsum/scalar.c.in -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-scalar.c &
99

10-
#################################### NEON ###################################
10+
#################################### NEON #####################################
1111
tools/xngen src/f32-rdsum/neon.c.in -D CHANNELS=16 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-neon-c16.c &
1212
tools/xngen src/f32-rdsum/neon.c.in -D CHANNELS=32 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-neon-c32.c &
1313
tools/xngen src/f32-rdsum/neon.c.in -D CHANNELS=64 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-neon-c64.c &
1414

15-
#################################### SSE ####################################
15+
#################################### SSE ######################################
1616
tools/xngen src/f32-rdsum/sse.c.in -D CHANNELS=16 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-sse-c16.c &
1717
tools/xngen src/f32-rdsum/sse.c.in -D CHANNELS=32 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-sse-c32.c &
1818
tools/xngen src/f32-rdsum/sse.c.in -D CHANNELS=64 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-sse-c64.c &
1919

20-
#################################### AVX ####################################
20+
#################################### AVX ######################################
2121
tools/xngen src/f32-rdsum/avx.c.in -D CHANNELS=16 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c16.c &
2222
tools/xngen src/f32-rdsum/avx.c.in -D CHANNELS=32 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c32.c &
2323
tools/xngen src/f32-rdsum/avx.c.in -D CHANNELS=64 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-c64.c &
2424

25+
#################################### WAsm SIMD ################################
26+
tools/xngen src/f32-rdsum/wasm-simd.c.in -D CHANNELS=16 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-wasmsimd-c16.c &
27+
tools/xngen src/f32-rdsum/wasm-simd.c.in -D CHANNELS=32 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-wasmsimd-c32.c &
28+
tools/xngen src/f32-rdsum/wasm-simd.c.in -D CHANNELS=64 -D ACCUMULATORS=7 -o src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-wasmsimd-c64.c &
29+
2530
wait
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
// Auto-generated file. Do not edit!
2+
// Template: src/f32-rdsum/wasm-simd.c.in
3+
// Generator: tools/xngen
4+
//
5+
// Copyright 2024 Google LLC //
6+
// This source code is licensed under the BSD-style license found in the
7+
// LICENSE file in the root directory of this source tree.
8+
9+
#include <assert.h>
10+
11+
#include <wasm_simd128.h>
12+
13+
#include <xnnpack/common.h>
14+
#include <xnnpack/reduce.h>
15+
#include <xnnpack/math.h>
16+
17+
void xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c16(
18+
size_t rows,
19+
size_t channels,
20+
const float* input,
21+
size_t input_stride,
22+
const float* zero,
23+
float* output,
24+
const union xnn_f32_scale_params params[restrict XNN_MIN_ELEMENTS(1)])
25+
{
26+
assert(rows != 0);
27+
assert(channels != 0);
28+
assert(input != NULL);
29+
assert(output != NULL);
30+
31+
const v128_t vscale = wasm_v128_load32_splat(&params->scalar.scale);
32+
33+
size_t input_increment = 7 * input_stride;
34+
for (; channels >= 16; channels -= 16) {
35+
const float* i0 = input;
36+
const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride);
37+
const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride);
38+
const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride);
39+
const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride);
40+
const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride);
41+
const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride);
42+
43+
v128_t vacc0 = wasm_i32x4_const_splat(0.f);
44+
v128_t vacc1 = wasm_i32x4_const_splat(0.f);
45+
v128_t vacc2 = wasm_i32x4_const_splat(0.f);
46+
v128_t vacc3 = wasm_i32x4_const_splat(0.f);
47+
48+
for (int r = rows; r > 0; r -= 7) {
49+
if XNN_UNPREDICTABLE(r < 2) {
50+
i1 = zero;
51+
}
52+
if XNN_UNPREDICTABLE(r <= 2) {
53+
i2 = zero;
54+
}
55+
if XNN_UNPREDICTABLE(r < 4) {
56+
i3 = zero;
57+
}
58+
if XNN_UNPREDICTABLE(r <= 4) {
59+
i4 = zero;
60+
}
61+
if XNN_UNPREDICTABLE(r < 6) {
62+
i5 = zero;
63+
}
64+
if XNN_UNPREDICTABLE(r <= 6) {
65+
i6 = zero;
66+
}
67+
v128_t vin0;
68+
v128_t vin1;
69+
v128_t vin2;
70+
v128_t vin3;
71+
vin0 = wasm_v128_load(&i0[0]);
72+
vin1 = wasm_v128_load(&i0[4]);
73+
vin2 = wasm_v128_load(&i0[8]);
74+
vin3 = wasm_v128_load(&i0[12]);
75+
vacc0 = wasm_f32x4_add(vin0, vacc0);
76+
vacc1 = wasm_f32x4_add(vin1, vacc1);
77+
vacc2 = wasm_f32x4_add(vin2, vacc2);
78+
vacc3 = wasm_f32x4_add(vin3, vacc3);
79+
vin0 = wasm_v128_load(&i1[0]);
80+
vin1 = wasm_v128_load(&i1[4]);
81+
vin2 = wasm_v128_load(&i1[8]);
82+
vin3 = wasm_v128_load(&i1[12]);
83+
vacc0 = wasm_f32x4_add(vin0, vacc0);
84+
vacc1 = wasm_f32x4_add(vin1, vacc1);
85+
vacc2 = wasm_f32x4_add(vin2, vacc2);
86+
vacc3 = wasm_f32x4_add(vin3, vacc3);
87+
vin0 = wasm_v128_load(&i2[0]);
88+
vin1 = wasm_v128_load(&i2[4]);
89+
vin2 = wasm_v128_load(&i2[8]);
90+
vin3 = wasm_v128_load(&i2[12]);
91+
vacc0 = wasm_f32x4_add(vin0, vacc0);
92+
vacc1 = wasm_f32x4_add(vin1, vacc1);
93+
vacc2 = wasm_f32x4_add(vin2, vacc2);
94+
vacc3 = wasm_f32x4_add(vin3, vacc3);
95+
vin0 = wasm_v128_load(&i3[0]);
96+
vin1 = wasm_v128_load(&i3[4]);
97+
vin2 = wasm_v128_load(&i3[8]);
98+
vin3 = wasm_v128_load(&i3[12]);
99+
vacc0 = wasm_f32x4_add(vin0, vacc0);
100+
vacc1 = wasm_f32x4_add(vin1, vacc1);
101+
vacc2 = wasm_f32x4_add(vin2, vacc2);
102+
vacc3 = wasm_f32x4_add(vin3, vacc3);
103+
vin0 = wasm_v128_load(&i4[0]);
104+
vin1 = wasm_v128_load(&i4[4]);
105+
vin2 = wasm_v128_load(&i4[8]);
106+
vin3 = wasm_v128_load(&i4[12]);
107+
vacc0 = wasm_f32x4_add(vin0, vacc0);
108+
vacc1 = wasm_f32x4_add(vin1, vacc1);
109+
vacc2 = wasm_f32x4_add(vin2, vacc2);
110+
vacc3 = wasm_f32x4_add(vin3, vacc3);
111+
vin0 = wasm_v128_load(&i5[0]);
112+
vin1 = wasm_v128_load(&i5[4]);
113+
vin2 = wasm_v128_load(&i5[8]);
114+
vin3 = wasm_v128_load(&i5[12]);
115+
vacc0 = wasm_f32x4_add(vin0, vacc0);
116+
vacc1 = wasm_f32x4_add(vin1, vacc1);
117+
vacc2 = wasm_f32x4_add(vin2, vacc2);
118+
vacc3 = wasm_f32x4_add(vin3, vacc3);
119+
vin0 = wasm_v128_load(&i6[0]);
120+
vin1 = wasm_v128_load(&i6[4]);
121+
vin2 = wasm_v128_load(&i6[8]);
122+
vin3 = wasm_v128_load(&i6[12]);
123+
vacc0 = wasm_f32x4_add(vin0, vacc0);
124+
vacc1 = wasm_f32x4_add(vin1, vacc1);
125+
vacc2 = wasm_f32x4_add(vin2, vacc2);
126+
vacc3 = wasm_f32x4_add(vin3, vacc3);
127+
i0 = (const float*) ((uintptr_t) i0 + input_increment);
128+
i1 = (const float*) ((uintptr_t) i1 + input_increment);
129+
i2 = (const float*) ((uintptr_t) i2 + input_increment);
130+
i3 = (const float*) ((uintptr_t) i3 + input_increment);
131+
i4 = (const float*) ((uintptr_t) i4 + input_increment);
132+
i5 = (const float*) ((uintptr_t) i5 + input_increment);
133+
i6 = (const float*) ((uintptr_t) i6 + input_increment);
134+
}
135+
vacc0 = wasm_f32x4_mul(vacc0, vscale);
136+
vacc1 = wasm_f32x4_mul(vacc1, vscale);
137+
vacc2 = wasm_f32x4_mul(vacc2, vscale);
138+
vacc3 = wasm_f32x4_mul(vacc3, vscale);
139+
140+
const float* o = output;
141+
v128_t vo0 = wasm_v128_load(o); o += 4;
142+
v128_t vo1 = wasm_v128_load(o); o += 4;
143+
v128_t vo2 = wasm_v128_load(o); o += 4;
144+
v128_t vo3 = wasm_v128_load(o); o += 4;
145+
vacc0 = wasm_f32x4_add(vo0, vacc0);
146+
vacc1 = wasm_f32x4_add(vo1, vacc1);
147+
vacc2 = wasm_f32x4_add(vo2, vacc2);
148+
vacc3 = wasm_f32x4_add(vo3, vacc3);
149+
wasm_v128_store(output, vacc0); output += 4;
150+
wasm_v128_store(output, vacc1); output += 4;
151+
wasm_v128_store(output, vacc2); output += 4;
152+
wasm_v128_store(output, vacc3); output += 4;
153+
154+
input = (const float*) ((uintptr_t) input + 16 * sizeof(float));
155+
}
156+
if (channels != 0) {
157+
input_increment = 7 * input_stride;
158+
const float* i0 = input;
159+
const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride);
160+
const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride);
161+
const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride);
162+
const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride);
163+
const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride);
164+
const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride);
165+
v128_t vacc[4];
166+
vacc[0] = wasm_i32x4_const_splat(0.f);
167+
vacc[1] = wasm_i32x4_const_splat(0.f);
168+
vacc[2] = wasm_i32x4_const_splat(0.f);
169+
vacc[3] = wasm_i32x4_const_splat(0.f);
170+
171+
const size_t num_chunks = round_up_po2(channels, 4) >> 2;
172+
for (int r = rows; r > 0; r -= 7) {
173+
if XNN_UNPREDICTABLE(r < 2) {
174+
i1 = zero;
175+
}
176+
if XNN_UNPREDICTABLE(r <= 2) {
177+
i2 = zero;
178+
}
179+
if XNN_UNPREDICTABLE(r < 4) {
180+
i3 = zero;
181+
}
182+
if XNN_UNPREDICTABLE(r <= 4) {
183+
i4 = zero;
184+
}
185+
if XNN_UNPREDICTABLE(r < 6) {
186+
i5 = zero;
187+
}
188+
if XNN_UNPREDICTABLE(r <= 6) {
189+
i6 = zero;
190+
}
191+
for (int i = 0; i < num_chunks; ++i) {
192+
vacc[i] = wasm_f32x4_add(wasm_v128_load(&i0[i*4]), vacc[i]);
193+
vacc[i] = wasm_f32x4_add(wasm_v128_load(&i1[i*4]), vacc[i]);
194+
vacc[i] = wasm_f32x4_add(wasm_v128_load(&i2[i*4]), vacc[i]);
195+
vacc[i] = wasm_f32x4_add(wasm_v128_load(&i3[i*4]), vacc[i]);
196+
vacc[i] = wasm_f32x4_add(wasm_v128_load(&i4[i*4]), vacc[i]);
197+
vacc[i] = wasm_f32x4_add(wasm_v128_load(&i5[i*4]), vacc[i]);
198+
vacc[i] = wasm_f32x4_add(wasm_v128_load(&i6[i*4]), vacc[i]);
199+
}
200+
i0 = (const float*) ((uintptr_t) i0 + input_increment);
201+
i1 = (const float*) ((uintptr_t) i1 + input_increment);
202+
i2 = (const float*) ((uintptr_t) i2 + input_increment);
203+
i3 = (const float*) ((uintptr_t) i3 + input_increment);
204+
i4 = (const float*) ((uintptr_t) i4 + input_increment);
205+
i5 = (const float*) ((uintptr_t) i5 + input_increment);
206+
i6 = (const float*) ((uintptr_t) i6 + input_increment);
207+
}
208+
for (int i = 0; i < num_chunks; ++i) {
209+
vacc[i] = wasm_f32x4_mul(vacc[i], vscale);
210+
}
211+
212+
v128_t vo[4];
213+
const float* o = output;
214+
for (int i = 0; i < channels >> 2; ++i) {
215+
vo[i] = wasm_v128_load(o); o += 4;
216+
}
217+
for (int i = 0; i < channels >> 2; ++i) {
218+
vacc[i] = wasm_f32x4_add(vo[i], vacc[i]);
219+
}
220+
for (int i = 0; i < channels >> 2; ++i) {
221+
wasm_v128_store(output, vacc[i]); output += 4;
222+
}
223+
const size_t pos = channels / 4;
224+
v128_t vout = vacc[pos];
225+
if (channels & 2) {
226+
v128_t vo = wasm_f32x4_make(output[0], output[1], 0.f, 0.f);
227+
wasm_v128_store64_lane(output, wasm_f32x4_add(vo, vout), 0);
228+
vout = wasm_v64x2_shuffle(vout, vout, 1, 1);
229+
output += 2;
230+
}
231+
if (channels & 1) {
232+
*output += wasm_f32x4_extract_lane(vout, 0);
233+
}
234+
}
235+
}

0 commit comments

Comments
 (0)