Skip to content

Commit

Permalink
F32-RMINMAXSUM - add reduction sum to rminmax
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 634176068
  • Loading branch information
fbarchard authored and xnnpack-bot committed May 16, 2024
1 parent 7fabcac commit c303b6c
Show file tree
Hide file tree
Showing 26 changed files with 1,279 additions and 187 deletions.
4 changes: 4 additions & 0 deletions cmake/gen/rvv_microkernels.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ SET(ALL_RVV_MICROKERNEL_SRCS
src/f32-rminmax/gen/f32-rminmax-rvv-u2v.c
src/f32-rminmax/gen/f32-rminmax-rvv-u4v.c
src/f32-rminmax/gen/f32-rminmax-rvv-u8v.c
src/f32-rminmaxsum/gen/f32-rminmaxsum-rvv-u1v.c
src/f32-rminmaxsum/gen/f32-rminmaxsum-rvv-u2v.c
src/f32-rminmaxsum/gen/f32-rminmaxsum-rvv-u4v.c
src/f32-rminmaxsum/gen/f32-rminmaxsum-rvv-u8v.c
src/f32-rsum/f32-rsum-rvv-u1v.c
src/f32-vbinary/gen/f32-vadd-minmax-rvv-u4v.c
src/f32-vbinary/gen/f32-vadd-minmax-rvv-u8v.c
Expand Down
175 changes: 0 additions & 175 deletions scripts/generate-f32-rminmax.sh

This file was deleted.

228 changes: 228 additions & 0 deletions scripts/generate-f32-rminmaxsum.sh

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/f32-rminmax/gen/f32-rmax-rvv-u1v.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Auto-generated file. Do not edit!
// Template: src/f32-rminmax/rvv.c.in
// Template: src/f32-rminmaxsum/rvv.c.in
// Generator: tools/xngen
//
// Copyright 2023 SiFive, Inc.
Expand Down
2 changes: 1 addition & 1 deletion src/f32-rminmax/gen/f32-rmax-rvv-u2v.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Auto-generated file. Do not edit!
// Template: src/f32-rminmax/rvv.c.in
// Template: src/f32-rminmaxsum/rvv.c.in
// Generator: tools/xngen
//
// Copyright 2023 SiFive, Inc.
Expand Down
2 changes: 1 addition & 1 deletion src/f32-rminmax/gen/f32-rmax-rvv-u4v.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Auto-generated file. Do not edit!
// Template: src/f32-rminmax/rvv.c.in
// Template: src/f32-rminmaxsum/rvv.c.in
// Generator: tools/xngen
//
// Copyright 2023 SiFive, Inc.
Expand Down
2 changes: 1 addition & 1 deletion src/f32-rminmax/gen/f32-rmax-rvv-u8v.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Auto-generated file. Do not edit!
// Template: src/f32-rminmax/rvv.c.in
// Template: src/f32-rminmaxsum/rvv.c.in
// Generator: tools/xngen
//
// Copyright 2023 SiFive, Inc.
Expand Down
2 changes: 1 addition & 1 deletion src/f32-rminmax/gen/f32-rmin-rvv-u1v.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Auto-generated file. Do not edit!
// Template: src/f32-rminmax/rvv.c.in
// Template: src/f32-rminmaxsum/rvv.c.in
// Generator: tools/xngen
//
// Copyright 2023 SiFive, Inc.
Expand Down
2 changes: 1 addition & 1 deletion src/f32-rminmax/gen/f32-rmin-rvv-u2v.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Auto-generated file. Do not edit!
// Template: src/f32-rminmax/rvv.c.in
// Template: src/f32-rminmaxsum/rvv.c.in
// Generator: tools/xngen
//
// Copyright 2023 SiFive, Inc.
Expand Down
2 changes: 1 addition & 1 deletion src/f32-rminmax/gen/f32-rmin-rvv-u4v.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Auto-generated file. Do not edit!
// Template: src/f32-rminmax/rvv.c.in
// Template: src/f32-rminmaxsum/rvv.c.in
// Generator: tools/xngen
//
// Copyright 2023 SiFive, Inc.
Expand Down
2 changes: 1 addition & 1 deletion src/f32-rminmax/gen/f32-rmin-rvv-u8v.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Auto-generated file. Do not edit!
// Template: src/f32-rminmax/rvv.c.in
// Template: src/f32-rminmaxsum/rvv.c.in
// Generator: tools/xngen
//
// Copyright 2023 SiFive, Inc.
Expand Down
2 changes: 1 addition & 1 deletion src/f32-rminmax/gen/f32-rminmax-rvv-u1v.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Auto-generated file. Do not edit!
// Template: src/f32-rminmax/rvv.c.in
// Template: src/f32-rminmaxsum/rvv.c.in
// Generator: tools/xngen
//
// Copyright 2023 SiFive, Inc.
Expand Down
2 changes: 1 addition & 1 deletion src/f32-rminmax/gen/f32-rminmax-rvv-u2v.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Auto-generated file. Do not edit!
// Template: src/f32-rminmax/rvv.c.in
// Template: src/f32-rminmaxsum/rvv.c.in
// Generator: tools/xngen
//
// Copyright 2023 SiFive, Inc.
Expand Down
2 changes: 1 addition & 1 deletion src/f32-rminmax/gen/f32-rminmax-rvv-u4v.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Auto-generated file. Do not edit!
// Template: src/f32-rminmax/rvv.c.in
// Template: src/f32-rminmaxsum/rvv.c.in
// Generator: tools/xngen
//
// Copyright 2023 SiFive, Inc.
Expand Down
2 changes: 1 addition & 1 deletion src/f32-rminmax/gen/f32-rminmax-rvv-u8v.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Auto-generated file. Do not edit!
// Template: src/f32-rminmax/rvv.c.in
// Template: src/f32-rminmaxsum/rvv.c.in
// Generator: tools/xngen
//
// Copyright 2023 SiFive, Inc.
Expand Down
128 changes: 128 additions & 0 deletions src/f32-rminmaxsum/avx.c.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
// Copyright 2023 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert BATCH_TILE % 8 == 0
$assert BATCH_TILE >= 8
$SIMD_TILE = BATCH_TILE // 8
$assert ACCUMULATORS <= SIMD_TILE
$assert OP in ["MAX", "MIN", "MINMAX", "MINMAXSUM"]
#include <assert.h>

#include <immintrin.h>

#include <xnnpack/common.h>
#include <xnnpack/reduce.h>


$ACC_SUFFIX = "" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS
$EMIT_MIN = "MIN" in OP
$EMIT_MAX = "MAX" in OP
$EMIT_SUM = "SUM" in OP
void xnn_f32_r${OP.lower()}_ukernel__avx_u${BATCH_TILE}${ACC_SUFFIX}(
size_t batch,
const float* input,
float* output,
const union xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS(1)])
{
assert(batch != 0);
assert(batch % sizeof(float) == 0);
assert(input != NULL);
assert(output != NULL);

$if EMIT_MIN:
__m256 vmin0 = _mm256_broadcast_ss(input);
$if EMIT_MAX:
__m256 vmax0 = vmin0;
$elif EMIT_MAX:
__m256 vmax0 = _mm256_broadcast_ss(input);
$if EMIT_SUM:
__m256 vsum0 = _mm256_setzero_ps();
$for A in range(1, ACCUMULATORS):
$if EMIT_MIN:
__m256 vmin${A} = vmin0;
$if EMIT_MAX:
__m256 vmax${A} = vmax0;
$if EMIT_SUM:
__m256 vsum${A} = vsum0;
$if BATCH_TILE > 8:
for (; batch >= ${BATCH_TILE} * sizeof(float); batch -= ${BATCH_TILE} * sizeof(float)) {
const __m256 vt0 = _mm256_loadu_ps(input);
$for N in range(1, SIMD_TILE):
const __m256 vt${N} = _mm256_loadu_ps(input + ${N * 8});
input += ${BATCH_TILE};

$for N in range(SIMD_TILE):
$if EMIT_MIN:
vmin${N % ACCUMULATORS} = _mm256_min_ps(vmin${N % ACCUMULATORS}, vt${N});
$if EMIT_MAX:
vmax${N % ACCUMULATORS} = _mm256_max_ps(vmax${N % ACCUMULATORS}, vt${N});
$if EMIT_SUM:
vsum${N % ACCUMULATORS} = _mm256_add_ps(vsum${N % ACCUMULATORS}, vt${N});
}
$if ACCUMULATORS > 1:
$ACC_SLICE = 1
$while ACC_SLICE < ACCUMULATORS:
$for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
$if A + ACC_SLICE < ACCUMULATORS:
$if EMIT_MIN:
vmin${A} = _mm256_min_ps(vmin${A}, vmin${A + ACC_SLICE});
$if EMIT_MAX:
vmax${A} = _mm256_max_ps(vmax${A}, vmax${A + ACC_SLICE});
$if EMIT_SUM:
vsum${A} = _mm256_add_ps(vsum${A}, vsum${A + ACC_SLICE});
$ACC_SLICE *= 2
for (; batch >= 8 * sizeof(float); batch -= 8 * sizeof(float)) {
const __m256 vt = _mm256_loadu_ps(input);
input += 8;

$if EMIT_MIN:
vmin0 = _mm256_min_ps(vmin0, vt);
$if EMIT_MAX:
vmax0 = _mm256_max_ps(vmax0, vt);
$if EMIT_SUM:
vsum0 = _mm256_add_ps(vsum0, vt);
}
if XNN_UNLIKELY(batch != 0) {
assert(batch >= 1 * sizeof(float));
assert(batch <= 7 * sizeof(float));
const __m256i vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) &params->avx.mask_table[7] - batch));

