Skip to content

Commit

Permalink
Switch to the new rational_9_6 microkernels for f32-vtanh.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 629396487
  • Loading branch information
gonnet authored and xnnpack-bot committed May 6, 2024
1 parent 7189ed9 commit 5be5be8
Show file tree
Hide file tree
Showing 9 changed files with 492 additions and 1,762 deletions.
350 changes: 100 additions & 250 deletions src/amalgam/gen/avx.c

Large diffs are not rendered by default.

219 changes: 0 additions & 219 deletions src/amalgam/gen/avx2.c
Expand Up @@ -4,7 +4,6 @@
// LICENSE file in the root directory of this source tree.

#include <assert.h>
#include <stddef.h>
#include <stdint.h>
#include <string.h>

Expand All @@ -17,7 +16,6 @@
#include <xnnpack/intrinsics-polyfill.h>
#include <xnnpack/lut.h>
#include <xnnpack/math.h>
#include <xnnpack/microparams.h>
#include <xnnpack/packw.h>
#include <xnnpack/pavgpool.h>
#include <xnnpack/prefetch.h>
Expand Down Expand Up @@ -2972,223 +2970,6 @@ void xnn_f32_vsigmoid_ukernel__avx2_rr1_p5_div_u40(
}
}

void xnn_f32_vtanh_ukernel__avx2_expm1minus_rr1_lut4_p4h3ts_perm_div_u32(
size_t batch,
const float* input,
float* output,
const union xnn_f32_tanh_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
{
assert(batch != 0);
assert(batch % sizeof(float) == 0);
assert(input != NULL);
assert(output != NULL);

const __m256 vsign_mask = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.sign_mask);
const __m256 vsat_cutoff = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.sat_cutoff);
const __m256 vlog2e = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.log2e);
const __m256 vmagic_bias = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.magic_bias);
const __m256 vtable = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.table);
const __m256 vminus_ln2 = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.minus_ln2);
const __m256 vc4 = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.c4);
const __m256 vc3 = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.c3);
const __m256 vc2 = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.c2);
const __m256 vtwo = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.two);
const __m256 vminus_one = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.minus_one);

