Skip to content

Commit

Permalink
neon mlal qs8 rsum use addw instead of mlal
Browse files Browse the repository at this point in the history
- remove vone which was being lengthened with mlal and then multiplied by the input
- use sliced accumulators for 16 bit accumulation
- rename functions from neon_mlal to neon_addw
- add const to mask variable in remainder handler for addw and neondot microkernels

PiperOrigin-RevId: 634595235
  • Loading branch information
fbarchard authored and xnnpack-bot committed May 18, 2024
1 parent aee04e1 commit 385e9f9
Show file tree
Hide file tree
Showing 23 changed files with 335 additions and 321 deletions.
32 changes: 16 additions & 16 deletions bench/qs8-rsum-minmax-fp32.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ BENCHMARK_CAPTURE(qs8_rsum, scalar_imagic_u4,
->UseRealTime();

#if XNN_ARCH_ARM || XNN_ARCH_ARM64
BENCHMARK_CAPTURE(qs8_rsum, neon_mlal_u16,
xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u16,
BENCHMARK_CAPTURE(qs8_rsum, neon_addw_u16,
xnn_qs8_rsum_minmax_fp32_ukernel__neon_addw_u16,
xnn_init_qs8_avgpool_minmax_fp32_neon_params,
benchmark::utils::CheckNEON)
->Apply(BenchmarkRSUM)
Expand All @@ -48,8 +48,8 @@ BENCHMARK_CAPTURE(qs8_rsum, scalar_imagic_u4,


#if XNN_ARCH_ARM || XNN_ARCH_ARM64
BENCHMARK_CAPTURE(qs8_rsum, neon_mlal_u32,
xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u32,
BENCHMARK_CAPTURE(qs8_rsum, neon_addw_u32,
xnn_qs8_rsum_minmax_fp32_ukernel__neon_addw_u32,
xnn_init_qs8_avgpool_minmax_fp32_neon_params,
benchmark::utils::CheckNEON)
->Apply(BenchmarkRSUM)
Expand All @@ -58,8 +58,8 @@ BENCHMARK_CAPTURE(qs8_rsum, scalar_imagic_u4,


#if XNN_ARCH_ARM || XNN_ARCH_ARM64
BENCHMARK_CAPTURE(qs8_rsum, neon_mlal_u64,
xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u64,
BENCHMARK_CAPTURE(qs8_rsum, neon_addw_u64,
xnn_qs8_rsum_minmax_fp32_ukernel__neon_addw_u64,
xnn_init_qs8_avgpool_minmax_fp32_neon_params,
benchmark::utils::CheckNEON)
->Apply(BenchmarkRSUM)
Expand All @@ -68,8 +68,8 @@ BENCHMARK_CAPTURE(qs8_rsum, scalar_imagic_u4,


#if XNN_ARCH_ARM || XNN_ARCH_ARM64
BENCHMARK_CAPTURE(qs8_rsum, neon_mlal_u16_acc2,
xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u16_acc2,
BENCHMARK_CAPTURE(qs8_rsum, neon_addw_u16_acc2,
xnn_qs8_rsum_minmax_fp32_ukernel__neon_addw_u16_acc2,
xnn_init_qs8_avgpool_minmax_fp32_neon_params,
benchmark::utils::CheckNEON)
->Apply(BenchmarkRSUM)
Expand All @@ -78,8 +78,8 @@ BENCHMARK_CAPTURE(qs8_rsum, scalar_imagic_u4,


#if XNN_ARCH_ARM || XNN_ARCH_ARM64
BENCHMARK_CAPTURE(qs8_rsum, neon_mlal_u32_acc2,
xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u32_acc2,
BENCHMARK_CAPTURE(qs8_rsum, neon_addw_u32_acc2,
xnn_qs8_rsum_minmax_fp32_ukernel__neon_addw_u32_acc2,
xnn_init_qs8_avgpool_minmax_fp32_neon_params,
benchmark::utils::CheckNEON)
->Apply(BenchmarkRSUM)
Expand All @@ -88,8 +88,8 @@ BENCHMARK_CAPTURE(qs8_rsum, scalar_imagic_u4,


#if XNN_ARCH_ARM || XNN_ARCH_ARM64
BENCHMARK_CAPTURE(qs8_rsum, neon_mlal_u64_acc2,
xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u64_acc2,
BENCHMARK_CAPTURE(qs8_rsum, neon_addw_u64_acc2,
xnn_qs8_rsum_minmax_fp32_ukernel__neon_addw_u64_acc2,
xnn_init_qs8_avgpool_minmax_fp32_neon_params,
benchmark::utils::CheckNEON)
->Apply(BenchmarkRSUM)
Expand All @@ -98,8 +98,8 @@ BENCHMARK_CAPTURE(qs8_rsum, scalar_imagic_u4,


#if XNN_ARCH_ARM || XNN_ARCH_ARM64
BENCHMARK_CAPTURE(qs8_rsum, neon_mlal_u32_acc4,
xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u32_acc4,
BENCHMARK_CAPTURE(qs8_rsum, neon_addw_u32_acc4,
xnn_qs8_rsum_minmax_fp32_ukernel__neon_addw_u32_acc4,
xnn_init_qs8_avgpool_minmax_fp32_neon_params,
benchmark::utils::CheckNEON)
->Apply(BenchmarkRSUM)
Expand All @@ -108,8 +108,8 @@ BENCHMARK_CAPTURE(qs8_rsum, scalar_imagic_u4,


#if XNN_ARCH_ARM || XNN_ARCH_ARM64
BENCHMARK_CAPTURE(qs8_rsum, neon_mlal_u64_acc4,
xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u64_acc4,
BENCHMARK_CAPTURE(qs8_rsum, neon_addw_u64_acc4,
xnn_qs8_rsum_minmax_fp32_ukernel__neon_addw_u64_acc4,
xnn_init_qs8_avgpool_minmax_fp32_neon_params,
benchmark::utils::CheckNEON)
->Apply(BenchmarkRSUM)
Expand Down
16 changes: 8 additions & 8 deletions cmake/gen/neon_microkernels.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -684,14 +684,14 @@ SET(ALL_NEON_MICROKERNEL_SRCS
src/qs8-requantization/qs8-requantization-rndna-neon.c
src/qs8-requantization/qs8-requantization-rndnu-neon-mull.c
src/qs8-requantization/qs8-requantization-rndnu-neon-qdmulh.c
src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u16-acc2.c
src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u16.c
src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u32-acc2.c
src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u32-acc4.c
src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u32.c
src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u64-acc2.c
src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u64-acc4.c
src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u64.c
src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u16-acc2.c
src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u16.c
src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u32-acc2.c
src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u32-acc4.c
src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u32.c
src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u64-acc2.c
src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u64-acc4.c
src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u64.c
src/qs8-vadd/gen/qs8-vadd-minmax-neon-ld64-u8.c
src/qs8-vadd/gen/qs8-vadd-minmax-neon-ld64-u16.c
src/qs8-vadd/gen/qs8-vadd-minmax-neon-ld64-u24.c
Expand Down
16 changes: 8 additions & 8 deletions gen/neon_microkernels.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -680,14 +680,14 @@ ALL_NEON_MICROKERNEL_SRCS = [
"src/qs8-requantization/qs8-requantization-rndna-neon.c",
"src/qs8-requantization/qs8-requantization-rndnu-neon-mull.c",
"src/qs8-requantization/qs8-requantization-rndnu-neon-qdmulh.c",
"src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u16-acc2.c",
"src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u16.c",
"src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u32-acc2.c",
"src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u32-acc4.c",
"src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u32.c",
"src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u64-acc2.c",
"src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u64-acc4.c",
"src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u64.c",
"src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u16-acc2.c",
"src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u16.c",
"src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u32-acc2.c",
"src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u32-acc4.c",
"src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u32.c",
"src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u64-acc2.c",
"src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u64-acc4.c",
"src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u64.c",
"src/qs8-vadd/gen/qs8-vadd-minmax-neon-ld64-u8.c",
"src/qs8-vadd/gen/qs8-vadd-minmax-neon-ld64-u16.c",
"src/qs8-vadd/gen/qs8-vadd-minmax-neon-ld64-u24.c",
Expand Down
16 changes: 8 additions & 8 deletions scripts/generate-qs8-rsum.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@ tools/xngen src/qs8-rsum/scalar.c.in -D CHANNEL_TILE=2 -D ACCUMULATORS=1 -D REQU
tools/xngen src/qs8-rsum/scalar.c.in -D CHANNEL_TILE=4 -D ACCUMULATORS=1 -D REQUANTIZATION=FP32 -D VARIANT=IMAGIC -D WASM=0 -o src/qs8-rsum/gen/qs8-rdsum-minmax-fp32-scalar-imagic-u4-acc1.c &

################################## ARM NEON ###################################
tools/xngen src/qs8-rsum/neon-mlal.c.in -D ACCUMULATORS=1 -D CHANNEL_TILE=16 -D REQUANTIZATION=FP32 -D ARMV8=0 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u16.c &
tools/xngen src/qs8-rsum/neon-mlal.c.in -D ACCUMULATORS=1 -D CHANNEL_TILE=32 -D REQUANTIZATION=FP32 -D ARMV8=0 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u32.c &
tools/xngen src/qs8-rsum/neon-mlal.c.in -D ACCUMULATORS=1 -D CHANNEL_TILE=64 -D REQUANTIZATION=FP32 -D ARMV8=0 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u64.c &
tools/xngen src/qs8-rsum/neon-addw.c.in -D ACCUMULATORS=1 -D CHANNEL_TILE=16 -D REQUANTIZATION=FP32 -D ARMV8=0 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u16.c &
tools/xngen src/qs8-rsum/neon-addw.c.in -D ACCUMULATORS=1 -D CHANNEL_TILE=32 -D REQUANTIZATION=FP32 -D ARMV8=0 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u32.c &
tools/xngen src/qs8-rsum/neon-addw.c.in -D ACCUMULATORS=1 -D CHANNEL_TILE=64 -D REQUANTIZATION=FP32 -D ARMV8=0 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u64.c &

tools/xngen src/qs8-rsum/neon-mlal.c.in -D ACCUMULATORS=2 -D CHANNEL_TILE=16 -D REQUANTIZATION=FP32 -D ARMV8=0 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u16-acc2.c &
tools/xngen src/qs8-rsum/neon-mlal.c.in -D ACCUMULATORS=2 -D CHANNEL_TILE=32 -D REQUANTIZATION=FP32 -D ARMV8=0 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u32-acc2.c &
tools/xngen src/qs8-rsum/neon-mlal.c.in -D ACCUMULATORS=2 -D CHANNEL_TILE=64 -D REQUANTIZATION=FP32 -D ARMV8=0 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u64-acc2.c &
tools/xngen src/qs8-rsum/neon-addw.c.in -D ACCUMULATORS=2 -D CHANNEL_TILE=16 -D REQUANTIZATION=FP32 -D ARMV8=0 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u16-acc2.c &
tools/xngen src/qs8-rsum/neon-addw.c.in -D ACCUMULATORS=2 -D CHANNEL_TILE=32 -D REQUANTIZATION=FP32 -D ARMV8=0 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u32-acc2.c &
tools/xngen src/qs8-rsum/neon-addw.c.in -D ACCUMULATORS=2 -D CHANNEL_TILE=64 -D REQUANTIZATION=FP32 -D ARMV8=0 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u64-acc2.c &

tools/xngen src/qs8-rsum/neon-mlal.c.in -D ACCUMULATORS=4 -D CHANNEL_TILE=32 -D REQUANTIZATION=FP32 -D ARMV8=0 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u32-acc4.c &
tools/xngen src/qs8-rsum/neon-mlal.c.in -D ACCUMULATORS=4 -D CHANNEL_TILE=64 -D REQUANTIZATION=FP32 -D ARMV8=0 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u64-acc4.c &
tools/xngen src/qs8-rsum/neon-addw.c.in -D ACCUMULATORS=4 -D CHANNEL_TILE=32 -D REQUANTIZATION=FP32 -D ARMV8=0 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u32-acc4.c &
tools/xngen src/qs8-rsum/neon-addw.c.in -D ACCUMULATORS=4 -D CHANNEL_TILE=64 -D REQUANTIZATION=FP32 -D ARMV8=0 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-addw-u64-acc4.c &

tools/xngen src/qs8-rsum/neondot.c.in -D ACCUMULATORS=1 -D CHANNEL_TILE=16 -D REQUANTIZATION=FP32 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neondot-u16.c &
tools/xngen src/qs8-rsum/neondot.c.in -D ACCUMULATORS=1 -D CHANNEL_TILE=32 -D REQUANTIZATION=FP32 -o src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neondot-u32.c &
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Auto-generated file. Do not edit!
// Template: src/qs8-rsum/neon-mlal.c.in
// Template: src/qs8-rsum/neon-addw.c.in
// Generator: tools/xngen
//
// Copyright 2024 Google LLC
Expand All @@ -15,7 +15,7 @@
#include <xnnpack/math.h>
#include <xnnpack/reduce.h>

void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u16_acc2(
void xnn_qs8_rsum_minmax_fp32_ukernel__neon_addw_u16_acc2(
size_t batch,
const int8_t* input,
int8_t* output,
Expand All @@ -25,7 +25,8 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u16_acc2(
assert(input != NULL);
assert(output != NULL);

int8x8_t vone = vdup_n_s8(1);
// 256 int8s may be summed into an int16 before overflowing.
// There are 8 lanes in the accumulator register and 2 registers.
int num_batches = batch >> 9;
int32x4_t vacc0 = vmovq_n_s32(0);
int32x4_t vacc1 = vmovq_n_s32(0);
Expand All @@ -36,8 +37,8 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u16_acc2(
const int8x8_t vt0 = vld1_s8(input); input += 8;
const int8x8_t vt1 = vld1_s8(input); input += 8;

vacc16_0 = vmlal_s8(vacc16_0, vt0, vone);
vacc16_1 = vmlal_s8(vacc16_1, vt1, vone);
vacc16_0 = vaddw_s8(vacc16_0, vt0);
vacc16_1 = vaddw_s8(vacc16_1, vt1);
}
vacc0 = vaddq_s32(vacc0, vaddq_s32(vmovl_s16(vget_low_s16(vacc16_0)), vmovl_s16(vget_high_s16(vacc16_0))));
vacc1 = vaddq_s32(vacc1, vaddq_s32(vmovl_s16(vget_low_s16(vacc16_1)), vmovl_s16(vget_high_s16(vacc16_1))));
Expand All @@ -49,18 +50,18 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u16_acc2(
for (; batch >= 16; batch -= 16) {
const int8x8_t vt0 = vld1_s8(input); input += 8;
const int8x8_t vt1 = vld1_s8(input); input += 8;
vacc16_0 = vmlal_s8(vacc16_0, vt0, vone);
vacc16_1 = vmlal_s8(vacc16_1, vt1, vone);
vacc16_0 = vaddw_s8(vacc16_0, vt0);
vacc16_1 = vaddw_s8(vacc16_1, vt1);
}
vacc16_0 = vaddq_s16(vacc16_0, vacc16_1);
for (; batch >= 8; batch -= 8) {
const int8x8_t vt = vld1_s8(input); input += 8;
vacc16_0 = vmlal_s8(vacc16_0, vt, vone);
vacc16_0 = vaddw_s8(vacc16_0, vt);
}
if (XNN_UNLIKELY(batch != 0)) {
int8x8_t vt = vld1_s8(input);
vone = vld1_s8(&params->fp32_neon.mask_table[15 - batch]);
vacc16_0 = vmlal_s8(vacc16_0, vt, vone);
const int8x8_t vt = vld1_s8(input);
const int8x8_t vmask = vld1_s8(&params->fp32_neon.mask_table[15 - batch]);
vacc16_0 = vmlal_s8(vacc16_0, vt, vmask);
}
vacc0 = vaddq_s32(vacc0, vaddq_s32(vmovl_s16(vget_low_s16(vacc16_0)), vmovl_s16(vget_high_s16(vacc16_0))));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Auto-generated file. Do not edit!
// Template: src/qs8-rsum/neon-mlal.c.in
// Template: src/qs8-rsum/neon-addw.c.in
// Generator: tools/xngen
//
// Copyright 2024 Google LLC
Expand All @@ -15,7 +15,7 @@
#include <xnnpack/math.h>
#include <xnnpack/reduce.h>

void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u16(
void xnn_qs8_rsum_minmax_fp32_ukernel__neon_addw_u16(
size_t batch,
const int8_t* input,
int8_t* output,
Expand All @@ -25,7 +25,8 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u16(
assert(input != NULL);
assert(output != NULL);

int8x8_t vone = vdup_n_s8(1);
// 256 int8s may be summed into an int16 before overflowing.
// There are 8 lanes in the accumulator register and 1 registers.
int num_batches = batch >> 8;
int32x4_t vacc0 = vmovq_n_s32(0);
for (; num_batches > 0; --num_batches) {
Expand All @@ -34,8 +35,8 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u16(
const int8x8_t vt0 = vld1_s8(input); input += 8;
const int8x8_t vt1 = vld1_s8(input); input += 8;

vacc16_0 = vmlal_s8(vacc16_0, vt0, vone);
vacc16_0 = vmlal_s8(vacc16_0, vt1, vone);
vacc16_0 = vaddw_s8(vacc16_0, vt0);
vacc16_0 = vaddw_s8(vacc16_0, vt1);
}
vacc0 = vaddq_s32(vacc0, vaddq_s32(vmovl_s16(vget_low_s16(vacc16_0)), vmovl_s16(vget_high_s16(vacc16_0))));
batch -= 256;
Expand All @@ -45,17 +46,17 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u16(
for (; batch >= 16; batch -= 16) {
const int8x8_t vt0 = vld1_s8(input); input += 8;
const int8x8_t vt1 = vld1_s8(input); input += 8;
vacc16_0 = vmlal_s8(vacc16_0, vt0, vone);
vacc16_0 = vmlal_s8(vacc16_0, vt1, vone);
vacc16_0 = vaddw_s8(vacc16_0, vt0);
vacc16_0 = vaddw_s8(vacc16_0, vt1);
}
for (; batch >= 8; batch -= 8) {
const int8x8_t vt = vld1_s8(input); input += 8;
vacc16_0 = vmlal_s8(vacc16_0, vt, vone);
vacc16_0 = vaddw_s8(vacc16_0, vt);
}
if (XNN_UNLIKELY(batch != 0)) {
int8x8_t vt = vld1_s8(input);
vone = vld1_s8(&params->fp32_neon.mask_table[15 - batch]);
vacc16_0 = vmlal_s8(vacc16_0, vt, vone);
const int8x8_t vt = vld1_s8(input);
const int8x8_t vmask = vld1_s8(&params->fp32_neon.mask_table[15 - batch]);
vacc16_0 = vmlal_s8(vacc16_0, vt, vmask);
}
vacc0 = vaddq_s32(vacc0, vaddq_s32(vmovl_s16(vget_low_s16(vacc16_0)), vmovl_s16(vget_high_s16(vacc16_0))));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Auto-generated file. Do not edit!
// Template: src/qs8-rsum/neon-mlal.c.in
// Template: src/qs8-rsum/neon-addw.c.in
// Generator: tools/xngen
//
// Copyright 2024 Google LLC
Expand All @@ -15,7 +15,7 @@
#include <xnnpack/math.h>
#include <xnnpack/reduce.h>

void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u32_acc2(
void xnn_qs8_rsum_minmax_fp32_ukernel__neon_addw_u32_acc2(
size_t batch,
const int8_t* input,
int8_t* output,
Expand All @@ -25,7 +25,8 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u32_acc2(
assert(input != NULL);
assert(output != NULL);

int8x8_t vone = vdup_n_s8(1);
// 256 int8s may be summed into an int16 before overflowing.
// There are 8 lanes in the accumulator register and 2 registers.
int num_batches = batch >> 9;
int32x4_t vacc0 = vmovq_n_s32(0);
int32x4_t vacc1 = vmovq_n_s32(0);
Expand All @@ -38,10 +39,10 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u32_acc2(
const int8x8_t vt2 = vld1_s8(input); input += 8;
const int8x8_t vt3 = vld1_s8(input); input += 8;

vacc16_0 = vmlal_s8(vacc16_0, vt0, vone);
vacc16_1 = vmlal_s8(vacc16_1, vt1, vone);
vacc16_0 = vmlal_s8(vacc16_0, vt2, vone);
vacc16_1 = vmlal_s8(vacc16_1, vt3, vone);
vacc16_0 = vaddw_s8(vacc16_0, vt0);
vacc16_1 = vaddw_s8(vacc16_1, vt1);
vacc16_0 = vaddw_s8(vacc16_0, vt2);
vacc16_1 = vaddw_s8(vacc16_1, vt3);
}
vacc0 = vaddq_s32(vacc0, vaddq_s32(vmovl_s16(vget_low_s16(vacc16_0)), vmovl_s16(vget_high_s16(vacc16_0))));
vacc1 = vaddq_s32(vacc1, vaddq_s32(vmovl_s16(vget_low_s16(vacc16_1)), vmovl_s16(vget_high_s16(vacc16_1))));
Expand All @@ -55,20 +56,20 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u32_acc2(
const int8x8_t vt1 = vld1_s8(input); input += 8;
const int8x8_t vt2 = vld1_s8(input); input += 8;
const int8x8_t vt3 = vld1_s8(input); input += 8;
vacc16_0 = vmlal_s8(vacc16_0, vt0, vone);
vacc16_1 = vmlal_s8(vacc16_1, vt1, vone);
vacc16_0 = vmlal_s8(vacc16_0, vt2, vone);
vacc16_1 = vmlal_s8(vacc16_1, vt3, vone);
vacc16_0 = vaddw_s8(vacc16_0, vt0);
vacc16_1 = vaddw_s8(vacc16_1, vt1);
vacc16_0 = vaddw_s8(vacc16_0, vt2);
vacc16_1 = vaddw_s8(vacc16_1, vt3);
}
vacc16_0 = vaddq_s16(vacc16_0, vacc16_1);
for (; batch >= 8; batch -= 8) {
const int8x8_t vt = vld1_s8(input); input += 8;
vacc16_0 = vmlal_s8(vacc16_0, vt, vone);
vacc16_0 = vaddw_s8(vacc16_0, vt);
}
if (XNN_UNLIKELY(batch != 0)) {
int8x8_t vt = vld1_s8(input);
vone = vld1_s8(&params->fp32_neon.mask_table[15 - batch]);
vacc16_0 = vmlal_s8(vacc16_0, vt, vone);
const int8x8_t vt = vld1_s8(input);
const int8x8_t vmask = vld1_s8(&params->fp32_neon.mask_table[15 - batch]);
vacc16_0 = vmlal_s8(vacc16_0, vt, vmask);
}
vacc0 = vaddq_s32(vacc0, vaddq_s32(vmovl_s16(vget_low_s16(vacc16_0)), vmovl_s16(vget_high_s16(vacc16_0))));
}
Expand Down

0 comments on commit 385e9f9

Please sign in to comment.