Skip to content

Commit 5be5be8

Browse files
gonnetxnnpack-bot
authored andcommitted
Switch to the new rational_9_6 microkernels for f32-vtanh.
PiperOrigin-RevId: 629396487
1 parent 7189ed9 commit 5be5be8

File tree

9 files changed

+492
-1762
lines changed

9 files changed

+492
-1762
lines changed

src/amalgam/gen/avx.c

Lines changed: 100 additions & 250 deletions
Large diffs are not rendered by default.

src/amalgam/gen/avx2.c

Lines changed: 0 additions & 219 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
// LICENSE file in the root directory of this source tree.
55

66
#include <assert.h>
7-
#include <stddef.h>
87
#include <stdint.h>
98
#include <string.h>
109

@@ -17,7 +16,6 @@
1716
#include <xnnpack/intrinsics-polyfill.h>
1817
#include <xnnpack/lut.h>
1918
#include <xnnpack/math.h>
20-
#include <xnnpack/microparams.h>
2119
#include <xnnpack/packw.h>
2220
#include <xnnpack/pavgpool.h>
2321
#include <xnnpack/prefetch.h>
@@ -2972,223 +2970,6 @@ void xnn_f32_vsigmoid_ukernel__avx2_rr1_p5_div_u40(
29722970
}
29732971
}
29742972

