Skip to content

Commit

Permalink
neon mlal qs8 rsum use addw instead of mlal
Browse files Browse the repository at this point in the history
- remove vone which was being lengthened with mlal and then multiplied by the input
- use sliced accumulators for 16 bit accumulation
- add const to mask variable in remainder handler

PiperOrigin-RevId: 634595235
  • Loading branch information
fbarchard authored and xnnpack-bot committed May 18, 2024
1 parent 22ff33b commit 62d8873
Show file tree
Hide file tree
Showing 10 changed files with 146 additions and 132 deletions.
19 changes: 10 additions & 9 deletions src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u16-acc2.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u16_acc2(
assert(input != NULL);
assert(output != NULL);

int8x8_t vone = vdup_n_s8(1);
int num_batches = batch >> 9;
// 256 int8s may be summed into an int16 before overflowing.
// There are 8 lanes in the accumulator register and 2 registers.
int num_batches = batch >> 12;
int32x4_t vacc0 = vmovq_n_s32(0);
int32x4_t vacc1 = vmovq_n_s32(0);
for (; num_batches > 0; --num_batches) {
Expand All @@ -36,8 +37,8 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u16_acc2(
const int8x8_t vt0 = vld1_s8(input); input += 8;
const int8x8_t vt1 = vld1_s8(input); input += 8;

vacc16_0 = vmlal_s8(vacc16_0, vt0, vone);
vacc16_1 = vmlal_s8(vacc16_1, vt1, vone);
vacc16_0 = vaddw_s8(vacc16_0, vt0);
vacc16_1 = vaddw_s8(vacc16_1, vt1);
}
vacc0 = vaddq_s32(vacc0, vaddq_s32(vmovl_s16(vget_low_s16(vacc16_0)), vmovl_s16(vget_high_s16(vacc16_0))));
vacc1 = vaddq_s32(vacc1, vaddq_s32(vmovl_s16(vget_low_s16(vacc16_1)), vmovl_s16(vget_high_s16(vacc16_1))));
Expand All @@ -49,17 +50,17 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u16_acc2(
for (; batch >= 16; batch -= 16) {
const int8x8_t vt0 = vld1_s8(input); input += 8;
const int8x8_t vt1 = vld1_s8(input); input += 8;
vacc16_0 = vmlal_s8(vacc16_0, vt0, vone);
vacc16_1 = vmlal_s8(vacc16_1, vt1, vone);
vacc16_0 = vaddw_s8(vacc16_0, vt0);
vacc16_1 = vaddw_s8(vacc16_1, vt1);
}
vacc16_0 = vaddq_s16(vacc16_0, vacc16_1);
for (; batch >= 8; batch -= 8) {
const int8x8_t vt = vld1_s8(input); input += 8;
vacc16_0 = vmlal_s8(vacc16_0, vt, vone);
vacc16_0 = vaddw_s8(vacc16_0, vt);
}
if (XNN_UNLIKELY(batch != 0)) {
int8x8_t vt = vld1_s8(input);
vone = vld1_s8(&params->fp32_neon.mask_table[15 - batch]);
const int8x8_t vt = vld1_s8(input);
const int8x8_t vone = vld1_s8(&params->fp32_neon.mask_table[7 - batch]);
vacc16_0 = vmlal_s8(vacc16_0, vt, vone);
}
vacc0 = vaddq_s32(vacc0, vaddq_s32(vmovl_s16(vget_low_s16(vacc16_0)), vmovl_s16(vget_high_s16(vacc16_0))));
Expand Down
19 changes: 10 additions & 9 deletions src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u16.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,18 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u16(
assert(input != NULL);
assert(output != NULL);

int8x8_t vone = vdup_n_s8(1);
int num_batches = batch >> 8;
// 256 int8s may be summed into an int16 before overflowing.
// There are 8 lanes in the accumulator register and 1 registers.
int num_batches = batch >> 11;
int32x4_t vacc0 = vmovq_n_s32(0);
for (; num_batches > 0; --num_batches) {
int16x8_t vacc16_0 = vmovq_n_s16(0);
for (size_t current_batch = 256; current_batch > 0; current_batch -= 16) {
const int8x8_t vt0 = vld1_s8(input); input += 8;
const int8x8_t vt1 = vld1_s8(input); input += 8;

vacc16_0 = vmlal_s8(vacc16_0, vt0, vone);
vacc16_0 = vmlal_s8(vacc16_0, vt1, vone);
vacc16_0 = vaddw_s8(vacc16_0, vt0);
vacc16_0 = vaddw_s8(vacc16_0, vt1);
}
vacc0 = vaddq_s32(vacc0, vaddq_s32(vmovl_s16(vget_low_s16(vacc16_0)), vmovl_s16(vget_high_s16(vacc16_0))));
batch -= 256;
Expand All @@ -45,16 +46,16 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u16(
for (; batch >= 16; batch -= 16) {
const int8x8_t vt0 = vld1_s8(input); input += 8;
const int8x8_t vt1 = vld1_s8(input); input += 8;
vacc16_0 = vmlal_s8(vacc16_0, vt0, vone);
vacc16_0 = vmlal_s8(vacc16_0, vt1, vone);
vacc16_0 = vaddw_s8(vacc16_0, vt0);
vacc16_0 = vaddw_s8(vacc16_0, vt1);
}
for (; batch >= 8; batch -= 8) {
const int8x8_t vt = vld1_s8(input); input += 8;
vacc16_0 = vmlal_s8(vacc16_0, vt, vone);
vacc16_0 = vaddw_s8(vacc16_0, vt);
}
if (XNN_UNLIKELY(batch != 0)) {
int8x8_t vt = vld1_s8(input);
vone = vld1_s8(&params->fp32_neon.mask_table[15 - batch]);
const int8x8_t vt = vld1_s8(input);
const int8x8_t vone = vld1_s8(&params->fp32_neon.mask_table[7 - batch]);
vacc16_0 = vmlal_s8(vacc16_0, vt, vone);
}
vacc0 = vaddq_s32(vacc0, vaddq_s32(vmovl_s16(vget_low_s16(vacc16_0)), vmovl_s16(vget_high_s16(vacc16_0))));
Expand Down
27 changes: 14 additions & 13 deletions src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u32-acc2.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u32_acc2(
assert(input != NULL);
assert(output != NULL);

int8x8_t vone = vdup_n_s8(1);
int num_batches = batch >> 9;
// 256 int8s may be summed into an int16 before overflowing.
// There are 8 lanes in the accumulator register and 2 registers.
int num_batches = batch >> 12;
int32x4_t vacc0 = vmovq_n_s32(0);
int32x4_t vacc1 = vmovq_n_s32(0);
for (; num_batches > 0; --num_batches) {
Expand All @@ -38,10 +39,10 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u32_acc2(
const int8x8_t vt2 = vld1_s8(input); input += 8;
const int8x8_t vt3 = vld1_s8(input); input += 8;

vacc16_0 = vmlal_s8(vacc16_0, vt0, vone);
vacc16_1 = vmlal_s8(vacc16_1, vt1, vone);
vacc16_0 = vmlal_s8(vacc16_0, vt2, vone);
vacc16_1 = vmlal_s8(vacc16_1, vt3, vone);
vacc16_0 = vaddw_s8(vacc16_0, vt0);
vacc16_1 = vaddw_s8(vacc16_1, vt1);
vacc16_0 = vaddw_s8(vacc16_0, vt2);
vacc16_1 = vaddw_s8(vacc16_1, vt3);
}
vacc0 = vaddq_s32(vacc0, vaddq_s32(vmovl_s16(vget_low_s16(vacc16_0)), vmovl_s16(vget_high_s16(vacc16_0))));
vacc1 = vaddq_s32(vacc1, vaddq_s32(vmovl_s16(vget_low_s16(vacc16_1)), vmovl_s16(vget_high_s16(vacc16_1))));
Expand All @@ -55,19 +56,19 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u32_acc2(
const int8x8_t vt1 = vld1_s8(input); input += 8;
const int8x8_t vt2 = vld1_s8(input); input += 8;
const int8x8_t vt3 = vld1_s8(input); input += 8;
vacc16_0 = vmlal_s8(vacc16_0, vt0, vone);
vacc16_1 = vmlal_s8(vacc16_1, vt1, vone);
vacc16_0 = vmlal_s8(vacc16_0, vt2, vone);
vacc16_1 = vmlal_s8(vacc16_1, vt3, vone);
vacc16_0 = vaddw_s8(vacc16_0, vt0);
vacc16_1 = vaddw_s8(vacc16_1, vt1);
vacc16_0 = vaddw_s8(vacc16_0, vt2);
vacc16_1 = vaddw_s8(vacc16_1, vt3);
}
vacc16_0 = vaddq_s16(vacc16_0, vacc16_1);
for (; batch >= 8; batch -= 8) {
const int8x8_t vt = vld1_s8(input); input += 8;
vacc16_0 = vmlal_s8(vacc16_0, vt, vone);
vacc16_0 = vaddw_s8(vacc16_0, vt);
}
if (XNN_UNLIKELY(batch != 0)) {
int8x8_t vt = vld1_s8(input);
vone = vld1_s8(&params->fp32_neon.mask_table[15 - batch]);
const int8x8_t vt = vld1_s8(input);
const int8x8_t vone = vld1_s8(&params->fp32_neon.mask_table[7 - batch]);
vacc16_0 = vmlal_s8(vacc16_0, vt, vone);
}
vacc0 = vaddq_s32(vacc0, vaddq_s32(vmovl_s16(vget_low_s16(vacc16_0)), vmovl_s16(vget_high_s16(vacc16_0))));
Expand Down
29 changes: 15 additions & 14 deletions src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u32-acc4.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u32_acc4(
assert(input != NULL);
assert(output != NULL);

int8x8_t vone = vdup_n_s8(1);
int num_batches = batch >> 10;
// 256 int8s may be summed into an int16 before overflowing.
// There are 8 lanes in the accumulator register and 4 registers.
int num_batches = batch >> 13;
int32x4_t vacc0 = vmovq_n_s32(0);
int32x4_t vacc1 = vmovq_n_s32(0);
int32x4_t vacc2 = vmovq_n_s32(0);
Expand All @@ -42,10 +43,10 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u32_acc4(
const int8x8_t vt2 = vld1_s8(input); input += 8;
const int8x8_t vt3 = vld1_s8(input); input += 8;

vacc16_0 = vmlal_s8(vacc16_0, vt0, vone);
vacc16_1 = vmlal_s8(vacc16_1, vt1, vone);
vacc16_2 = vmlal_s8(vacc16_2, vt2, vone);
vacc16_3 = vmlal_s8(vacc16_3, vt3, vone);
vacc16_0 = vaddw_s8(vacc16_0, vt0);
vacc16_1 = vaddw_s8(vacc16_1, vt1);
vacc16_2 = vaddw_s8(vacc16_2, vt2);
vacc16_3 = vaddw_s8(vacc16_3, vt3);
}
vacc0 = vaddq_s32(vacc0, vaddq_s32(vmovl_s16(vget_low_s16(vacc16_0)), vmovl_s16(vget_high_s16(vacc16_0))));
vacc1 = vaddq_s32(vacc1, vaddq_s32(vmovl_s16(vget_low_s16(vacc16_1)), vmovl_s16(vget_high_s16(vacc16_1))));
Expand All @@ -63,21 +64,21 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u32_acc4(
const int8x8_t vt1 = vld1_s8(input); input += 8;
const int8x8_t vt2 = vld1_s8(input); input += 8;
const int8x8_t vt3 = vld1_s8(input); input += 8;
vacc16_0 = vmlal_s8(vacc16_0, vt0, vone);
vacc16_1 = vmlal_s8(vacc16_1, vt1, vone);
vacc16_2 = vmlal_s8(vacc16_2, vt2, vone);
vacc16_3 = vmlal_s8(vacc16_3, vt3, vone);
vacc16_0 = vaddw_s8(vacc16_0, vt0);
vacc16_1 = vaddw_s8(vacc16_1, vt1);
vacc16_2 = vaddw_s8(vacc16_2, vt2);
vacc16_3 = vaddw_s8(vacc16_3, vt3);
}
vacc16_0 = vaddq_s16(vacc16_0, vacc16_1);
vacc16_2 = vaddq_s16(vacc16_2, vacc16_3);
vacc16_0 = vaddq_s16(vacc16_0, vacc16_2);
vacc16_0 = vaddq_s16(vacc16_0, vacc16_3);
for (; batch >= 8; batch -= 8) {
const int8x8_t vt = vld1_s8(input); input += 8;
vacc16_0 = vmlal_s8(vacc16_0, vt, vone);
vacc16_0 = vaddw_s8(vacc16_0, vt);
}
if (XNN_UNLIKELY(batch != 0)) {
int8x8_t vt = vld1_s8(input);
vone = vld1_s8(&params->fp32_neon.mask_table[15 - batch]);
const int8x8_t vt = vld1_s8(input);
const int8x8_t vone = vld1_s8(&params->fp32_neon.mask_table[7 - batch]);
vacc16_0 = vmlal_s8(vacc16_0, vt, vone);
}
vacc0 = vaddq_s32(vacc0, vaddq_s32(vmovl_s16(vget_low_s16(vacc16_0)), vmovl_s16(vget_high_s16(vacc16_0))));
Expand Down
27 changes: 14 additions & 13 deletions src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u32.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u32(
assert(input != NULL);
assert(output != NULL);

int8x8_t vone = vdup_n_s8(1);
int num_batches = batch >> 8;
// 256 int8s may be summed into an int16 before overflowing.
// There are 8 lanes in the accumulator register and 1 registers.
int num_batches = batch >> 11;
int32x4_t vacc0 = vmovq_n_s32(0);
for (; num_batches > 0; --num_batches) {
int16x8_t vacc16_0 = vmovq_n_s16(0);
Expand All @@ -36,10 +37,10 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u32(
const int8x8_t vt2 = vld1_s8(input); input += 8;
const int8x8_t vt3 = vld1_s8(input); input += 8;

vacc16_0 = vmlal_s8(vacc16_0, vt0, vone);
vacc16_0 = vmlal_s8(vacc16_0, vt1, vone);
vacc16_0 = vmlal_s8(vacc16_0, vt2, vone);
vacc16_0 = vmlal_s8(vacc16_0, vt3, vone);
vacc16_0 = vaddw_s8(vacc16_0, vt0);
vacc16_0 = vaddw_s8(vacc16_0, vt1);
vacc16_0 = vaddw_s8(vacc16_0, vt2);
vacc16_0 = vaddw_s8(vacc16_0, vt3);
}
vacc0 = vaddq_s32(vacc0, vaddq_s32(vmovl_s16(vget_low_s16(vacc16_0)), vmovl_s16(vget_high_s16(vacc16_0))));
batch -= 256;
Expand All @@ -51,18 +52,18 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u32(
const int8x8_t vt1 = vld1_s8(input); input += 8;
const int8x8_t vt2 = vld1_s8(input); input += 8;
const int8x8_t vt3 = vld1_s8(input); input += 8;
vacc16_0 = vmlal_s8(vacc16_0, vt0, vone);
vacc16_0 = vmlal_s8(vacc16_0, vt1, vone);
vacc16_0 = vmlal_s8(vacc16_0, vt2, vone);
vacc16_0 = vmlal_s8(vacc16_0, vt3, vone);
vacc16_0 = vaddw_s8(vacc16_0, vt0);
vacc16_0 = vaddw_s8(vacc16_0, vt1);
vacc16_0 = vaddw_s8(vacc16_0, vt2);
vacc16_0 = vaddw_s8(vacc16_0, vt3);
}
for (; batch >= 8; batch -= 8) {
const int8x8_t vt = vld1_s8(input); input += 8;
vacc16_0 = vmlal_s8(vacc16_0, vt, vone);
vacc16_0 = vaddw_s8(vacc16_0, vt);
}
if (XNN_UNLIKELY(batch != 0)) {
int8x8_t vt = vld1_s8(input);
vone = vld1_s8(&params->fp32_neon.mask_table[15 - batch]);
const int8x8_t vt = vld1_s8(input);
const int8x8_t vone = vld1_s8(&params->fp32_neon.mask_table[7 - batch]);
vacc16_0 = vmlal_s8(vacc16_0, vt, vone);
}
vacc0 = vaddq_s32(vacc0, vaddq_s32(vmovl_s16(vget_low_s16(vacc16_0)), vmovl_s16(vget_high_s16(vacc16_0))));
Expand Down
43 changes: 22 additions & 21 deletions src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u64-acc2.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u64_acc2(
assert(input != NULL);
assert(output != NULL);

int8x8_t vone = vdup_n_s8(1);
int num_batches = batch >> 9;
// 256 int8s may be summed into an int16 before overflowing.
// There are 8 lanes in the accumulator register and 2 registers.
int num_batches = batch >> 12;
int32x4_t vacc0 = vmovq_n_s32(0);
int32x4_t vacc1 = vmovq_n_s32(0);
for (; num_batches > 0; --num_batches) {
Expand All @@ -42,14 +43,14 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u64_acc2(
const int8x8_t vt6 = vld1_s8(input); input += 8;
const int8x8_t vt7 = vld1_s8(input); input += 8;

vacc16_0 = vmlal_s8(vacc16_0, vt0, vone);
vacc16_1 = vmlal_s8(vacc16_1, vt1, vone);
vacc16_0 = vmlal_s8(vacc16_0, vt2, vone);
vacc16_1 = vmlal_s8(vacc16_1, vt3, vone);
vacc16_0 = vmlal_s8(vacc16_0, vt4, vone);
vacc16_1 = vmlal_s8(vacc16_1, vt5, vone);
vacc16_0 = vmlal_s8(vacc16_0, vt6, vone);
vacc16_1 = vmlal_s8(vacc16_1, vt7, vone);
vacc16_0 = vaddw_s8(vacc16_0, vt0);
vacc16_1 = vaddw_s8(vacc16_1, vt1);
vacc16_0 = vaddw_s8(vacc16_0, vt2);
vacc16_1 = vaddw_s8(vacc16_1, vt3);
vacc16_0 = vaddw_s8(vacc16_0, vt4);
vacc16_1 = vaddw_s8(vacc16_1, vt5);
vacc16_0 = vaddw_s8(vacc16_0, vt6);
vacc16_1 = vaddw_s8(vacc16_1, vt7);
}
vacc0 = vaddq_s32(vacc0, vaddq_s32(vmovl_s16(vget_low_s16(vacc16_0)), vmovl_s16(vget_high_s16(vacc16_0))));
vacc1 = vaddq_s32(vacc1, vaddq_s32(vmovl_s16(vget_low_s16(vacc16_1)), vmovl_s16(vget_high_s16(vacc16_1))));
Expand All @@ -67,23 +68,23 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u64_acc2(
const int8x8_t vt5 = vld1_s8(input); input += 8;
const int8x8_t vt6 = vld1_s8(input); input += 8;
const int8x8_t vt7 = vld1_s8(input); input += 8;
vacc16_0 = vmlal_s8(vacc16_0, vt0, vone);
vacc16_1 = vmlal_s8(vacc16_1, vt1, vone);
vacc16_0 = vmlal_s8(vacc16_0, vt2, vone);
vacc16_1 = vmlal_s8(vacc16_1, vt3, vone);
vacc16_0 = vmlal_s8(vacc16_0, vt4, vone);
vacc16_1 = vmlal_s8(vacc16_1, vt5, vone);
vacc16_0 = vmlal_s8(vacc16_0, vt6, vone);
vacc16_1 = vmlal_s8(vacc16_1, vt7, vone);
vacc16_0 = vaddw_s8(vacc16_0, vt0);
vacc16_1 = vaddw_s8(vacc16_1, vt1);
vacc16_0 = vaddw_s8(vacc16_0, vt2);
vacc16_1 = vaddw_s8(vacc16_1, vt3);
vacc16_0 = vaddw_s8(vacc16_0, vt4);
vacc16_1 = vaddw_s8(vacc16_1, vt5);
vacc16_0 = vaddw_s8(vacc16_0, vt6);
vacc16_1 = vaddw_s8(vacc16_1, vt7);
}
vacc16_0 = vaddq_s16(vacc16_0, vacc16_1);
for (; batch >= 8; batch -= 8) {
const int8x8_t vt = vld1_s8(input); input += 8;
vacc16_0 = vmlal_s8(vacc16_0, vt, vone);
vacc16_0 = vaddw_s8(vacc16_0, vt);
}
if (XNN_UNLIKELY(batch != 0)) {
int8x8_t vt = vld1_s8(input);
vone = vld1_s8(&params->fp32_neon.mask_table[15 - batch]);
const int8x8_t vt = vld1_s8(input);
const int8x8_t vone = vld1_s8(&params->fp32_neon.mask_table[7 - batch]);
vacc16_0 = vmlal_s8(vacc16_0, vt, vone);
}
vacc0 = vaddq_s32(vacc0, vaddq_s32(vmovl_s16(vget_low_s16(vacc16_0)), vmovl_s16(vget_high_s16(vacc16_0))));
Expand Down
45 changes: 23 additions & 22 deletions src/qs8-rsum/gen/qs8-rsum-minmax-fp32-neon-u64-acc4.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u64_acc4(
assert(input != NULL);
assert(output != NULL);

int8x8_t vone = vdup_n_s8(1);
int num_batches = batch >> 10;
// 256 int8s may be summed into an int16 before overflowing.
// There are 8 lanes in the accumulator register and 4 registers.
int num_batches = batch >> 13;
int32x4_t vacc0 = vmovq_n_s32(0);
int32x4_t vacc1 = vmovq_n_s32(0);
int32x4_t vacc2 = vmovq_n_s32(0);
Expand All @@ -46,14 +47,14 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u64_acc4(
const int8x8_t vt6 = vld1_s8(input); input += 8;
const int8x8_t vt7 = vld1_s8(input); input += 8;

vacc16_0 = vmlal_s8(vacc16_0, vt0, vone);
vacc16_1 = vmlal_s8(vacc16_1, vt1, vone);
vacc16_2 = vmlal_s8(vacc16_2, vt2, vone);
vacc16_3 = vmlal_s8(vacc16_3, vt3, vone);
vacc16_0 = vmlal_s8(vacc16_0, vt4, vone);
vacc16_1 = vmlal_s8(vacc16_1, vt5, vone);
vacc16_2 = vmlal_s8(vacc16_2, vt6, vone);
vacc16_3 = vmlal_s8(vacc16_3, vt7, vone);
vacc16_0 = vaddw_s8(vacc16_0, vt0);
vacc16_1 = vaddw_s8(vacc16_1, vt1);
vacc16_2 = vaddw_s8(vacc16_2, vt2);
vacc16_3 = vaddw_s8(vacc16_3, vt3);
vacc16_0 = vaddw_s8(vacc16_0, vt4);
vacc16_1 = vaddw_s8(vacc16_1, vt5);
vacc16_2 = vaddw_s8(vacc16_2, vt6);
vacc16_3 = vaddw_s8(vacc16_3, vt7);
}
vacc0 = vaddq_s32(vacc0, vaddq_s32(vmovl_s16(vget_low_s16(vacc16_0)), vmovl_s16(vget_high_s16(vacc16_0))));
vacc1 = vaddq_s32(vacc1, vaddq_s32(vmovl_s16(vget_low_s16(vacc16_1)), vmovl_s16(vget_high_s16(vacc16_1))));
Expand All @@ -75,25 +76,25 @@ void xnn_qs8_rsum_minmax_fp32_ukernel__neon_mlal_u64_acc4(
const int8x8_t vt5 = vld1_s8(input); input += 8;
const int8x8_t vt6 = vld1_s8(input); input += 8;
const int8x8_t vt7 = vld1_s8(input); input += 8;
vacc16_0 = vmlal_s8(vacc16_0, vt0, vone);
vacc16_1 = vmlal_s8(vacc16_1, vt1, vone);
vacc16_2 = vmlal_s8(vacc16_2, vt2, vone);
vacc16_3 = vmlal_s8(vacc16_3, vt3, vone);
vacc16_0 = vmlal_s8(vacc16_0, vt4, vone);
vacc16_1 = vmlal_s8(vacc16_1, vt5, vone);
vacc16_2 = vmlal_s8(vacc16_2, vt6, vone);
vacc16_3 = vmlal_s8(vacc16_3, vt7, vone);
vacc16_0 = vaddw_s8(vacc16_0, vt0);
vacc16_1 = vaddw_s8(vacc16_1, vt1);
vacc16_2 = vaddw_s8(vacc16_2, vt2);
vacc16_3 = vaddw_s8(vacc16_3, vt3);
vacc16_0 = vaddw_s8(vacc16_0, vt4);
vacc16_1 = vaddw_s8(vacc16_1, vt5);
vacc16_2 = vaddw_s8(vacc16_2, vt6);
vacc16_3 = vaddw_s8(vacc16_3, vt7);
}
vacc16_0 = vaddq_s16(vacc16_0, vacc16_1);
vacc16_2 = vaddq_s16(vacc16_2, vacc16_3);
vacc16_0 = vaddq_s16(vacc16_0, vacc16_2);
vacc16_0 = vaddq_s16(vacc16_0, vacc16_3);
for (; batch >= 8; batch -= 8) {
const int8x8_t vt = vld1_s8(input); input += 8;
vacc16_0 = vmlal_s8(vacc16_0, vt, vone);
vacc16_0 = vaddw_s8(vacc16_0, vt);
}
if (XNN_UNLIKELY(batch != 0)) {
int8x8_t vt = vld1_s8(input);
vone = vld1_s8(&params->fp32_neon.mask_table[15 - batch]);
const int8x8_t vt = vld1_s8(input);
const int8x8_t vone = vld1_s8(&params->fp32_neon.mask_table[7 - batch]);
vacc16_0 = vmlal_s8(vacc16_0, vt, vone);
}
vacc0 = vaddq_s32(vacc0, vaddq_s32(vmovl_s16(vget_low_s16(vacc16_0)), vmovl_s16(vget_high_s16(vacc16_0))));
Expand Down

0 comments on commit 62d8873

Please sign in to comment.