Skip to content

Commit

Permalink
F16-RMINMAX microkernels using AVX512 FP16 arithmetics
Browse files Browse the repository at this point in the history
- F16-RMINMAX and F16-RMIN microkernels using FP16 arithmetics

PiperOrigin-RevId: 630721869
  • Loading branch information
fbarchard authored and xnnpack-bot committed May 6, 2024
1 parent 0994a40 commit f49c9ce
Show file tree
Hide file tree
Showing 21 changed files with 1,277 additions and 7 deletions.
33 changes: 33 additions & 0 deletions bench/f16-rmin.cc
Expand Up @@ -98,6 +98,39 @@ static void f16_rmin(
->UseRealTime();
#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64

#if XNN_ENABLE_AVX512FP16 && (XNN_ARCH_X86 || XNN_ARCH_X86_64)
BENCHMARK_CAPTURE(f16_rmin, avx512fp16_u32,
xnn_f16_rmin_ukernel__avx512fp16_u32,
/*init_params=*/nullptr,
benchmark::utils::CheckAVX512FP16)
->Apply(benchmark::utils::ReductionParameters<uint16_t>)
->UseRealTime();
BENCHMARK_CAPTURE(f16_rmin, avx512fp16_u64_acc2,
xnn_f16_rmin_ukernel__avx512fp16_u64_acc2,
/*init_params=*/nullptr,
benchmark::utils::CheckAVX512FP16)
->Apply(benchmark::utils::ReductionParameters<uint16_t>)
->UseRealTime();
BENCHMARK_CAPTURE(f16_rmin, avx512fp16_u96_acc3,
xnn_f16_rmin_ukernel__avx512fp16_u96_acc3,
/*init_params=*/nullptr,
benchmark::utils::CheckAVX512FP16)
->Apply(benchmark::utils::ReductionParameters<uint16_t>)
->UseRealTime();
BENCHMARK_CAPTURE(f16_rmin, avx512fp16_u128_acc2,
xnn_f16_rmin_ukernel__avx512fp16_u128_acc2,
/*init_params=*/nullptr,
benchmark::utils::CheckAVX512FP16)
->Apply(benchmark::utils::ReductionParameters<uint16_t>)
->UseRealTime();
BENCHMARK_CAPTURE(f16_rmin, avx512fp16_u128_acc4,
xnn_f16_rmin_ukernel__avx512fp16_u128_acc4,
/*init_params=*/nullptr,
benchmark::utils::CheckAVX512FP16)
->Apply(benchmark::utils::ReductionParameters<uint16_t>)
->UseRealTime();
#endif // XNN_ENABLE_AVX512FP16 && (XNN_ARCH_X86 || XNN_ARCH_X86_64)

#if XNN_ARCH_X86 || XNN_ARCH_X86_64
BENCHMARK_CAPTURE(f16_rmin, avx512skx_u16,
xnn_f16_rmin_ukernel__avx512skx_u16,
Expand Down
33 changes: 33 additions & 0 deletions bench/f16-rminmax.cc
Expand Up @@ -98,6 +98,39 @@ static void f16_rminmax(
->UseRealTime();
#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64

#if XNN_ENABLE_AVX512FP16 && (XNN_ARCH_X86 || XNN_ARCH_X86_64)
BENCHMARK_CAPTURE(f16_rminmax, avx512fp16_u32,
xnn_f16_rminmax_ukernel__avx512fp16_u32,
/*init_params=*/nullptr,
benchmark::utils::CheckAVX512FP16)
->Apply(benchmark::utils::ReductionParameters<uint16_t>)
->UseRealTime();
BENCHMARK_CAPTURE(f16_rminmax, avx512fp16_u64_acc2,
xnn_f16_rminmax_ukernel__avx512fp16_u64_acc2,
/*init_params=*/nullptr,
benchmark::utils::CheckAVX512FP16)
->Apply(benchmark::utils::ReductionParameters<uint16_t>)
->UseRealTime();
BENCHMARK_CAPTURE(f16_rminmax, avx512fp16_u96_acc3,
xnn_f16_rminmax_ukernel__avx512fp16_u96_acc3,
/*init_params=*/nullptr,
benchmark::utils::CheckAVX512FP16)
->Apply(benchmark::utils::ReductionParameters<uint16_t>)
->UseRealTime();
BENCHMARK_CAPTURE(f16_rminmax, avx512fp16_u128_acc2,
xnn_f16_rminmax_ukernel__avx512fp16_u128_acc2,
/*init_params=*/nullptr,
benchmark::utils::CheckAVX512FP16)
->Apply(benchmark::utils::ReductionParameters<uint16_t>)
->UseRealTime();
BENCHMARK_CAPTURE(f16_rminmax, avx512fp16_u128_acc4,
xnn_f16_rminmax_ukernel__avx512fp16_u128_acc4,
/*init_params=*/nullptr,
benchmark::utils::CheckAVX512FP16)
->Apply(benchmark::utils::ReductionParameters<uint16_t>)
->UseRealTime();
#endif // XNN_ENABLE_AVX512FP16 && (XNN_ARCH_X86 || XNN_ARCH_X86_64)

#if XNN_ARCH_X86 || XNN_ARCH_X86_64
BENCHMARK_CAPTURE(f16_rminmax, avx512skx_u16,
xnn_f16_rminmax_ukernel__avx512skx_u16,
Expand Down
12 changes: 11 additions & 1 deletion cmake/microkernels.cmake
Expand Up @@ -1548,7 +1548,17 @@ SET(ALL_AVX512FP16_MICROKERNEL_SRCS
src/f16-rminmax/gen/f16-rmax-avx512fp16-u64-acc2.c
src/f16-rminmax/gen/f16-rmax-avx512fp16-u96-acc3.c
src/f16-rminmax/gen/f16-rmax-avx512fp16-u128-acc2.c
src/f16-rminmax/gen/f16-rmax-avx512fp16-u128-acc4.c)
src/f16-rminmax/gen/f16-rmax-avx512fp16-u128-acc4.c
src/f16-rminmax/gen/f16-rmin-avx512fp16-u32.c
src/f16-rminmax/gen/f16-rmin-avx512fp16-u64-acc2.c
src/f16-rminmax/gen/f16-rmin-avx512fp16-u96-acc3.c
src/f16-rminmax/gen/f16-rmin-avx512fp16-u128-acc2.c
src/f16-rminmax/gen/f16-rmin-avx512fp16-u128-acc4.c
src/f16-rminmax/gen/f16-rminmax-avx512fp16-u32.c
src/f16-rminmax/gen/f16-rminmax-avx512fp16-u64-acc2.c
src/f16-rminmax/gen/f16-rminmax-avx512fp16-u96-acc3.c
src/f16-rminmax/gen/f16-rminmax-avx512fp16-u128-acc2.c
src/f16-rminmax/gen/f16-rminmax-avx512fp16-u128-acc4.c)

SET(ALL_AVX512SKX_MICROKERNEL_SRCS
src/f16-f32-vcvt/gen/f16-f32-vcvt-avx512skx-u16.c
Expand Down
10 changes: 10 additions & 0 deletions microkernels.bzl
Expand Up @@ -1550,6 +1550,16 @@ ALL_AVX512FP16_MICROKERNEL_SRCS = [
"src/f16-rminmax/gen/f16-rmax-avx512fp16-u96-acc3.c",
"src/f16-rminmax/gen/f16-rmax-avx512fp16-u128-acc2.c",
"src/f16-rminmax/gen/f16-rmax-avx512fp16-u128-acc4.c",
"src/f16-rminmax/gen/f16-rmin-avx512fp16-u32.c",
"src/f16-rminmax/gen/f16-rmin-avx512fp16-u64-acc2.c",
"src/f16-rminmax/gen/f16-rmin-avx512fp16-u96-acc3.c",
"src/f16-rminmax/gen/f16-rmin-avx512fp16-u128-acc2.c",
"src/f16-rminmax/gen/f16-rmin-avx512fp16-u128-acc4.c",
"src/f16-rminmax/gen/f16-rminmax-avx512fp16-u32.c",
"src/f16-rminmax/gen/f16-rminmax-avx512fp16-u64-acc2.c",
"src/f16-rminmax/gen/f16-rminmax-avx512fp16-u96-acc3.c",
"src/f16-rminmax/gen/f16-rminmax-avx512fp16-u128-acc2.c",
"src/f16-rminmax/gen/f16-rminmax-avx512fp16-u128-acc4.c",
]

ALL_AVX512SKX_MICROKERNEL_SRCS = [
Expand Down
12 changes: 12 additions & 0 deletions scripts/generate-f16-rminmax.sh
Expand Up @@ -30,6 +30,18 @@ tools/xngen src/f16-rminmax/avx512fp16.c.in -D BATCH_TILE=96 -D ACCUMULATORS=3
tools/xngen src/f16-rminmax/avx512fp16.c.in -D BATCH_TILE=128 -D ACCUMULATORS=2 -D OP=MAX -o src/f16-rminmax/gen/f16-rmax-avx512fp16-u128-acc2.c &
tools/xngen src/f16-rminmax/avx512fp16.c.in -D BATCH_TILE=128 -D ACCUMULATORS=4 -D OP=MAX -o src/f16-rminmax/gen/f16-rmax-avx512fp16-u128-acc4.c &

tools/xngen src/f16-rminmax/avx512fp16.c.in -D BATCH_TILE=32 -D ACCUMULATORS=1 -D OP=MIN -o src/f16-rminmax/gen/f16-rmin-avx512fp16-u32.c &
tools/xngen src/f16-rminmax/avx512fp16.c.in -D BATCH_TILE=64 -D ACCUMULATORS=2 -D OP=MIN -o src/f16-rminmax/gen/f16-rmin-avx512fp16-u64-acc2.c &
tools/xngen src/f16-rminmax/avx512fp16.c.in -D BATCH_TILE=96 -D ACCUMULATORS=3 -D OP=MIN -o src/f16-rminmax/gen/f16-rmin-avx512fp16-u96-acc3.c &
tools/xngen src/f16-rminmax/avx512fp16.c.in -D BATCH_TILE=128 -D ACCUMULATORS=2 -D OP=MIN -o src/f16-rminmax/gen/f16-rmin-avx512fp16-u128-acc2.c &
tools/xngen src/f16-rminmax/avx512fp16.c.in -D BATCH_TILE=128 -D ACCUMULATORS=4 -D OP=MIN -o src/f16-rminmax/gen/f16-rmin-avx512fp16-u128-acc4.c &

tools/xngen src/f16-rminmax/avx512fp16.c.in -D BATCH_TILE=32 -D ACCUMULATORS=1 -D OP=MINMAX -o src/f16-rminmax/gen/f16-rminmax-avx512fp16-u32.c &
tools/xngen src/f16-rminmax/avx512fp16.c.in -D BATCH_TILE=64 -D ACCUMULATORS=2 -D OP=MINMAX -o src/f16-rminmax/gen/f16-rminmax-avx512fp16-u64-acc2.c &
tools/xngen src/f16-rminmax/avx512fp16.c.in -D BATCH_TILE=96 -D ACCUMULATORS=3 -D OP=MINMAX -o src/f16-rminmax/gen/f16-rminmax-avx512fp16-u96-acc3.c &
tools/xngen src/f16-rminmax/avx512fp16.c.in -D BATCH_TILE=128 -D ACCUMULATORS=2 -D OP=MINMAX -o src/f16-rminmax/gen/f16-rminmax-avx512fp16-u128-acc2.c &
tools/xngen src/f16-rminmax/avx512fp16.c.in -D BATCH_TILE=128 -D ACCUMULATORS=4 -D OP=MINMAX -o src/f16-rminmax/gen/f16-rminmax-avx512fp16-u128-acc4.c &

################################## x86 AVX512SKX #################################
tools/xngen src/f16-rminmax/avx512skx.c.in -D BATCH_TILE=16 -D ACCUMULATORS=1 -D OP=MAX -o src/f16-rminmax/gen/f16-rmax-avx512skx-u16.c &
tools/xngen src/f16-rminmax/avx512skx.c.in -D BATCH_TILE=32 -D ACCUMULATORS=2 -D OP=MAX -o src/f16-rminmax/gen/f16-rmax-avx512skx-u32-acc2.c &
Expand Down
24 changes: 18 additions & 6 deletions src/amalgam/gen/avx512skx.c
Expand Up @@ -1313,10 +1313,18 @@ void xnn_f32_vtanh_ukernel__avx512skx_expm1minus_rr1_lut4_p4h3ts_perm_div_u64(
const __m512 vx3 = _mm512_loadu_ps(input + 48);
input += 64;

const __m512 vz0 = _mm512_range_ps(vsat_cutoff, vx0, 0xA);
const __m512 vz1 = _mm512_range_ps(vsat_cutoff, vx1, 0xA);
const __m512 vz2 = _mm512_range_ps(vsat_cutoff, vx2, 0xA);
const __m512 vz3 = _mm512_range_ps(vsat_cutoff, vx3, 0xA);
const __mmask16 vnan_mask0 = _mm512_cmp_ps_mask(vx0, vx0, _CMP_EQ_OQ);
__m512 vz0 = _mm512_range_ps(vsat_cutoff, vx0, 0xA);
vz0 = _mm512_mask_blend_ps(vnan_mask0, vx0, vz0);
const __mmask16 vnan_mask1 = _mm512_cmp_ps_mask(vx1, vx1, _CMP_EQ_OQ);
__m512 vz1 = _mm512_range_ps(vsat_cutoff, vx1, 0xA);
vz1 = _mm512_mask_blend_ps(vnan_mask1, vx1, vz1);
const __mmask16 vnan_mask2 = _mm512_cmp_ps_mask(vx2, vx2, _CMP_EQ_OQ);
__m512 vz2 = _mm512_range_ps(vsat_cutoff, vx2, 0xA);
vz2 = _mm512_mask_blend_ps(vnan_mask2, vx2, vz2);
const __mmask16 vnan_mask3 = _mm512_cmp_ps_mask(vx3, vx3, _CMP_EQ_OQ);
__m512 vz3 = _mm512_range_ps(vsat_cutoff, vx3, 0xA);
vz3 = _mm512_mask_blend_ps(vnan_mask3, vx3, vz3);
__m512 vn0 = _mm512_fmadd_ps(vz0, vminus_log2e, vmagic_bias);
__m512 vn1 = _mm512_fmadd_ps(vz1, vminus_log2e, vmagic_bias);
__m512 vn2 = _mm512_fmadd_ps(vz2, vminus_log2e, vmagic_bias);
Expand Down Expand Up @@ -1399,7 +1407,9 @@ void xnn_f32_vtanh_ukernel__avx512skx_expm1minus_rr1_lut4_p4h3ts_perm_div_u64(
const __m512 vx = _mm512_loadu_ps(input);
input += 16;

const __m512 vz = _mm512_range_ps(vsat_cutoff, vx, 0xA);
const __mmask16 vnan_mask = _mm512_cmp_ps_mask(vx, vx, _CMP_EQ_OQ);
__m512 vz = _mm512_range_ps(vsat_cutoff, vx, 0xA);
vz = _mm512_mask_blend_ps(vnan_mask, vx, vz);
__m512 vn = _mm512_fmadd_ps(vz, vminus_log2e, vmagic_bias);

const __m512i ve = _mm512_slli_epi32(_mm512_castps_si512(vn), 21);
Expand Down Expand Up @@ -1438,7 +1448,9 @@ void xnn_f32_vtanh_ukernel__avx512skx_expm1minus_rr1_lut4_p4h3ts_perm_div_u64(

const __m512 vx = _mm512_maskz_loadu_ps(vmask, input);

const __m512 vz = _mm512_range_ps(vsat_cutoff, vx, 0xA);
const __mmask16 vnan_mask = _mm512_cmp_ps_mask(vx, vx, _CMP_EQ_OQ);
__m512 vz = _mm512_range_ps(vsat_cutoff, vx, 0xA);
vz = _mm512_mask_blend_ps(vnan_mask, vx, vz);
__m512 vn = _mm512_fmadd_ps(vz, vminus_log2e, vmagic_bias);

const __m512i ve = _mm512_slli_epi32(_mm512_castps_si512(vn), 21);
Expand Down
72 changes: 72 additions & 0 deletions src/f16-rminmax/gen/f16-rmin-avx512fp16-u128-acc2.c
@@ -0,0 +1,72 @@
// Auto-generated file. Do not edit!
// Template: src/f16-rminmax/avx512fp16.c.in
// Generator: tools/xngen
//
// Copyright 2024 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <assert.h>

#include <immintrin.h>

#include <xnnpack/common.h>
#include <xnnpack/reduce.h>


void xnn_f16_rmin_ukernel__avx512fp16_u128_acc2(
size_t batch,
const void* input,
void* output,
const union xnn_f16_default_params params[restrict XNN_MIN_ELEMENTS(1)])
{
assert(batch != 0);
assert(batch % sizeof(uint16_t) == 0);
assert(input != NULL);
assert(output != NULL);

#if defined(__clang__)
const uint16_t* i = (const uint16_t*) input;
__m512h vmin0 = _mm512_castsi512_ph(_mm512_set1_epi16(*i));
__m512h vmin1 = vmin0;
for (; batch >= 128 * sizeof(uint16_t); batch -= 128 * sizeof(uint16_t)) {
const __m512h vt0 = _mm512_loadu_ph(i);
const __m512h vt1 = _mm512_loadu_ph((i + 32));
const __m512h vt2 = _mm512_loadu_ph((i + 64));
const __m512h vt3 = _mm512_loadu_ph((i + 96));
i += 128;

vmin0 = _mm512_min_ph(vmin0, vt0);
vmin1 = _mm512_min_ph(vmin1, vt1);
vmin0 = _mm512_min_ph(vmin0, vt2);
vmin1 = _mm512_min_ph(vmin1, vt3);
}
vmin0 = _mm512_min_ph(vmin0, vmin1);
for (; batch >= 32 * sizeof(uint16_t); batch -= 32 * sizeof(uint16_t)) {
const __m512h vt = _mm512_loadu_ph(i);
i += 32;

vmin0 = _mm512_min_ph(vmin0, vt);
}
if XNN_UNLIKELY(batch != 0) {
assert(batch >= 1 * sizeof(uint16_t));
assert(batch <= 31 * sizeof(uint16_t));

// Prepare mask for valid elements (depends on batch).
batch >>= XNN_LOG2_SIZEOF_HALF;
const __mmask32 vmask = _cvtu32_mask32((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1)));

const __m512h vt = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(vmask, i));

vmin0 = _mm512_mask_min_ph(vmin0, vmask, vmin0, vt);
}
__m256h vmin256 = _mm256_min_ph(_mm512_castph512_ph256(vmin0), _mm256_castpd_ph(_mm512_extractf64x4_pd(_mm512_castph_pd(vmin0), 1)));
__m128h vmin = _mm_min_ph(_mm256_castph256_ph128(vmin256), _mm_castps_ph(_mm256_extractf128_ps(_mm256_castph_ps(vmin256), 1)));
vmin = _mm_min_ph(vmin, _mm_castps_ph(_mm_movehl_ps(_mm_castph_ps(vmin), _mm_castph_ps(vmin))));
vmin = _mm_min_ph(vmin, _mm_castps_ph(_mm_movehdup_ps(_mm_castph_ps(vmin))));
vmin = _mm_min_sh(vmin, _mm_castsi128_ph(_mm_srli_epi32(_mm_castph_si128(vmin), 16)));

*((uint16_t*) output) = (uint16_t) _mm_extract_epi16(_mm_castph_si128(vmin), 0);
#endif //defined(__clang__)
}
76 changes: 76 additions & 0 deletions src/f16-rminmax/gen/f16-rmin-avx512fp16-u128-acc4.c
@@ -0,0 +1,76 @@
// Auto-generated file. Do not edit!
// Template: src/f16-rminmax/avx512fp16.c.in
// Generator: tools/xngen
//
// Copyright 2024 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <assert.h>

#include <immintrin.h>

#include <xnnpack/common.h>
#include <xnnpack/reduce.h>


void xnn_f16_rmin_ukernel__avx512fp16_u128_acc4(
size_t batch,
const void* input,
void* output,
const union xnn_f16_default_params params[restrict XNN_MIN_ELEMENTS(1)])
{
assert(batch != 0);
assert(batch % sizeof(uint16_t) == 0);
assert(input != NULL);
assert(output != NULL);

#if defined(__clang__)
const uint16_t* i = (const uint16_t*) input;
__m512h vmin0 = _mm512_castsi512_ph(_mm512_set1_epi16(*i));
__m512h vmin1 = vmin0;
__m512h vmin2 = vmin0;
__m512h vmin3 = vmin0;
for (; batch >= 128 * sizeof(uint16_t); batch -= 128 * sizeof(uint16_t)) {
const __m512h vt0 = _mm512_loadu_ph(i);
const __m512h vt1 = _mm512_loadu_ph((i + 32));
const __m512h vt2 = _mm512_loadu_ph((i + 64));
const __m512h vt3 = _mm512_loadu_ph((i + 96));
i += 128;

vmin0 = _mm512_min_ph(vmin0, vt0);
vmin1 = _mm512_min_ph(vmin1, vt1);
vmin2 = _mm512_min_ph(vmin2, vt2);
vmin3 = _mm512_min_ph(vmin3, vt3);
}
vmin0 = _mm512_min_ph(vmin0, vmin1);
vmin2 = _mm512_min_ph(vmin2, vmin3);
vmin0 = _mm512_min_ph(vmin0, vmin2);
for (; batch >= 32 * sizeof(uint16_t); batch -= 32 * sizeof(uint16_t)) {
const __m512h vt = _mm512_loadu_ph(i);
i += 32;

vmin0 = _mm512_min_ph(vmin0, vt);
}
if XNN_UNLIKELY(batch != 0) {
assert(batch >= 1 * sizeof(uint16_t));
assert(batch <= 31 * sizeof(uint16_t));

// Prepare mask for valid elements (depends on batch).
batch >>= XNN_LOG2_SIZEOF_HALF;
const __mmask32 vmask = _cvtu32_mask32((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1)));

const __m512h vt = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(vmask, i));

vmin0 = _mm512_mask_min_ph(vmin0, vmask, vmin0, vt);
}
__m256h vmin256 = _mm256_min_ph(_mm512_castph512_ph256(vmin0), _mm256_castpd_ph(_mm512_extractf64x4_pd(_mm512_castph_pd(vmin0), 1)));
__m128h vmin = _mm_min_ph(_mm256_castph256_ph128(vmin256), _mm_castps_ph(_mm256_extractf128_ps(_mm256_castph_ps(vmin256), 1)));
vmin = _mm_min_ph(vmin, _mm_castps_ph(_mm_movehl_ps(_mm_castph_ps(vmin), _mm_castph_ps(vmin))));
vmin = _mm_min_ph(vmin, _mm_castps_ph(_mm_movehdup_ps(_mm_castph_ps(vmin))));
vmin = _mm_min_sh(vmin, _mm_castsi128_ph(_mm_srli_epi32(_mm_castph_si128(vmin), 16)));

*((uint16_t*) output) = (uint16_t) _mm_extract_epi16(_mm_castph_si128(vmin), 0);
#endif //defined(__clang__)
}
58 changes: 58 additions & 0 deletions src/f16-rminmax/gen/f16-rmin-avx512fp16-u32.c
@@ -0,0 +1,58 @@
// Auto-generated file. Do not edit!
// Template: src/f16-rminmax/avx512fp16.c.in
// Generator: tools/xngen
//
// Copyright 2024 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <assert.h>

#include <immintrin.h>

#include <xnnpack/common.h>
#include <xnnpack/reduce.h>


void xnn_f16_rmin_ukernel__avx512fp16_u32(
size_t batch,
const void* input,
void* output,
const union xnn_f16_default_params params[restrict XNN_MIN_ELEMENTS(1)])
{
assert(batch != 0);
assert(batch % sizeof(uint16_t) == 0);
assert(input != NULL);
assert(output != NULL);

#if defined(__clang__)
const uint16_t* i = (const uint16_t*) input;
__m512h vmin0 = _mm512_castsi512_ph(_mm512_set1_epi16(*i));
for (; batch >= 32 * sizeof(uint16_t); batch -= 32 * sizeof(uint16_t)) {
const __m512h vt = _mm512_loadu_ph(i);
i += 32;

vmin0 = _mm512_min_ph(vmin0, vt);
}
if XNN_UNLIKELY(batch != 0) {
assert(batch >= 1 * sizeof(uint16_t));
assert(batch <= 31 * sizeof(uint16_t));

// Prepare mask for valid elements (depends on batch).
batch >>= XNN_LOG2_SIZEOF_HALF;
const __mmask32 vmask = _cvtu32_mask32((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1)));

const __m512h vt = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(vmask, i));

vmin0 = _mm512_mask_min_ph(vmin0, vmask, vmin0, vt);
}
__m256h vmin256 = _mm256_min_ph(_mm512_castph512_ph256(vmin0), _mm256_castpd_ph(_mm512_extractf64x4_pd(_mm512_castph_pd(vmin0), 1)));
__m128h vmin = _mm_min_ph(_mm256_castph256_ph128(vmin256), _mm_castps_ph(_mm256_extractf128_ps(_mm256_castph_ps(vmin256), 1)));
vmin = _mm_min_ph(vmin, _mm_castps_ph(_mm_movehl_ps(_mm_castph_ps(vmin), _mm_castph_ps(vmin))));
vmin = _mm_min_ph(vmin, _mm_castps_ph(_mm_movehdup_ps(_mm_castph_ps(vmin))));
vmin = _mm_min_sh(vmin, _mm_castsi128_ph(_mm_srli_epi32(_mm_castph_si128(vmin), 16)));

*((uint16_t*) output) = (uint16_t) _mm_extract_epi16(_mm_castph_si128(vmin), 0);
#endif //defined(__clang__)
}

0 comments on commit f49c9ce

Please sign in to comment.