2975-
void xnn_f32_vtanh_ukernel__avx2_expm1minus_rr1_lut4_p4h3ts_perm_div_u32(
2976-
size_t batch,
2977-
const float* input,
2978-
float* output,
2979-
const union xnn_f32_tanh_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
2980-
{
2981-
assert(batch != 0);
2982-
assert(batch % sizeof(float) == 0);
2983-
assert(input != NULL);
2984-
assert(output != NULL);
2985-
2986-
const __m256 vsign_mask = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.sign_mask);
2987-
const __m256 vsat_cutoff = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.sat_cutoff);
2988-
const __m256 vlog2e = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.log2e);
2989-
const __m256 vmagic_bias = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.magic_bias);
2990-
const __m256 vtable = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.table);
2991-
const __m256 vminus_ln2 = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.minus_ln2);
2992-
const __m256 vc4 = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.c4);
2993-
const __m256 vc3 = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.c3);
2994-
const __m256 vc2 = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.c2);
2995-
const __m256 vtwo = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.two);
2996-
const __m256 vminus_one = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.minus_one);
2997-
2998-
for (; batch >= 32 * sizeof(float); batch -= 32 * sizeof(float)) {
2999-
const __m256 vx0 = _mm256_loadu_ps(input);
3000-
const __m256 vx1 = _mm256_loadu_ps(input + 8);
3001-
const __m256 vx2 = _mm256_loadu_ps(input + 16);
3002-
const __m256 vx3 = _mm256_loadu_ps(input + 24);
3003-
input += 32;
3004-
3005-
__m256 vz0 = _mm256_or_ps(vx0, vsign_mask);
3006-
__m256 vz1 = _mm256_or_ps(vx1, vsign_mask);
3007-
__m256 vz2 = _mm256_or_ps(vx2, vsign_mask);
3008-
__m256 vz3 = _mm256_or_ps(vx3, vsign_mask);
3009-
3010-
const __m256 vinvsignx0 = _mm256_xor_ps(vx0, vz0);
3011-
vz0 = _mm256_max_ps(vsat_cutoff, vz0);
3012-
const __m256 vinvsignx1 = _mm256_xor_ps(vx1, vz1);
3013-
vz1 = _mm256_max_ps(vsat_cutoff, vz1);
3014-
const __m256 vinvsignx2 = _mm256_xor_ps(vx2, vz2);
3015-
vz2 = _mm256_max_ps(vsat_cutoff, vz2);
3016-
const __m256 vinvsignx3 = _mm256_xor_ps(vx3, vz3);
3017-
vz3 = _mm256_max_ps(vsat_cutoff, vz3);
3018-
3019-
__m256 vn0 = _mm256_fmadd_ps(vz0, vlog2e, vmagic_bias);
3020-
__m256 vn1 = _mm256_fmadd_ps(vz1, vlog2e, vmagic_bias);
3021-
__m256 vn2 = _mm256_fmadd_ps(vz2, vlog2e, vmagic_bias);
3022-
__m256 vn3 = _mm256_fmadd_ps(vz3, vlog2e, vmagic_bias);
3023-
3024-
const __m256i ve0 = _mm256_slli_epi32(_mm256_castps_si256(vn0), 21);
3025-
const __m256i ve1 = _mm256_slli_epi32(_mm256_castps_si256(vn1), 21);
3026-
const __m256i ve2 = _mm256_slli_epi32(_mm256_castps_si256(vn2), 21);
3027-
const __m256i ve3 = _mm256_slli_epi32(_mm256_castps_si256(vn3), 21);
3028-
3029-
const __m256i vl0 = _mm256_castps_si256(_mm256_permutevar_ps(vtable, _mm256_castps_si256(vn0)));
3030-
const __m256i vl1 = _mm256_castps_si256(_mm256_permutevar_ps(vtable, _mm256_castps_si256(vn1)));
3031-
const __m256i vl2 = _mm256_castps_si256(_mm256_permutevar_ps(vtable, _mm256_castps_si256(vn2)));
3032-
const __m256i vl3 = _mm256_castps_si256(_mm256_permutevar_ps(vtable, _mm256_castps_si256(vn3)));
3033-
3034-
const __m256 vs0 = _mm256_castsi256_ps(_mm256_add_epi32(vl0, ve0));
3035-
const __m256 vs1 = _mm256_castsi256_ps(_mm256_add_epi32(vl1, ve1));
3036-
const __m256 vs2 = _mm256_castsi256_ps(_mm256_add_epi32(vl2, ve2));
3037-
const __m256 vs3 = _mm256_castsi256_ps(_mm256_add_epi32(vl3, ve3));
3038-
3039-
vn0 = _mm256_sub_ps(vn0, vmagic_bias);
3040-
vn1 = _mm256_sub_ps(vn1, vmagic_bias);
3041-
vn2 = _mm256_sub_ps(vn2, vmagic_bias);
3042-
vn3 = _mm256_sub_ps(vn3, vmagic_bias);
3043-
3044-
const __m256 vt0 = _mm256_fmadd_ps(vn0, vminus_ln2, vz0);
3045-
const __m256 vt1 = _mm256_fmadd_ps(vn1, vminus_ln2, vz1);
3046-
const __m256 vt2 = _mm256_fmadd_ps(vn2, vminus_ln2, vz2);
3047-
const __m256 vt3 = _mm256_fmadd_ps(vn3, vminus_ln2, vz3);
3048-
3049-
__m256 vp0 = vc4;
3050-
__m256 vp1 = vc4;
3051-
__m256 vp2 = vc4;
3052-
__m256 vp3 = vc4;
3053-
vp0 = _mm256_fmadd_ps(vp0, vt0, vc3);
3054-
vp1 = _mm256_fmadd_ps(vp1, vt1, vc3);
3055-
vp2 = _mm256_fmadd_ps(vp2, vt2, vc3);
3056-
vp3 = _mm256_fmadd_ps(vp3, vt3, vc3);
3057-
vp0 = _mm256_fmadd_ps(vp0, vt0, vc2);
3058-
vp1 = _mm256_fmadd_ps(vp1, vt1, vc2);
3059-
vp2 = _mm256_fmadd_ps(vp2, vt2, vc2);
3060-
vp3 = _mm256_fmadd_ps(vp3, vt3, vc2);
3061-
vp0 = _mm256_fmadd_ps(vp0, vt0, vtwo);
3062-
vp1 = _mm256_fmadd_ps(vp1, vt1, vtwo);
3063-
vp2 = _mm256_fmadd_ps(vp2, vt2, vtwo);
3064-
vp3 = _mm256_fmadd_ps(vp3, vt3, vtwo);
3065-
3066-
const __m256 vts0 = _mm256_mul_ps(vt0, vs0);
3067-
const __m256 vsmo0 = _mm256_add_ps(vs0, vminus_one);
3068-
const __m256 vts1 = _mm256_mul_ps(vt1, vs1);
3069-
const __m256 vsmo1 = _mm256_add_ps(vs1, vminus_one);
3070-
const __m256 vts2 = _mm256_mul_ps(vt2, vs2);
3071-
const __m256 vsmo2 = _mm256_add_ps(vs2, vminus_one);
3072-
const __m256 vts3 = _mm256_mul_ps(vt3, vs3);
3073-
const __m256 vsmo3 = _mm256_add_ps(vs3, vminus_one);
3074-
const __m256 vemo0 = _mm256_fmadd_ps(vp0, vts0, vsmo0);
3075-
const __m256 vemo1 = _mm256_fmadd_ps(vp1, vts1, vsmo1);
3076-
const __m256 vemo2 = _mm256_fmadd_ps(vp2, vts2, vsmo2);
3077-
const __m256 vemo3 = _mm256_fmadd_ps(vp3, vts3, vsmo3);
3078-
const __m256 vepo0 = _mm256_add_ps(vemo0, vtwo);
3079-
const __m256 vepo1 = _mm256_add_ps(vemo1, vtwo);
3080-
const __m256 vepo2 = _mm256_add_ps(vemo2, vtwo);
3081-
const __m256 vepo3 = _mm256_add_ps(vemo3, vtwo);
3082-
3083-
__m256 vy0 = _mm256_div_ps(vemo0, vepo0);
3084-
__m256 vy1 = _mm256_div_ps(vemo1, vepo1);
3085-
__m256 vy2 = _mm256_div_ps(vemo2, vepo2);
3086-
__m256 vy3 = _mm256_div_ps(vemo3, vepo3);
3087-
3088-
vy0 = _mm256_xor_ps(vy0, vinvsignx0);
3089-
vy1 = _mm256_xor_ps(vy1, vinvsignx1);
3090-
vy2 = _mm256_xor_ps(vy2, vinvsignx2);
3091-
vy3 = _mm256_xor_ps(vy3, vinvsignx3);
3092-
3093-
_mm256_storeu_ps(output, vy0);
3094-
_mm256_storeu_ps(output + 8, vy1);
3095-
_mm256_storeu_ps(output + 16, vy2);
3096-
_mm256_storeu_ps(output + 24, vy3);
3097-
output += 32;
3098-
}
3099-
for (; batch >= 8 * sizeof(float); batch -= 8 * sizeof(float)) {
3100-
const __m256 vx = _mm256_loadu_ps(input);
3101-
input += 8;
3102-
3103-
__m256 vz = _mm256_or_ps(vx, vsign_mask);
3104-
3105-
const __m256 vinvsignx = _mm256_xor_ps(vx, vz);
3106-
vz = _mm256_max_ps(vsat_cutoff, vz);
3107-
3108-
__m256 vn = _mm256_fmadd_ps(vz, vlog2e, vmagic_bias);
3109-
3110-
const __m256i ve = _mm256_slli_epi32(_mm256_castps_si256(vn), 21);
3111-
3112-
const __m256i vl = _mm256_castps_si256(_mm256_permutevar_ps(vtable, _mm256_castps_si256(vn)));
3113-
3114-
const __m256 vs = _mm256_castsi256_ps(_mm256_add_epi32(vl, ve));
3115-
3116-
vn = _mm256_sub_ps(vn, vmagic_bias);
3117-
3118-
const __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vz);
3119-
3120-
__m256 vp = vc4;
3121-
vp = _mm256_fmadd_ps(vp, vt, vc3);
3122-
vp = _mm256_fmadd_ps(vp, vt, vc2);
3123-
vp = _mm256_fmadd_ps(vp, vt, vtwo);
3124-
3125-
const __m256 vts = _mm256_mul_ps(vt, vs);
3126-
const __m256 vsmo = _mm256_add_ps(vs, vminus_one);
3127-
const __m256 vemo = _mm256_fmadd_ps(vp, vts, vsmo);
3128-
const __m256 vepo = _mm256_add_ps(vemo, vtwo);
3129-
3130-
__m256 vy = _mm256_div_ps(vemo, vepo);
3131-
3132-
vy = _mm256_xor_ps(vy, vinvsignx);
3133-
3134-
_mm256_storeu_ps(output, vy);
3135-
output += 8;
3136-
}
3137-
if XNN_UNLIKELY(batch != 0) {
3138-
assert(batch >= 1 * sizeof(float));
3139-
assert(batch <= 7 * sizeof(float));
3140-
const __m256i vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) &params->avx_expm1minus_rr1_lut4_p4h3_perm.mask_table[7] - batch));
3141-
3142-
const __m256 vx = _mm256_maskload_ps(input, vmask);
3143-
3144-
__m256 vz = _mm256_or_ps(vx, vsign_mask);
3145-
3146-
const __m256 vinvsignx = _mm256_xor_ps(vx, vz);
3147-
vz = _mm256_max_ps(vsat_cutoff, vz);
3148-
3149-
__m256 vn = _mm256_fmadd_ps(vz, vlog2e, vmagic_bias);
3150-
3151-
const __m256i ve = _mm256_slli_epi32(_mm256_castps_si256(vn), 21);
3152-
3153-
const __m256i vl = _mm256_castps_si256(_mm256_permutevar_ps(vtable, _mm256_castps_si256(vn)));
3154-
3155-
const __m256 vs = _mm256_castsi256_ps(_mm256_add_epi32(vl, ve));
3156-
3157-
vn = _mm256_sub_ps(vn, vmagic_bias);
3158-
3159-
const __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vz);
3160-
3161-
__m256 vp = vc4;
3162-
vp = _mm256_fmadd_ps(vp, vt, vc3);
3163-
vp = _mm256_fmadd_ps(vp, vt, vc2);
3164-
vp = _mm256_fmadd_ps(vp, vt, vtwo);
3165-
3166-
const __m256 vts = _mm256_mul_ps(vt, vs);
3167-
const __m256 vsmo = _mm256_add_ps(vs, vminus_one);
3168-
const __m256 vemo = _mm256_fmadd_ps(vp, vts, vsmo);
3169-
const __m256 vepo = _mm256_add_ps(vemo, vtwo);
3170-
3171-
__m256 vy = _mm256_div_ps(vemo, vepo);
3172-
3173-
vy = _mm256_xor_ps(vy, vinvsignx);
3174-
3175-
__m128 vy_lo = _mm256_castps256_ps128(vy);
3176-
if (batch & (4 * sizeof(float))) {
3177-
_mm_storeu_ps(output, vy_lo);
3178-
vy_lo = _mm256_extractf128_ps(vy, 1);
3179-
output += 4;
3180-
}
3181-
if (batch & (2 * sizeof(float))) {
3182-
_mm_storel_pi((__m64*) output, vy_lo);
3183-
vy_lo = _mm_movehl_ps(vy_lo, vy_lo);
3184-
output += 2;
3185-
}
3186-
if (batch & (1 * sizeof(float))) {
3187-
_mm_store_ss(output, vy_lo);
3188-
}
3189-
}
3190-
}
3191-
31922973
void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx2(
31932974
size_t mr,
31942975
size_t nc,

src/amalgam/gen/avx512f.c

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <xnnpack/igemm.h>
1616
#include <xnnpack/intrinsics-polyfill.h>
1717
#include <xnnpack/math.h>
18+
#include <xnnpack/microparams.h>
1819
#include <xnnpack/packw.h>
1920
#include <xnnpack/prefetch.h>
2021
#include <xnnpack/prelu.h>
@@ -3922,6 +3923,108 @@ void xnn_f32_vsqrt_ukernel__avx512f_rsqrt_u16(
39223923
}
39233924
}
39243925

