-
Notifications
You must be signed in to change notification settings - Fork 326
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
F32-RMINMAXSUM - add reduction sum to rminmax
PiperOrigin-RevId: 634176068
- Loading branch information
1 parent
7fabcac
commit c303b6c
Showing
26 changed files
with
1,279 additions
and
187 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
// Copyright 2023 Google LLC | ||
// | ||
// This source code is licensed under the BSD-style license found in the | ||
// LICENSE file in the root directory of this source tree. | ||
|
||
$assert BATCH_TILE % 8 == 0 | ||
$assert BATCH_TILE >= 8 | ||
$SIMD_TILE = BATCH_TILE // 8 | ||
$assert ACCUMULATORS <= SIMD_TILE | ||
$assert OP in ["MAX", "MIN", "MINMAX", "MINMAXSUM"] | ||
#include <assert.h> | ||
|
||
#include <immintrin.h> | ||
|
||
#include <xnnpack/common.h> | ||
#include <xnnpack/reduce.h> | ||
|
||
|
||
$ACC_SUFFIX = "" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS | ||
$EMIT_MIN = "MIN" in OP | ||
$EMIT_MAX = "MAX" in OP | ||
$EMIT_SUM = "SUM" in OP | ||
void xnn_f32_r${OP.lower()}_ukernel__avx_u${BATCH_TILE}${ACC_SUFFIX}( | ||
size_t batch, | ||
const float* input, | ||
float* output, | ||
const union xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS(1)]) | ||
{ | ||
assert(batch != 0); | ||
assert(batch % sizeof(float) == 0); | ||
assert(input != NULL); | ||
assert(output != NULL); | ||
|
||
$if EMIT_MIN: | ||
__m256 vmin0 = _mm256_broadcast_ss(input); | ||
$if EMIT_MAX: | ||
__m256 vmax0 = vmin0; | ||
$elif EMIT_MAX: | ||
__m256 vmax0 = _mm256_broadcast_ss(input); | ||
$if EMIT_SUM: | ||
__m256 vsum0 = _mm256_setzero_ps(); | ||
$for A in range(1, ACCUMULATORS): | ||
$if EMIT_MIN: | ||
__m256 vmin${A} = vmin0; | ||
$if EMIT_MAX: | ||
__m256 vmax${A} = vmax0; | ||
$if EMIT_SUM: | ||
__m256 vsum${A} = vsum0; | ||
$if BATCH_TILE > 8: | ||
for (; batch >= ${BATCH_TILE} * sizeof(float); batch -= ${BATCH_TILE} * sizeof(float)) { | ||
const __m256 vt0 = _mm256_loadu_ps(input); | ||
$for N in range(1, SIMD_TILE): | ||
const __m256 vt${N} = _mm256_loadu_ps(input + ${N * 8}); | ||
input += ${BATCH_TILE}; | ||
|
||
$for N in range(SIMD_TILE): | ||
$if EMIT_MIN: | ||
vmin${N % ACCUMULATORS} = _mm256_min_ps(vmin${N % ACCUMULATORS}, vt${N}); | ||
$if EMIT_MAX: | ||
vmax${N % ACCUMULATORS} = _mm256_max_ps(vmax${N % ACCUMULATORS}, vt${N}); | ||
$if EMIT_SUM: | ||
vsum${N % ACCUMULATORS} = _mm256_add_ps(vsum${N % ACCUMULATORS}, vt${N}); | ||
} | ||
$if ACCUMULATORS > 1: | ||
$ACC_SLICE = 1 | ||
$while ACC_SLICE < ACCUMULATORS: | ||
$for A in range(0, ACCUMULATORS, ACC_SLICE * 2): | ||
$if A + ACC_SLICE < ACCUMULATORS: | ||
$if EMIT_MIN: | ||
vmin${A} = _mm256_min_ps(vmin${A}, vmin${A + ACC_SLICE}); | ||
$if EMIT_MAX: | ||
vmax${A} = _mm256_max_ps(vmax${A}, vmax${A + ACC_SLICE}); | ||
$if EMIT_SUM: | ||
vsum${A} = _mm256_add_ps(vsum${A}, vsum${A + ACC_SLICE}); | ||
$ACC_SLICE *= 2 | ||
for (; batch >= 8 * sizeof(float); batch -= 8 * sizeof(float)) { | ||
const __m256 vt = _mm256_loadu_ps(input); | ||
input += 8; | ||
|
||
$if EMIT_MIN: | ||
vmin0 = _mm256_min_ps(vmin0, vt); | ||
$if EMIT_MAX: | ||
vmax0 = _mm256_max_ps(vmax0, vt); | ||
$if EMIT_SUM: | ||
vsum0 = _mm256_add_ps(vsum0, vt); | ||
} | ||
if XNN_UNLIKELY(batch != 0) { | ||
assert(batch >= 1 * sizeof(float)); | ||
assert(batch <= 7 * sizeof(float)); | ||
const __m256i vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) ¶ms->avx.mask_table[7] - batch)); | ||
|
||
const __m256 vt = _mm256_maskload_ps(input, vmask); | ||
|
||
$if EMIT_MIN: | ||
vmin0 = _mm256_blendv_ps(vmin0, _mm256_min_ps(vmin0, vt), _mm256_castsi256_ps(vmask)); | ||
$if EMIT_MAX: | ||
vmax0 = _mm256_blendv_ps(vmax0, _mm256_max_ps(vmax0, vt), _mm256_castsi256_ps(vmask)); | ||
$if EMIT_SUM: | ||
vsum0 = _mm256_blendv_ps(vsum0, _mm256_add_ps(vsum0, vt), _mm256_castsi256_ps(vmask)); | ||
} | ||
$if EMIT_MIN: | ||
__m128 vmin = _mm_min_ps(_mm256_castps256_ps128(vmin0), _mm256_extractf128_ps(vmin0, 1)); | ||
$if EMIT_MAX: | ||
__m128 vmax = _mm_max_ps(_mm256_castps256_ps128(vmax0), _mm256_extractf128_ps(vmax0, 1)); | ||
$if EMIT_SUM: | ||
__m128 vsum = _mm_add_ps(_mm256_castps256_ps128(vsum0), _mm256_extractf128_ps(vsum0, 1)); | ||
|
||
$if EMIT_MIN: | ||
vmin = _mm_min_ps(vmin, _mm_movehl_ps(vmin, vmin)); | ||
$if EMIT_MAX: | ||
vmax = _mm_max_ps(vmax, _mm_movehl_ps(vmax, vmax)); | ||
$if EMIT_SUM: | ||
vsum = _mm_add_ps(vsum, _mm_movehl_ps(vsum, vsum)); | ||
$if EMIT_MIN: | ||
vmin = _mm_min_ss(vmin, _mm_movehdup_ps(vmin)); | ||
$if EMIT_MAX: | ||
vmax = _mm_max_ss(vmax, _mm_movehdup_ps(vmax)); | ||
$if EMIT_SUM: | ||
vsum = _mm_add_ss(vsum, _mm_movehdup_ps(vsum)); | ||
$if EMIT_MIN: | ||
_mm_store_ss(output, vmin); | ||
$if EMIT_MAX: | ||
_mm_store_ss(output + 1, vmax); | ||
$if EMIT_SUM: | ||
_mm_store_ss(output + 2, vsum); | ||
$elif EMIT_MAX: | ||
_mm_store_ss(output, vmax); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
// Copyright 2023 Google LLC | ||
// | ||
// This source code is licensed under the BSD-style license found in the | ||
// LICENSE file in the root directory of this source tree. | ||
|
||
$assert BATCH_TILE % 16 == 0 | ||
$assert BATCH_TILE >= 16 | ||
$SIMD_TILE = BATCH_TILE // 16 | ||
$assert ACCUMULATORS <= SIMD_TILE | ||
$assert OP in ["MAX", "MIN", "MINMAX", "MINMAXSUM"] | ||
#include <assert.h> | ||
|
||
#include <immintrin.h> | ||
|
||
#include <xnnpack/common.h> | ||
#include <xnnpack/reduce.h> | ||
|
||
|
||
$ACC_SUFFIX = "" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS | ||
$EMIT_MIN = "MIN" in OP | ||
$EMIT_MAX = "MAX" in OP | ||
$EMIT_SUM = "SUM" in OP | ||
void xnn_f32_r${OP.lower()}_ukernel__avx512f_u${BATCH_TILE}${ACC_SUFFIX}( | ||
size_t batch, | ||
const float* input, | ||
float* output, | ||
const union xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS(1)]) | ||
{ | ||
assert(batch != 0); | ||
assert(batch % sizeof(float) == 0); | ||
assert(input != NULL); | ||
assert(output != NULL); | ||
|
||
$if EMIT_MIN: | ||
__m512 vmin0 = _mm512_set1_ps(*input); | ||
$if EMIT_MAX: | ||
__m512 vmax0 = vmin0; | ||
$elif EMIT_MAX: | ||
__m512 vmax0 = _mm512_set1_ps(*input); | ||
$if EMIT_SUM: | ||
__m512 vsum0 = _mm512_setzero_ps(); | ||
$for A in range(1, ACCUMULATORS): | ||
$if EMIT_MIN: | ||
__m512 vmin${A} = vmin0; | ||
$if EMIT_MAX: | ||
__m512 vmax${A} = vmax0; | ||
$if EMIT_SUM: | ||
__m512 vsum${A} = vsum0; | ||
$if BATCH_TILE > 16: | ||
for (; batch >= ${BATCH_TILE} * sizeof(float); batch -= ${BATCH_TILE} * sizeof(float)) { | ||
const __m512 vt0 = _mm512_loadu_ps(input); | ||
$for N in range(1, SIMD_TILE): | ||
const __m512 vt${N} = _mm512_loadu_ps(input + ${N * 16}); | ||
input += ${BATCH_TILE}; | ||
|
||
$for N in range(SIMD_TILE): | ||
$if EMIT_MIN: | ||
vmin${N % ACCUMULATORS} = _mm512_min_ps(vmin${N % ACCUMULATORS}, vt${N}); | ||
$if EMIT_MAX: | ||
vmax${N % ACCUMULATORS} = _mm512_max_ps(vmax${N % ACCUMULATORS}, vt${N}); | ||
$if EMIT_SUM: | ||
vsum${N % ACCUMULATORS} = _mm512_add_ps(vsum${N % ACCUMULATORS}, vt${N}); | ||
} | ||
$if ACCUMULATORS > 1: | ||
$ACC_SLICE = 1 | ||
$while ACC_SLICE < ACCUMULATORS: | ||
$for A in range(0, ACCUMULATORS, ACC_SLICE * 2): | ||
$if A + ACC_SLICE < ACCUMULATORS: | ||
$if EMIT_MIN: | ||
vmin${A} = _mm512_min_ps(vmin${A}, vmin${A + ACC_SLICE}); | ||
$if EMIT_MAX: | ||
vmax${A} = _mm512_max_ps(vmax${A}, vmax${A + ACC_SLICE}); | ||
$if EMIT_SUM: | ||
vsum${A} = _mm512_add_ps(vsum${A}, vsum${A + ACC_SLICE}); | ||
$ACC_SLICE *= 2 | ||
for (; batch >= 16 * sizeof(float); batch -= 16 * sizeof(float)) { | ||
const __m512 vt = _mm512_loadu_ps(input); | ||
input += 16; | ||
|
||
$if EMIT_MIN: | ||
vmin0 = _mm512_min_ps(vmin0, vt); | ||
$if EMIT_MAX: | ||
vmax0 = _mm512_max_ps(vmax0, vt); | ||
$if EMIT_SUM: | ||
vsum0 = _mm512_add_ps(vsum0, vt); | ||
} | ||
if XNN_UNLIKELY(batch != 0) { | ||
assert(batch >= 1 * sizeof(float)); | ||
assert(batch <= 15 * sizeof(float)); | ||
|
||
// Prepare mask for valid elements (depends on batch). | ||
batch >>= XNN_LOG2_SIZEOF_FLOAT; | ||
const __mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); | ||
|
||
const __m512 vt = _mm512_maskz_loadu_ps(vmask, input); | ||
|
||
$if EMIT_MIN: | ||
vmin0 = _mm512_mask_min_ps(vmin0, vmask, vmin0, vt); | ||
$if EMIT_MAX: | ||
vmax0 = _mm512_mask_max_ps(vmax0, vmask, vmax0, vt); | ||
$if EMIT_SUM: | ||
vsum0 = _mm512_mask_add_ps(vsum0, vmask, vsum0, vt); | ||
} | ||
$if EMIT_MIN: | ||
__m256 vmin256 = _mm256_min_ps(_mm512_castps512_ps256(vmin0), _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(vmin0), 1))); | ||
$if EMIT_MAX: | ||
__m256 vmax256 = _mm256_max_ps(_mm512_castps512_ps256(vmax0), _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(vmax0), 1))); | ||
$if EMIT_SUM: | ||
__m256 vsum256 = _mm256_add_ps(_mm512_castps512_ps256(vsum0), _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(vsum0), 1))); | ||
$if EMIT_MIN: | ||
__m128 vmin = _mm_min_ps(_mm256_castps256_ps128(vmin256), _mm256_extractf128_ps(vmin256, 1)); | ||
$if EMIT_MAX: | ||
__m128 vmax = _mm_max_ps(_mm256_castps256_ps128(vmax256), _mm256_extractf128_ps(vmax256, 1)); | ||
$if EMIT_SUM: | ||
__m128 vsum = _mm_add_ps(_mm256_castps256_ps128(vsum256), _mm256_extractf128_ps(vsum256, 1)); | ||
$if EMIT_MIN: | ||
vmin = _mm_min_ps(vmin, _mm_movehl_ps(vmin, vmin)); | ||
$if EMIT_MAX: | ||
vmax = _mm_max_ps(vmax, _mm_movehl_ps(vmax, vmax)); | ||
$if EMIT_SUM: | ||
vsum = _mm_add_ps(vsum, _mm_movehl_ps(vsum, vsum)); | ||
$if EMIT_MIN: | ||
vmin = _mm_min_ss(vmin, _mm_movehdup_ps(vmin)); | ||
$if EMIT_MAX: | ||
vmax = _mm_max_ss(vmax, _mm_movehdup_ps(vmax)); | ||
$if EMIT_SUM: | ||
vsum = _mm_add_ss(vsum, _mm_movehdup_ps(vsum)); | ||
$if EMIT_MIN: | ||
_mm_store_ss(output, vmin); | ||
$if EMIT_MAX: | ||
_mm_store_ss(output + 1, vmax); | ||
$if EMIT_SUM: | ||
_mm_store_ss(output + 2, vsum); | ||
$elif EMIT_MAX: | ||
_mm_store_ss(output, vmax); | ||
} |
Oops, something went wrong.