Skip to content

Commit 385e9f9

Browse files
fbarchardxnnpack-bot
authored andcommitted
neon mlal qs8 rsum use addw instead of mlal
- remove vone which was being lengthened with mlal and then multiplied by the input - use sliced accumulators for 16 bit accumulation - rename functions from neon_mlal to neon_addw - add const to mask variable in remainder handler for addw and neondot microkernels PiperOrigin-RevId: 634595235
1 parent aee04e1 commit 385e9f9

23 files changed

+335
-321
lines changed

bench/qs8-rsum-minmax-fp32.cc

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ BENCHMARK_CAPTURE(qs8_rsum, scalar_imagic_u4,
3838
->UseRealTime();
3939

4040
#if XNN_ARCH_ARM || XNN_ARCH_ARM64
41-
BENCHMARK_CAPTURE(qs8_rsum, neon_mlal_u16,
42-
xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u16,
41+
BENCHMARK_CAPTURE(qs8_rsum, neon_addw_u16,
42+
xnn_qs8_rsum_minmax_fp32_ukernel__neon_addw_u16,
4343
xnn_init_qs8_avgpool_minmax_fp32_neon_params,
4444
benchmark::utils::CheckNEON)
4545
->Apply(BenchmarkRSUM)
@@ -48,8 +48,8 @@ BENCHMARK_CAPTURE(qs8_rsum, scalar_imagic_u4,
4848

4949

5050
#if XNN_ARCH_ARM || XNN_ARCH_ARM64
51-
BENCHMARK_CAPTURE(qs8_rsum, neon_mlal_u32,
52-
xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u32,
51+
BENCHMARK_CAPTURE(qs8_rsum, neon_addw_u32,
52+
xnn_qs8_rsum_minmax_fp32_ukernel__neon_addw_u32,
5353
xnn_init_qs8_avgpool_minmax_fp32_neon_params,
5454
benchmark::utils::CheckNEON)
5555
->Apply(BenchmarkRSUM)
@@ -58,8 +58,8 @@ BENCHMARK_CAPTURE(qs8_rsum, scalar_imagic_u4,
5858

5959

6060
#if XNN_ARCH_ARM || XNN_ARCH_ARM64
61-
BENCHMARK_CAPTURE(qs8_rsum, neon_mlal_u64,
62-
xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u64,
61+
BENCHMARK_CAPTURE(qs8_rsum, neon_addw_u64,
62+
xnn_qs8_rsum_minmax_fp32_ukernel__neon_addw_u64,
6363
xnn_init_qs8_avgpool_minmax_fp32_neon_params,
6464
benchmark::utils::CheckNEON)
6565
->Apply(BenchmarkRSUM)
@@ -68,8 +68,8 @@ BENCHMARK_CAPTURE(qs8_rsum, scalar_imagic_u4,
6868

6969

7070
#if XNN_ARCH_ARM || XNN_ARCH_ARM64
71-
BENCHMARK_CAPTURE(qs8_rsum, neon_mlal_u16_acc2,
72-
xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u16_acc2,
71+
BENCHMARK_CAPTURE(qs8_rsum, neon_addw_u16_acc2,
72+
xnn_qs8_rsum_minmax_fp32_ukernel__neon_addw_u16_acc2,
7373
xnn_init_qs8_avgpool_minmax_fp32_neon_params,
7474
benchmark::utils::CheckNEON)
7575
->Apply(BenchmarkRSUM)
@@ -78,8 +78,8 @@ BENCHMARK_CAPTURE(qs8_rsum, scalar_imagic_u4,
7878

7979

8080
#if XNN_ARCH_ARM || XNN_ARCH_ARM64
81-
BENCHMARK_CAPTURE(qs8_rsum, neon_mlal_u32_acc2,
82-
xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u32_acc2,
81+
BENCHMARK_CAPTURE(qs8_rsum, neon_addw_u32_acc2,
82+
xnn_qs8_rsum_minmax_fp32_ukernel__neon_addw_u32_acc2,
8383
xnn_init_qs8_avgpool_minmax_fp32_neon_params,
8484
benchmark::utils::CheckNEON)
8585
->Apply(BenchmarkRSUM)
@@ -88,8 +88,8 @@ BENCHMARK_CAPTURE(qs8_rsum, scalar_imagic_u4,
8888

8989

9090
#if XNN_ARCH_ARM || XNN_ARCH_ARM64
91-
BENCHMARK_CAPTURE(qs8_rsum, neon_mlal_u64_acc2,
92-
xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u64_acc2,
91+
BENCHMARK_CAPTURE(qs8_rsum, neon_addw_u64_acc2,
92+
xnn_qs8_rsum_minmax_fp32_ukernel__neon_addw_u64_acc2,
9393
xnn_init_qs8_avgpool_minmax_fp32_neon_params,
9494
benchmark::utils::CheckNEON)
9595
->Apply(BenchmarkRSUM)
@@ -98,8 +98,8 @@ BENCHMARK_CAPTURE(qs8_rsum, scalar_imagic_u4,
9898

9999

100100
#if XNN_ARCH_ARM || XNN_ARCH_ARM64
101-
BENCHMARK_CAPTURE(qs8_rsum, neon_mlal_u32_acc4,
102-
xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u32_acc4,
101+
BENCHMARK_CAPTURE(qs8_rsum, neon_addw_u32_acc4,
102+
xnn_qs8_rsum_minmax_fp32_ukernel__neon_addw_u32_acc4,
103103
xnn_init_qs8_avgpool_minmax_fp32_neon_params,
104104
benchmark::utils::CheckNEON)
105105
->Apply(BenchmarkRSUM)
@@ -108,8 +108,8 @@ BENCHMARK_CAPTURE(qs8_rsum, scalar_imagic_u4,
108108

109109

110110
#if XNN_ARCH_ARM || XNN_ARCH_ARM64
111-
BENCHMARK_CAPTURE(qs8_rsum, neon_mlal_u64_acc4,
112-
xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u64_acc4,
111+
BENCHMARK_CAPTURE(qs8_rsum, neon_addw_u64_acc4,
112+
xnn_qs8_rsum_minmax_fp32_ukernel__neon_addw_u64_acc4,
113113
xnn_init_qs8_avgpool_minmax_fp32_neon_params,
114114
benchmark::utils::CheckNEON)
115115
->Apply(BenchmarkRSUM)

cmake/gen/neon_microkernels.cmake

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -684,14 +684,14 @@ SET(ALL_NEON_MICROKERNEL_SRCS
684684
src/qs8-requantization/qs8-requantization-rndna-neon.c
685685
src/qs8-requantization/qs8-requantization-rndnu-neon-mull.c
686686
src/qs8-requantization/qs8-requantization-rndnu-neon-qdmulh.c
687-
src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u16-acc2.c
688-
src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u16.c
689-
src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u32-acc2.c
690-
src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u32-acc4.c
691-
src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u32.c
692-
src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u64-acc2.c
693-
src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u64-acc4.c
694-
src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u64.c
687+
src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u16-acc2.c
688+
src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u16.c
689+
src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u32-acc2.c
690+
src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u32-acc4.c
691+
src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u32.c
692+
src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u64-acc2.c
693+
src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u64-acc4.c
694+
src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u64.c
695695
src/qs8-vadd/gen/qs8-vadd-minmax-neon-ld64-u8.c
696696
src/qs8-vadd/gen/qs8-vadd-minmax-neon-ld64-u16.c
697697
src/qs8-vadd/gen/qs8-vadd-minmax-neon-ld64-u24.c

gen/neon_microkernels.bzl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -680,14 +680,14 @@ ALL_NEON_MICROKERNEL_SRCS = [
680680
"src/qs8-requantization/qs8-requantization-rndna-neon.c",
681681
"src/qs8-requantization/qs8-requantization-rndnu-neon-mull.c",
682682
"src/qs8-requantization/qs8-requantization-rndnu-neon-qdmulh.c",
683-
"src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u16-acc2.c",
684-
"src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u16.c",
685-
"src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u32-acc2.c",
686-
"src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u32-acc4.c",
687-
"src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u32.c",
688-
"src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u64-acc2.c",
689-
"src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u64-acc4.c",
690-
"src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u64.c",
683+
"src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u16-acc2.c",
684+
"src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u16.c",
685+
"src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u32-acc2.c",
686+
"src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u32-acc4.c",
687+
"src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u32.c",
688+
"src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u64-acc2.c",
689+
"src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u64-acc4.c",
690+
"src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u64.c",
691691
"src/qs8-vadd/gen/qs8-vadd-minmax-neon-ld64-u8.c",
692692
"src/qs8-vadd/gen/qs8-vadd-minmax-neon-ld64-u16.c",
693693
"src/qs8-vadd/gen/qs8-vadd-minmax-neon-ld64-u24.c",

scripts/generate-qs8-rsum.sh

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,16 @@ tools/xngen src/qs8-rsum/scalar.c.in -D CHANNEL_TILE=2 -D ACCUMULATORS=1 -D REQU
1010
tools/xngen src/qs8-rsum/scalar.c.in -D CHANNEL_TILE=4 -D ACCUMULATORS=1 -D REQUANTIZATION=FP32 -D VARIANT=IMAGIC -D WASM=0 -o src/qs8-rsum/gen/qs8-rdsum-minmax-fp32-scalar-imagic-u4-acc1.c &
1111

1212
################################## ARM NEON ###################################
13-
tools/xngen src/qs8-rsum/neon-mlal.c.in -D ACCUMULATORS=1 -D CHANNEL_TILE=16 -D REQUANTIZATION=FP32 -D ARMV8=0 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u16.c &
14-
tools/xngen src/qs8-rsum/neon-mlal.c.in -D ACCUMULATORS=1 -D CHANNEL_TILE=32 -D REQUANTIZATION=FP32 -D ARMV8=0 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u32.c &
15-
tools/xngen src/qs8-rsum/neon-mlal.c.in -D ACCUMULATORS=1 -D CHANNEL_TILE=64 -D REQUANTIZATION=FP32 -D ARMV8=0 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u64.c &
13+
tools/xngen src/qs8-rsum/neon-addw.c.in -D ACCUMULATORS=1 -D CHANNEL_TILE=16 -D REQUANTIZATION=FP32 -D ARMV8=0 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u16.c &
14+
tools/xngen src/qs8-rsum/neon-addw.c.in -D ACCUMULATORS=1 -D CHANNEL_TILE=32 -D REQUANTIZATION=FP32 -D ARMV8=0 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u32.c &
15+
tools/xngen src/qs8-rsum/neon-addw.c.in -D ACCUMULATORS=1 -D CHANNEL_TILE=64 -D REQUANTIZATION=FP32 -D ARMV8=0 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u64.c &
1616

17-
tools/xngen src/qs8-rsum/neon-mlal.c.in -D ACCUMULATORS=2 -D CHANNEL_TILE=16 -D REQUANTIZATION=FP32 -D ARMV8=0 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u16-acc2.c &
18-
tools/xngen src/qs8-rsum/neon-mlal.c.in -D ACCUMULATORS=2 -D CHANNEL_TILE=32 -D REQUANTIZATION=FP32 -D ARMV8=0 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u32-acc2.c &
19-
tools/xngen src/qs8-rsum/neon-mlal.c.in -D ACCUMULATORS=2 -D CHANNEL_TILE=64 -D REQUANTIZATION=FP32 -D ARMV8=0 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u64-acc2.c &
17+
tools/xngen src/qs8-rsum/neon-addw.c.in -D ACCUMULATORS=2 -D CHANNEL_TILE=16 -D REQUANTIZATION=FP32 -D ARMV8=0 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u16-acc2.c &
18+
tools/xngen src/qs8-rsum/neon-addw.c.in -D ACCUMULATORS=2 -D CHANNEL_TILE=32 -D REQUANTIZATION=FP32 -D ARMV8=0 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u32-acc2.c &
19+
tools/xngen src/qs8-rsum/neon-addw.c.in -D ACCUMULATORS=2 -D CHANNEL_TILE=64 -D REQUANTIZATION=FP32 -D ARMV8=0 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u64-acc2.c &
2020

21-
tools/xngen src/qs8-rsum/neon-mlal.c.in -D ACCUMULATORS=4 -D CHANNEL_TILE=32 -D REQUANTIZATION=FP32 -D ARMV8=0 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u32-acc4.c &
22-
tools/xngen src/qs8-rsum/neon-mlal.c.in -D ACCUMULATORS=4 -D CHANNEL_TILE=64 -D REQUANTIZATION=FP32 -D ARMV8=0 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u64-acc4.c &
21+
tools/xngen src/qs8-rsum/neon-addw.c.in -D ACCUMULATORS=4 -D CHANNEL_TILE=32 -D REQUANTIZATION=FP32 -D ARMV8=0 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u32-acc4.c &
22+
tools/xngen src/qs8-rsum/neon-addw.c.in -D ACCUMULATORS=4 -D CHANNEL_TILE=64 -D REQUANTIZATION=FP32 -D ARMV8=0 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u64-acc4.c &
2323

2424
tools/xngen src/qs8-rsum/neondot.c.in -D ACCUMULATORS=1 -D CHANNEL_TILE=16 -D REQUANTIZATION=FP32 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neondot-u16.c &
2525
tools/xngen src/qs8-rsum/neondot.c.in -D ACCUMULATORS=1 -D CHANNEL_TILE=32 -D REQUANTIZATION=FP32 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neondot-u32.c &

src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u16-acc2.c renamed to src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u16-acc2.c

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// Auto-generated file. Do not edit!
2-
// Template: src/qs8-rsum/neon-mlal.c.in
2+
// Template: src/qs8-rsum/neon-addw.c.in
33
// Generator: tools/xngen
44
//
55
// Copyright 2024 Google LLC
@@ -15,7 +15,7 @@
1515
#include <xnnpack/math.h>
1616
#include <xnnpack/reduce.h>
1717

18-
void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u16_acc2(
18+
void xnn_qs8_rsum_minmax_fp32_ukernel__neon_addw_u16_acc2(
1919
size_t batch,
2020
const int8_t* input,
2121
int8_t* output,
@@ -25,7 +25,8 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u16_acc2(
2525
assert(input != NULL);
2626
assert(output != NULL);
2727

28-
int8x8_t vone = vdup_n_s8(1);
28+
// 256 int8s may be summed into an int16 before overflowing.
29+
// There are 8 lanes in the accumulator register and 2 registers.
2930
int num_batches = batch >> 9;
3031
int32x4_t vacc0 = vmovq_n_s32(0);
3132
int32x4_t vacc1 = vmovq_n_s32(0);
@@ -36,8 +37,8 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u16_acc2(
3637
const int8x8_t vt0 = vld1_s8(input); input += 8;
3738
const int8x8_t vt1 = vld1_s8(input); input += 8;
3839

39-
vacc16_0 = vmlal_s8(vacc16_0, vt0, vone);
40-
vacc16_1 = vmlal_s8(vacc16_1, vt1, vone);
40+
vacc16_0 = vaddw_s8(vacc16_0, vt0);
41+
vacc16_1 = vaddw_s8(vacc16_1, vt1);
4142
}
4243
vacc0 = vaddq_s32(vacc0, vaddq_s32(vmovl_s16(vget_low_s16(vacc16_0)), vmovl_s16(vget_high_s16(vacc16_0))));
4344
vacc1 = vaddq_s32(vacc1, vaddq_s32(vmovl_s16(vget_low_s16(vacc16_1)), vmovl_s16(vget_high_s16(vacc16_1))));
@@ -49,18 +50,18 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u16_acc2(
4950
for (; batch >= 16; batch -= 16) {
5051
const int8x8_t vt0 = vld1_s8(input); input += 8;
5152
const int8x8_t vt1 = vld1_s8(input); input += 8;
52-
vacc16_0 = vmlal_s8(vacc16_0, vt0, vone);
53-
vacc16_1 = vmlal_s8(vacc16_1, vt1, vone);
53+
vacc16_0 = vaddw_s8(vacc16_0, vt0);
54+
vacc16_1 = vaddw_s8(vacc16_1, vt1);
5455
}
5556
vacc16_0 = vaddq_s16(vacc16_0, vacc16_1);
5657
for (; batch >= 8; batch -= 8) {
5758
const int8x8_t vt = vld1_s8(input); input += 8;
58-
vacc16_0 = vmlal_s8(vacc16_0, vt, vone);
59+
vacc16_0 = vaddw_s8(vacc16_0, vt);
5960
}
6061
if (XNN_UNLIKELY(batch != 0)) {
61-
int8x8_t vt = vld1_s8(input);
62-
vone = vld1_s8(&params->fp32_neon.mask_table[15 - batch]);
63-
vacc16_0 = vmlal_s8(vacc16_0, vt, vone);
62+
const int8x8_t vt = vld1_s8(input);
63+
const int8x8_t vmask = vld1_s8(&params->fp32_neon.mask_table[15 - batch]);
64+
vacc16_0 = vmlal_s8(vacc16_0, vt, vmask);
6465
}
6566
vacc0 = vaddq_s32(vacc0, vaddq_s32(vmovl_s16(vget_low_s16(vacc16_0)), vmovl_s16(vget_high_s16(vacc16_0))));
6667
}

src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u16.c renamed to src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u16.c

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// Auto-generated file. Do not edit!
2-
// Template: src/qs8-rsum/neon-mlal.c.in
2+
// Template: src/qs8-rsum/neon-addw.c.in
33
// Generator: tools/xngen
44
//
55
// Copyright 2024 Google LLC
@@ -15,7 +15,7 @@
1515
#include <xnnpack/math.h>
1616
#include <xnnpack/reduce.h>
1717

18-
void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u16(
18+
void xnn_qs8_rsum_minmax_fp32_ukernel__neon_addw_u16(
1919
size_t batch,
2020
const int8_t* input,
2121
int8_t* output,
@@ -25,7 +25,8 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u16(
2525
assert(input != NULL);
2626
assert(output != NULL);
2727

28-
int8x8_t vone = vdup_n_s8(1);
28+
// 256 int8s may be summed into an int16 before overflowing.
29+
// There are 8 lanes in the accumulator register and 1 registers.
2930
int num_batches = batch >> 8;
3031
int32x4_t vacc0 = vmovq_n_s32(0);
3132
for (; num_batches > 0; --num_batches) {
@@ -34,8 +35,8 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u16(
3435
const int8x8_t vt0 = vld1_s8(input); input += 8;
3536
const int8x8_t vt1 = vld1_s8(input); input += 8;
3637

37-
vacc16_0 = vmlal_s8(vacc16_0, vt0, vone);
38-
vacc16_0 = vmlal_s8(vacc16_0, vt1, vone);
38+
vacc16_0 = vaddw_s8(vacc16_0, vt0);
39+
vacc16_0 = vaddw_s8(vacc16_0, vt1);
3940
}
4041
vacc0 = vaddq_s32(vacc0, vaddq_s32(vmovl_s16(vget_low_s16(vacc16_0)), vmovl_s16(vget_high_s16(vacc16_0))));
4142
batch -= 256;
@@ -45,17 +46,17 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u16(
4546
for (; batch >= 16; batch -= 16) {
4647
const int8x8_t vt0 = vld1_s8(input); input += 8;
4748
const int8x8_t vt1 = vld1_s8(input); input += 8;
48-
vacc16_0 = vmlal_s8(vacc16_0, vt0, vone);
49-
vacc16_0 = vmlal_s8(vacc16_0, vt1, vone);
49+
vacc16_0 = vaddw_s8(vacc16_0, vt0);
50+
vacc16_0 = vaddw_s8(vacc16_0, vt1);
5051
}
5152
for (; batch >= 8; batch -= 8) {
5253
const int8x8_t vt = vld1_s8(input); input += 8;
53-
vacc16_0 = vmlal_s8(vacc16_0, vt, vone);
54+
vacc16_0 = vaddw_s8(vacc16_0, vt);
5455
}
5556
if (XNN_UNLIKELY(batch != 0)) {
56-
int8x8_t vt = vld1_s8(input);
57-
vone = vld1_s8(&params->fp32_neon.mask_table[15 - batch]);
58-
vacc16_0 = vmlal_s8(vacc16_0, vt, vone);
57+
const int8x8_t vt = vld1_s8(input);
58+
const int8x8_t vmask = vld1_s8(&params->fp32_neon.mask_table[15 - batch]);
59+
vacc16_0 = vmlal_s8(vacc16_0, vt, vmask);
5960
}
6061
vacc0 = vaddq_s32(vacc0, vaddq_s32(vmovl_s16(vget_low_s16(vacc16_0)), vmovl_s16(vget_high_s16(vacc16_0))));
6162
}

src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u32-acc2.c renamed to src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u32-acc2.c

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// Auto-generated file. Do not edit!
2-
// Template: src/qs8-rsum/neon-mlal.c.in
2+
// Template: src/qs8-rsum/neon-addw.c.in
33
// Generator: tools/xngen
44
//
55
// Copyright 2024 Google LLC
@@ -15,7 +15,7 @@
1515
#include <xnnpack/math.h>
1616
#include <xnnpack/reduce.h>
1717

18-
void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u32_acc2(
18+
void xnn_qs8_rsum_minmax_fp32_ukernel__neon_addw_u32_acc2(
1919
size_t batch,
2020
const int8_t* input,
2121
int8_t* output,
@@ -25,7 +25,8 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u32_acc2(
2525
assert(input != NULL);
2626
assert(output != NULL);
2727

28-
int8x8_t vone = vdup_n_s8(1);
28+
// 256 int8s may be summed into an int16 before overflowing.
29+
// There are 8 lanes in the accumulator register and 2 registers.
2930
int num_batches = batch >> 9;
3031
int32x4_t vacc0 = vmovq_n_s32(0);
3132
int32x4_t vacc1 = vmovq_n_s32(0);
@@ -38,10 +39,10 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u32_acc2(
3839
const int8x8_t vt2 = vld1_s8(input); input += 8;
3940
const int8x8_t vt3 = vld1_s8(input); input += 8;
4041

41-
vacc16_0 = vmlal_s8(vacc16_0, vt0, vone);
42-
vacc16_1 = vmlal_s8(vacc16_1, vt1, vone);
43-
vacc16_0 = vmlal_s8(vacc16_0, vt2, vone);
44-
vacc16_1 = vmlal_s8(vacc16_1, vt3, vone);
42+
vacc16_0 = vaddw_s8(vacc16_0, vt0);
43+
vacc16_1 = vaddw_s8(vacc16_1, vt1);
44+
vacc16_0 = vaddw_s8(vacc16_0, vt2);
45+
vacc16_1 = vaddw_s8(vacc16_1, vt3);
4546
}
4647
vacc0 = vaddq_s32(vacc0, vaddq_s32(vmovl_s16(vget_low_s16(vacc16_0)), vmovl_s16(vget_high_s16(vacc16_0))));
4748
vacc1 = vaddq_s32(vacc1, vaddq_s32(vmovl_s16(vget_low_s16(vacc16_1)), vmovl_s16(vget_high_s16(vacc16_1))));
@@ -55,20 +56,20 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u32_acc2(
5556
const int8x8_t vt1 = vld1_s8(input); input += 8;
5657
const int8x8_t vt2 = vld1_s8(input); input += 8;
5758
const int8x8_t vt3 = vld1_s8(input); input += 8;
58-
vacc16_0 = vmlal_s8(vacc16_0, vt0, vone);
59-
vacc16_1 = vmlal_s8(vacc16_1, vt1, vone);
60-
vacc16_0 = vmlal_s8(vacc16_0, vt2, vone);
61-
vacc16_1 = vmlal_s8(vacc16_1, vt3, vone);
59+
vacc16_0 = vaddw_s8(vacc16_0, vt0);
60+
vacc16_1 = vaddw_s8(vacc16_1, vt1);
61+
vacc16_0 = vaddw_s8(vacc16_0, vt2);
62+
vacc16_1 = vaddw_s8(vacc16_1, vt3);
6263
}
6364
vacc16_0 = vaddq_s16(vacc16_0, vacc16_1);
6465
for (; batch >= 8; batch -= 8) {
6566
const int8x8_t vt = vld1_s8(input); input += 8;
66-
vacc16_0 = vmlal_s8(vacc16_0, vt, vone);
67+
vacc16_0 = vaddw_s8(vacc16_0, vt);
6768
}
6869
if (XNN_UNLIKELY(batch != 0)) {
69-
int8x8_t vt = vld1_s8(input);
70-
vone = vld1_s8(&params->fp32_neon.mask_table[15 - batch]);
71-
vacc16_0 = vmlal_s8(vacc16_0, vt, vone);
70+
const int8x8_t vt = vld1_s8(input);
71+
const int8x8_t vmask = vld1_s8(&params->fp32_neon.mask_table[15 - batch]);
72+
vacc16_0 = vmlal_s8(vacc16_0, vt, vmask);
7273
}
7374
vacc0 = vaddq_s32(vacc0, vaddq_s32(vmovl_s16(vget_low_s16(vacc16_0)), vmovl_s16(vget_high_s16(vacc16_0))));
7475
}

0 commit comments

Comments
 (0)