for (; batch >= 32 * sizeof(float); batch -= 32 * sizeof(float)) {
const __m256 vx0 = _mm256_loadu_ps(input);
const __m256 vx1 = _mm256_loadu_ps(input + 8);
const __m256 vx2 = _mm256_loadu_ps(input + 16);
const __m256 vx3 = _mm256_loadu_ps(input + 24);
input += 32;

__m256 vz0 = _mm256_or_ps(vx0, vsign_mask);
__m256 vz1 = _mm256_or_ps(vx1, vsign_mask);
__m256 vz2 = _mm256_or_ps(vx2, vsign_mask);
__m256 vz3 = _mm256_or_ps(vx3, vsign_mask);

const __m256 vinvsignx0 = _mm256_xor_ps(vx0, vz0);
vz0 = _mm256_max_ps(vsat_cutoff, vz0);
const __m256 vinvsignx1 = _mm256_xor_ps(vx1, vz1);
vz1 = _mm256_max_ps(vsat_cutoff, vz1);
const __m256 vinvsignx2 = _mm256_xor_ps(vx2, vz2);
vz2 = _mm256_max_ps(vsat_cutoff, vz2);
const __m256 vinvsignx3 = _mm256_xor_ps(vx3, vz3);
vz3 = _mm256_max_ps(vsat_cutoff, vz3);

__m256 vn0 = _mm256_fmadd_ps(vz0, vlog2e, vmagic_bias);
__m256 vn1 = _mm256_fmadd_ps(vz1, vlog2e, vmagic_bias);
__m256 vn2 = _mm256_fmadd_ps(vz2, vlog2e, vmagic_bias);
__m256 vn3 = _mm256_fmadd_ps(vz3, vlog2e, vmagic_bias);

const __m256i ve0 = _mm256_slli_epi32(_mm256_castps_si256(vn0), 21);
const __m256i ve1 = _mm256_slli_epi32(_mm256_castps_si256(vn1), 21);
const __m256i ve2 = _mm256_slli_epi32(_mm256_castps_si256(vn2), 21);
const __m256i ve3 = _mm256_slli_epi32(_mm256_castps_si256(vn3), 21);

const __m256i vl0 = _mm256_castps_si256(_mm256_permutevar_ps(vtable, _mm256_castps_si256(vn0)));
const __m256i vl1 = _mm256_castps_si256(_mm256_permutevar_ps(vtable, _mm256_castps_si256(vn1)));
const __m256i vl2 = _mm256_castps_si256(_mm256_permutevar_ps(vtable, _mm256_castps_si256(vn2)));
const __m256i vl3 = _mm256_castps_si256(_mm256_permutevar_ps(vtable, _mm256_castps_si256(vn3)));

const __m256 vs0 = _mm256_castsi256_ps(_mm256_add_epi32(vl0, ve0));
const __m256 vs1 = _mm256_castsi256_ps(_mm256_add_epi32(vl1, ve1));
const __m256 vs2 = _mm256_castsi256_ps(_mm256_add_epi32(vl2, ve2));
const __m256 vs3 = _mm256_castsi256_ps(_mm256_add_epi32(vl3, ve3));

vn0 = _mm256_sub_ps(vn0, vmagic_bias);
vn1 = _mm256_sub_ps(vn1, vmagic_bias);
vn2 = _mm256_sub_ps(vn2, vmagic_bias);
vn3 = _mm256_sub_ps(vn3, vmagic_bias);

const __m256 vt0 = _mm256_fmadd_ps(vn0, vminus_ln2, vz0);
const __m256 vt1 = _mm256_fmadd_ps(vn1, vminus_ln2, vz1);
const __m256 vt2 = _mm256_fmadd_ps(vn2, vminus_ln2, vz2);
const __m256 vt3 = _mm256_fmadd_ps(vn3, vminus_ln2, vz3);

__m256 vp0 = vc4;
__m256 vp1 = vc4;
__m256 vp2 = vc4;
__m256 vp3 = vc4;
vp0 = _mm256_fmadd_ps(vp0, vt0, vc3);
vp1 = _mm256_fmadd_ps(vp1, vt1, vc3);
vp2 = _mm256_fmadd_ps(vp2, vt2, vc3);
vp3 = _mm256_fmadd_ps(vp3, vt3, vc3);
vp0 = _mm256_fmadd_ps(vp0, vt0, vc2);
vp1 = _mm256_fmadd_ps(vp1, vt1, vc2);
vp2 = _mm256_fmadd_ps(vp2, vt2, vc2);
vp3 = _mm256_fmadd_ps(vp3, vt3, vc2);
vp0 = _mm256_fmadd_ps(vp0, vt0, vtwo);
vp1 = _mm256_fmadd_ps(vp1, vt1, vtwo);
vp2 = _mm256_fmadd_ps(vp2, vt2, vtwo);
vp3 = _mm256_fmadd_ps(vp3, vt3, vtwo);

const __m256 vts0 = _mm256_mul_ps(vt0, vs0);
const __m256 vsmo0 = _mm256_add_ps(vs0, vminus_one);
const __m256 vts1 = _mm256_mul_ps(vt1, vs1);
const __m256 vsmo1 = _mm256_add_ps(vs1, vminus_one);
const __m256 vts2 = _mm256_mul_ps(vt2, vs2);
const __m256 vsmo2 = _mm256_add_ps(vs2, vminus_one);
const __m256 vts3 = _mm256_mul_ps(vt3, vs3);
const __m256 vsmo3 = _mm256_add_ps(vs3, vminus_one);
const __m256 vemo0 = _mm256_fmadd_ps(vp0, vts0, vsmo0);
const __m256 vemo1 = _mm256_fmadd_ps(vp1, vts1, vsmo1);
const __m256 vemo2 = _mm256_fmadd_ps(vp2, vts2, vsmo2);
const __m256 vemo3 = _mm256_fmadd_ps(vp3, vts3, vsmo3);
const __m256 vepo0 = _mm256_add_ps(vemo0, vtwo);
const __m256 vepo1 = _mm256_add_ps(vemo1, vtwo);
const __m256 vepo2 = _mm256_add_ps(vemo2, vtwo);
const __m256 vepo3 = _mm256_add_ps(vemo3, vtwo);

__m256 vy0 = _mm256_div_ps(vemo0, vepo0);
__m256 vy1 = _mm256_div_ps(vemo1, vepo1);
__m256 vy2 = _mm256_div_ps(vemo2, vepo2);
__m256 vy3 = _mm256_div_ps(vemo3, vepo3);

vy0 = _mm256_xor_ps(vy0, vinvsignx0);
vy1 = _mm256_xor_ps(vy1, vinvsignx1);
vy2 = _mm256_xor_ps(vy2, vinvsignx2);
vy3 = _mm256_xor_ps(vy3, vinvsignx3);

_mm256_storeu_ps(output, vy0);
_mm256_storeu_ps(output + 8, vy1);
_mm256_storeu_ps(output + 16, vy2);
_mm256_storeu_ps(output + 24, vy3);
output += 32;
}
for (; batch >= 8 * sizeof(float); batch -= 8 * sizeof(float)) {
const __m256 vx = _mm256_loadu_ps(input);
input += 8;

__m256 vz = _mm256_or_ps(vx, vsign_mask);

const __m256 vinvsignx = _mm256_xor_ps(vx, vz);
vz = _mm256_max_ps(vsat_cutoff, vz);

__m256 vn = _mm256_fmadd_ps(vz, vlog2e, vmagic_bias);

const __m256i ve = _mm256_slli_epi32(_mm256_castps_si256(vn), 21);

const __m256i vl = _mm256_castps_si256(_mm256_permutevar_ps(vtable, _mm256_castps_si256(vn)));

const __m256 vs = _mm256_castsi256_ps(_mm256_add_epi32(vl, ve));

vn = _mm256_sub_ps(vn, vmagic_bias);

const __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vz);