3926+
void xnn_f32_vtanh_ukernel__avx512f_rational_9_6_nr_u16(
3927+
size_t batch,
3928+
const float* input,
3929+
float* output,
3930+
const union xnn_f32_tanh_params params[restrict XNN_MIN_ELEMENTS(1)])
3931+
{
3932+
assert(batch != 0);
3933+
assert(batch % sizeof(float) == 0);
3934+
assert(input != NULL);
3935+
assert(output != NULL);
3936+
3937+
// Cap the inputs to this value as `tanh(x)` will always be `+/-1.0f` beyond
3938+
// this point. This value is chosen as the first floating point number as of
3939+
// which the interpolation returns 1.0f.
3940+
const __m512 vmax_x = _mm512_set1_ps(params->avx512_rational_9_6.max_abs_x);
3941+
const __m512 vmin_x = _mm512_set1_ps(-params->avx512_rational_9_6.max_abs_x);
3942+
3943+
// The monomial coefficients of the numerator polynomial (odd).
3944+
const __m512 valpha_1 = _mm512_set1_ps(params->avx512_rational_9_6.alpha_1);
3945+
const __m512 valpha_3 = _mm512_set1_ps(params->avx512_rational_9_6.alpha_3);
3946+
const __m512 valpha_5 = _mm512_set1_ps(params->avx512_rational_9_6.alpha_5);
3947+
const __m512 valpha_7 = _mm512_set1_ps(params->avx512_rational_9_6.alpha_7);
3948+
const __m512 valpha_9 = _mm512_set1_ps(params->avx512_rational_9_6.alpha_9);
3949+
3950+
// The monomial coefficients of the denominator polynomial (even).
3951+
const __m512 vbeta_0 = _mm512_set1_ps(params->avx512_rational_9_6.beta_0);
3952+
const __m512 vbeta_2 = _mm512_set1_ps(params->avx512_rational_9_6.beta_2);
3953+
const __m512 vbeta_4 = _mm512_set1_ps(params->avx512_rational_9_6.beta_4);
3954+
const __m512 vbeta_6 = _mm512_set1_ps(params->avx512_rational_9_6.beta_6);
3955+
3956+
// Constant needed for the Newton-Raphson iteration of the reciprocal.
3957+
const __m512 vtwo = _mm512_set1_ps(params->avx512_rational_9_6.two);
3958+
3959+
for (; batch >= 16 * sizeof(float); batch -= 16 * sizeof(float)) {
3960+
__m512 vx = _mm512_loadu_ps(input);
3961+
input += 16;
3962+
3963+
// Clamp the inputs to the interpolation range.
3964+
vx = _mm512_min_ps(vmax_x, vx);
3965+
vx = _mm512_max_ps(vmin_x, vx);
3966+
3967+
// Since the polynomials are odd/even, we need x^2.
3968+
const __m512 vx2 = _mm512_mul_ps(vx, vx);
3969+
3970+
// Evaluate the numerator polynomial p.
3971+
__m512 vp = _mm512_fmadd_ps(vx2, valpha_9, valpha_7);
3972+
vp = _mm512_fmadd_ps(vx2, vp, valpha_5);
3973+
vp = _mm512_fmadd_ps(vx2, vp, valpha_3);
3974+
vp = _mm512_fmadd_ps(vx2, vp, valpha_1);
3975+
vp = _mm512_mul_ps(vx, vp);
3976+
3977+
// Evaluate the denominator polynomial q.
3978+
__m512 vq = _mm512_fmadd_ps(vx2, vbeta_6, vbeta_4);
3979+
vq = _mm512_fmadd_ps(vx2, vq, vbeta_2);
3980+
vq = _mm512_fmadd_ps(vx2, vq, vbeta_0);
3981+
3982+
// Divide the numerator by the denominator.
3983+
const __m512 vt0 = _mm512_rcp14_ps(vq);
3984+
const __m512 vt1 = _mm512_mul_ps(vt0, _mm512_fnmadd_ps(vt0, vq, vtwo));
3985+
const __m512 vy = _mm512_mul_ps(vp, vt1);
3986+
3987+
_mm512_storeu_ps(output, vy);
3988+
output += 16;
3989+
}
3990+
if XNN_UNLIKELY(batch != 0) {
3991+
assert(batch >= 1 * sizeof(float));
3992+
assert(batch <= 15 * sizeof(float));
3993+
3994+
// Prepare mask for valid 32-bit elements (depends on batch).
3995+
batch >>= XNN_LOG2_SIZEOF_FLOAT;
3996+
const __mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1)));
3997+
3998+
__m512 vx = _mm512_maskz_loadu_ps(vmask, input);
3999+
4000+
// Clamp the inputs to the interpolation range.
4001+
vx = _mm512_min_ps(vmax_x, vx);
4002+
vx = _mm512_max_ps(vmin_x, vx);
4003+
4004+
// Since the polynomials are odd/even, we need x^2.
4005+
const __m512 vx2 = _mm512_mul_ps(vx, vx);
4006+
4007+
// Evaluate the numerator polynomial p.
4008+
__m512 vp = _mm512_fmadd_ps(vx2, valpha_9, valpha_7);
4009+
vp = _mm512_fmadd_ps(vx2, vp, valpha_5);
4010+
vp = _mm512_fmadd_ps(vx2, vp, valpha_3);
4011+
vp = _mm512_fmadd_ps(vx2, vp, valpha_1);
4012+
vp = _mm512_mul_ps(vx, vp);
4013+
4014+
// Evaluate the denominator polynomial q.
4015+
__m512 vq = _mm512_fmadd_ps(vx2, vbeta_6, vbeta_4);
4016+
vq = _mm512_fmadd_ps(vx2, vq, vbeta_2);
4017+
vq = _mm512_fmadd_ps(vx2, vq, vbeta_0);
4018+
4019+
// Divide the numerator by the denominator.
4020+
const __m512 vt0 = _mm512_rcp14_ps(vq);
4021+
const __m512 vt1 = _mm512_mul_ps(vt0, _mm512_fnmadd_ps(vt0, vq, vtwo));
4022+
const __m512 vy = _mm512_mul_ps(vp, vt1);
4023+
4024+
_mm512_mask_storeu_ps(output, vmask, vy);
4025+
}
4026+
}
4027+
39254028
void xnn_f32_vabs_ukernel__avx512f_u16(
39264029
size_t batch,
39274030
const float* input,

0 commit comments

Comments
 (0)