const __m256 vt = _mm256_maskload_ps(input, vmask);

$if EMIT_MIN:
vmin0 = _mm256_blendv_ps(vmin0, _mm256_min_ps(vmin0, vt), _mm256_castsi256_ps(vmask));
$if EMIT_MAX:
vmax0 = _mm256_blendv_ps(vmax0, _mm256_max_ps(vmax0, vt), _mm256_castsi256_ps(vmask));
$if EMIT_SUM:
vsum0 = _mm256_blendv_ps(vsum0, _mm256_add_ps(vsum0, vt), _mm256_castsi256_ps(vmask));
}
$if EMIT_MIN:
__m128 vmin = _mm_min_ps(_mm256_castps256_ps128(vmin0), _mm256_extractf128_ps(vmin0, 1));
$if EMIT_MAX:
__m128 vmax = _mm_max_ps(_mm256_castps256_ps128(vmax0), _mm256_extractf128_ps(vmax0, 1));
$if EMIT_SUM:
__m128 vsum = _mm_add_ps(_mm256_castps256_ps128(vsum0), _mm256_extractf128_ps(vsum0, 1));

$if EMIT_MIN:
vmin = _mm_min_ps(vmin, _mm_movehl_ps(vmin, vmin));
$if EMIT_MAX:
vmax = _mm_max_ps(vmax, _mm_movehl_ps(vmax, vmax));
$if EMIT_SUM:
vsum = _mm_add_ps(vsum, _mm_movehl_ps(vsum, vsum));
$if EMIT_MIN:
vmin = _mm_min_ss(vmin, _mm_movehdup_ps(vmin));
$if EMIT_MAX:
vmax = _mm_max_ss(vmax, _mm_movehdup_ps(vmax));
$if EMIT_SUM:
vsum = _mm_add_ss(vsum, _mm_movehdup_ps(vsum));
$if EMIT_MIN:
_mm_store_ss(output, vmin);
$if EMIT_MAX:
_mm_store_ss(output + 1, vmax);
$if EMIT_SUM:
_mm_store_ss(output + 2, vsum);
$elif EMIT_MAX:
_mm_store_ss(output, vmax);
}
136 changes: 136 additions & 0 deletions src/f32-rminmaxsum/avx512f.c.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
// Copyright 2023 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert BATCH_TILE % 16 == 0
$assert BATCH_TILE >= 16
$SIMD_TILE = BATCH_TILE // 16
$assert ACCUMULATORS <= SIMD_TILE
$assert OP in ["MAX", "MIN", "MINMAX", "MINMAXSUM"]
#include <assert.h>