__m256 vp = vc4;
vp = _mm256_fmadd_ps(vp, vt, vc3);
vp = _mm256_fmadd_ps(vp, vt, vc2);
vp = _mm256_fmadd_ps(vp, vt, vtwo);

const __m256 vts = _mm256_mul_ps(vt, vs);
const __m256 vsmo = _mm256_add_ps(vs, vminus_one);
const __m256 vemo = _mm256_fmadd_ps(vp, vts, vsmo);
const __m256 vepo = _mm256_add_ps(vemo, vtwo);

__m256 vy = _mm256_div_ps(vemo, vepo);

vy = _mm256_xor_ps(vy, vinvsignx);

_mm256_storeu_ps(output, vy);
output += 8;
}
if XNN_UNLIKELY(batch != 0) {
assert(batch >= 1 * sizeof(float));
assert(batch <= 7 * sizeof(float));
const __m256i vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) &params->avx_expm1minus_rr1_lut4_p4h3_perm.mask_table[7] - batch));

const __m256 vx = _mm256_maskload_ps(input, vmask);

__m256 vz = _mm256_or_ps(vx, vsign_mask);

const __m256 vinvsignx = _mm256_xor_ps(vx, vz);
vz = _mm256_max_ps(vsat_cutoff, vz);

__m256 vn = _mm256_fmadd_ps(vz, vlog2e, vmagic_bias);

const __m256i ve = _mm256_slli_epi32(_mm256_castps_si256(vn), 21);

const __m256i vl = _mm256_castps_si256(_mm256_permutevar_ps(vtable, _mm256_castps_si256(vn)));

const __m256 vs = _mm256_castsi256_ps(_mm256_add_epi32(vl, ve));

vn = _mm256_sub_ps(vn, vmagic_bias);

const __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vz);

__m256 vp = vc4;
vp = _mm256_fmadd_ps(vp, vt, vc3);
vp = _mm256_fmadd_ps(vp, vt, vc2);
vp = _mm256_fmadd_ps(vp, vt, vtwo);

const __m256 vts = _mm256_mul_ps(vt, vs);
const __m256 vsmo = _mm256_add_ps(vs, vminus_one);
const __m256 vemo = _mm256_fmadd_ps(vp, vts, vsmo);
const __m256 vepo = _mm256_add_ps(vemo, vtwo);

__m256 vy = _mm256_div_ps(vemo, vepo);

vy = _mm256_xor_ps(vy, vinvsignx);

__m128 vy_lo = _mm256_castps256_ps128(vy);
if (batch & (4 * sizeof(float))) {
_mm_storeu_ps(output, vy_lo);
vy_lo = _mm256_extractf128_ps(vy, 1);
output += 4;
}
if (batch & (2 * sizeof(float))) {
_mm_storel_pi((__m64*) output, vy_lo);
vy_lo = _mm_movehl_ps(vy_lo, vy_lo);
output += 2;
}
if (batch & (1 * sizeof(float))) {
_mm_store_ss(output, vy_lo);
}
}
}

void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx2(
size_t mr,
size_t nc,
Expand Down
103 changes: 103 additions & 0 deletions src/amalgam/gen/avx512f.c
Expand Up @@ -15,6 +15,7 @@
#include <xnnpack/igemm.h>
#include <xnnpack/intrinsics-polyfill.h>
#include <xnnpack/math.h>
#include <xnnpack/microparams.h>
#include <xnnpack/packw.h>
#include <xnnpack/prefetch.h>
#include <xnnpack/prelu.h>
Expand Down Expand Up @@ -3922,6 +3923,108 @@ void xnn_f32_vsqrt_ukernel__avx512f_rsqrt_u16(
}
}

