From 5be5be89d703285c592330db463dba48cfff3c15 Mon Sep 17 00:00:00 2001 From: Pedro Gonnet Date: Tue, 30 Apr 2024 06:43:03 -0700 Subject: [PATCH] Switch to the new `rational_9_6` microkernels for `f32-vtanh`. PiperOrigin-RevId: 629396487 --- src/amalgam/gen/avx.c | 350 ++++++-------------- src/amalgam/gen/avx2.c | 219 ------------- src/amalgam/gen/avx512f.c | 103 ++++++ src/amalgam/gen/avx512skx.c | 189 ----------- src/amalgam/gen/fma3.c | 327 ++++++------------ src/amalgam/gen/scalar.c | 211 +++--------- src/amalgam/gen/sse2.c | 438 +++++++------------------ src/amalgam/gen/sse41.c | 367 --------------------- src/configs/unary-elementwise-config.c | 50 ++- 9 files changed, 492 insertions(+), 1762 deletions(-) diff --git a/src/amalgam/gen/avx.c b/src/amalgam/gen/avx.c index 01aafbc98d4..dc395ef7c47 100644 --- a/src/amalgam/gen/avx.c +++ b/src/amalgam/gen/avx.c @@ -6054,243 +6054,106 @@ void xnn_f32_vsqrt_ukernel__avx_rsqrt_u16( } } -void xnn_f32_vtanh_ukernel__avx_expm1minus_rr1_lut4_p4h2ts_perm_div_u48( +void xnn_f32_vtanh_ukernel__avx_rational_9_6_div_u16( size_t batch, const float* input, float* output, - const union xnn_f32_tanh_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS + const union xnn_f32_tanh_params params[restrict XNN_MIN_ELEMENTS(1)]) { assert(batch != 0); assert(batch % sizeof(float) == 0); assert(input != NULL); assert(output != NULL); - const __m256 vsign_mask = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h2_perm.sign_mask); - const __m256 vsat_cutoff = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h2_perm.sat_cutoff); - const __m256 vlog2e = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h2_perm.log2e); - const __m256 vmagic_bias = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h2_perm.magic_bias); - const __m128 vtable = _mm_load_ps(params->avx_expm1minus_rr1_lut4_p4h2_perm.table); - const __m256 vminus_ln2 = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h2_perm.minus_ln2); - const __m256 vc4 = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h2_perm.c4); - const __m256 vc3 = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h2_perm.c3); - const __m256 vc2 = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h2_perm.c2); - const __m256 vtwo = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h2_perm.two); - const __m256 vminus_one = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h2_perm.minus_one); - - for (; batch >= 48 * sizeof(float); batch -= 48 * sizeof(float)) { - const __m256 vx0 = _mm256_loadu_ps(input); - const __m256 vx1 = _mm256_loadu_ps(input + 8); - const __m256 vx2 = _mm256_loadu_ps(input + 16); - const __m256 vx3 = _mm256_loadu_ps(input + 24); - const __m256 vx4 = _mm256_loadu_ps(input + 32); - const __m256 vx5 = _mm256_loadu_ps(input + 40); - input += 48; - - __m256 vz0 = _mm256_or_ps(vx0, vsign_mask); - __m256 vz1 = _mm256_or_ps(vx1, vsign_mask); - __m256 vz2 = _mm256_or_ps(vx2, vsign_mask); - __m256 vz3 = _mm256_or_ps(vx3, vsign_mask); - __m256 vz4 = _mm256_or_ps(vx4, vsign_mask); - __m256 vz5 = _mm256_or_ps(vx5, vsign_mask); - - const __m256 vinvsignx0 = _mm256_xor_ps(vx0, vz0); - vz0 = _mm256_max_ps(vsat_cutoff, vz0); - const __m256 vinvsignx1 = _mm256_xor_ps(vx1, vz1); - vz1 = _mm256_max_ps(vsat_cutoff, vz1); - const __m256 vinvsignx2 = _mm256_xor_ps(vx2, vz2); - vz2 = _mm256_max_ps(vsat_cutoff, vz2); - const __m256 vinvsignx3 = _mm256_xor_ps(vx3, vz3); - vz3 = _mm256_max_ps(vsat_cutoff, vz3); - const __m256 vinvsignx4 = _mm256_xor_ps(vx4, vz4); - vz4 = _mm256_max_ps(vsat_cutoff, vz4); - const __m256 vinvsignx5 = _mm256_xor_ps(vx5, vz5); - vz5 = _mm256_max_ps(vsat_cutoff, vz5); - - __m256 vn0 = _mm256_add_ps(_mm256_mul_ps(vz0, vlog2e), vmagic_bias); - __m256 vn1 = _mm256_add_ps(_mm256_mul_ps(vz1, vlog2e), vmagic_bias); - __m256 vn2 = _mm256_add_ps(_mm256_mul_ps(vz2, vlog2e), vmagic_bias); - __m256 vn3 = _mm256_add_ps(_mm256_mul_ps(vz3, vlog2e), vmagic_bias); - __m256 vn4 = _mm256_add_ps(_mm256_mul_ps(vz4, vlog2e), vmagic_bias); - __m256 vn5 = _mm256_add_ps(_mm256_mul_ps(vz5, vlog2e), vmagic_bias); - - const __m128 vn0_hi = _mm256_extractf128_ps(vn0, 1); - __m128i ve0_lo = _mm_slli_epi32(_mm_castps_si128(_mm256_castps256_ps128(vn0)), 21); - const __m128 vn1_hi = _mm256_extractf128_ps(vn1, 1); - __m128i ve1_lo = _mm_slli_epi32(_mm_castps_si128(_mm256_castps256_ps128(vn1)), 21); - const __m128 vn2_hi = _mm256_extractf128_ps(vn2, 1); - __m128i ve2_lo = _mm_slli_epi32(_mm_castps_si128(_mm256_castps256_ps128(vn2)), 21); - const __m128 vn3_hi = _mm256_extractf128_ps(vn3, 1); - __m128i ve3_lo = _mm_slli_epi32(_mm_castps_si128(_mm256_castps256_ps128(vn3)), 21); - const __m128 vn4_hi = _mm256_extractf128_ps(vn4, 1); - __m128i ve4_lo = _mm_slli_epi32(_mm_castps_si128(_mm256_castps256_ps128(vn4)), 21); - const __m128 vn5_hi = _mm256_extractf128_ps(vn5, 1); - __m128i ve5_lo = _mm_slli_epi32(_mm_castps_si128(_mm256_castps256_ps128(vn5)), 21); - - __m128i ve0_hi = _mm_slli_epi32(_mm_castps_si128(vn0_hi), 21); - const __m128i vl0_lo = _mm_castps_si128(_mm_permutevar_ps(vtable, _mm_castps_si128(_mm256_castps256_ps128(vn0)))); - const __m128i vl0_hi = _mm_castps_si128(_mm_permutevar_ps(vtable, _mm_castps_si128(vn0_hi))); - __m128i ve1_hi = _mm_slli_epi32(_mm_castps_si128(vn1_hi), 21); - const __m128i vl1_lo = _mm_castps_si128(_mm_permutevar_ps(vtable, _mm_castps_si128(_mm256_castps256_ps128(vn1)))); - const __m128i vl1_hi = _mm_castps_si128(_mm_permutevar_ps(vtable, _mm_castps_si128(vn1_hi))); - __m128i ve2_hi = _mm_slli_epi32(_mm_castps_si128(vn2_hi), 21); - const __m128i vl2_lo = _mm_castps_si128(_mm_permutevar_ps(vtable, _mm_castps_si128(_mm256_castps256_ps128(vn2)))); - const __m128i vl2_hi = _mm_castps_si128(_mm_permutevar_ps(vtable, _mm_castps_si128(vn2_hi))); - __m128i ve3_hi = _mm_slli_epi32(_mm_castps_si128(vn3_hi), 21); - const __m128i vl3_lo = _mm_castps_si128(_mm_permutevar_ps(vtable, _mm_castps_si128(_mm256_castps256_ps128(vn3)))); - const __m128i vl3_hi = _mm_castps_si128(_mm_permutevar_ps(vtable, _mm_castps_si128(vn3_hi))); - __m128i ve4_hi = _mm_slli_epi32(_mm_castps_si128(vn4_hi), 21); - const __m128i vl4_lo = _mm_castps_si128(_mm_permutevar_ps(vtable, _mm_castps_si128(_mm256_castps256_ps128(vn4)))); - const __m128i vl4_hi = _mm_castps_si128(_mm_permutevar_ps(vtable, _mm_castps_si128(vn4_hi))); - __m128i ve5_hi = _mm_slli_epi32(_mm_castps_si128(vn5_hi), 21); - const __m128i vl5_lo = _mm_castps_si128(_mm_permutevar_ps(vtable, _mm_castps_si128(_mm256_castps256_ps128(vn5)))); - const __m128i vl5_hi = _mm_castps_si128(_mm_permutevar_ps(vtable, _mm_castps_si128(vn5_hi))); - - const __m128 vs0_lo = _mm_castsi128_ps(_mm_add_epi32(ve0_lo, vl0_lo)); - const __m128 vs0_hi = _mm_castsi128_ps(_mm_add_epi32(ve0_hi, vl0_hi)); - const __m128 vs1_lo = _mm_castsi128_ps(_mm_add_epi32(ve1_lo, vl1_lo)); - const __m128 vs1_hi = _mm_castsi128_ps(_mm_add_epi32(ve1_hi, vl1_hi)); - const __m128 vs2_lo = _mm_castsi128_ps(_mm_add_epi32(ve2_lo, vl2_lo)); - const __m128 vs2_hi = _mm_castsi128_ps(_mm_add_epi32(ve2_hi, vl2_hi)); - const __m128 vs3_lo = _mm_castsi128_ps(_mm_add_epi32(ve3_lo, vl3_lo)); - const __m128 vs3_hi = _mm_castsi128_ps(_mm_add_epi32(ve3_hi, vl3_hi)); - const __m128 vs4_lo = _mm_castsi128_ps(_mm_add_epi32(ve4_lo, vl4_lo)); - const __m128 vs4_hi = _mm_castsi128_ps(_mm_add_epi32(ve4_hi, vl4_hi)); - const __m128 vs5_lo = _mm_castsi128_ps(_mm_add_epi32(ve5_lo, vl5_lo)); - const __m128 vs5_hi = _mm_castsi128_ps(_mm_add_epi32(ve5_hi, vl5_hi)); - - const __m256 vs0 = _mm256_insertf128_ps(_mm256_castps128_ps256(vs0_lo), vs0_hi, 1); - const __m256 vs1 = _mm256_insertf128_ps(_mm256_castps128_ps256(vs1_lo), vs1_hi, 1); - const __m256 vs2 = _mm256_insertf128_ps(_mm256_castps128_ps256(vs2_lo), vs2_hi, 1); - const __m256 vs3 = _mm256_insertf128_ps(_mm256_castps128_ps256(vs3_lo), vs3_hi, 1); - const __m256 vs4 = _mm256_insertf128_ps(_mm256_castps128_ps256(vs4_lo), vs4_hi, 1); - const __m256 vs5 = _mm256_insertf128_ps(_mm256_castps128_ps256(vs5_lo), vs5_hi, 1); - - vn0 = _mm256_sub_ps(vn0, vmagic_bias); - vn1 = _mm256_sub_ps(vn1, vmagic_bias); - vn2 = _mm256_sub_ps(vn2, vmagic_bias); - vn3 = _mm256_sub_ps(vn3, vmagic_bias); - vn4 = _mm256_sub_ps(vn4, vmagic_bias); - vn5 = _mm256_sub_ps(vn5, vmagic_bias); + // Cap the inputs to this value as `tanh(x)` will always be `+/-1.0f` beyond + // this point. This value is chosen as the first floating point number as of + // which the interpolation returns 1.0f. + const __m256 vmax_x = _mm256_set1_ps(params->avx_rational_9_6.max_abs_x); + const __m256 vmin_x = _mm256_set1_ps(-params->avx_rational_9_6.max_abs_x); + + // The monomial coefficients of the numerator polynomial (odd). + const __m256 valpha_1 = _mm256_set1_ps(params->avx_rational_9_6.alpha_1); + const __m256 valpha_3 = _mm256_set1_ps(params->avx_rational_9_6.alpha_3); + const __m256 valpha_5 = _mm256_set1_ps(params->avx_rational_9_6.alpha_5); + const __m256 valpha_7 = _mm256_set1_ps(params->avx_rational_9_6.alpha_7); + const __m256 valpha_9 = _mm256_set1_ps(params->avx_rational_9_6.alpha_9); + + // The monomial coefficients of the denominator polynomial (even). + const __m256 vbeta_0 = _mm256_set1_ps(params->avx_rational_9_6.beta_0); + const __m256 vbeta_2 = _mm256_set1_ps(params->avx_rational_9_6.beta_2); + const __m256 vbeta_4 = _mm256_set1_ps(params->avx_rational_9_6.beta_4); + const __m256 vbeta_6 = _mm256_set1_ps(params->avx_rational_9_6.beta_6); - const __m256 vt0 = _mm256_add_ps(_mm256_mul_ps(vn0, vminus_ln2), vz0); - const __m256 vt1 = _mm256_add_ps(_mm256_mul_ps(vn1, vminus_ln2), vz1); - const __m256 vt2 = _mm256_add_ps(_mm256_mul_ps(vn2, vminus_ln2), vz2); - const __m256 vt3 = _mm256_add_ps(_mm256_mul_ps(vn3, vminus_ln2), vz3); - const __m256 vt4 = _mm256_add_ps(_mm256_mul_ps(vn4, vminus_ln2), vz4); - const __m256 vt5 = _mm256_add_ps(_mm256_mul_ps(vn5, vminus_ln2), vz5); - __m256 vp0 = _mm256_add_ps(_mm256_mul_ps(vc4, vt0), vc3); - __m256 vp1 = _mm256_add_ps(_mm256_mul_ps(vc4, vt1), vc3); - __m256 vp2 = _mm256_add_ps(_mm256_mul_ps(vc4, vt2), vc3); - __m256 vp3 = _mm256_add_ps(_mm256_mul_ps(vc4, vt3), vc3); - __m256 vp4 = _mm256_add_ps(_mm256_mul_ps(vc4, vt4), vc3); - __m256 vp5 = _mm256_add_ps(_mm256_mul_ps(vc4, vt5), vc3); - vp0 = _mm256_add_ps(_mm256_mul_ps(vp0, vt0), vc2); - vp1 = _mm256_add_ps(_mm256_mul_ps(vp1, vt1), vc2); - vp2 = _mm256_add_ps(_mm256_mul_ps(vp2, vt2), vc2); - vp3 = _mm256_add_ps(_mm256_mul_ps(vp3, vt3), vc2); - vp4 = _mm256_add_ps(_mm256_mul_ps(vp4, vt4), vc2); - vp5 = _mm256_add_ps(_mm256_mul_ps(vp5, vt5), vc2); - vp0 = _mm256_mul_ps(vp0, vt0); - vp1 = _mm256_mul_ps(vp1, vt1); - vp2 = _mm256_mul_ps(vp2, vt2); - vp3 = _mm256_mul_ps(vp3, vt3); - vp4 = _mm256_mul_ps(vp4, vt4); - vp5 = _mm256_mul_ps(vp5, vt5); - - const __m256 vts0 = _mm256_mul_ps(vt0, vs0); - const __m256 vsmo0 = _mm256_add_ps(vs0, vminus_one); - const __m256 vts1 = _mm256_mul_ps(vt1, vs1); - const __m256 vsmo1 = _mm256_add_ps(vs1, vminus_one); - const __m256 vts2 = _mm256_mul_ps(vt2, vs2); - const __m256 vsmo2 = _mm256_add_ps(vs2, vminus_one); - const __m256 vts3 = _mm256_mul_ps(vt3, vs3); - const __m256 vsmo3 = _mm256_add_ps(vs3, vminus_one); - const __m256 vts4 = _mm256_mul_ps(vt4, vs4); - const __m256 vsmo4 = _mm256_add_ps(vs4, vminus_one); - const __m256 vts5 = _mm256_mul_ps(vt5, vs5); - const __m256 vsmo5 = _mm256_add_ps(vs5, vminus_one); - - vp0 = _mm256_add_ps(_mm256_mul_ps(vp0, vts0), vts0); - vp1 = _mm256_add_ps(_mm256_mul_ps(vp1, vts1), vts1); - vp2 = _mm256_add_ps(_mm256_mul_ps(vp2, vts2), vts2); - vp3 = _mm256_add_ps(_mm256_mul_ps(vp3, vts3), vts3); - vp4 = _mm256_add_ps(_mm256_mul_ps(vp4, vts4), vts4); - vp5 = _mm256_add_ps(_mm256_mul_ps(vp5, vts5), vts5); - const __m256 vemo0 = _mm256_add_ps(_mm256_mul_ps(vp0, vtwo), vsmo0); - const __m256 vemo1 = _mm256_add_ps(_mm256_mul_ps(vp1, vtwo), vsmo1); - const __m256 vemo2 = _mm256_add_ps(_mm256_mul_ps(vp2, vtwo), vsmo2); - const __m256 vemo3 = _mm256_add_ps(_mm256_mul_ps(vp3, vtwo), vsmo3); - const __m256 vemo4 = _mm256_add_ps(_mm256_mul_ps(vp4, vtwo), vsmo4); - const __m256 vemo5 = _mm256_add_ps(_mm256_mul_ps(vp5, vtwo), vsmo5); - - const __m256 vepo0 = _mm256_add_ps(vemo0, vtwo); - const __m256 vepo1 = _mm256_add_ps(vemo1, vtwo); - const __m256 vepo2 = _mm256_add_ps(vemo2, vtwo); - const __m256 vepo3 = _mm256_add_ps(vemo3, vtwo); - const __m256 vepo4 = _mm256_add_ps(vemo4, vtwo); - const __m256 vepo5 = _mm256_add_ps(vemo5, vtwo); - __m256 vy0 = _mm256_div_ps(vemo0, vepo0); - __m256 vy1 = _mm256_div_ps(vemo1, vepo1); - __m256 vy2 = _mm256_div_ps(vemo2, vepo2); - __m256 vy3 = _mm256_div_ps(vemo3, vepo3); - __m256 vy4 = _mm256_div_ps(vemo4, vepo4); - __m256 vy5 = _mm256_div_ps(vemo5, vepo5); - - vy0 = _mm256_xor_ps(vy0, vinvsignx0); - vy1 = _mm256_xor_ps(vy1, vinvsignx1); - vy2 = _mm256_xor_ps(vy2, vinvsignx2); - vy3 = _mm256_xor_ps(vy3, vinvsignx3); - vy4 = _mm256_xor_ps(vy4, vinvsignx4); - vy5 = _mm256_xor_ps(vy5, vinvsignx5); + for (; batch >= 16 * sizeof(float); batch -= 16 * sizeof(float)) { + __m256 vx_0 = _mm256_loadu_ps(input); + __m256 vx_1 = _mm256_loadu_ps(input + 8); + input += 16; - _mm256_storeu_ps(output, vy0); - _mm256_storeu_ps(output + 8, vy1); - _mm256_storeu_ps(output + 16, vy2); - _mm256_storeu_ps(output + 24, vy3); - _mm256_storeu_ps(output + 32, vy4); - _mm256_storeu_ps(output + 40, vy5); - output += 48; + // Clamp the inputs to the interpolation range. + vx_0 = _mm256_min_ps(vmax_x, vx_0); + vx_1 = _mm256_min_ps(vmax_x, vx_1); + vx_0 = _mm256_max_ps(vmin_x, vx_0); + vx_1 = _mm256_max_ps(vmin_x, vx_1); + + // Since the polynomials are odd/even, we need x^2. + const __m256 vx2_0 = _mm256_mul_ps(vx_0, vx_0); + const __m256 vx2_1 = _mm256_mul_ps(vx_1, vx_1); + + // Evaluate the numerator polynomial p. + __m256 vp_0 = _mm256_add_ps(_mm256_mul_ps(vx2_0, valpha_9), valpha_7); + __m256 vp_1 = _mm256_add_ps(_mm256_mul_ps(vx2_1, valpha_9), valpha_7); + vp_0 = _mm256_add_ps(_mm256_mul_ps(vx2_0, vp_0), valpha_5); + vp_1 = _mm256_add_ps(_mm256_mul_ps(vx2_1, vp_1), valpha_5); + vp_0 = _mm256_add_ps(_mm256_mul_ps(vx2_0, vp_0), valpha_3); + vp_1 = _mm256_add_ps(_mm256_mul_ps(vx2_1, vp_1), valpha_3); + vp_0 = _mm256_add_ps(_mm256_mul_ps(vx2_0, vp_0), valpha_1); + vp_1 = _mm256_add_ps(_mm256_mul_ps(vx2_1, vp_1), valpha_1); + vp_0 = _mm256_mul_ps(vx_0, vp_0); + vp_1 = _mm256_mul_ps(vx_1, vp_1); + + // Evaluate the denominator polynomial q. + __m256 vq_0 = _mm256_add_ps(_mm256_mul_ps(vx2_0, vbeta_6), vbeta_4); + __m256 vq_1 = _mm256_add_ps(_mm256_mul_ps(vx2_1, vbeta_6), vbeta_4); + vq_0 = _mm256_add_ps(_mm256_mul_ps(vx2_0, vq_0), vbeta_2); + vq_1 = _mm256_add_ps(_mm256_mul_ps(vx2_1, vq_1), vbeta_2); + vq_0 = _mm256_add_ps(_mm256_mul_ps(vx2_0, vq_0), vbeta_0); + vq_1 = _mm256_add_ps(_mm256_mul_ps(vx2_1, vq_1), vbeta_0); + + // Divide the numerator by the denominator. + const __m256 vy_0 = _mm256_div_ps(vp_0, vq_0); + const __m256 vy_1 = _mm256_div_ps(vp_1, vq_1); + + + _mm256_storeu_ps(output, vy_0); + _mm256_storeu_ps(output + 8, vy_1); + output += 16; } for (; batch >= 8 * sizeof(float); batch -= 8 * sizeof(float)) { - const __m256 vx = _mm256_loadu_ps(input); + __m256 vx = _mm256_loadu_ps(input); input += 8; - __m256 vz = _mm256_or_ps(vx, vsign_mask); - - const __m256 vinvsignx = _mm256_xor_ps(vx, vz); - vz = _mm256_max_ps(vsat_cutoff, vz); - - __m256 vn = _mm256_add_ps(_mm256_mul_ps(vz, vlog2e), vmagic_bias); - - const __m128 vn_hi = _mm256_extractf128_ps(vn, 1); - __m128i ve_lo = _mm_slli_epi32(_mm_castps_si128(_mm256_castps256_ps128(vn)), 21); - __m128i ve_hi = _mm_slli_epi32(_mm_castps_si128(vn_hi), 21); + // Clamp the inputs to the interpolation range. + vx = _mm256_min_ps(vmax_x, vx); + vx = _mm256_max_ps(vmin_x, vx); - const __m128i vl_lo = _mm_castps_si128(_mm_permutevar_ps(vtable, _mm_castps_si128(_mm256_castps256_ps128(vn)))); - const __m128i vl_hi = _mm_castps_si128(_mm_permutevar_ps(vtable, _mm_castps_si128(vn_hi))); - - const __m128 vs_lo = _mm_castsi128_ps(_mm_add_epi32(ve_lo, vl_lo)); - const __m128 vs_hi = _mm_castsi128_ps(_mm_add_epi32(ve_hi, vl_hi)); - const __m256 vs = _mm256_insertf128_ps(_mm256_castps128_ps256(vs_lo), vs_hi, 1); - - vn = _mm256_sub_ps(vn, vmagic_bias); - - const __m256 vt = _mm256_add_ps(_mm256_mul_ps(vn, vminus_ln2), vz); - - __m256 vp = _mm256_add_ps(_mm256_mul_ps(vc4, vt), vc3); - vp = _mm256_add_ps(_mm256_mul_ps(vp, vt), vc2); - vp = _mm256_mul_ps(vp, vt); + // Since the polynomials are odd/even, we need x^2. + const __m256 vx2 = _mm256_mul_ps(vx, vx); - const __m256 vts = _mm256_mul_ps(vt, vs); - const __m256 vsmo = _mm256_add_ps(vs, vminus_one); - vp = _mm256_add_ps(_mm256_mul_ps(vp, vts), vts); - const __m256 vemo = _mm256_add_ps(_mm256_mul_ps(vp, vtwo), vsmo); + // Evaluate the numerator polynomial p. + __m256 vp = _mm256_add_ps(_mm256_mul_ps(vx2, valpha_9), valpha_7); + vp = _mm256_add_ps(_mm256_mul_ps(vx2, vp), valpha_5); + vp = _mm256_add_ps(_mm256_mul_ps(vx2, vp), valpha_3); + vp = _mm256_add_ps(_mm256_mul_ps(vx2, vp), valpha_1); + vp = _mm256_mul_ps(vx, vp); - const __m256 vepo = _mm256_add_ps(vemo, vtwo); - __m256 vy = _mm256_div_ps(vemo, vepo); + // Evaluate the denominator polynomial q. + __m256 vq = _mm256_add_ps(_mm256_mul_ps(vx2, vbeta_6), vbeta_4); + vq = _mm256_add_ps(_mm256_mul_ps(vx2, vq), vbeta_2); + vq = _mm256_add_ps(_mm256_mul_ps(vx2, vq), vbeta_0); - vy = _mm256_xor_ps(vy, vinvsignx); + // Divide the numerator by the denominator. + const __m256 vy = _mm256_div_ps(vp, vq); _mm256_storeu_ps(output, vy); output += 8; @@ -6298,45 +6161,32 @@ void xnn_f32_vtanh_ukernel__avx_expm1minus_rr1_lut4_p4h2ts_perm_div_u48( if XNN_UNLIKELY(batch != 0) { assert(batch >= 1 * sizeof(float)); assert(batch <= 7 * sizeof(float)); - const __m256i vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) ¶ms->avx_expm1minus_rr1_lut4_p4h2_perm.mask_table[7] - batch)); - - const __m256 vx = _mm256_maskload_ps(input, vmask); - - __m256 vz = _mm256_or_ps(vx, vsign_mask); - - const __m256 vinvsignx = _mm256_xor_ps(vx, vz); - vz = _mm256_max_ps(vsat_cutoff, vz); - - __m256 vn = _mm256_add_ps(_mm256_mul_ps(vz, vlog2e), vmagic_bias); - - const __m128 vn_hi = _mm256_extractf128_ps(vn, 1); - __m128i ve_lo = _mm_slli_epi32(_mm_castps_si128(_mm256_castps256_ps128(vn)), 21); - __m128i ve_hi = _mm_slli_epi32(_mm_castps_si128(vn_hi), 21); - - const __m128i vl_lo = _mm_castps_si128(_mm_permutevar_ps(vtable, _mm_castps_si128(_mm256_castps256_ps128(vn)))); - const __m128i vl_hi = _mm_castps_si128(_mm_permutevar_ps(vtable, _mm_castps_si128(vn_hi))); - - const __m128 vs_lo = _mm_castsi128_ps(_mm_add_epi32(ve_lo, vl_lo)); - const __m128 vs_hi = _mm_castsi128_ps(_mm_add_epi32(ve_hi, vl_hi)); - const __m256 vs = _mm256_insertf128_ps(_mm256_castps128_ps256(vs_lo), vs_hi, 1); + const __m256i vmask = _mm256_loadu_si256( + (const __m256i*) ((uintptr_t) ¶ms->avx_rational_9_6.mask_table[7] - batch)); - vn = _mm256_sub_ps(vn, vmagic_bias); + __m256 vx = _mm256_maskload_ps(input, vmask); - const __m256 vt = _mm256_add_ps(_mm256_mul_ps(vn, vminus_ln2), vz); + // Clamp the inputs to the interpolation range. + vx = _mm256_min_ps(vmax_x, vx); + vx = _mm256_max_ps(vmin_x, vx); - __m256 vp = _mm256_add_ps(_mm256_mul_ps(vc4, vt), vc3); - vp = _mm256_add_ps(_mm256_mul_ps(vp, vt), vc2); - vp = _mm256_mul_ps(vp, vt); + // Since the polynomials are odd/even, we need x^2. + const __m256 vx2 = _mm256_mul_ps(vx, vx); - const __m256 vts = _mm256_mul_ps(vt, vs); - const __m256 vsmo = _mm256_add_ps(vs, vminus_one); - vp = _mm256_add_ps(_mm256_mul_ps(vp, vts), vts); - const __m256 vemo = _mm256_add_ps(_mm256_mul_ps(vp, vtwo), vsmo); + // Evaluate the numerator polynomial p. + __m256 vp = _mm256_add_ps(_mm256_mul_ps(vx2, valpha_9), valpha_7); + vp = _mm256_add_ps(_mm256_mul_ps(vx2, vp), valpha_5); + vp = _mm256_add_ps(_mm256_mul_ps(vx2, vp), valpha_3); + vp = _mm256_add_ps(_mm256_mul_ps(vx2, vp), valpha_1); + vp = _mm256_mul_ps(vx, vp); - const __m256 vepo = _mm256_add_ps(vemo, vtwo); - __m256 vy = _mm256_div_ps(vemo, vepo); + // Evaluate the denominator polynomial q. + __m256 vq = _mm256_add_ps(_mm256_mul_ps(vx2, vbeta_6), vbeta_4); + vq = _mm256_add_ps(_mm256_mul_ps(vx2, vq), vbeta_2); + vq = _mm256_add_ps(_mm256_mul_ps(vx2, vq), vbeta_0); - vy = _mm256_xor_ps(vy, vinvsignx); + // Divide the numerator by the denominator. + const __m256 vy = _mm256_div_ps(vp, vq); __m128 vy_lo = _mm256_castps256_ps128(vy); if (batch & (4 * sizeof(float))) { diff --git a/src/amalgam/gen/avx2.c b/src/amalgam/gen/avx2.c index 9c07d583a35..40e70e67382 100644 --- a/src/amalgam/gen/avx2.c +++ b/src/amalgam/gen/avx2.c @@ -4,7 +4,6 @@ // LICENSE file in the root directory of this source tree. #include -#include #include #include @@ -17,7 +16,6 @@ #include #include #include -#include #include #include #include @@ -2972,223 +2970,6 @@ void xnn_f32_vsigmoid_ukernel__avx2_rr1_p5_div_u40( } } -void xnn_f32_vtanh_ukernel__avx2_expm1minus_rr1_lut4_p4h3ts_perm_div_u32( - size_t batch, - const float* input, - float* output, - const union xnn_f32_tanh_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS -{ - assert(batch != 0); - assert(batch % sizeof(float) == 0); - assert(input != NULL); - assert(output != NULL); - - const __m256 vsign_mask = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.sign_mask); - const __m256 vsat_cutoff = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.sat_cutoff); - const __m256 vlog2e = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.log2e); - const __m256 vmagic_bias = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.magic_bias); - const __m256 vtable = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.table); - const __m256 vminus_ln2 = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.minus_ln2); - const __m256 vc4 = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.c4); - const __m256 vc3 = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.c3); - const __m256 vc2 = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.c2); - const __m256 vtwo = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.two); - const __m256 vminus_one = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.minus_one); - - for (; batch >= 32 * sizeof(float); batch -= 32 * sizeof(float)) { - const __m256 vx0 = _mm256_loadu_ps(input); - const __m256 vx1 = _mm256_loadu_ps(input + 8); - const __m256 vx2 = _mm256_loadu_ps(input + 16); - const __m256 vx3 = _mm256_loadu_ps(input + 24); - input += 32; - - __m256 vz0 = _mm256_or_ps(vx0, vsign_mask); - __m256 vz1 = _mm256_or_ps(vx1, vsign_mask); - __m256 vz2 = _mm256_or_ps(vx2, vsign_mask); - __m256 vz3 = _mm256_or_ps(vx3, vsign_mask); - - const __m256 vinvsignx0 = _mm256_xor_ps(vx0, vz0); - vz0 = _mm256_max_ps(vsat_cutoff, vz0); - const __m256 vinvsignx1 = _mm256_xor_ps(vx1, vz1); - vz1 = _mm256_max_ps(vsat_cutoff, vz1); - const __m256 vinvsignx2 = _mm256_xor_ps(vx2, vz2); - vz2 = _mm256_max_ps(vsat_cutoff, vz2); - const __m256 vinvsignx3 = _mm256_xor_ps(vx3, vz3); - vz3 = _mm256_max_ps(vsat_cutoff, vz3); - - __m256 vn0 = _mm256_fmadd_ps(vz0, vlog2e, vmagic_bias); - __m256 vn1 = _mm256_fmadd_ps(vz1, vlog2e, vmagic_bias); - __m256 vn2 = _mm256_fmadd_ps(vz2, vlog2e, vmagic_bias); - __m256 vn3 = _mm256_fmadd_ps(vz3, vlog2e, vmagic_bias); - - const __m256i ve0 = _mm256_slli_epi32(_mm256_castps_si256(vn0), 21); - const __m256i ve1 = _mm256_slli_epi32(_mm256_castps_si256(vn1), 21); - const __m256i ve2 = _mm256_slli_epi32(_mm256_castps_si256(vn2), 21); - const __m256i ve3 = _mm256_slli_epi32(_mm256_castps_si256(vn3), 21); - - const __m256i vl0 = _mm256_castps_si256(_mm256_permutevar_ps(vtable, _mm256_castps_si256(vn0))); - const __m256i vl1 = _mm256_castps_si256(_mm256_permutevar_ps(vtable, _mm256_castps_si256(vn1))); - const __m256i vl2 = _mm256_castps_si256(_mm256_permutevar_ps(vtable, _mm256_castps_si256(vn2))); - const __m256i vl3 = _mm256_castps_si256(_mm256_permutevar_ps(vtable, _mm256_castps_si256(vn3))); - - const __m256 vs0 = _mm256_castsi256_ps(_mm256_add_epi32(vl0, ve0)); - const __m256 vs1 = _mm256_castsi256_ps(_mm256_add_epi32(vl1, ve1)); - const __m256 vs2 = _mm256_castsi256_ps(_mm256_add_epi32(vl2, ve2)); - const __m256 vs3 = _mm256_castsi256_ps(_mm256_add_epi32(vl3, ve3)); - - vn0 = _mm256_sub_ps(vn0, vmagic_bias); - vn1 = _mm256_sub_ps(vn1, vmagic_bias); - vn2 = _mm256_sub_ps(vn2, vmagic_bias); - vn3 = _mm256_sub_ps(vn3, vmagic_bias); - - const __m256 vt0 = _mm256_fmadd_ps(vn0, vminus_ln2, vz0); - const __m256 vt1 = _mm256_fmadd_ps(vn1, vminus_ln2, vz1); - const __m256 vt2 = _mm256_fmadd_ps(vn2, vminus_ln2, vz2); - const __m256 vt3 = _mm256_fmadd_ps(vn3, vminus_ln2, vz3); - - __m256 vp0 = vc4; - __m256 vp1 = vc4; - __m256 vp2 = vc4; - __m256 vp3 = vc4; - vp0 = _mm256_fmadd_ps(vp0, vt0, vc3); - vp1 = _mm256_fmadd_ps(vp1, vt1, vc3); - vp2 = _mm256_fmadd_ps(vp2, vt2, vc3); - vp3 = _mm256_fmadd_ps(vp3, vt3, vc3); - vp0 = _mm256_fmadd_ps(vp0, vt0, vc2); - vp1 = _mm256_fmadd_ps(vp1, vt1, vc2); - vp2 = _mm256_fmadd_ps(vp2, vt2, vc2); - vp3 = _mm256_fmadd_ps(vp3, vt3, vc2); - vp0 = _mm256_fmadd_ps(vp0, vt0, vtwo); - vp1 = _mm256_fmadd_ps(vp1, vt1, vtwo); - vp2 = _mm256_fmadd_ps(vp2, vt2, vtwo); - vp3 = _mm256_fmadd_ps(vp3, vt3, vtwo); - - const __m256 vts0 = _mm256_mul_ps(vt0, vs0); - const __m256 vsmo0 = _mm256_add_ps(vs0, vminus_one); - const __m256 vts1 = _mm256_mul_ps(vt1, vs1); - const __m256 vsmo1 = _mm256_add_ps(vs1, vminus_one); - const __m256 vts2 = _mm256_mul_ps(vt2, vs2); - const __m256 vsmo2 = _mm256_add_ps(vs2, vminus_one); - const __m256 vts3 = _mm256_mul_ps(vt3, vs3); - const __m256 vsmo3 = _mm256_add_ps(vs3, vminus_one); - const __m256 vemo0 = _mm256_fmadd_ps(vp0, vts0, vsmo0); - const __m256 vemo1 = _mm256_fmadd_ps(vp1, vts1, vsmo1); - const __m256 vemo2 = _mm256_fmadd_ps(vp2, vts2, vsmo2); - const __m256 vemo3 = _mm256_fmadd_ps(vp3, vts3, vsmo3); - const __m256 vepo0 = _mm256_add_ps(vemo0, vtwo); - const __m256 vepo1 = _mm256_add_ps(vemo1, vtwo); - const __m256 vepo2 = _mm256_add_ps(vemo2, vtwo); - const __m256 vepo3 = _mm256_add_ps(vemo3, vtwo); - - __m256 vy0 = _mm256_div_ps(vemo0, vepo0); - __m256 vy1 = _mm256_div_ps(vemo1, vepo1); - __m256 vy2 = _mm256_div_ps(vemo2, vepo2); - __m256 vy3 = _mm256_div_ps(vemo3, vepo3); - - vy0 = _mm256_xor_ps(vy0, vinvsignx0); - vy1 = _mm256_xor_ps(vy1, vinvsignx1); - vy2 = _mm256_xor_ps(vy2, vinvsignx2); - vy3 = _mm256_xor_ps(vy3, vinvsignx3); - - _mm256_storeu_ps(output, vy0); - _mm256_storeu_ps(output + 8, vy1); - _mm256_storeu_ps(output + 16, vy2); - _mm256_storeu_ps(output + 24, vy3); - output += 32; - } - for (; batch >= 8 * sizeof(float); batch -= 8 * sizeof(float)) { - const __m256 vx = _mm256_loadu_ps(input); - input += 8; - - __m256 vz = _mm256_or_ps(vx, vsign_mask); - - const __m256 vinvsignx = _mm256_xor_ps(vx, vz); - vz = _mm256_max_ps(vsat_cutoff, vz); - - __m256 vn = _mm256_fmadd_ps(vz, vlog2e, vmagic_bias); - - const __m256i ve = _mm256_slli_epi32(_mm256_castps_si256(vn), 21); - - const __m256i vl = _mm256_castps_si256(_mm256_permutevar_ps(vtable, _mm256_castps_si256(vn))); - - const __m256 vs = _mm256_castsi256_ps(_mm256_add_epi32(vl, ve)); - - vn = _mm256_sub_ps(vn, vmagic_bias); - - const __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vz); - - __m256 vp = vc4; - vp = _mm256_fmadd_ps(vp, vt, vc3); - vp = _mm256_fmadd_ps(vp, vt, vc2); - vp = _mm256_fmadd_ps(vp, vt, vtwo); - - const __m256 vts = _mm256_mul_ps(vt, vs); - const __m256 vsmo = _mm256_add_ps(vs, vminus_one); - const __m256 vemo = _mm256_fmadd_ps(vp, vts, vsmo); - const __m256 vepo = _mm256_add_ps(vemo, vtwo); - - __m256 vy = _mm256_div_ps(vemo, vepo); - - vy = _mm256_xor_ps(vy, vinvsignx); - - _mm256_storeu_ps(output, vy); - output += 8; - } - if XNN_UNLIKELY(batch != 0) { - assert(batch >= 1 * sizeof(float)); - assert(batch <= 7 * sizeof(float)); - const __m256i vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) ¶ms->avx_expm1minus_rr1_lut4_p4h3_perm.mask_table[7] - batch)); - - const __m256 vx = _mm256_maskload_ps(input, vmask); - - __m256 vz = _mm256_or_ps(vx, vsign_mask); - - const __m256 vinvsignx = _mm256_xor_ps(vx, vz); - vz = _mm256_max_ps(vsat_cutoff, vz); - - __m256 vn = _mm256_fmadd_ps(vz, vlog2e, vmagic_bias); - - const __m256i ve = _mm256_slli_epi32(_mm256_castps_si256(vn), 21); - - const __m256i vl = _mm256_castps_si256(_mm256_permutevar_ps(vtable, _mm256_castps_si256(vn))); - - const __m256 vs = _mm256_castsi256_ps(_mm256_add_epi32(vl, ve)); - - vn = _mm256_sub_ps(vn, vmagic_bias); - - const __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vz); - - __m256 vp = vc4; - vp = _mm256_fmadd_ps(vp, vt, vc3); - vp = _mm256_fmadd_ps(vp, vt, vc2); - vp = _mm256_fmadd_ps(vp, vt, vtwo); - - const __m256 vts = _mm256_mul_ps(vt, vs); - const __m256 vsmo = _mm256_add_ps(vs, vminus_one); - const __m256 vemo = _mm256_fmadd_ps(vp, vts, vsmo); - const __m256 vepo = _mm256_add_ps(vemo, vtwo); - - __m256 vy = _mm256_div_ps(vemo, vepo); - - vy = _mm256_xor_ps(vy, vinvsignx); - - __m128 vy_lo = _mm256_castps256_ps128(vy); - if (batch & (4 * sizeof(float))) { - _mm_storeu_ps(output, vy_lo); - vy_lo = _mm256_extractf128_ps(vy, 1); - output += 4; - } - if (batch & (2 * sizeof(float))) { - _mm_storel_pi((__m64*) output, vy_lo); - vy_lo = _mm_movehl_ps(vy_lo, vy_lo); - output += 2; - } - if (batch & (1 * sizeof(float))) { - _mm_store_ss(output, vy_lo); - } - } -} - void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx2( size_t mr, size_t nc, diff --git a/src/amalgam/gen/avx512f.c b/src/amalgam/gen/avx512f.c index 4cdbcb61f6a..647974fe9b8 100644 --- a/src/amalgam/gen/avx512f.c +++ b/src/amalgam/gen/avx512f.c @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -3922,6 +3923,108 @@ void xnn_f32_vsqrt_ukernel__avx512f_rsqrt_u16( } } +void xnn_f32_vtanh_ukernel__avx512f_rational_9_6_nr_u16( + size_t batch, + const float* input, + float* output, + const union xnn_f32_tanh_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + // Cap the inputs to this value as `tanh(x)` will always be `+/-1.0f` beyond + // this point. This value is chosen as the first floating point number as of + // which the interpolation returns 1.0f. + const __m512 vmax_x = _mm512_set1_ps(params->avx512_rational_9_6.max_abs_x); + const __m512 vmin_x = _mm512_set1_ps(-params->avx512_rational_9_6.max_abs_x); + + // The monomial coefficients of the numerator polynomial (odd). + const __m512 valpha_1 = _mm512_set1_ps(params->avx512_rational_9_6.alpha_1); + const __m512 valpha_3 = _mm512_set1_ps(params->avx512_rational_9_6.alpha_3); + const __m512 valpha_5 = _mm512_set1_ps(params->avx512_rational_9_6.alpha_5); + const __m512 valpha_7 = _mm512_set1_ps(params->avx512_rational_9_6.alpha_7); + const __m512 valpha_9 = _mm512_set1_ps(params->avx512_rational_9_6.alpha_9); + + // The monomial coefficients of the denominator polynomial (even). + const __m512 vbeta_0 = _mm512_set1_ps(params->avx512_rational_9_6.beta_0); + const __m512 vbeta_2 = _mm512_set1_ps(params->avx512_rational_9_6.beta_2); + const __m512 vbeta_4 = _mm512_set1_ps(params->avx512_rational_9_6.beta_4); + const __m512 vbeta_6 = _mm512_set1_ps(params->avx512_rational_9_6.beta_6); + + // Constant needed for the Newton-Raphson iteration of the reciprocal. + const __m512 vtwo = _mm512_set1_ps(params->avx512_rational_9_6.two); + + for (; batch >= 16 * sizeof(float); batch -= 16 * sizeof(float)) { + __m512 vx = _mm512_loadu_ps(input); + input += 16; + + // Clamp the inputs to the interpolation range. + vx = _mm512_min_ps(vmax_x, vx); + vx = _mm512_max_ps(vmin_x, vx); + + // Since the polynomials are odd/even, we need x^2. + const __m512 vx2 = _mm512_mul_ps(vx, vx); + + // Evaluate the numerator polynomial p. + __m512 vp = _mm512_fmadd_ps(vx2, valpha_9, valpha_7); + vp = _mm512_fmadd_ps(vx2, vp, valpha_5); + vp = _mm512_fmadd_ps(vx2, vp, valpha_3); + vp = _mm512_fmadd_ps(vx2, vp, valpha_1); + vp = _mm512_mul_ps(vx, vp); + + // Evaluate the denominator polynomial q. + __m512 vq = _mm512_fmadd_ps(vx2, vbeta_6, vbeta_4); + vq = _mm512_fmadd_ps(vx2, vq, vbeta_2); + vq = _mm512_fmadd_ps(vx2, vq, vbeta_0); + + // Divide the numerator by the denominator. + const __m512 vt0 = _mm512_rcp14_ps(vq); + const __m512 vt1 = _mm512_mul_ps(vt0, _mm512_fnmadd_ps(vt0, vq, vtwo)); + const __m512 vy = _mm512_mul_ps(vp, vt1); + + _mm512_storeu_ps(output, vy); + output += 16; + } + if XNN_UNLIKELY(batch != 0) { + assert(batch >= 1 * sizeof(float)); + assert(batch <= 15 * sizeof(float)); + + // Prepare mask for valid 32-bit elements (depends on batch). + batch >>= XNN_LOG2_SIZEOF_FLOAT; + const __mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); + + __m512 vx = _mm512_maskz_loadu_ps(vmask, input); + + // Clamp the inputs to the interpolation range. + vx = _mm512_min_ps(vmax_x, vx); + vx = _mm512_max_ps(vmin_x, vx); + + // Since the polynomials are odd/even, we need x^2. + const __m512 vx2 = _mm512_mul_ps(vx, vx); + + // Evaluate the numerator polynomial p. + __m512 vp = _mm512_fmadd_ps(vx2, valpha_9, valpha_7); + vp = _mm512_fmadd_ps(vx2, vp, valpha_5); + vp = _mm512_fmadd_ps(vx2, vp, valpha_3); + vp = _mm512_fmadd_ps(vx2, vp, valpha_1); + vp = _mm512_mul_ps(vx, vp); + + // Evaluate the denominator polynomial q. + __m512 vq = _mm512_fmadd_ps(vx2, vbeta_6, vbeta_4); + vq = _mm512_fmadd_ps(vx2, vq, vbeta_2); + vq = _mm512_fmadd_ps(vx2, vq, vbeta_0); + + // Divide the numerator by the denominator. + const __m512 vt0 = _mm512_rcp14_ps(vq); + const __m512 vt1 = _mm512_mul_ps(vt0, _mm512_fnmadd_ps(vt0, vq, vtwo)); + const __m512 vy = _mm512_mul_ps(vp, vt1); + + _mm512_mask_storeu_ps(output, vmask, vy); + } +} + void xnn_f32_vabs_ukernel__avx512f_u16( size_t batch, const float* input, diff --git a/src/amalgam/gen/avx512skx.c b/src/amalgam/gen/avx512skx.c index fc47b8b0177..5198496fced 100644 --- a/src/amalgam/gen/avx512skx.c +++ b/src/amalgam/gen/avx512skx.c @@ -4,8 +4,6 @@ // LICENSE file in the root directory of this source tree. #include -#include -#include #include @@ -16,13 +14,11 @@ #include #include #include -#include #include #include #include #include #include -#include void xnn_f16_f32_vcvt_ukernel__avx512skx_u16( @@ -1283,191 +1279,6 @@ void xnn_f32_qu8_vcvt_ukernel__avx512skx_u128( } } -void xnn_f32_vtanh_ukernel__avx512skx_expm1minus_rr1_lut4_p4h3ts_perm_div_u64( - size_t batch, - const float* input, - float* output, - const union xnn_f32_tanh_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS -{ - assert(batch != 0); - assert(batch % sizeof(float) == 0); - assert(input != NULL); - assert(output != NULL); - - const __m512 vsat_cutoff = _mm512_set1_ps(params->avx512_expm1minus_rr1_lut4_p4h3_perm.sat_cutoff); - const __m512 vminus_log2e = _mm512_set1_ps(params->avx512_expm1minus_rr1_lut4_p4h3_perm.minus_log2e); - const __m512 vmagic_bias = _mm512_set1_ps(params->avx512_expm1minus_rr1_lut4_p4h3_perm.magic_bias); - const __m512 vtable = _mm512_load_ps(params->avx512_expm1minus_rr1_lut4_p4h3_perm.table); - const __m512 vln2 = _mm512_set1_ps(params->avx512_expm1minus_rr1_lut4_p4h3_perm.ln2); - const __m512 vc4 = _mm512_set1_ps(params->avx512_expm1minus_rr1_lut4_p4h3_perm.c4); - const __m512 vc3 = _mm512_set1_ps(params->avx512_expm1minus_rr1_lut4_p4h3_perm.c3); - const __m512 vc2 = _mm512_set1_ps(params->avx512_expm1minus_rr1_lut4_p4h3_perm.c2); - const __m512 vminus_two = _mm512_set1_ps(params->avx512_expm1minus_rr1_lut4_p4h3_perm.minus_two); - const __m512 vone = _mm512_set1_ps(params->avx512_expm1minus_rr1_lut4_p4h3_perm.one); - const __m512i vsign_mask = _mm512_set1_epi32((int) params->avx512_expm1minus_rr1_lut4_p4h3_perm.sign_mask); - - for (; batch >= 64 * sizeof(float); batch -= 64 * sizeof(float)) { - const __m512 vx0 = _mm512_loadu_ps(input); - const __m512 vx1 = _mm512_loadu_ps(input + 16); - const __m512 vx2 = _mm512_loadu_ps(input + 32); - const __m512 vx3 = _mm512_loadu_ps(input + 48); - input += 64; - - const __m512 vz0 = _mm512_range_ps(vsat_cutoff, vx0, 0xA); - const __m512 vz1 = _mm512_range_ps(vsat_cutoff, vx1, 0xA); - const __m512 vz2 = _mm512_range_ps(vsat_cutoff, vx2, 0xA); - const __m512 vz3 = _mm512_range_ps(vsat_cutoff, vx3, 0xA); - __m512 vn0 = _mm512_fmadd_ps(vz0, vminus_log2e, vmagic_bias); - __m512 vn1 = _mm512_fmadd_ps(vz1, vminus_log2e, vmagic_bias); - __m512 vn2 = _mm512_fmadd_ps(vz2, vminus_log2e, vmagic_bias); - __m512 vn3 = _mm512_fmadd_ps(vz3, vminus_log2e, vmagic_bias); - - const __m512i ve0 = _mm512_slli_epi32(_mm512_castps_si512(vn0), 21); - const __m512i ve1 = _mm512_slli_epi32(_mm512_castps_si512(vn1), 21); - const __m512i ve2 = _mm512_slli_epi32(_mm512_castps_si512(vn2), 21); - const __m512i ve3 = _mm512_slli_epi32(_mm512_castps_si512(vn3), 21); - - const __m512i vl0 = _mm512_castps_si512(_mm512_permutevar_ps(vtable, _mm512_castps_si512(vn0))); - const __m512i vl1 = _mm512_castps_si512(_mm512_permutevar_ps(vtable, _mm512_castps_si512(vn1))); - const __m512i vl2 = _mm512_castps_si512(_mm512_permutevar_ps(vtable, _mm512_castps_si512(vn2))); - const __m512i vl3 = _mm512_castps_si512(_mm512_permutevar_ps(vtable, _mm512_castps_si512(vn3))); - - const __m512 vs0 = _mm512_castsi512_ps(_mm512_add_epi32(vl0, ve0)); - vn0 = _mm512_sub_ps(vn0, vmagic_bias); - const __m512 vs1 = _mm512_castsi512_ps(_mm512_add_epi32(vl1, ve1)); - vn1 = _mm512_sub_ps(vn1, vmagic_bias); - const __m512 vs2 = _mm512_castsi512_ps(_mm512_add_epi32(vl2, ve2)); - vn2 = _mm512_sub_ps(vn2, vmagic_bias); - const __m512 vs3 = _mm512_castsi512_ps(_mm512_add_epi32(vl3, ve3)); - vn3 = _mm512_sub_ps(vn3, vmagic_bias); - - const __m512 vt0 = _mm512_fmadd_ps(vn0, vln2, vz0); - const __m512 vt1 = _mm512_fmadd_ps(vn1, vln2, vz1); - const __m512 vt2 = _mm512_fmadd_ps(vn2, vln2, vz2); - const __m512 vt3 = _mm512_fmadd_ps(vn3, vln2, vz3); - - __m512 vp0 = vc4; - __m512 vp1 = vc4; - __m512 vp2 = vc4; - __m512 vp3 = vc4; - vp0 = _mm512_fmadd_ps(vp0, vt0, vc3); - vp1 = _mm512_fmadd_ps(vp1, vt1, vc3); - vp2 = _mm512_fmadd_ps(vp2, vt2, vc3); - vp3 = _mm512_fmadd_ps(vp3, vt3, vc3); - vp0 = _mm512_fmadd_ps(vp0, vt0, vc2); - vp1 = _mm512_fmadd_ps(vp1, vt1, vc2); - vp2 = _mm512_fmadd_ps(vp2, vt2, vc2); - vp3 = _mm512_fmadd_ps(vp3, vt3, vc2); - vp0 = _mm512_fmadd_ps(vp0, vt0, vminus_two); - vp1 = _mm512_fmadd_ps(vp1, vt1, vminus_two); - vp2 = _mm512_fmadd_ps(vp2, vt2, vminus_two); - vp3 = _mm512_fmadd_ps(vp3, vt3, vminus_two); - - const __m512 vts0 = _mm512_mul_ps(vt0, vs0); - const __m512 vsmo0 = _mm512_sub_ps(vs0, vone); - const __m512 vts1 = _mm512_mul_ps(vt1, vs1); - const __m512 vsmo1 = _mm512_sub_ps(vs1, vone); - const __m512 vts2 = _mm512_mul_ps(vt2, vs2); - const __m512 vsmo2 = _mm512_sub_ps(vs2, vone); - const __m512 vts3 = _mm512_mul_ps(vt3, vs3); - const __m512 vsmo3 = _mm512_sub_ps(vs3, vone); - const __m512 vemo0 = _mm512_fmadd_ps(vp0, vts0, vsmo0); - const __m512 vemo1 = _mm512_fmadd_ps(vp1, vts1, vsmo1); - const __m512 vemo2 = _mm512_fmadd_ps(vp2, vts2, vsmo2); - const __m512 vemo3 = _mm512_fmadd_ps(vp3, vts3, vsmo3); - const __m512 vepo0 = _mm512_sub_ps(vemo0, vminus_two); - const __m512 vepo1 = _mm512_sub_ps(vemo1, vminus_two); - const __m512 vepo2 = _mm512_sub_ps(vemo2, vminus_two); - const __m512 vepo3 = _mm512_sub_ps(vemo3, vminus_two); - - __m512 vy0 = _mm512_div_ps(vemo0, vepo0); - __m512 vy1 = _mm512_div_ps(vemo1, vepo1); - __m512 vy2 = _mm512_div_ps(vemo2, vepo2); - __m512 vy3 = _mm512_div_ps(vemo3, vepo3); - vy0 = _mm512_castsi512_ps(_mm512_ternarylogic_epi32(_mm512_castps_si512(vy0), _mm512_castps_si512(vx0), vsign_mask, 0xD8)); - vy1 = _mm512_castsi512_ps(_mm512_ternarylogic_epi32(_mm512_castps_si512(vy1), _mm512_castps_si512(vx1), vsign_mask, 0xD8)); - vy2 = _mm512_castsi512_ps(_mm512_ternarylogic_epi32(_mm512_castps_si512(vy2), _mm512_castps_si512(vx2), vsign_mask, 0xD8)); - vy3 = _mm512_castsi512_ps(_mm512_ternarylogic_epi32(_mm512_castps_si512(vy3), _mm512_castps_si512(vx3), vsign_mask, 0xD8)); - - _mm512_storeu_ps(output, vy0); - _mm512_storeu_ps(output + 16, vy1); - _mm512_storeu_ps(output + 32, vy2); - _mm512_storeu_ps(output + 48, vy3); - output += 64; - } - for (; batch >= 16 * sizeof(float); batch -= 16 * sizeof(float)) { - const __m512 vx = _mm512_loadu_ps(input); - input += 16; - - const __m512 vz = _mm512_range_ps(vsat_cutoff, vx, 0xA); - __m512 vn = _mm512_fmadd_ps(vz, vminus_log2e, vmagic_bias); - - const __m512i ve = _mm512_slli_epi32(_mm512_castps_si512(vn), 21); - - const __m512i vl = _mm512_castps_si512(_mm512_permutevar_ps(vtable, _mm512_castps_si512(vn))); - - const __m512 vs = _mm512_castsi512_ps(_mm512_add_epi32(vl, ve)); - - vn = _mm512_sub_ps(vn, vmagic_bias); - - const __m512 vt = _mm512_fmadd_ps(vn, vln2, vz); - - __m512 vp = vc4; - vp = _mm512_fmadd_ps(vp, vt, vc3); - vp = _mm512_fmadd_ps(vp, vt, vc2); - vp = _mm512_fmadd_ps(vp, vt, vminus_two); - - const __m512 vts = _mm512_mul_ps(vt, vs); - const __m512 vsmo = _mm512_sub_ps(vs, vone); - const __m512 vemo = _mm512_fmadd_ps(vp, vts, vsmo); - const __m512 vepo = _mm512_sub_ps(vemo, vminus_two); - - __m512 vy = _mm512_div_ps(vemo, vepo); - vy = _mm512_castsi512_ps(_mm512_ternarylogic_epi32(_mm512_castps_si512(vy), _mm512_castps_si512(vx), vsign_mask, 0xD8)); - - _mm512_storeu_ps(output, vy); - output += 16; - } - if XNN_UNLIKELY(batch != 0) { - assert(batch >= 1 * sizeof(float)); - assert(batch <= 15 * sizeof(float)); - - // Prepare mask for valid 32-bit elements (depends on batch). - batch >>= XNN_LOG2_SIZEOF_FLOAT; - const __mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); - - const __m512 vx = _mm512_maskz_loadu_ps(vmask, input); - - const __m512 vz = _mm512_range_ps(vsat_cutoff, vx, 0xA); - __m512 vn = _mm512_fmadd_ps(vz, vminus_log2e, vmagic_bias); - - const __m512i ve = _mm512_slli_epi32(_mm512_castps_si512(vn), 21); - - const __m512i vl = _mm512_castps_si512(_mm512_permutevar_ps(vtable, _mm512_castps_si512(vn))); - - const __m512 vs = _mm512_castsi512_ps(_mm512_add_epi32(vl, ve)); - - vn = _mm512_sub_ps(vn, vmagic_bias); - - const __m512 vt = _mm512_fmadd_ps(vn, vln2, vz); - - __m512 vp = vc4; - vp = _mm512_fmadd_ps(vp, vt, vc3); - vp = _mm512_fmadd_ps(vp, vt, vc2); - vp = _mm512_fmadd_ps(vp, vt, vminus_two); - - const __m512 vts = _mm512_mul_ps(vt, vs); - const __m512 vsmo = _mm512_sub_ps(vs, vone); - const __m512 vemo = _mm512_fmadd_ps(vp, vts, vsmo); - const __m512 vepo = _mm512_sub_ps(vemo, vminus_two); - - __m512 vy = _mm512_div_ps(vemo, vepo); - vy = _mm512_castsi512_ps(_mm512_ternarylogic_epi32(_mm512_castps_si512(vy), _mm512_castps_si512(vx), vsign_mask, 0xD8)); - - _mm512_mask_storeu_ps(output, vmask, vy); - } -} - void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx512skx( size_t mr, size_t nc, diff --git a/src/amalgam/gen/fma3.c b/src/amalgam/gen/fma3.c index a39943a155f..60b9f583900 100644 --- a/src/amalgam/gen/fma3.c +++ b/src/amalgam/gen/fma3.c @@ -5556,217 +5556,105 @@ void xnn_f32_vsqrt_ukernel__fma3_rsqrt_u16( } } -void xnn_f32_vtanh_ukernel__fma3_expm1minus_rr1_lut4_p4h3ts_perm_div_u40( +void xnn_f32_vtanh_ukernel__fma3_rational_9_6_div_u16( size_t batch, const float* input, float* output, - const union xnn_f32_tanh_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS + const union xnn_f32_tanh_params params[restrict XNN_MIN_ELEMENTS(1)]) { assert(batch != 0); assert(batch % sizeof(float) == 0); assert(input != NULL); assert(output != NULL); - const __m256 vsign_mask = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.sign_mask); - const __m256 vsat_cutoff = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.sat_cutoff); - const __m256 vlog2e = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.log2e); - const __m256 vmagic_bias = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.magic_bias); - const __m128 vtable = _mm_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.table); - const __m256 vminus_ln2 = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.minus_ln2); - const __m256 vc4 = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.c4); - const __m256 vc3 = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.c3); - const __m256 vc2 = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.c2); - const __m256 vtwo = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.two); - const __m256 vminus_one = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.minus_one); - - for (; batch >= 40 * sizeof(float); batch -= 40 * sizeof(float)) { - const __m256 vx0 = _mm256_loadu_ps(input); - const __m256 vx1 = _mm256_loadu_ps(input + 8); - const __m256 vx2 = _mm256_loadu_ps(input + 16); - const __m256 vx3 = _mm256_loadu_ps(input + 24); - const __m256 vx4 = _mm256_loadu_ps(input + 32); - input += 40; - - __m256 vz0 = _mm256_or_ps(vx0, vsign_mask); - __m256 vz1 = _mm256_or_ps(vx1, vsign_mask); - __m256 vz2 = _mm256_or_ps(vx2, vsign_mask); - __m256 vz3 = _mm256_or_ps(vx3, vsign_mask); - __m256 vz4 = _mm256_or_ps(vx4, vsign_mask); - - const __m256 vinvsignx0 = _mm256_xor_ps(vx0, vz0); - vz0 = _mm256_max_ps(vsat_cutoff, vz0); - const __m256 vinvsignx1 = _mm256_xor_ps(vx1, vz1); - vz1 = _mm256_max_ps(vsat_cutoff, vz1); - const __m256 vinvsignx2 = _mm256_xor_ps(vx2, vz2); - vz2 = _mm256_max_ps(vsat_cutoff, vz2); - const __m256 vinvsignx3 = _mm256_xor_ps(vx3, vz3); - vz3 = _mm256_max_ps(vsat_cutoff, vz3); - const __m256 vinvsignx4 = _mm256_xor_ps(vx4, vz4); - vz4 = _mm256_max_ps(vsat_cutoff, vz4); - - __m256 vn0 = _mm256_fmadd_ps(vz0, vlog2e, vmagic_bias); - __m256 vn1 = _mm256_fmadd_ps(vz1, vlog2e, vmagic_bias); - __m256 vn2 = _mm256_fmadd_ps(vz2, vlog2e, vmagic_bias); - __m256 vn3 = _mm256_fmadd_ps(vz3, vlog2e, vmagic_bias); - __m256 vn4 = _mm256_fmadd_ps(vz4, vlog2e, vmagic_bias); - - const __m128 vn0_hi = _mm256_extractf128_ps(vn0, 1); - __m128i ve0_lo = _mm_slli_epi32(_mm_castps_si128(_mm256_castps256_ps128(vn0)), 21); - const __m128 vn1_hi = _mm256_extractf128_ps(vn1, 1); - __m128i ve1_lo = _mm_slli_epi32(_mm_castps_si128(_mm256_castps256_ps128(vn1)), 21); - const __m128 vn2_hi = _mm256_extractf128_ps(vn2, 1); - __m128i ve2_lo = _mm_slli_epi32(_mm_castps_si128(_mm256_castps256_ps128(vn2)), 21); - const __m128 vn3_hi = _mm256_extractf128_ps(vn3, 1); - __m128i ve3_lo = _mm_slli_epi32(_mm_castps_si128(_mm256_castps256_ps128(vn3)), 21); - const __m128 vn4_hi = _mm256_extractf128_ps(vn4, 1); - __m128i ve4_lo = _mm_slli_epi32(_mm_castps_si128(_mm256_castps256_ps128(vn4)), 21); - - __m128i ve0_hi = _mm_slli_epi32(_mm_castps_si128(vn0_hi), 21); - const __m128i vl0_lo = _mm_castps_si128(_mm_permutevar_ps(vtable, _mm_castps_si128(_mm256_castps256_ps128(vn0)))); - const __m128i vl0_hi = _mm_castps_si128(_mm_permutevar_ps(vtable, _mm_castps_si128(vn0_hi))); - __m128i ve1_hi = _mm_slli_epi32(_mm_castps_si128(vn1_hi), 21); - const __m128i vl1_lo = _mm_castps_si128(_mm_permutevar_ps(vtable, _mm_castps_si128(_mm256_castps256_ps128(vn1)))); - const __m128i vl1_hi = _mm_castps_si128(_mm_permutevar_ps(vtable, _mm_castps_si128(vn1_hi))); - __m128i ve2_hi = _mm_slli_epi32(_mm_castps_si128(vn2_hi), 21); - const __m128i vl2_lo = _mm_castps_si128(_mm_permutevar_ps(vtable, _mm_castps_si128(_mm256_castps256_ps128(vn2)))); - const __m128i vl2_hi = _mm_castps_si128(_mm_permutevar_ps(vtable, _mm_castps_si128(vn2_hi))); - __m128i ve3_hi = _mm_slli_epi32(_mm_castps_si128(vn3_hi), 21); - const __m128i vl3_lo = _mm_castps_si128(_mm_permutevar_ps(vtable, _mm_castps_si128(_mm256_castps256_ps128(vn3)))); - const __m128i vl3_hi = _mm_castps_si128(_mm_permutevar_ps(vtable, _mm_castps_si128(vn3_hi))); - __m128i ve4_hi = _mm_slli_epi32(_mm_castps_si128(vn4_hi), 21); - const __m128i vl4_lo = _mm_castps_si128(_mm_permutevar_ps(vtable, _mm_castps_si128(_mm256_castps256_ps128(vn4)))); - const __m128i vl4_hi = _mm_castps_si128(_mm_permutevar_ps(vtable, _mm_castps_si128(vn4_hi))); - - const __m128 vs0_lo = _mm_castsi128_ps(_mm_add_epi32(ve0_lo, vl0_lo)); - const __m128 vs0_hi = _mm_castsi128_ps(_mm_add_epi32(ve0_hi, vl0_hi)); - const __m128 vs1_lo = _mm_castsi128_ps(_mm_add_epi32(ve1_lo, vl1_lo)); - const __m128 vs1_hi = _mm_castsi128_ps(_mm_add_epi32(ve1_hi, vl1_hi)); - const __m128 vs2_lo = _mm_castsi128_ps(_mm_add_epi32(ve2_lo, vl2_lo)); - const __m128 vs2_hi = _mm_castsi128_ps(_mm_add_epi32(ve2_hi, vl2_hi)); - const __m128 vs3_lo = _mm_castsi128_ps(_mm_add_epi32(ve3_lo, vl3_lo)); - const __m128 vs3_hi = _mm_castsi128_ps(_mm_add_epi32(ve3_hi, vl3_hi)); - const __m128 vs4_lo = _mm_castsi128_ps(_mm_add_epi32(ve4_lo, vl4_lo)); - const __m128 vs4_hi = _mm_castsi128_ps(_mm_add_epi32(ve4_hi, vl4_hi)); - - const __m256 vs0 = _mm256_insertf128_ps(_mm256_castps128_ps256(vs0_lo), vs0_hi, 1); - const __m256 vs1 = _mm256_insertf128_ps(_mm256_castps128_ps256(vs1_lo), vs1_hi, 1); - const __m256 vs2 = _mm256_insertf128_ps(_mm256_castps128_ps256(vs2_lo), vs2_hi, 1); - const __m256 vs3 = _mm256_insertf128_ps(_mm256_castps128_ps256(vs3_lo), vs3_hi, 1); - const __m256 vs4 = _mm256_insertf128_ps(_mm256_castps128_ps256(vs4_lo), vs4_hi, 1); - - vn0 = _mm256_sub_ps(vn0, vmagic_bias); - vn1 = _mm256_sub_ps(vn1, vmagic_bias); - vn2 = _mm256_sub_ps(vn2, vmagic_bias); - vn3 = _mm256_sub_ps(vn3, vmagic_bias); - vn4 = _mm256_sub_ps(vn4, vmagic_bias); - - const __m256 vt0 = _mm256_fmadd_ps(vn0, vminus_ln2, vz0); - const __m256 vt1 = _mm256_fmadd_ps(vn1, vminus_ln2, vz1); - const __m256 vt2 = _mm256_fmadd_ps(vn2, vminus_ln2, vz2); - const __m256 vt3 = _mm256_fmadd_ps(vn3, vminus_ln2, vz3); - const __m256 vt4 = _mm256_fmadd_ps(vn4, vminus_ln2, vz4); - - __m256 vp0 = vc4; - __m256 vp1 = vc4; - __m256 vp2 = vc4; - __m256 vp3 = vc4; - __m256 vp4 = vc4; - vp0 = _mm256_fmadd_ps(vp0, vt0, vc3); - vp1 = _mm256_fmadd_ps(vp1, vt1, vc3); - vp2 = _mm256_fmadd_ps(vp2, vt2, vc3); - vp3 = _mm256_fmadd_ps(vp3, vt3, vc3); - vp4 = _mm256_fmadd_ps(vp4, vt4, vc3); - vp0 = _mm256_fmadd_ps(vp0, vt0, vc2); - vp1 = _mm256_fmadd_ps(vp1, vt1, vc2); - vp2 = _mm256_fmadd_ps(vp2, vt2, vc2); - vp3 = _mm256_fmadd_ps(vp3, vt3, vc2); - vp4 = _mm256_fmadd_ps(vp4, vt4, vc2); - vp0 = _mm256_fmadd_ps(vp0, vt0, vtwo); - vp1 = _mm256_fmadd_ps(vp1, vt1, vtwo); - vp2 = _mm256_fmadd_ps(vp2, vt2, vtwo); - vp3 = _mm256_fmadd_ps(vp3, vt3, vtwo); - vp4 = _mm256_fmadd_ps(vp4, vt4, vtwo); - - const __m256 vts0 = _mm256_mul_ps(vt0, vs0); - const __m256 vsmo0 = _mm256_add_ps(vs0, vminus_one); - const __m256 vts1 = _mm256_mul_ps(vt1, vs1); - const __m256 vsmo1 = _mm256_add_ps(vs1, vminus_one); - const __m256 vts2 = _mm256_mul_ps(vt2, vs2); - const __m256 vsmo2 = _mm256_add_ps(vs2, vminus_one); - const __m256 vts3 = _mm256_mul_ps(vt3, vs3); - const __m256 vsmo3 = _mm256_add_ps(vs3, vminus_one); - const __m256 vts4 = _mm256_mul_ps(vt4, vs4); - const __m256 vsmo4 = _mm256_add_ps(vs4, vminus_one); - - const __m256 vemo0 = _mm256_fmadd_ps(vp0, vts0, vsmo0); - const __m256 vemo1 = _mm256_fmadd_ps(vp1, vts1, vsmo1); - const __m256 vemo2 = _mm256_fmadd_ps(vp2, vts2, vsmo2); - const __m256 vemo3 = _mm256_fmadd_ps(vp3, vts3, vsmo3); - const __m256 vemo4 = _mm256_fmadd_ps(vp4, vts4, vsmo4); - - const __m256 vepo0 = _mm256_add_ps(vemo0, vtwo); - const __m256 vepo1 = _mm256_add_ps(vemo1, vtwo); - const __m256 vepo2 = _mm256_add_ps(vemo2, vtwo); - const __m256 vepo3 = _mm256_add_ps(vemo3, vtwo); - const __m256 vepo4 = _mm256_add_ps(vemo4, vtwo); - __m256 vy0 = _mm256_div_ps(vemo0, vepo0); - __m256 vy1 = _mm256_div_ps(vemo1, vepo1); - __m256 vy2 = _mm256_div_ps(vemo2, vepo2); - __m256 vy3 = _mm256_div_ps(vemo3, vepo3); - __m256 vy4 = _mm256_div_ps(vemo4, vepo4); - - vy0 = _mm256_xor_ps(vy0, vinvsignx0); - vy1 = _mm256_xor_ps(vy1, vinvsignx1); - vy2 = _mm256_xor_ps(vy2, vinvsignx2); - vy3 = _mm256_xor_ps(vy3, vinvsignx3); - vy4 = _mm256_xor_ps(vy4, vinvsignx4); + // Cap the inputs to this value as `tanh(x)` will always be `+/-1.0f` beyond + // this point. This value is chosen as the first floating point number as of + // which the interpolation returns 1.0f. + const __m256 vmax_x = _mm256_set1_ps(params->fma3_rational_9_6.max_abs_x); + const __m256 vmin_x = _mm256_set1_ps(-params->fma3_rational_9_6.max_abs_x); + + // The monomial coefficients of the numerator polynomial (odd). + const __m256 valpha_1 = _mm256_set1_ps(params->fma3_rational_9_6.alpha_1); + const __m256 valpha_3 = _mm256_set1_ps(params->fma3_rational_9_6.alpha_3); + const __m256 valpha_5 = _mm256_set1_ps(params->fma3_rational_9_6.alpha_5); + const __m256 valpha_7 = _mm256_set1_ps(params->fma3_rational_9_6.alpha_7); + const __m256 valpha_9 = _mm256_set1_ps(params->fma3_rational_9_6.alpha_9); + + // The monomial coefficients of the denominator polynomial (even). + const __m256 vbeta_0 = _mm256_set1_ps(params->fma3_rational_9_6.beta_0); + const __m256 vbeta_2 = _mm256_set1_ps(params->fma3_rational_9_6.beta_2); + const __m256 vbeta_4 = _mm256_set1_ps(params->fma3_rational_9_6.beta_4); + const __m256 vbeta_6 = _mm256_set1_ps(params->fma3_rational_9_6.beta_6); - _mm256_storeu_ps(output, vy0); - _mm256_storeu_ps(output + 8, vy1); - _mm256_storeu_ps(output + 16, vy2); - _mm256_storeu_ps(output + 24, vy3); - _mm256_storeu_ps(output + 32, vy4); - output += 40; + + for (; batch >= 16 * sizeof(float); batch -= 16 * sizeof(float)) { + __m256 vx_0 = _mm256_loadu_ps(input); + __m256 vx_1 = _mm256_loadu_ps(input + 8); + input += 16; + + // Clamp the inputs to the interpolation range. + vx_0 = _mm256_min_ps(vmax_x, vx_0); + vx_1 = _mm256_min_ps(vmax_x, vx_1); + vx_0 = _mm256_max_ps(vmin_x, vx_0); + vx_1 = _mm256_max_ps(vmin_x, vx_1); + + // Since the polynomials are odd/even, we need x^2. + const __m256 vx2_0 = _mm256_mul_ps(vx_0, vx_0); + const __m256 vx2_1 = _mm256_mul_ps(vx_1, vx_1); + + // Evaluate the numerator polynomial p. + __m256 vp_0 = _mm256_fmadd_ps(vx2_0, valpha_9, valpha_7); + __m256 vp_1 = _mm256_fmadd_ps(vx2_1, valpha_9, valpha_7); + vp_0 = _mm256_fmadd_ps(vx2_0, vp_0, valpha_5); + vp_1 = _mm256_fmadd_ps(vx2_1, vp_1, valpha_5); + vp_0 = _mm256_fmadd_ps(vx2_0, vp_0, valpha_3); + vp_1 = _mm256_fmadd_ps(vx2_1, vp_1, valpha_3); + vp_0 = _mm256_fmadd_ps(vx2_0, vp_0, valpha_1); + vp_1 = _mm256_fmadd_ps(vx2_1, vp_1, valpha_1); + vp_0 = _mm256_mul_ps(vx_0, vp_0); + vp_1 = _mm256_mul_ps(vx_1, vp_1); + + // Evaluate the denominator polynomial q. + __m256 vq_0 = _mm256_fmadd_ps(vx2_0, vbeta_6, vbeta_4); + __m256 vq_1 = _mm256_fmadd_ps(vx2_1, vbeta_6, vbeta_4); + vq_0 = _mm256_fmadd_ps(vx2_0, vq_0, vbeta_2); + vq_1 = _mm256_fmadd_ps(vx2_1, vq_1, vbeta_2); + vq_0 = _mm256_fmadd_ps(vx2_0, vq_0, vbeta_0); + vq_1 = _mm256_fmadd_ps(vx2_1, vq_1, vbeta_0); + + // Divide the numerator by the denominator. + const __m256 vy_0 = _mm256_div_ps(vp_0, vq_0); + const __m256 vy_1 = _mm256_div_ps(vp_1, vq_1); + + _mm256_storeu_ps(output, vy_0); + _mm256_storeu_ps(output + 8, vy_1); + output += 16; } for (; batch >= 8 * sizeof(float); batch -= 8 * sizeof(float)) { - const __m256 vx = _mm256_loadu_ps(input); + __m256 vx = _mm256_loadu_ps(input); input += 8; - __m256 vz = _mm256_or_ps(vx, vsign_mask); - - const __m256 vinvsignx = _mm256_xor_ps(vx, vz); - vz = _mm256_max_ps(vsat_cutoff, vz); - - __m256 vn = _mm256_fmadd_ps(vz, vlog2e, vmagic_bias); + // Clamp the inputs to the interpolation range. + vx = _mm256_min_ps(vmax_x, vx); + vx = _mm256_max_ps(vmin_x, vx); - const __m128 vn_hi = _mm256_extractf128_ps(vn, 1); - __m128i ve_lo = _mm_slli_epi32(_mm_castps_si128(_mm256_castps256_ps128(vn)), 21); - __m128i ve_hi = _mm_slli_epi32(_mm_castps_si128(vn_hi), 21); + // Since the polynomials are odd/even, we need x^2. + const __m256 vx2 = _mm256_mul_ps(vx, vx); - const __m128i vl_lo = _mm_castps_si128(_mm_permutevar_ps(vtable, _mm_castps_si128(_mm256_castps256_ps128(vn)))); - const __m128i vl_hi = _mm_castps_si128(_mm_permutevar_ps(vtable, _mm_castps_si128(vn_hi))); + // Evaluate the numerator polynomial p. + __m256 vp = _mm256_fmadd_ps(vx2, valpha_9, valpha_7); + vp = _mm256_fmadd_ps(vx2, vp, valpha_5); + vp = _mm256_fmadd_ps(vx2, vp, valpha_3); + vp = _mm256_fmadd_ps(vx2, vp, valpha_1); + vp = _mm256_mul_ps(vx, vp); - const __m128 vs_lo = _mm_castsi128_ps(_mm_add_epi32(ve_lo, vl_lo)); - const __m128 vs_hi = _mm_castsi128_ps(_mm_add_epi32(ve_hi, vl_hi)); - const __m256 vs = _mm256_insertf128_ps(_mm256_castps128_ps256(vs_lo), vs_hi, 1); + // Evaluate the denominator polynomial q. + __m256 vq = _mm256_fmadd_ps(vx2, vbeta_6, vbeta_4); + vq = _mm256_fmadd_ps(vx2, vq, vbeta_2); + vq = _mm256_fmadd_ps(vx2, vq, vbeta_0); - vn = _mm256_sub_ps(vn, vmagic_bias); - - const __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vz); - - __m256 vp = vc4; - vp = _mm256_fmadd_ps(vp, vt, vc3); - vp = _mm256_fmadd_ps(vp, vt, vc2); - vp = _mm256_fmadd_ps(vp, vt, vtwo); - - const __m256 vts = _mm256_mul_ps(vt, vs); - const __m256 vsmo = _mm256_add_ps(vs, vminus_one); - const __m256 vemo = _mm256_fmadd_ps(vp, vts, vsmo); - - const __m256 vepo = _mm256_add_ps(vemo, vtwo); - __m256 vy = _mm256_div_ps(vemo, vepo); - - vy = _mm256_xor_ps(vy, vinvsignx); + // Divide the numerator by the denominator. + const __m256 vy = _mm256_div_ps(vp, vq); _mm256_storeu_ps(output, vy); output += 8; @@ -5774,45 +5662,32 @@ void xnn_f32_vtanh_ukernel__fma3_expm1minus_rr1_lut4_p4h3ts_perm_div_u40( if XNN_UNLIKELY(batch != 0) { assert(batch >= 1 * sizeof(float)); assert(batch <= 7 * sizeof(float)); - const __m256i vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) ¶ms->avx_expm1minus_rr1_lut4_p4h3_perm.mask_table[7] - batch)); - - const __m256 vx = _mm256_maskload_ps(input, vmask); - - __m256 vz = _mm256_or_ps(vx, vsign_mask); - - const __m256 vinvsignx = _mm256_xor_ps(vx, vz); - vz = _mm256_max_ps(vsat_cutoff, vz); - - __m256 vn = _mm256_fmadd_ps(vz, vlog2e, vmagic_bias); - - const __m128 vn_hi = _mm256_extractf128_ps(vn, 1); - __m128i ve_lo = _mm_slli_epi32(_mm_castps_si128(_mm256_castps256_ps128(vn)), 21); - __m128i ve_hi = _mm_slli_epi32(_mm_castps_si128(vn_hi), 21); - - const __m128i vl_lo = _mm_castps_si128(_mm_permutevar_ps(vtable, _mm_castps_si128(_mm256_castps256_ps128(vn)))); - const __m128i vl_hi = _mm_castps_si128(_mm_permutevar_ps(vtable, _mm_castps_si128(vn_hi))); - - const __m128 vs_lo = _mm_castsi128_ps(_mm_add_epi32(ve_lo, vl_lo)); - const __m128 vs_hi = _mm_castsi128_ps(_mm_add_epi32(ve_hi, vl_hi)); - const __m256 vs = _mm256_insertf128_ps(_mm256_castps128_ps256(vs_lo), vs_hi, 1); + const __m256i vmask = _mm256_loadu_si256( + (const __m256i*) ((uintptr_t) ¶ms->fma3_rational_9_6.mask_table[7] - batch)); - vn = _mm256_sub_ps(vn, vmagic_bias); + __m256 vx = _mm256_maskload_ps(input, vmask); - const __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vz); + // Clamp the inputs to the interpolation range. + vx = _mm256_min_ps(vmax_x, vx); + vx = _mm256_max_ps(vmin_x, vx); - __m256 vp = vc4; - vp = _mm256_fmadd_ps(vp, vt, vc3); - vp = _mm256_fmadd_ps(vp, vt, vc2); - vp = _mm256_fmadd_ps(vp, vt, vtwo); + // Since the polynomials are odd/even, we need x^2. + const __m256 vx2 = _mm256_mul_ps(vx, vx); - const __m256 vts = _mm256_mul_ps(vt, vs); - const __m256 vsmo = _mm256_add_ps(vs, vminus_one); - const __m256 vemo = _mm256_fmadd_ps(vp, vts, vsmo); + // Evaluate the numerator polynomial p. + __m256 vp = _mm256_fmadd_ps(vx2, valpha_9, valpha_7); + vp = _mm256_fmadd_ps(vx2, vp, valpha_5); + vp = _mm256_fmadd_ps(vx2, vp, valpha_3); + vp = _mm256_fmadd_ps(vx2, vp, valpha_1); + vp = _mm256_mul_ps(vx, vp); - const __m256 vepo = _mm256_add_ps(vemo, vtwo); - __m256 vy = _mm256_div_ps(vemo, vepo); + // Evaluate the denominator polynomial q. + __m256 vq = _mm256_fmadd_ps(vx2, vbeta_6, vbeta_4); + vq = _mm256_fmadd_ps(vx2, vq, vbeta_2); + vq = _mm256_fmadd_ps(vx2, vq, vbeta_0); - vy = _mm256_xor_ps(vy, vinvsignx); + // Divide the numerator by the denominator. + const __m256 vy = _mm256_div_ps(vp, vq); __m128 vy_lo = _mm256_castps256_ps128(vy); if (batch & (4 * sizeof(float))) { diff --git a/src/amalgam/gen/scalar.c b/src/amalgam/gen/scalar.c index e5536d46eb9..252573b4e7b 100644 --- a/src/amalgam/gen/scalar.c +++ b/src/amalgam/gen/scalar.c @@ -14206,176 +14206,65 @@ void xnn_f32_vsqrt_ukernel__scalar_sqrt_u1( } } -extern XNN_INTERNAL const uint32_t xnn_table_exp2minus_k_over_8[8]; - -void xnn_f32_vtanh_ukernel__scalar_expm1minus_rr1_lut8_p4h3ts_div_u4( +void xnn_f32_vtanh_ukernel__scalar_rational_9_6_u1( size_t batch, const float* input, float* output, - const union xnn_f32_tanh_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS + const union xnn_f32_tanh_params params[restrict XNN_MIN_ELEMENTS(1)]) { assert(batch != 0); assert(batch % sizeof(float) == 0); assert(input != NULL); assert(output != NULL); - const float vsat_cutoff = params->scalar_expm1minus_rr1_lut8_p4h3.sat_cutoff; - const float vminus_log2e = params->scalar_expm1minus_rr1_lut8_p4h3.minus_log2e; - const float vmagic_bias = params->scalar_expm1minus_rr1_lut8_p4h3.magic_bias; - const uint32_t vindex_mask = UINT32_C(0x7); - const float vln2 = params->scalar_expm1minus_rr1_lut8_p4h3.ln2; - const float vc4 = params->scalar_expm1minus_rr1_lut8_p4h3.c4; - const float vc3 = params->scalar_expm1minus_rr1_lut8_p4h3.c3; - const float vc2 = params->scalar_expm1minus_rr1_lut8_p4h3.c2; - const float vminus_two = params->scalar_expm1minus_rr1_lut8_p4h3.minus_two; - const float vone = params->scalar_expm1minus_rr1_lut8_p4h3.one; - - for (; batch >= 4 * sizeof(float); batch -= 4 * sizeof(float)) { - const float vx0 = input[0]; - const float vx1 = input[1]; - const float vx2 = input[2]; - const float vx3 = input[3]; - input += 4; - - float vz0 = fabsf(vx0); - float vz1 = fabsf(vx1); - float vz2 = fabsf(vx2); - float vz3 = fabsf(vx3); - - vz0 = math_pmin_f32(vz0, vsat_cutoff); - vz1 = math_pmin_f32(vz1, vsat_cutoff); - vz2 = math_pmin_f32(vz2, vsat_cutoff); - vz3 = math_pmin_f32(vz3, vsat_cutoff); - - float vn0 = vz0 * vminus_log2e + vmagic_bias; - float vn1 = vz1 * vminus_log2e + vmagic_bias; - float vn2 = vz2 * vminus_log2e + vmagic_bias; - float vn3 = vz3 * vminus_log2e + vmagic_bias; - - const uint32_t vb0 = float_as_uint32(vn0); - vn0 -= vmagic_bias; - const uint32_t vb1 = float_as_uint32(vn1); - vn1 -= vmagic_bias; - const uint32_t vb2 = float_as_uint32(vn2); - vn2 -= vmagic_bias; - const uint32_t vb3 = float_as_uint32(vn3); - vn3 -= vmagic_bias; - - const uint32_t vidx0 = vb0 & vindex_mask; - const uint32_t vidx1 = vb1 & vindex_mask; - const uint32_t vidx2 = vb2 & vindex_mask; - const uint32_t vidx3 = vb3 & vindex_mask; - - const uint32_t vl0 = xnn_table_exp2minus_k_over_8[vidx0]; - uint32_t ve0 = vb0 << 20; - const uint32_t vl1 = xnn_table_exp2minus_k_over_8[vidx1]; - uint32_t ve1 = vb1 << 20; - const uint32_t vl2 = xnn_table_exp2minus_k_over_8[vidx2]; - uint32_t ve2 = vb2 << 20; - const uint32_t vl3 = xnn_table_exp2minus_k_over_8[vidx3]; - uint32_t ve3 = vb3 << 20; - - ve0 += vl0; - ve1 += vl1; - ve2 += vl2; - ve3 += vl3; - - const float vt0 = vn0 * vln2 + vz0; - const float vs0 = uint32_as_float(ve0); - const float vt1 = vn1 * vln2 + vz1; - const float vs1 = uint32_as_float(ve1); - const float vt2 = vn2 * vln2 + vz2; - const float vs2 = uint32_as_float(ve2); - const float vt3 = vn3 * vln2 + vz3; - const float vs3 = uint32_as_float(ve3); - - float vp0 = vc4 * vt0 + vc3; - float vp1 = vc4 * vt1 + vc3; - float vp2 = vc4 * vt2 + vc3; - float vp3 = vc4 * vt3 + vc3; - vp0 = vp0 * vt0 + vc2; - vp1 = vp1 * vt1 + vc2; - vp2 = vp2 * vt2 + vc2; - vp3 = vp3 * vt3 + vc2; - vp0 = vp0 * vt0 + vminus_two; - vp1 = vp1 * vt1 + vminus_two; - vp2 = vp2 * vt2 + vminus_two; - vp3 = vp3 * vt3 + vminus_two; - - const float vts0 = vt0 * vs0; - const float vsmo0 = vs0 - vone; - const float vts1 = vt1 * vs1; - const float vsmo1 = vs1 - vone; - const float vts2 = vt2 * vs2; - const float vsmo2 = vs2 - vone; - const float vts3 = vt3 * vs3; - const float vsmo3 = vs3 - vone; - - const float vemo0 = vp0 * vts0 + vsmo0; - const float vemo1 = vp1 * vts1 + vsmo1; - const float vemo2 = vp2 * vts2 + vsmo2; - const float vemo3 = vp3 * vts3 + vsmo3; - - const float vepo0 = vemo0 - vminus_two; - const float vepo1 = vemo1 - vminus_two; - const float vepo2 = vemo2 - vminus_two; - const float vepo3 = vemo3 - vminus_two; - - float vy0 = vemo0 / vepo0; - float vy1 = vemo1 / vepo1; - float vy2 = vemo2 / vepo2; - float vy3 = vemo3 / vepo3; - - vy0 = copysignf(vy0, vx0); - vy1 = copysignf(vy1, vx1); - vy2 = copysignf(vy2, vx2); - vy3 = copysignf(vy3, vx3); - - output[0] = vy0; - output[1] = vy1; - output[2] = vy2; - output[3] = vy3; - output += 4; - } - if XNN_UNLIKELY(batch != 0) { - do { - const float vx = *input++; - - float vz = fabsf(vx); - - vz = math_pmin_f32(vz, vsat_cutoff); - - float vn = vz * vminus_log2e + vmagic_bias; - - const uint32_t vb = float_as_uint32(vn); - vn -= vmagic_bias; - - const uint32_t vidx = vb & vindex_mask; - const uint32_t vl = xnn_table_exp2minus_k_over_8[vidx]; - uint32_t ve = vb << 20; - ve += vl; - const float vs = uint32_as_float(ve); - - const float vt = vn * vln2 + vz; - - float vp = vc4 * vt + vc3; - vp = vp * vt + vc2; - vp = vp * vt + vminus_two; - - const float vts = vt * vs; - const float vsmo = vs - vone; - const float vemo = vp * vts + vsmo; - - const float vepo = vemo - vminus_two; - - float vy = vemo / vepo; - - vy = copysignf(vy, vx); - - *output++ = vy; - - batch -= sizeof(float); - } while (batch != 0); + // Cap the inputs to this value as `tanh(x)` will always be `+/-1.0f` beyond + // this point. This value is chosen as the first floating point number as of + // which the interpolation returns 1.0f. + const float max_x = 7.623543739319f; + const float min_x = -7.623543739319f; + + // The monomial coefficients of the numerator polynomial (odd). + const float alpha_1 = -9.022999554873e-03f; + const float alpha_3 = -1.146968104877e-03f; + const float alpha_5 = -2.432360815874e-05f; + const float alpha_7 = -6.458659385089e-08f; + const float alpha_9 = 5.535878699892e-11f; + + // The monomial coefficients of the denominator polynomial (even). + const float beta_0 = -9.023001417518e-03f; + const float beta_2 = -4.154618829489e-03f; + const float beta_4 = -2.061512641376e-04f; + const float beta_6 = -1.774490101525e-06f; + + for (; batch >= sizeof(float); batch -= sizeof(float)) { + float x = *input; + input++; + + // Clamp the inputs to the interpolation range. Note that we don't use + //`fminf` or `fmaxf` since they let `NaN`s through. + x = max_x < x ? max_x : x; + x = x < min_x ? min_x : x; + + // Since the polynomials are odd/even, we need x^2. + const float x2 = x * x; + + // Evaluate the numerator polynomial p. + float p = x2 * alpha_9 + alpha_7; + p = x2 * p + alpha_5; + p = x2 * p + alpha_3; + p = x2 * p + alpha_1; + p = x * p; + + // Evaluate the denominator polynomial q. + float q = x2 * beta_6 + beta_4; + q = x2 * q + beta_2; + q = x2 * q + beta_0; + + // Divide the numerator by the denominator. + const float y = p / q; + + *output = y; + output++; } } diff --git a/src/amalgam/gen/sse2.c b/src/amalgam/gen/sse2.c index 0afdf46926a..0bc47a4deeb 100644 --- a/src/amalgam/gen/sse2.c +++ b/src/amalgam/gen/sse2.c @@ -2648,9 +2648,7 @@ void xnn_f32_vsigmoid_ukernel__sse2_rr2_lut64_p2_div_u8( } } -extern XNN_INTERNAL const uint32_t xnn_table_exp2minus_k_over_8[8]; - -void xnn_f32_vtanh_ukernel__sse2_expm1minus_rr1_lut8_p4h3ts_div_u16( +void xnn_f32_vtanh_ukernel__sse2_rational_9_6_div_u8( size_t batch, const float* input, float* output, @@ -2661,344 +2659,142 @@ void xnn_f32_vtanh_ukernel__sse2_expm1minus_rr1_lut8_p4h3ts_div_u16( assert(input != NULL); assert(output != NULL); - const __m128 vsign_mask = _mm_load_ps(params->sse_expm1minus_rr1_lut8_p4h3.sign_mask); - const __m128 vsat_cutoff = _mm_load_ps(params->sse_expm1minus_rr1_lut8_p4h3.sat_cutoff); - const __m128 vlog2e = _mm_load_ps(params->sse_expm1minus_rr1_lut8_p4h3.log2e); - const __m128 vmagic_bias = _mm_load_ps(params->sse_expm1minus_rr1_lut8_p4h3.magic_bias); - const __m128i vindex_mask = _mm_load_si128((const __m128i*) params->sse_expm1minus_rr1_lut8_p4h3.index_mask); - const __m128 vminus_ln2 = _mm_load_ps(params->sse_expm1minus_rr1_lut8_p4h3.minus_ln2); - const __m128 vc4 = _mm_load_ps(params->sse_expm1minus_rr1_lut8_p4h3.c4); - const __m128 vc3 = _mm_load_ps(params->sse_expm1minus_rr1_lut8_p4h3.c3); - const __m128 vc2 = _mm_load_ps(params->sse_expm1minus_rr1_lut8_p4h3.c2); - const __m128 vminus_two = _mm_load_ps(params->sse_expm1minus_rr1_lut8_p4h3.minus_two); - const __m128 vminus_one = _mm_load_ps(params->sse_expm1minus_rr1_lut8_p4h3.minus_one); - - for (; batch >= 16 * sizeof(float); batch -= 16 * sizeof(float)) { - const __m128 vx0123 = _mm_loadu_ps(input); - const __m128 vx4567 = _mm_loadu_ps(input + 4); - const __m128 vx89AB = _mm_loadu_ps(input + 8); - const __m128 vxCDEF = _mm_loadu_ps(input + 12); - input += 16; - - __m128 vz0123 = _mm_or_ps(vx0123, vsign_mask); - __m128 vz4567 = _mm_or_ps(vx4567, vsign_mask); - __m128 vz89AB = _mm_or_ps(vx89AB, vsign_mask); - __m128 vzCDEF = _mm_or_ps(vxCDEF, vsign_mask); - - const __m128 vinvsignx0123 = _mm_xor_ps(vx0123, vz0123); - const __m128 vinvsignx4567 = _mm_xor_ps(vx4567, vz4567); - const __m128 vinvsignx89AB = _mm_xor_ps(vx89AB, vz89AB); - const __m128 vinvsignxCDEF = _mm_xor_ps(vxCDEF, vzCDEF); - - vz0123 = _mm_max_ps(vsat_cutoff, vz0123); - vz4567 = _mm_max_ps(vsat_cutoff, vz4567); - vz89AB = _mm_max_ps(vsat_cutoff, vz89AB); - vzCDEF = _mm_max_ps(vsat_cutoff, vzCDEF); - - __m128 vn0123 = _mm_add_ps(_mm_mul_ps(vz0123, vlog2e), vmagic_bias); - __m128 vn4567 = _mm_add_ps(_mm_mul_ps(vz4567, vlog2e), vmagic_bias); - __m128 vn89AB = _mm_add_ps(_mm_mul_ps(vz89AB, vlog2e), vmagic_bias); - __m128 vnCDEF = _mm_add_ps(_mm_mul_ps(vzCDEF, vlog2e), vmagic_bias); - - const __m128i ve0123 = _mm_slli_epi32(_mm_castps_si128(vn0123), 20); - const __m128i ve4567 = _mm_slli_epi32(_mm_castps_si128(vn4567), 20); - const __m128i ve89AB = _mm_slli_epi32(_mm_castps_si128(vn89AB), 20); - const __m128i veCDEF = _mm_slli_epi32(_mm_castps_si128(vnCDEF), 20); - - #if XNN_ARCH_X86_64 - __m128i vidx0123 = _mm_and_si128(_mm_castps_si128(vn0123), vindex_mask); - __m128i vidx4567 = _mm_and_si128(_mm_castps_si128(vn4567), vindex_mask); - __m128i vidx89AB = _mm_and_si128(_mm_castps_si128(vn89AB), vindex_mask); - __m128i vidxCDEF = _mm_and_si128(_mm_castps_si128(vnCDEF), vindex_mask); - - const uint64_t vidx01 = (uint64_t) _mm_cvtsi128_si64(vidx0123); - vidx0123 = _mm_unpackhi_epi64(vidx0123, vidx0123); - const uint64_t vidx45 = (uint64_t) _mm_cvtsi128_si64(vidx4567); - vidx4567 = _mm_unpackhi_epi64(vidx4567, vidx4567); - const uint64_t vidx89 = (uint64_t) _mm_cvtsi128_si64(vidx89AB); - vidx89AB = _mm_unpackhi_epi64(vidx89AB, vidx89AB); - const uint64_t vidxCD = (uint64_t) _mm_cvtsi128_si64(vidxCDEF); - vidxCDEF = _mm_unpackhi_epi64(vidxCDEF, vidxCDEF); - - const uint64_t vidx23 = (uint64_t) _mm_cvtsi128_si64(vidx0123); - const uint64_t vidx67 = (uint64_t) _mm_cvtsi128_si64(vidx4567); - const uint64_t vidxAB = (uint64_t) _mm_cvtsi128_si64(vidx89AB); - const uint64_t vidxEF = (uint64_t) _mm_cvtsi128_si64(vidxCDEF); - - const __m128i vl0 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[(uint32_t) vidx01]); - const __m128i vl1 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[(uint32_t) (vidx01 >> 32)]); - const __m128i vl4 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[(uint32_t) vidx45]); - const __m128i vl5 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[(uint32_t) (vidx45 >> 32)]); - const __m128i vl8 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[(uint32_t) vidx89]); - const __m128i vl9 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[(uint32_t) (vidx89 >> 32)]); - const __m128i vlC = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[(uint32_t) vidxCD]); - const __m128i vlD = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[(uint32_t) (vidxCD >> 32)]); - - const __m128i vl01 = _mm_unpacklo_epi32(vl0, vl1); - const __m128i vl45 = _mm_unpacklo_epi32(vl4, vl5); - const __m128i vl89 = _mm_unpacklo_epi32(vl8, vl9); - const __m128i vlCD = _mm_unpacklo_epi32(vlC, vlD); - - const __m128i vl2 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[(uint32_t) vidx23]); - const __m128i vl3 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[(uint32_t) (vidx23 >> 32)]); - const __m128i vl6 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[(uint32_t) vidx67]); - const __m128i vl7 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[(uint32_t) (vidx67 >> 32)]); - const __m128i vlA = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[(uint32_t) vidxAB]); - const __m128i vlB = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[(uint32_t) (vidxAB >> 32)]); - const __m128i vlE = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[(uint32_t) vidxEF]); - const __m128i vlF = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[(uint32_t) (vidxEF >> 32)]); - - const __m128i vl23 = _mm_unpacklo_epi32(vl2, vl3); - const __m128i vl67 = _mm_unpacklo_epi32(vl6, vl7); - const __m128i vlAB = _mm_unpacklo_epi32(vlA, vlB); - const __m128i vlEF = _mm_unpacklo_epi32(vlE, vlF); - #else - const __m128i vidx0123 = _mm_and_si128(_mm_castps_si128(vn0123), vindex_mask); - const __m128i vidx4567 = _mm_and_si128(_mm_castps_si128(vn4567), vindex_mask); - const __m128i vidx89AB = _mm_and_si128(_mm_castps_si128(vn89AB), vindex_mask); - const __m128i vidxCDEF = _mm_and_si128(_mm_castps_si128(vnCDEF), vindex_mask); - - const uint32_t vidx0 = (uint32_t) _mm_cvtsi128_si32(vidx0123); - const uint32_t vidx4 = (uint32_t) _mm_cvtsi128_si32(vidx4567); - const uint32_t vidx8 = (uint32_t) _mm_cvtsi128_si32(vidx89AB); - const uint32_t vidxC = (uint32_t) _mm_cvtsi128_si32(vidxCDEF); - - const __m128i vl0 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[vidx0]); - const __m128i vl4 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[vidx4]); - const __m128i vl8 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[vidx8]); - const __m128i vlC = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[vidxC]); - - const uint32_t vidx1 = (uint32_t) _mm_extract_epi16(vidx0123, 2); - const uint32_t vidx5 = (uint32_t) _mm_extract_epi16(vidx4567, 2); - const uint32_t vidx9 = (uint32_t) _mm_extract_epi16(vidx89AB, 2); - const uint32_t vidxD = (uint32_t) _mm_extract_epi16(vidxCDEF, 2); - - const __m128i vl1 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[vidx1]); - const __m128i vl5 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[vidx5]); - const __m128i vl9 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[vidx9]); - const __m128i vlD = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[vidxD]); - - const __m128i vl01 = _mm_unpacklo_epi32(vl0, vl1); - const __m128i vl45 = _mm_unpacklo_epi32(vl4, vl5); - const __m128i vl89 = _mm_unpacklo_epi32(vl8, vl9); - const __m128i vlCD = _mm_unpacklo_epi32(vlC, vlD); - - const uint32_t vidx2 = (uint32_t) _mm_extract_epi16(vidx0123, 4); - const uint32_t vidx6 = (uint32_t) _mm_extract_epi16(vidx4567, 4); - const uint32_t vidxA = (uint32_t) _mm_extract_epi16(vidx89AB, 4); - const uint32_t vidxE = (uint32_t) _mm_extract_epi16(vidxCDEF, 4); - - const __m128i vl2 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[vidx2]); - const __m128i vl6 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[vidx6]); - const __m128i vlA = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[vidxA]); - const __m128i vlE = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[vidxE]); - - const uint32_t vidx3 = (uint32_t) _mm_extract_epi16(vidx0123, 6); - const uint32_t vidx7 = (uint32_t) _mm_extract_epi16(vidx4567, 6); - const uint32_t vidxB = (uint32_t) _mm_extract_epi16(vidx89AB, 6); - const uint32_t vidxF = (uint32_t) _mm_extract_epi16(vidxCDEF, 6); - - const __m128i vl3 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[vidx3]); - const __m128i vl7 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[vidx7]); - const __m128i vlB = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[vidxB]); - const __m128i vlF = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[vidxF]); - - const __m128i vl23 = _mm_unpacklo_epi32(vl2, vl3); - const __m128i vl67 = _mm_unpacklo_epi32(vl6, vl7); - const __m128i vlAB = _mm_unpacklo_epi32(vlA, vlB); - const __m128i vlEF = _mm_unpacklo_epi32(vlE, vlF); - #endif - const __m128i vl0123 = _mm_unpacklo_epi64(vl01, vl23); - const __m128i vl4567 = _mm_unpacklo_epi64(vl45, vl67); - const __m128i vl89AB = _mm_unpacklo_epi64(vl89, vlAB); - const __m128i vlCDEF = _mm_unpacklo_epi64(vlCD, vlEF); - - const __m128 vs0123 = _mm_castsi128_ps(_mm_add_epi32(vl0123, ve0123)); - const __m128 vs4567 = _mm_castsi128_ps(_mm_add_epi32(vl4567, ve4567)); - const __m128 vs89AB = _mm_castsi128_ps(_mm_add_epi32(vl89AB, ve89AB)); - const __m128 vsCDEF = _mm_castsi128_ps(_mm_add_epi32(vlCDEF, veCDEF)); - - vn0123 = _mm_sub_ps(vn0123, vmagic_bias); - vn4567 = _mm_sub_ps(vn4567, vmagic_bias); - vn89AB = _mm_sub_ps(vn89AB, vmagic_bias); - vnCDEF = _mm_sub_ps(vnCDEF, vmagic_bias); - - const __m128 vt0123 = _mm_add_ps(_mm_mul_ps(vn0123, vminus_ln2), vz0123); - const __m128 vt4567 = _mm_add_ps(_mm_mul_ps(vn4567, vminus_ln2), vz4567); - const __m128 vt89AB = _mm_add_ps(_mm_mul_ps(vn89AB, vminus_ln2), vz89AB); - const __m128 vtCDEF = _mm_add_ps(_mm_mul_ps(vnCDEF, vminus_ln2), vzCDEF); + // Cap the inputs to this value as `tanh(x)` will always be `+/-1.0f` beyond + // this point. This value is chosen as the first floating point number as of + // which the interpolation returns 1.0f. +#if XNN_ARCH_X86 + const __m128 vmax_x = _mm_load_ps(params->sse_rational_9_6.max_x); + const __m128 vmin_x = _mm_load_ps(params->sse_rational_9_6.min_x); +#else + const __m128 vmax_x = _mm_set1_ps(7.623543739319f); + const __m128 vmin_x = _mm_set1_ps(-7.623543739319f); +#endif // XNN_ARCH_X86 + + // The monomial coefficients of the numerator polynomial (odd). +#if XNN_ARCH_X86 + const __m128 valpha_1 = _mm_load_ps(params->sse_rational_9_6.alpha_1); + const __m128 valpha_3 = _mm_load_ps(params->sse_rational_9_6.alpha_3); + const __m128 valpha_5 = _mm_load_ps(params->sse_rational_9_6.alpha_5); + const __m128 valpha_7 = _mm_load_ps(params->sse_rational_9_6.alpha_7); + const __m128 valpha_9 = _mm_load_ps(params->sse_rational_9_6.alpha_9); +#else + const __m128 valpha_1 = _mm_set1_ps(-9.022999554873e-03f); + const __m128 valpha_3 = _mm_set1_ps(-1.146968104877e-03f); + const __m128 valpha_5 = _mm_set1_ps(-2.432360815874e-05f); + const __m128 valpha_7 = _mm_set1_ps(-6.458659385089e-08f); + const __m128 valpha_9 = _mm_set1_ps(5.535878699892e-11f); +#endif // XNN_ARCH_X86 + + // The monomial coefficients of the denominator polynomial (even). +#if XNN_ARCH_X86 + const __m128 vbeta_0 = _mm_load_ps(params->sse_rational_9_6.beta_0); + const __m128 vbeta_2 = _mm_load_ps(params->sse_rational_9_6.beta_2); + const __m128 vbeta_4 = _mm_load_ps(params->sse_rational_9_6.beta_4); + const __m128 vbeta_6 = _mm_load_ps(params->sse_rational_9_6.beta_6); +#else + const __m128 vbeta_0 = _mm_set1_ps(-9.023001417518e-03f); + const __m128 vbeta_2 = _mm_set1_ps(-4.154618829489e-03f); + const __m128 vbeta_4 = _mm_set1_ps(-2.061512641376e-04f); + const __m128 vbeta_6 = _mm_set1_ps(-1.774490101525e-06f); +#endif // XNN_ARCH_X86 + - __m128 vp0123 = _mm_add_ps(_mm_mul_ps(vc4, vt0123), vc3); - __m128 vp4567 = _mm_add_ps(_mm_mul_ps(vc4, vt4567), vc3); - __m128 vp89AB = _mm_add_ps(_mm_mul_ps(vc4, vt89AB), vc3); - __m128 vpCDEF = _mm_add_ps(_mm_mul_ps(vc4, vtCDEF), vc3); - vp0123 = _mm_add_ps(_mm_mul_ps(vp0123, vt0123), vc2); - vp4567 = _mm_add_ps(_mm_mul_ps(vp4567, vt4567), vc2); - vp89AB = _mm_add_ps(_mm_mul_ps(vp89AB, vt89AB), vc2); - vpCDEF = _mm_add_ps(_mm_mul_ps(vpCDEF, vtCDEF), vc2); - vp0123 = _mm_sub_ps(_mm_mul_ps(vp0123, vt0123), vminus_two); - vp4567 = _mm_sub_ps(_mm_mul_ps(vp4567, vt4567), vminus_two); - vp89AB = _mm_sub_ps(_mm_mul_ps(vp89AB, vt89AB), vminus_two); - vpCDEF = _mm_sub_ps(_mm_mul_ps(vpCDEF, vtCDEF), vminus_two); - - const __m128 vts0123 = _mm_mul_ps(vt0123, vs0123); - const __m128 vsmo0123 = _mm_add_ps(vs0123, vminus_one); - const __m128 vts4567 = _mm_mul_ps(vt4567, vs4567); - const __m128 vsmo4567 = _mm_add_ps(vs4567, vminus_one); - const __m128 vts89AB = _mm_mul_ps(vt89AB, vs89AB); - const __m128 vsmo89AB = _mm_add_ps(vs89AB, vminus_one); - const __m128 vtsCDEF = _mm_mul_ps(vtCDEF, vsCDEF); - const __m128 vsmoCDEF = _mm_add_ps(vsCDEF, vminus_one); - const __m128 vemo0123 = _mm_add_ps(_mm_mul_ps(vp0123, vts0123), vsmo0123); - const __m128 vemo4567 = _mm_add_ps(_mm_mul_ps(vp4567, vts4567), vsmo4567); - const __m128 vemo89AB = _mm_add_ps(_mm_mul_ps(vp89AB, vts89AB), vsmo89AB); - const __m128 vemoCDEF = _mm_add_ps(_mm_mul_ps(vpCDEF, vtsCDEF), vsmoCDEF); - - const __m128 vepo0123 = _mm_sub_ps(vemo0123, vminus_two); - const __m128 vepo4567 = _mm_sub_ps(vemo4567, vminus_two); - const __m128 vepo89AB = _mm_sub_ps(vemo89AB, vminus_two); - const __m128 vepoCDEF = _mm_sub_ps(vemoCDEF, vminus_two); - - __m128 vy0123 = _mm_div_ps(vemo0123, vepo0123); - __m128 vy4567 = _mm_div_ps(vemo4567, vepo4567); - __m128 vy89AB = _mm_div_ps(vemo89AB, vepo89AB); - __m128 vyCDEF = _mm_div_ps(vemoCDEF, vepoCDEF); - - - vy0123 = _mm_xor_ps(vy0123, vinvsignx0123); - vy4567 = _mm_xor_ps(vy4567, vinvsignx4567); - vy89AB = _mm_xor_ps(vy89AB, vinvsignx89AB); - vyCDEF = _mm_xor_ps(vyCDEF, vinvsignxCDEF); + for (; batch >= 8 * sizeof(float); batch -= 8 * sizeof(float)) { + __m128 vx_0 = _mm_loadu_ps(input); + __m128 vx_1 = _mm_loadu_ps(input + 4); + input += 8; - _mm_storeu_ps(output, vy0123); - _mm_storeu_ps(output + 4, vy4567); - _mm_storeu_ps(output + 8, vy89AB); - _mm_storeu_ps(output + 12, vyCDEF); - output += 16; + // Clamp the inputs to the interpolation range. + vx_0 = _mm_min_ps(vmax_x, vx_0); + vx_1 = _mm_min_ps(vmax_x, vx_1); + vx_0 = _mm_max_ps(vmin_x, vx_0); + vx_1 = _mm_max_ps(vmin_x, vx_1); + + // Since the polynomials are odd/even, we need x^2. + const __m128 vx2_0 = _mm_mul_ps(vx_0, vx_0); + const __m128 vx2_1 = _mm_mul_ps(vx_1, vx_1); + + // Evaluate the numerator polynomial p. + __m128 vp_0 = _mm_add_ps(_mm_mul_ps(vx2_0, valpha_9), valpha_7); + __m128 vp_1 = _mm_add_ps(_mm_mul_ps(vx2_1, valpha_9), valpha_7); + vp_0 = _mm_add_ps(_mm_mul_ps(vx2_0, vp_0), valpha_5); + vp_1 = _mm_add_ps(_mm_mul_ps(vx2_1, vp_1), valpha_5); + vp_0 = _mm_add_ps(_mm_mul_ps(vx2_0, vp_0), valpha_3); + vp_1 = _mm_add_ps(_mm_mul_ps(vx2_1, vp_1), valpha_3); + vp_0 = _mm_add_ps(_mm_mul_ps(vx2_0, vp_0), valpha_1); + vp_1 = _mm_add_ps(_mm_mul_ps(vx2_1, vp_1), valpha_1); + vp_0 = _mm_mul_ps(vx_0, vp_0); + vp_1 = _mm_mul_ps(vx_1, vp_1); + + // Evaluate the denominator polynomial q. + __m128 vq_0 = _mm_add_ps(_mm_mul_ps(vx2_0, vbeta_6), vbeta_4); + __m128 vq_1 = _mm_add_ps(_mm_mul_ps(vx2_1, vbeta_6), vbeta_4); + vq_0 = _mm_add_ps(_mm_mul_ps(vx2_0, vq_0), vbeta_2); + vq_1 = _mm_add_ps(_mm_mul_ps(vx2_1, vq_1), vbeta_2); + vq_0 = _mm_add_ps(_mm_mul_ps(vx2_0, vq_0), vbeta_0); + vq_1 = _mm_add_ps(_mm_mul_ps(vx2_1, vq_1), vbeta_0); + + // Divide the numerator by the denominator. + const __m128 vy_0 = _mm_div_ps(vp_0, vq_0); + const __m128 vy_1 = _mm_div_ps(vp_1, vq_1); + + _mm_storeu_ps(output, vy_0); + _mm_storeu_ps(output + 4, vy_1); + output += 8; } for (; batch >= 4 * sizeof(float); batch -= 4 * sizeof(float)) { - const __m128 vx = _mm_loadu_ps(input); + __m128 vx = _mm_loadu_ps(input); input += 4; - __m128 vz = _mm_or_ps(vx, vsign_mask); - - const __m128 vinvsignx = _mm_xor_ps(vx, vz); - - vz = _mm_max_ps(vsat_cutoff, vz); - - __m128 vn = _mm_add_ps(_mm_mul_ps(vz, vlog2e), vmagic_bias); - - const __m128i ve = _mm_slli_epi32(_mm_castps_si128(vn), 20); - - #if XNN_ARCH_X86_64 - __m128i vidx = _mm_and_si128(_mm_castps_si128(vn), vindex_mask); - const uint64_t vidx_lo = (uint64_t) _mm_cvtsi128_si64(vidx); - vidx = _mm_unpackhi_epi64(vidx, vidx); - const uint64_t vidx_hi = (uint64_t) _mm_cvtsi128_si64(vidx); - const __m128i vl0 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[(uint32_t) vidx_lo]); - const __m128i vl1 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[(uint32_t) (vidx_lo >> 32)]); - const __m128i vl_lo = _mm_unpacklo_epi32(vl0, vl1); - const __m128i vl2 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[(uint32_t) vidx_hi]); - const __m128i vl3 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[(uint32_t) (vidx_hi >> 32)]); - const __m128i vl_hi = _mm_unpacklo_epi32(vl2, vl3); - #else - const __m128i vidx = _mm_and_si128(_mm_castps_si128(vn), vindex_mask); - const uint32_t vidx0 = (uint32_t) _mm_cvtsi128_si32(vidx); - const __m128i vl0 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[vidx0]); - const uint32_t vidx1 = (uint32_t) _mm_extract_epi16(vidx, 2); - const __m128i vl1 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[vidx1]); - const __m128i vl_lo = _mm_unpacklo_epi32(vl0, vl1); - const uint32_t vidx2 = (uint32_t) _mm_extract_epi16(vidx, 4); - const __m128i vl2 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[vidx2]); - const uint32_t vidx3 = (uint32_t) _mm_extract_epi16(vidx, 6); - const __m128i vl3 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[vidx3]); - const __m128i vl_hi = _mm_unpacklo_epi32(vl2, vl3); - #endif - const __m128i vl = _mm_unpacklo_epi64(vl_lo, vl_hi); - - const __m128 vs = _mm_castsi128_ps(_mm_add_epi32(vl, ve)); - - vn = _mm_sub_ps(vn, vmagic_bias); - - const __m128 vt = _mm_add_ps(_mm_mul_ps(vn, vminus_ln2), vz); + // Clamp the inputs to the interpolation range. + vx = _mm_min_ps(vmax_x, vx); + vx = _mm_max_ps(vmin_x, vx); - __m128 vp = _mm_add_ps(_mm_mul_ps(vc4, vt), vc3); - vp = _mm_add_ps(_mm_mul_ps(vp, vt), vc2); - vp = _mm_sub_ps(_mm_mul_ps(vp, vt), vminus_two); + // Since the polynomials are odd/even, we need x^2. + const __m128 vx2 = _mm_mul_ps(vx, vx); - const __m128 vts = _mm_mul_ps(vt, vs); - const __m128 vsmo = _mm_add_ps(vs, vminus_one); - const __m128 vemo = _mm_add_ps(_mm_mul_ps(vp, vts), vsmo); + // Evaluate the numerator polynomial p. + __m128 vp = _mm_add_ps(_mm_mul_ps(vx2, valpha_9), valpha_7); + vp = _mm_add_ps(_mm_mul_ps(vx2, vp), valpha_5); + vp = _mm_add_ps(_mm_mul_ps(vx2, vp), valpha_3); + vp = _mm_add_ps(_mm_mul_ps(vx2, vp), valpha_1); + vp = _mm_mul_ps(vx, vp); - const __m128 vepo = _mm_sub_ps(vemo, vminus_two); + // Evaluate the denominator polynomial q. + __m128 vq = _mm_add_ps(_mm_mul_ps(vx2, vbeta_6), vbeta_4); + vq = _mm_add_ps(_mm_mul_ps(vx2, vq), vbeta_2); + vq = _mm_add_ps(_mm_mul_ps(vx2, vq), vbeta_0); - __m128 vy = _mm_div_ps(vemo, vepo); - - - vy = _mm_xor_ps(vy, vinvsignx); + // Divide the numerator by the denominator. + const __m128 vy = _mm_div_ps(vp, vq); _mm_storeu_ps(output, vy); output += 4; } if XNN_UNLIKELY(batch != 0) { - const __m128 vx = _mm_loadu_ps(input); - - __m128 vz = _mm_or_ps(vx, vsign_mask); - - const __m128 vinvsignx = _mm_xor_ps(vx, vz); - - vz = _mm_max_ps(vsat_cutoff, vz); - - __m128 vn = _mm_add_ps(_mm_mul_ps(vz, vlog2e), vmagic_bias); - - const __m128i ve = _mm_slli_epi32(_mm_castps_si128(vn), 20); - - #if XNN_ARCH_X86_64 - __m128i vidx = _mm_and_si128(_mm_castps_si128(vn), vindex_mask); - const uint64_t vidx_lo = (uint64_t) _mm_cvtsi128_si64(vidx); - vidx = _mm_unpackhi_epi64(vidx, vidx); - const uint64_t vidx_hi = (uint64_t) _mm_cvtsi128_si64(vidx); - const __m128i vl0 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[(uint32_t) vidx_lo]); - const __m128i vl1 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[(uint32_t) (vidx_lo >> 32)]); - const __m128i vl_lo = _mm_unpacklo_epi32(vl0, vl1); - const __m128i vl2 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[(uint32_t) vidx_hi]); - const __m128i vl3 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[(uint32_t) (vidx_hi >> 32)]); - const __m128i vl_hi = _mm_unpacklo_epi32(vl2, vl3); - #else - const __m128i vidx = _mm_and_si128(_mm_castps_si128(vn), vindex_mask); - const uint32_t vidx0 = (uint32_t) _mm_cvtsi128_si32(vidx); - const __m128i vl0 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[vidx0]); - const uint32_t vidx1 = (uint32_t) _mm_extract_epi16(vidx, 2); - const __m128i vl1 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[vidx1]); - const __m128i vl_lo = _mm_unpacklo_epi32(vl0, vl1); - const uint32_t vidx2 = (uint32_t) _mm_extract_epi16(vidx, 4); - const __m128i vl2 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[vidx2]); - const uint32_t vidx3 = (uint32_t) _mm_extract_epi16(vidx, 6); - const __m128i vl3 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[vidx3]); - const __m128i vl_hi = _mm_unpacklo_epi32(vl2, vl3); - #endif - const __m128i vl = _mm_unpacklo_epi64(vl_lo, vl_hi); - - const __m128 vs = _mm_castsi128_ps(_mm_add_epi32(vl, ve)); - - vn = _mm_sub_ps(vn, vmagic_bias); - - const __m128 vt = _mm_add_ps(_mm_mul_ps(vn, vminus_ln2), vz); - - __m128 vp = _mm_add_ps(_mm_mul_ps(vc4, vt), vc3); - vp = _mm_add_ps(_mm_mul_ps(vp, vt), vc2); - vp = _mm_sub_ps(_mm_mul_ps(vp, vt), vminus_two); + __m128 vx = _mm_loadu_ps(input); - const __m128 vts = _mm_mul_ps(vt, vs); - const __m128 vsmo = _mm_add_ps(vs, vminus_one); - const __m128 vemo = _mm_add_ps(_mm_mul_ps(vp, vts), vsmo); + // Clamp the inputs to the interpolation range. + vx = _mm_min_ps(vmax_x, vx); + vx = _mm_max_ps(vmin_x, vx); - const __m128 vepo = _mm_sub_ps(vemo, vminus_two); + // Since the polynomials are odd/even, we need x^2. + const __m128 vx2 = _mm_mul_ps(vx, vx); - __m128 vy = _mm_div_ps(vemo, vepo); + // Evaluate the numerator polynomial p. + __m128 vp = _mm_add_ps(_mm_mul_ps(vx2, valpha_9), valpha_7); + vp = _mm_add_ps(_mm_mul_ps(vx2, vp), valpha_5); + vp = _mm_add_ps(_mm_mul_ps(vx2, vp), valpha_3); + vp = _mm_add_ps(_mm_mul_ps(vx2, vp), valpha_1); + vp = _mm_mul_ps(vx, vp); + // Evaluate the denominator polynomial q. + __m128 vq = _mm_add_ps(_mm_mul_ps(vx2, vbeta_6), vbeta_4); + vq = _mm_add_ps(_mm_mul_ps(vx2, vq), vbeta_2); + vq = _mm_add_ps(_mm_mul_ps(vx2, vq), vbeta_0); - vy = _mm_xor_ps(vy, vinvsignx); + // Divide the numerator by the denominator. + __m128 vy = _mm_div_ps(vp, vq); if (batch & (2 * sizeof(float))) { _mm_storel_pi((__m64*) output, vy); diff --git a/src/amalgam/gen/sse41.c b/src/amalgam/gen/sse41.c index b292c7fd859..3437d15f684 100644 --- a/src/amalgam/gen/sse41.c +++ b/src/amalgam/gen/sse41.c @@ -4,8 +4,6 @@ // LICENSE file in the root directory of this source tree. #include -#include -#include #include @@ -18,7 +16,6 @@ #include #include #include -#include #include #include #include @@ -1976,370 +1973,6 @@ void xnn_f32_vsigmoid_ukernel__sse41_rr2_lut64_p2_div_u8( } } -extern XNN_INTERNAL const uint32_t xnn_table_exp2minus_k_over_8[8]; - -void xnn_f32_vtanh_ukernel__sse41_expm1minus_rr1_lut8_p4h3ts_div_u20( - size_t batch, - const float* input, - float* output, - const union xnn_f32_tanh_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS -{ - assert(batch != 0); - assert(batch % sizeof(float) == 0); - assert(input != NULL); - assert(output != NULL); - - const __m128 vsign_mask = _mm_load_ps(params->sse_expm1minus_rr1_lut8_p4h3.sign_mask); - const __m128 vsat_cutoff = _mm_load_ps(params->sse_expm1minus_rr1_lut8_p4h3.sat_cutoff); - const __m128 vlog2e = _mm_load_ps(params->sse_expm1minus_rr1_lut8_p4h3.log2e); - const __m128 vmagic_bias = _mm_load_ps(params->sse_expm1minus_rr1_lut8_p4h3.magic_bias); - const __m128i vindex_mask = _mm_load_si128((const __m128i*) params->sse_expm1minus_rr1_lut8_p4h3.index_mask); - const __m128 vminus_ln2 = _mm_load_ps(params->sse_expm1minus_rr1_lut8_p4h3.minus_ln2); - const __m128 vc4 = _mm_load_ps(params->sse_expm1minus_rr1_lut8_p4h3.c4); - const __m128 vc3 = _mm_load_ps(params->sse_expm1minus_rr1_lut8_p4h3.c3); - const __m128 vc2 = _mm_load_ps(params->sse_expm1minus_rr1_lut8_p4h3.c2); - const __m128 vminus_two = _mm_load_ps(params->sse_expm1minus_rr1_lut8_p4h3.minus_two); - const __m128 vminus_one = _mm_load_ps(params->sse_expm1minus_rr1_lut8_p4h3.minus_one); - - for (; batch >= 20 * sizeof(float); batch -= 20 * sizeof(float)) { - const __m128 vx0123 = _mm_loadu_ps(input); - const __m128 vx4567 = _mm_loadu_ps(input + 4); - const __m128 vx89AB = _mm_loadu_ps(input + 8); - const __m128 vxCDEF = _mm_loadu_ps(input + 12); - const __m128 vxGHIJ = _mm_loadu_ps(input + 16); - input += 20; - - __m128 vz0123 = _mm_or_ps(vx0123, vsign_mask); - __m128 vz4567 = _mm_or_ps(vx4567, vsign_mask); - __m128 vz89AB = _mm_or_ps(vx89AB, vsign_mask); - __m128 vzCDEF = _mm_or_ps(vxCDEF, vsign_mask); - __m128 vzGHIJ = _mm_or_ps(vxGHIJ, vsign_mask); - - const __m128 vinvsignx0123 = _mm_xor_ps(vx0123, vz0123); - const __m128 vinvsignx4567 = _mm_xor_ps(vx4567, vz4567); - const __m128 vinvsignx89AB = _mm_xor_ps(vx89AB, vz89AB); - const __m128 vinvsignxCDEF = _mm_xor_ps(vxCDEF, vzCDEF); - const __m128 vinvsignxGHIJ = _mm_xor_ps(vxGHIJ, vzGHIJ); - - vz0123 = _mm_max_ps(vsat_cutoff, vz0123); - vz4567 = _mm_max_ps(vsat_cutoff, vz4567); - vz89AB = _mm_max_ps(vsat_cutoff, vz89AB); - vzCDEF = _mm_max_ps(vsat_cutoff, vzCDEF); - vzGHIJ = _mm_max_ps(vsat_cutoff, vzGHIJ); - - __m128 vn0123 = _mm_add_ps(_mm_mul_ps(vz0123, vlog2e), vmagic_bias); - __m128 vn4567 = _mm_add_ps(_mm_mul_ps(vz4567, vlog2e), vmagic_bias); - __m128 vn89AB = _mm_add_ps(_mm_mul_ps(vz89AB, vlog2e), vmagic_bias); - __m128 vnCDEF = _mm_add_ps(_mm_mul_ps(vzCDEF, vlog2e), vmagic_bias); - __m128 vnGHIJ = _mm_add_ps(_mm_mul_ps(vzGHIJ, vlog2e), vmagic_bias); - - const __m128i ve0123 = _mm_slli_epi32(_mm_castps_si128(vn0123), 20); - const __m128i ve4567 = _mm_slli_epi32(_mm_castps_si128(vn4567), 20); - const __m128i ve89AB = _mm_slli_epi32(_mm_castps_si128(vn89AB), 20); - const __m128i veCDEF = _mm_slli_epi32(_mm_castps_si128(vnCDEF), 20); - const __m128i veGHIJ = _mm_slli_epi32(_mm_castps_si128(vnGHIJ), 20); - - #if XNN_ARCH_X86_64 - __m128i vidx0123 = _mm_and_si128(_mm_castps_si128(vn0123), vindex_mask); - __m128i vidx4567 = _mm_and_si128(_mm_castps_si128(vn4567), vindex_mask); - __m128i vidx89AB = _mm_and_si128(_mm_castps_si128(vn89AB), vindex_mask); - __m128i vidxCDEF = _mm_and_si128(_mm_castps_si128(vnCDEF), vindex_mask); - __m128i vidxGHIJ = _mm_and_si128(_mm_castps_si128(vnGHIJ), vindex_mask); - - const uint64_t vidx01 = (uint64_t) _mm_cvtsi128_si64(vidx0123); - vidx0123 = _mm_unpackhi_epi64(vidx0123, vidx0123); - const uint64_t vidx45 = (uint64_t) _mm_cvtsi128_si64(vidx4567); - vidx4567 = _mm_unpackhi_epi64(vidx4567, vidx4567); - const uint64_t vidx89 = (uint64_t) _mm_cvtsi128_si64(vidx89AB); - vidx89AB = _mm_unpackhi_epi64(vidx89AB, vidx89AB); - const uint64_t vidxCD = (uint64_t) _mm_cvtsi128_si64(vidxCDEF); - vidxCDEF = _mm_unpackhi_epi64(vidxCDEF, vidxCDEF); - const uint64_t vidxGH = (uint64_t) _mm_cvtsi128_si64(vidxGHIJ); - vidxGHIJ = _mm_unpackhi_epi64(vidxGHIJ, vidxGHIJ); - - const uint64_t vidx23 = (uint64_t) _mm_cvtsi128_si64(vidx0123); - const uint64_t vidx67 = (uint64_t) _mm_cvtsi128_si64(vidx4567); - const uint64_t vidxAB = (uint64_t) _mm_cvtsi128_si64(vidx89AB); - const uint64_t vidxEF = (uint64_t) _mm_cvtsi128_si64(vidxCDEF); - const uint64_t vidxIJ = (uint64_t) _mm_cvtsi128_si64(vidxGHIJ); - - __m128i vl0123 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[(uint32_t) vidx01]); - vl0123 = _mm_insert_epi32(vl0123, (int) xnn_table_exp2minus_k_over_8[(uint32_t) (vidx01 >> 32)], 1); - vl0123 = _mm_insert_epi32(vl0123, (int) xnn_table_exp2minus_k_over_8[(uint32_t) vidx23], 2); - vl0123 = _mm_insert_epi32(vl0123, (int) xnn_table_exp2minus_k_over_8[(uint32_t) (vidx23 >> 32)], 3); - __m128i vl4567 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[(uint32_t) vidx45]); - vl4567 = _mm_insert_epi32(vl4567, (int) xnn_table_exp2minus_k_over_8[(uint32_t) (vidx45 >> 32)], 1); - vl4567 = _mm_insert_epi32(vl4567, (int) xnn_table_exp2minus_k_over_8[(uint32_t) vidx67], 2); - vl4567 = _mm_insert_epi32(vl4567, (int) xnn_table_exp2minus_k_over_8[(uint32_t) (vidx67 >> 32)], 3); - __m128i vl89AB = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[(uint32_t) vidx89]); - vl89AB = _mm_insert_epi32(vl89AB, (int) xnn_table_exp2minus_k_over_8[(uint32_t) (vidx89 >> 32)], 1); - vl89AB = _mm_insert_epi32(vl89AB, (int) xnn_table_exp2minus_k_over_8[(uint32_t) vidxAB], 2); - vl89AB = _mm_insert_epi32(vl89AB, (int) xnn_table_exp2minus_k_over_8[(uint32_t) (vidxAB >> 32)], 3); - __m128i vlCDEF = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[(uint32_t) vidxCD]); - vlCDEF = _mm_insert_epi32(vlCDEF, (int) xnn_table_exp2minus_k_over_8[(uint32_t) (vidxCD >> 32)], 1); - vlCDEF = _mm_insert_epi32(vlCDEF, (int) xnn_table_exp2minus_k_over_8[(uint32_t) vidxEF], 2); - vlCDEF = _mm_insert_epi32(vlCDEF, (int) xnn_table_exp2minus_k_over_8[(uint32_t) (vidxEF >> 32)], 3); - __m128i vlGHIJ = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[(uint32_t) vidxGH]); - vlGHIJ = _mm_insert_epi32(vlGHIJ, (int) xnn_table_exp2minus_k_over_8[(uint32_t) (vidxGH >> 32)], 1); - vlGHIJ = _mm_insert_epi32(vlGHIJ, (int) xnn_table_exp2minus_k_over_8[(uint32_t) vidxIJ], 2); - vlGHIJ = _mm_insert_epi32(vlGHIJ, (int) xnn_table_exp2minus_k_over_8[(uint32_t) (vidxIJ >> 32)], 3); - #else - const __m128i vidx0123 = _mm_and_si128(_mm_castps_si128(vn0123), vindex_mask); - const __m128i vidx4567 = _mm_and_si128(_mm_castps_si128(vn4567), vindex_mask); - const __m128i vidx89AB = _mm_and_si128(_mm_castps_si128(vn89AB), vindex_mask); - const __m128i vidxCDEF = _mm_and_si128(_mm_castps_si128(vnCDEF), vindex_mask); - const __m128i vidxGHIJ = _mm_and_si128(_mm_castps_si128(vnGHIJ), vindex_mask); - - const uint32_t vidx0 = (uint32_t) _mm_cvtsi128_si32(vidx0123); - const uint32_t vidx4 = (uint32_t) _mm_cvtsi128_si32(vidx4567); - const uint32_t vidx8 = (uint32_t) _mm_cvtsi128_si32(vidx89AB); - const uint32_t vidxC = (uint32_t) _mm_cvtsi128_si32(vidxCDEF); - const uint32_t vidxG = (uint32_t) _mm_cvtsi128_si32(vidxGHIJ); - - __m128i vl0123 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[vidx0]); - __m128i vl4567 = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[vidx4]); - __m128i vl89AB = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[vidx8]); - __m128i vlCDEF = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[vidxC]); - __m128i vlGHIJ = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[vidxG]); - - const uint32_t vidx1 = (uint32_t) _mm_extract_epi16(vidx0123, 2); - const uint32_t vidx5 = (uint32_t) _mm_extract_epi16(vidx4567, 2); - const uint32_t vidx9 = (uint32_t) _mm_extract_epi16(vidx89AB, 2); - const uint32_t vidxD = (uint32_t) _mm_extract_epi16(vidxCDEF, 2); - const uint32_t vidxH = (uint32_t) _mm_extract_epi16(vidxGHIJ, 2); - - vl0123 = _mm_insert_epi32(vl0123, (int) xnn_table_exp2minus_k_over_8[vidx1], 1); - vl4567 = _mm_insert_epi32(vl4567, (int) xnn_table_exp2minus_k_over_8[vidx5], 1); - vl89AB = _mm_insert_epi32(vl89AB, (int) xnn_table_exp2minus_k_over_8[vidx9], 1); - vlCDEF = _mm_insert_epi32(vlCDEF, (int) xnn_table_exp2minus_k_over_8[vidxD], 1); - vlGHIJ = _mm_insert_epi32(vlGHIJ, (int) xnn_table_exp2minus_k_over_8[vidxH], 1); - - const uint32_t vidx2 = (uint32_t) _mm_extract_epi16(vidx0123, 4); - const uint32_t vidx6 = (uint32_t) _mm_extract_epi16(vidx4567, 4); - const uint32_t vidxA = (uint32_t) _mm_extract_epi16(vidx89AB, 4); - const uint32_t vidxE = (uint32_t) _mm_extract_epi16(vidxCDEF, 4); - const uint32_t vidxI = (uint32_t) _mm_extract_epi16(vidxGHIJ, 4); - - vl0123 = _mm_insert_epi32(vl0123, (int) xnn_table_exp2minus_k_over_8[vidx2], 2); - vl4567 = _mm_insert_epi32(vl4567, (int) xnn_table_exp2minus_k_over_8[vidx6], 2); - vl89AB = _mm_insert_epi32(vl89AB, (int) xnn_table_exp2minus_k_over_8[vidxA], 2); - vlCDEF = _mm_insert_epi32(vlCDEF, (int) xnn_table_exp2minus_k_over_8[vidxE], 2); - vlGHIJ = _mm_insert_epi32(vlGHIJ, (int) xnn_table_exp2minus_k_over_8[vidxI], 2); - - const uint32_t vidx3 = (uint32_t) _mm_extract_epi16(vidx0123, 6); - const uint32_t vidx7 = (uint32_t) _mm_extract_epi16(vidx4567, 6); - const uint32_t vidxB = (uint32_t) _mm_extract_epi16(vidx89AB, 6); - const uint32_t vidxF = (uint32_t) _mm_extract_epi16(vidxCDEF, 6); - const uint32_t vidxJ = (uint32_t) _mm_extract_epi16(vidxGHIJ, 6); - - vl0123 = _mm_insert_epi32(vl0123, (int) xnn_table_exp2minus_k_over_8[vidx3], 3); - vl4567 = _mm_insert_epi32(vl4567, (int) xnn_table_exp2minus_k_over_8[vidx7], 3); - vl89AB = _mm_insert_epi32(vl89AB, (int) xnn_table_exp2minus_k_over_8[vidxB], 3); - vlCDEF = _mm_insert_epi32(vlCDEF, (int) xnn_table_exp2minus_k_over_8[vidxF], 3); - vlGHIJ = _mm_insert_epi32(vlGHIJ, (int) xnn_table_exp2minus_k_over_8[vidxJ], 3); - #endif - - const __m128 vs0123 = _mm_castsi128_ps(_mm_add_epi32(vl0123, ve0123)); - const __m128 vs4567 = _mm_castsi128_ps(_mm_add_epi32(vl4567, ve4567)); - const __m128 vs89AB = _mm_castsi128_ps(_mm_add_epi32(vl89AB, ve89AB)); - const __m128 vsCDEF = _mm_castsi128_ps(_mm_add_epi32(vlCDEF, veCDEF)); - const __m128 vsGHIJ = _mm_castsi128_ps(_mm_add_epi32(vlGHIJ, veGHIJ)); - - vn0123 = _mm_sub_ps(vn0123, vmagic_bias); - vn4567 = _mm_sub_ps(vn4567, vmagic_bias); - vn89AB = _mm_sub_ps(vn89AB, vmagic_bias); - vnCDEF = _mm_sub_ps(vnCDEF, vmagic_bias); - vnGHIJ = _mm_sub_ps(vnGHIJ, vmagic_bias); - - const __m128 vt0123 = _mm_add_ps(_mm_mul_ps(vn0123, vminus_ln2), vz0123); - const __m128 vt4567 = _mm_add_ps(_mm_mul_ps(vn4567, vminus_ln2), vz4567); - const __m128 vt89AB = _mm_add_ps(_mm_mul_ps(vn89AB, vminus_ln2), vz89AB); - const __m128 vtCDEF = _mm_add_ps(_mm_mul_ps(vnCDEF, vminus_ln2), vzCDEF); - const __m128 vtGHIJ = _mm_add_ps(_mm_mul_ps(vnGHIJ, vminus_ln2), vzGHIJ); - - __m128 vp0123 = _mm_add_ps(_mm_mul_ps(vc4, vt0123), vc3); - __m128 vp4567 = _mm_add_ps(_mm_mul_ps(vc4, vt4567), vc3); - __m128 vp89AB = _mm_add_ps(_mm_mul_ps(vc4, vt89AB), vc3); - __m128 vpCDEF = _mm_add_ps(_mm_mul_ps(vc4, vtCDEF), vc3); - __m128 vpGHIJ = _mm_add_ps(_mm_mul_ps(vc4, vtGHIJ), vc3); - vp0123 = _mm_add_ps(_mm_mul_ps(vp0123, vt0123), vc2); - vp4567 = _mm_add_ps(_mm_mul_ps(vp4567, vt4567), vc2); - vp89AB = _mm_add_ps(_mm_mul_ps(vp89AB, vt89AB), vc2); - vpCDEF = _mm_add_ps(_mm_mul_ps(vpCDEF, vtCDEF), vc2); - vpGHIJ = _mm_add_ps(_mm_mul_ps(vpGHIJ, vtGHIJ), vc2); - vp0123 = _mm_sub_ps(_mm_mul_ps(vp0123, vt0123), vminus_two); - vp4567 = _mm_sub_ps(_mm_mul_ps(vp4567, vt4567), vminus_two); - vp89AB = _mm_sub_ps(_mm_mul_ps(vp89AB, vt89AB), vminus_two); - vpCDEF = _mm_sub_ps(_mm_mul_ps(vpCDEF, vtCDEF), vminus_two); - vpGHIJ = _mm_sub_ps(_mm_mul_ps(vpGHIJ, vtGHIJ), vminus_two); - - const __m128 vts0123 = _mm_mul_ps(vt0123, vs0123); - const __m128 vsmo0123 = _mm_add_ps(vs0123, vminus_one); - const __m128 vts4567 = _mm_mul_ps(vt4567, vs4567); - const __m128 vsmo4567 = _mm_add_ps(vs4567, vminus_one); - const __m128 vts89AB = _mm_mul_ps(vt89AB, vs89AB); - const __m128 vsmo89AB = _mm_add_ps(vs89AB, vminus_one); - const __m128 vtsCDEF = _mm_mul_ps(vtCDEF, vsCDEF); - const __m128 vsmoCDEF = _mm_add_ps(vsCDEF, vminus_one); - const __m128 vtsGHIJ = _mm_mul_ps(vtGHIJ, vsGHIJ); - const __m128 vsmoGHIJ = _mm_add_ps(vsGHIJ, vminus_one); - const __m128 vemo0123 = _mm_add_ps(_mm_mul_ps(vp0123, vts0123), vsmo0123); - const __m128 vemo4567 = _mm_add_ps(_mm_mul_ps(vp4567, vts4567), vsmo4567); - const __m128 vemo89AB = _mm_add_ps(_mm_mul_ps(vp89AB, vts89AB), vsmo89AB); - const __m128 vemoCDEF = _mm_add_ps(_mm_mul_ps(vpCDEF, vtsCDEF), vsmoCDEF); - const __m128 vemoGHIJ = _mm_add_ps(_mm_mul_ps(vpGHIJ, vtsGHIJ), vsmoGHIJ); - - const __m128 vepo0123 = _mm_sub_ps(vemo0123, vminus_two); - const __m128 vepo4567 = _mm_sub_ps(vemo4567, vminus_two); - const __m128 vepo89AB = _mm_sub_ps(vemo89AB, vminus_two); - const __m128 vepoCDEF = _mm_sub_ps(vemoCDEF, vminus_two); - const __m128 vepoGHIJ = _mm_sub_ps(vemoGHIJ, vminus_two); - - __m128 vy0123 = _mm_div_ps(vemo0123, vepo0123); - __m128 vy4567 = _mm_div_ps(vemo4567, vepo4567); - __m128 vy89AB = _mm_div_ps(vemo89AB, vepo89AB); - __m128 vyCDEF = _mm_div_ps(vemoCDEF, vepoCDEF); - __m128 vyGHIJ = _mm_div_ps(vemoGHIJ, vepoGHIJ); - - - vy0123 = _mm_xor_ps(vy0123, vinvsignx0123); - vy4567 = _mm_xor_ps(vy4567, vinvsignx4567); - vy89AB = _mm_xor_ps(vy89AB, vinvsignx89AB); - vyCDEF = _mm_xor_ps(vyCDEF, vinvsignxCDEF); - vyGHIJ = _mm_xor_ps(vyGHIJ, vinvsignxGHIJ); - - _mm_storeu_ps(output, vy0123); - _mm_storeu_ps(output + 4, vy4567); - _mm_storeu_ps(output + 8, vy89AB); - _mm_storeu_ps(output + 12, vyCDEF); - _mm_storeu_ps(output + 16, vyGHIJ); - output += 20; - } - for (; batch >= 4 * sizeof(float); batch -= 4 * sizeof(float)) { - const __m128 vx = _mm_loadu_ps(input); - input += 4; - - __m128 vz = _mm_or_ps(vx, vsign_mask); - - const __m128 vinvsignx = _mm_xor_ps(vx, vz); - - vz = _mm_max_ps(vsat_cutoff, vz); - - __m128 vn = _mm_add_ps(_mm_mul_ps(vz, vlog2e), vmagic_bias); - - const __m128i ve = _mm_slli_epi32(_mm_castps_si128(vn), 20); - - #if XNN_ARCH_X86_64 - __m128i vidx = _mm_and_si128(_mm_castps_si128(vn), vindex_mask); - const uint64_t vidx_lo = (uint64_t) _mm_cvtsi128_si64(vidx); - vidx = _mm_unpackhi_epi64(vidx, vidx); - const uint64_t vidx_hi = (uint64_t) _mm_cvtsi128_si64(vidx); - __m128i vl = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[(uint32_t) vidx_lo]); - vl = _mm_insert_epi32(vl, (int) xnn_table_exp2minus_k_over_8[(uint32_t) (vidx_lo >> 32)], 1); - vl = _mm_insert_epi32(vl, (int) xnn_table_exp2minus_k_over_8[(uint32_t) vidx_hi], 2); - vl = _mm_insert_epi32(vl, (int) xnn_table_exp2minus_k_over_8[(uint32_t) (vidx_hi >> 32)], 3); - #else - const __m128i vidx = _mm_and_si128(_mm_castps_si128(vn), vindex_mask); - const uint32_t vidx0 = (uint32_t) _mm_cvtsi128_si32(vidx); - __m128i vl = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[vidx0]); - const uint32_t vidx1 = (uint32_t) _mm_extract_epi16(vidx, 2); - vl = _mm_insert_epi32(vl, (int) xnn_table_exp2minus_k_over_8[vidx1], 1); - const uint32_t vidx2 = (uint32_t) _mm_extract_epi16(vidx, 4); - vl = _mm_insert_epi32(vl, (int) xnn_table_exp2minus_k_over_8[vidx2], 2); - const uint32_t vidx3 = (uint32_t) _mm_extract_epi16(vidx, 6); - vl = _mm_insert_epi32(vl, (int) xnn_table_exp2minus_k_over_8[vidx3], 3); - #endif - - const __m128 vs = _mm_castsi128_ps(_mm_add_epi32(vl, ve)); - - vn = _mm_sub_ps(vn, vmagic_bias); - - const __m128 vt = _mm_add_ps(_mm_mul_ps(vn, vminus_ln2), vz); - - __m128 vp = _mm_add_ps(_mm_mul_ps(vc4, vt), vc3); - vp = _mm_add_ps(_mm_mul_ps(vp, vt), vc2); - vp = _mm_sub_ps(_mm_mul_ps(vp, vt), vminus_two); - - const __m128 vts = _mm_mul_ps(vt, vs); - const __m128 vsmo = _mm_add_ps(vs, vminus_one); - const __m128 vemo = _mm_add_ps(_mm_mul_ps(vp, vts), vsmo); - - const __m128 vepo = _mm_sub_ps(vemo, vminus_two); - - __m128 vy = _mm_div_ps(vemo, vepo); - - - vy = _mm_xor_ps(vy, vinvsignx); - - _mm_storeu_ps(output, vy); - output += 4; - } - if XNN_UNLIKELY(batch != 0) { - const __m128 vx = _mm_loadu_ps(input); - - __m128 vz = _mm_or_ps(vx, vsign_mask); - - const __m128 vinvsignx = _mm_xor_ps(vx, vz); - - vz = _mm_max_ps(vsat_cutoff, vz); - - __m128 vn = _mm_add_ps(_mm_mul_ps(vz, vlog2e), vmagic_bias); - - const __m128i ve = _mm_slli_epi32(_mm_castps_si128(vn), 20); - - #if XNN_ARCH_X86_64 - __m128i vidx = _mm_and_si128(_mm_castps_si128(vn), vindex_mask); - const uint64_t vidx_lo = (uint64_t) _mm_cvtsi128_si64(vidx); - vidx = _mm_unpackhi_epi64(vidx, vidx); - const uint64_t vidx_hi = (uint64_t) _mm_cvtsi128_si64(vidx); - __m128i vl = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[(uint32_t) vidx_lo]); - vl = _mm_insert_epi32(vl, (int) xnn_table_exp2minus_k_over_8[(uint32_t) (vidx_lo >> 32)], 1); - vl = _mm_insert_epi32(vl, (int) xnn_table_exp2minus_k_over_8[(uint32_t) vidx_hi], 2); - vl = _mm_insert_epi32(vl, (int) xnn_table_exp2minus_k_over_8[(uint32_t) (vidx_hi >> 32)], 3); - #else - const __m128i vidx = _mm_and_si128(_mm_castps_si128(vn), vindex_mask); - const uint32_t vidx0 = (uint32_t) _mm_cvtsi128_si32(vidx); - __m128i vl = _mm_cvtsi32_si128((int) xnn_table_exp2minus_k_over_8[vidx0]); - const uint32_t vidx1 = (uint32_t) _mm_extract_epi16(vidx, 2); - vl = _mm_insert_epi32(vl, (int) xnn_table_exp2minus_k_over_8[vidx1], 1); - const uint32_t vidx2 = (uint32_t) _mm_extract_epi16(vidx, 4); - vl = _mm_insert_epi32(vl, (int) xnn_table_exp2minus_k_over_8[vidx2], 2); - const uint32_t vidx3 = (uint32_t) _mm_extract_epi16(vidx, 6); - vl = _mm_insert_epi32(vl, (int) xnn_table_exp2minus_k_over_8[vidx3], 3); - #endif - - const __m128 vs = _mm_castsi128_ps(_mm_add_epi32(vl, ve)); - - vn = _mm_sub_ps(vn, vmagic_bias); - - const __m128 vt = _mm_add_ps(_mm_mul_ps(vn, vminus_ln2), vz); - - __m128 vp = _mm_add_ps(_mm_mul_ps(vc4, vt), vc3); - vp = _mm_add_ps(_mm_mul_ps(vp, vt), vc2); - vp = _mm_sub_ps(_mm_mul_ps(vp, vt), vminus_two); - - const __m128 vts = _mm_mul_ps(vt, vs); - const __m128 vsmo = _mm_add_ps(vs, vminus_one); - const __m128 vemo = _mm_add_ps(_mm_mul_ps(vp, vts), vsmo); - - const __m128 vepo = _mm_sub_ps(vemo, vminus_two); - - __m128 vy = _mm_div_ps(vemo, vepo); - - - vy = _mm_xor_ps(vy, vinvsignx); - - if (batch & (2 * sizeof(float))) { - _mm_storel_pi((__m64*) output, vy); - vy = _mm_movehl_ps(vy, vy); - output += 2; - } - if (batch & (1 * sizeof(float))) { - _mm_store_ss(output, vy); - } - } -} - void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x4c8__sse41_ld128( size_t mr, size_t nc, diff --git a/src/configs/unary-elementwise-config.c b/src/configs/unary-elementwise-config.c index 2a0339ec54f..1fe8093968f 100644 --- a/src/configs/unary-elementwise-config.c +++ b/src/configs/unary-elementwise-config.c @@ -1452,9 +1452,9 @@ static void init_f32_tanh_config(void) { f32_tanh_config.init.f32_tanh = xnn_init_f32_tanh_neon_expm1minus_rr1_p6h5_params; f32_tanh_config.element_tile = 8; } else if (!XNN_PLATFORM_MOBILE) { - f32_tanh_config.ukernel = (xnn_vunary_ukernel_fn) xnn_f32_vtanh_ukernel__scalar_expm1minus_rr1_lut8_p4h3ts_div_u4; - f32_tanh_config.init.f32_tanh = xnn_init_f32_tanh_scalar_expm1minus_rr1_lut8_p4h3_params; - f32_tanh_config.element_tile = 4; + f32_tanh_config.ukernel = (xnn_vunary_ukernel_fn) xnn_f32_vtanh_ukernel__scalar_rational_9_6_u1; + f32_tanh_config.init.f32_tanh = NULL; + f32_tanh_config.element_tile = 1; } #elif XNN_ARCH_ARM64 f32_tanh_config.ukernel = (xnn_vunary_ukernel_fn) xnn_f32_vtanh_ukernel__aarch64_neonfma_expm1minus_rr1_p6h5ts_div_u16; @@ -1464,29 +1464,21 @@ static void init_f32_tanh_config(void) { const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512skx) { - f32_tanh_config.ukernel = (xnn_vunary_ukernel_fn) xnn_f32_vtanh_ukernel__avx512skx_expm1minus_rr1_lut4_p4h3ts_perm_div_u64; - f32_tanh_config.init.f32_tanh = xnn_init_f32_tanh_avx512_expm1minus_rr1_lut4_p4h3_perm_params; - f32_tanh_config.element_tile = 64; - } else if (hardware_config->use_x86_avx2) { - f32_tanh_config.ukernel = (xnn_vunary_ukernel_fn) xnn_f32_vtanh_ukernel__avx2_expm1minus_rr1_lut4_p4h3ts_perm_div_u32; - f32_tanh_config.init.f32_tanh = xnn_init_f32_tanh_avx_expm1minus_rr1_lut4_p4h3_perm_params; - f32_tanh_config.element_tile = 32; + f32_tanh_config.ukernel = (xnn_vunary_ukernel_fn) xnn_f32_vtanh_ukernel__avx512f_rational_9_6_nr_u16; + f32_tanh_config.init.f32_tanh = xnn_init_f32_tanh_avx512_rational_9_6_params; + f32_tanh_config.element_tile = 16; } else if (hardware_config->use_x86_fma3) { - f32_tanh_config.ukernel = (xnn_vunary_ukernel_fn) xnn_f32_vtanh_ukernel__fma3_expm1minus_rr1_lut4_p4h3ts_perm_div_u40; - f32_tanh_config.init.f32_tanh = xnn_init_f32_tanh_avx_expm1minus_rr1_lut4_p4h2_perm_params; - f32_tanh_config.element_tile = 40; + f32_tanh_config.ukernel = (xnn_vunary_ukernel_fn) xnn_f32_vtanh_ukernel__fma3_rational_9_6_div_u16; + f32_tanh_config.init.f32_tanh = xnn_init_f32_tanh_fma3_rational_9_6_params; + f32_tanh_config.element_tile = 16; } else if (hardware_config->use_x86_avx) { - f32_tanh_config.ukernel = (xnn_vunary_ukernel_fn) xnn_f32_vtanh_ukernel__avx_expm1minus_rr1_lut4_p4h2ts_perm_div_u48; - f32_tanh_config.init.f32_tanh = xnn_init_f32_tanh_avx_expm1minus_rr1_lut4_p4h2_perm_params; - f32_tanh_config.element_tile = 48; - } else if (hardware_config->use_x86_sse4_1) { - f32_tanh_config.ukernel = (xnn_vunary_ukernel_fn) xnn_f32_vtanh_ukernel__sse41_expm1minus_rr1_lut8_p4h3ts_div_u20; - f32_tanh_config.init.f32_tanh = xnn_init_f32_tanh_sse_expm1minus_rr1_lut8_p4h3_params; - f32_tanh_config.element_tile = 20; - } else { - f32_tanh_config.ukernel = (xnn_vunary_ukernel_fn) xnn_f32_vtanh_ukernel__sse2_expm1minus_rr1_lut8_p4h3ts_div_u16; - f32_tanh_config.init.f32_tanh = xnn_init_f32_tanh_sse_expm1minus_rr1_lut8_p4h3_params; + f32_tanh_config.ukernel = (xnn_vunary_ukernel_fn) xnn_f32_vtanh_ukernel__avx_rational_9_6_div_u16; + f32_tanh_config.init.f32_tanh = xnn_init_f32_tanh_avx_rational_9_6_params; f32_tanh_config.element_tile = 16; + } else { + f32_tanh_config.ukernel = (xnn_vunary_ukernel_fn) xnn_f32_vtanh_ukernel__sse2_rational_9_6_div_u8; + f32_tanh_config.init.f32_tanh = xnn_init_f32_tanh_sse_rational_9_6_params; + f32_tanh_config.element_tile = 8; } #elif XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); @@ -1504,9 +1496,9 @@ static void init_f32_tanh_config(void) { const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); if (hardware_config->is_x86) { - f32_tanh_config.ukernel = (xnn_vunary_ukernel_fn) xnn_f32_vtanh_ukernel__scalar_expm1minus_rr1_lut8_p4h3ts_div_u4; - f32_tanh_config.init.f32_tanh = xnn_init_f32_tanh_scalar_expm1minus_rr1_lut8_p4h3_params; - f32_tanh_config.element_tile = 4; + f32_tanh_config.ukernel = (xnn_vunary_ukernel_fn) xnn_f32_vtanh_ukernel__scalar_rational_9_6_u1; + f32_tanh_config.init.f32_tanh = NULL; + f32_tanh_config.element_tile = 1; } else { f32_tanh_config.ukernel = (xnn_vunary_ukernel_fn) xnn_f32_vtanh_ukernel__wasm_expm1minus_rr1_p6h5ts_div_u4; f32_tanh_config.init.f32_tanh = xnn_init_f32_tanh_scalar_expm1minus_rr1_p6h5_params; @@ -1517,9 +1509,9 @@ static void init_f32_tanh_config(void) { f32_tanh_config.init.f32_tanh = xnn_init_f32_tanh_scalar_expm1minus_rr1_lut8_p4h3_params; f32_tanh_config.element_tile = 4; #else - f32_tanh_config.ukernel = (xnn_vunary_ukernel_fn) xnn_f32_vtanh_ukernel__scalar_expm1minus_rr1_lut8_p4h3ts_div_u4; - f32_tanh_config.init.f32_tanh = xnn_init_f32_tanh_scalar_expm1minus_rr1_lut8_p4h3_params; - f32_tanh_config.element_tile = 4; + f32_tanh_config.ukernel = (xnn_vunary_ukernel_fn) xnn_f32_vtanh_ukernel__scalar_rational_9_6_u1; + f32_tanh_config.init.f32_tanh = NULL; + f32_tanh_config.element_tile = 1; #endif }