#include <immintrin.h>

#include <xnnpack/common.h>
#include <xnnpack/reduce.h>


$ACC_SUFFIX = "" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS
$EMIT_MIN = "MIN" in OP
$EMIT_MAX = "MAX" in OP
$EMIT_SUM = "SUM" in OP
void xnn_f32_r${OP.lower()}_ukernel__avx512f_u${BATCH_TILE}${ACC_SUFFIX}(
size_t batch,
const float* input,
float* output,
const union xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS(1)])
{
assert(batch != 0);
assert(batch % sizeof(float) == 0);
assert(input != NULL);
assert(output != NULL);

$if EMIT_MIN:
__m512 vmin0 = _mm512_set1_ps(*input);
$if EMIT_MAX:
__m512 vmax0 = vmin0;
$elif EMIT_MAX:
__m512 vmax0 = _mm512_set1_ps(*input);
$if EMIT_SUM:
__m512 vsum0 = _mm512_setzero_ps();
$for A in range(1, ACCUMULATORS):
$if EMIT_MIN:
__m512 vmin${A} = vmin0;
$if EMIT_MAX:
__m512 vmax${A} = vmax0;
$if EMIT_SUM:
__m512 vsum${A} = vsum0;
$if BATCH_TILE > 16:
for (; batch >= ${BATCH_TILE} * sizeof(float); batch -= ${BATCH_TILE} * sizeof(float)) {
const __m512 vt0 = _mm512_loadu_ps(input);
$for N in range(1, SIMD_TILE):
const __m512 vt${N} = _mm512_loadu_ps(input + ${N * 16});
input += ${BATCH_TILE};

$for N in range(SIMD_TILE):
$if EMIT_MIN:
vmin${N % ACCUMULATORS} = _mm512_min_ps(vmin${N % ACCUMULATORS}, vt${N});
$if EMIT_MAX:
vmax${N % ACCUMULATORS} = _mm512_max_ps(vmax${N % ACCUMULATORS}, vt${N});
$if EMIT_SUM:
vsum${N % ACCUMULATORS} = _mm512_add_ps(vsum${N % ACCUMULATORS}, vt${N});
}
$if ACCUMULATORS > 1:
$ACC_SLICE = 1
$while ACC_SLICE < ACCUMULATORS:
$for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
$if A + ACC_SLICE < ACCUMULATORS:
$if EMIT_MIN:
vmin${A} = _mm512_min_ps(vmin${A}, vmin${A + ACC_SLICE});
$if EMIT_MAX:
vmax${A} = _mm512_max_ps(vmax${A}, vmax${A + ACC_SLICE});
$if EMIT_SUM:
vsum${A} = _mm512_add_ps(vsum${A}, vsum${A + ACC_SLICE});
$ACC_SLICE *= 2
for (; batch >= 16 * sizeof(float); batch -= 16 * sizeof(float)) {
const __m512 vt = _mm512_loadu_ps(input);
input += 16;

$if EMIT_MIN:
vmin0 = _mm512_min_ps(vmin0, vt);
$if EMIT_MAX:
vmax0 = _mm512_max_ps(vmax0, vt);
$if EMIT_SUM:
vsum0 = _mm512_add_ps(vsum0, vt);
}
if XNN_UNLIKELY(batch != 0) {
assert(batch >= 1 * sizeof(float));
assert(batch <= 15 * sizeof(float));

// Prepare mask for valid elements (depends on batch).
batch >>= XNN_LOG2_SIZEOF_FLOAT;
const __mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1)));