void xnn_f32_vtanh_ukernel__avx512f_rational_9_6_nr_u16(
size_t batch,
const float* input,
float* output,
const union xnn_f32_tanh_params params[restrict XNN_MIN_ELEMENTS(1)])
{
assert(batch != 0);
assert(batch % sizeof(float) == 0);
assert(input != NULL);
assert(output != NULL);

// Cap the inputs to this value as `tanh(x)` will always be `+/-1.0f` beyond
// this point. This value is chosen as the first floating point number as of
// which the interpolation returns 1.0f.
const __m512 vmax_x = _mm512_set1_ps(params->avx512_rational_9_6.max_abs_x);
const __m512 vmin_x = _mm512_set1_ps(-params->avx512_rational_9_6.max_abs_x);

// The monomial coefficients of the numerator polynomial (odd).
const __m512 valpha_1 = _mm512_set1_ps(params->avx512_rational_9_6.alpha_1);
const __m512 valpha_3 = _mm512_set1_ps(params->avx512_rational_9_6.alpha_3);
const __m512 valpha_5 = _mm512_set1_ps(params->avx512_rational_9_6.alpha_5);
const __m512 valpha_7 = _mm512_set1_ps(params->avx512_rational_9_6.alpha_7);
const __m512 valpha_9 = _mm512_set1_ps(params->avx512_rational_9_6.alpha_9);

// The monomial coefficients of the denominator polynomial (even).
const __m512 vbeta_0 = _mm512_set1_ps(params->avx512_rational_9_6.beta_0);
const __m512 vbeta_2 = _mm512_set1_ps(params->avx512_rational_9_6.beta_2);
const __m512 vbeta_4 = _mm512_set1_ps(params->avx512_rational_9_6.beta_4);
const __m512 vbeta_6 = _mm512_set1_ps(params->avx512_rational_9_6.beta_6);

// Constant needed for the Newton-Raphson iteration of the reciprocal.
const __m512 vtwo = _mm512_set1_ps(params->avx512_rational_9_6.two);

for (; batch >= 16 * sizeof(float); batch -= 16 * sizeof(float)) {
__m512 vx = _mm512_loadu_ps(input);
input += 16;

// Clamp the inputs to the interpolation range.
vx = _mm512_min_ps(vmax_x, vx);
vx = _mm512_max_ps(vmin_x, vx);

// Since the polynomials are odd/even, we need x^2.
const __m512 vx2 = _mm512_mul_ps(vx, vx);

// Evaluate the numerator polynomial p.
__m512 vp = _mm512_fmadd_ps(vx2, valpha_9, valpha_7);
vp = _mm512_fmadd_ps(vx2, vp, valpha_5);
vp = _mm512_fmadd_ps(vx2, vp, valpha_3);
vp = _mm512_fmadd_ps(vx2, vp, valpha_1);
vp = _mm512_mul_ps(vx, vp);

// Evaluate the denominator polynomial q.
__m512 vq = _mm512_fmadd_ps(vx2, vbeta_6, vbeta_4);
vq = _mm512_fmadd_ps(vx2, vq, vbeta_2);
vq = _mm512_fmadd_ps(vx2, vq, vbeta_0);

// Divide the numerator by the denominator.
const __m512 vt0 = _mm512_rcp14_ps(vq);
const __m512 vt1 = _mm512_mul_ps(vt0, _mm512_fnmadd_ps(vt0, vq, vtwo));
const __m512 vy = _mm512_mul_ps(vp, vt1);

_mm512_storeu_ps(output, vy);
output += 16;
}
if XNN_UNLIKELY(batch != 0) {
assert(batch >= 1 * sizeof(float));
assert(batch <= 15 * sizeof(float));

// Prepare mask for valid 32-bit elements (depends on batch).
batch >>= XNN_LOG2_SIZEOF_FLOAT;
const __mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1)));

__m512 vx = _mm512_maskz_loadu_ps(vmask, input);

// Clamp the inputs to the interpolation range.
vx = _mm512_min_ps(vmax_x, vx);
vx = _mm512_max_ps(vmin_x, vx);

// Since the polynomials are odd/even, we need x^2.
const __m512 vx2 = _mm512_mul_ps(vx, vx);

// Evaluate the numerator polynomial p.
__m512 vp = _mm512_fmadd_ps(vx2, valpha_9, valpha_7);
vp = _mm512_fmadd_ps(vx2, vp, valpha_5);
vp = _mm512_fmadd_ps(vx2, vp, valpha_3);
vp = _mm512_fmadd_ps(vx2, vp, valpha_1);
vp = _mm512_mul_ps(vx, vp);

// Evaluate the denominator polynomial q.
__m512 vq = _mm512_fmadd_ps(vx2, vbeta_6, vbeta_4);
vq = _mm512_fmadd_ps(vx2, vq, vbeta_2);
vq = _mm512_fmadd_ps(vx2, vq, vbeta_0);

// Divide the numerator by the denominator.
const __m512 vt0 = _mm512_rcp14_ps(vq);
const __m512 vt1 = _mm512_mul_ps(vt0, _mm512_fnmadd_ps(vt0, vq, vtwo));
const __m512 vy = _mm512_mul_ps(vp, vt1);

_mm512_mask_storeu_ps(output, vmask, vy);
}
}

void xnn_f32_vabs_ukernel__avx512f_u16(
size_t batch,
const float* input,
Expand Down

0 comments on commit 5be5be8

Please sign in to comment.