const __m512 vt = _mm512_maskz_loadu_ps(vmask, input);

$if EMIT_MIN:
vmin0 = _mm512_mask_min_ps(vmin0, vmask, vmin0, vt);
$if EMIT_MAX:
vmax0 = _mm512_mask_max_ps(vmax0, vmask, vmax0, vt);
$if EMIT_SUM:
vsum0 = _mm512_mask_add_ps(vsum0, vmask, vsum0, vt);
}
$if EMIT_MIN:
__m256 vmin256 = _mm256_min_ps(_mm512_castps512_ps256(vmin0), _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(vmin0), 1)));
$if EMIT_MAX:
__m256 vmax256 = _mm256_max_ps(_mm512_castps512_ps256(vmax0), _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(vmax0), 1)));
$if EMIT_SUM:
__m256 vsum256 = _mm256_add_ps(_mm512_castps512_ps256(vsum0), _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(vsum0), 1)));
$if EMIT_MIN:
__m128 vmin = _mm_min_ps(_mm256_castps256_ps128(vmin256), _mm256_extractf128_ps(vmin256, 1));
$if EMIT_MAX:
__m128 vmax = _mm_max_ps(_mm256_castps256_ps128(vmax256), _mm256_extractf128_ps(vmax256, 1));
$if EMIT_SUM:
__m128 vsum = _mm_add_ps(_mm256_castps256_ps128(vsum256), _mm256_extractf128_ps(vsum256, 1));
$if EMIT_MIN:
vmin = _mm_min_ps(vmin, _mm_movehl_ps(vmin, vmin));
$if EMIT_MAX:
vmax = _mm_max_ps(vmax, _mm_movehl_ps(vmax, vmax));
$if EMIT_SUM:
vsum = _mm_add_ps(vsum, _mm_movehl_ps(vsum, vsum));
$if EMIT_MIN:
vmin = _mm_min_ss(vmin, _mm_movehdup_ps(vmin));
$if EMIT_MAX:
vmax = _mm_max_ss(vmax, _mm_movehdup_ps(vmax));
$if EMIT_SUM:
vsum = _mm_add_ss(vsum, _mm_movehdup_ps(vsum));
$if EMIT_MIN:
_mm_store_ss(output, vmin);
$if EMIT_MAX:
_mm_store_ss(output + 1, vmax);
$if EMIT_SUM:
_mm_store_ss(output + 2, vsum);
$elif EMIT_MAX:
_mm_store_ss(output, vmax);
}

0 comments on commit c303b6c

Please sign in to comment.