|
4 | 4 | // LICENSE file in the root directory of this source tree.
|
5 | 5 |
|
6 | 6 | #include <assert.h>
|
7 |
| -#include <stddef.h> |
8 | 7 | #include <stdint.h>
|
9 | 8 | #include <string.h>
|
10 | 9 |
|
|
17 | 16 | #include <xnnpack/intrinsics-polyfill.h>
|
18 | 17 | #include <xnnpack/lut.h>
|
19 | 18 | #include <xnnpack/math.h>
|
20 |
| -#include <xnnpack/microparams.h> |
21 | 19 | #include <xnnpack/packw.h>
|
22 | 20 | #include <xnnpack/pavgpool.h>
|
23 | 21 | #include <xnnpack/prefetch.h>
|
@@ -2972,223 +2970,6 @@ void xnn_f32_vsigmoid_ukernel__avx2_rr1_p5_div_u40(
|
2972 | 2970 | }
|
2973 | 2971 | }
|
2974 | 2972 |
|
2975 |
| -void xnn_f32_vtanh_ukernel__avx2_expm1minus_rr1_lut4_p4h3ts_perm_div_u32( |
2976 |
| - size_t batch, |
2977 |
| - const float* input, |
2978 |
| - float* output, |
2979 |
| - const union xnn_f32_tanh_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS |
2980 |
| -{ |
2981 |
| - assert(batch != 0); |
2982 |
| - assert(batch % sizeof(float) == 0); |
2983 |
| - assert(input != NULL); |
2984 |
| - assert(output != NULL); |
2985 |
| - |
2986 |
| - const __m256 vsign_mask = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.sign_mask); |
2987 |
| - const __m256 vsat_cutoff = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.sat_cutoff); |
2988 |
| - const __m256 vlog2e = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.log2e); |
2989 |
| - const __m256 vmagic_bias = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.magic_bias); |
2990 |
| - const __m256 vtable = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.table); |
2991 |
| - const __m256 vminus_ln2 = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.minus_ln2); |
2992 |
| - const __m256 vc4 = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.c4); |
2993 |
| - const __m256 vc3 = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.c3); |
2994 |
| - const __m256 vc2 = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.c2); |
2995 |
| - const __m256 vtwo = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.two); |
2996 |
| - const __m256 vminus_one = _mm256_load_ps(params->avx_expm1minus_rr1_lut4_p4h3_perm.minus_one); |
2997 |
| - |
2998 |
| - for (; batch >= 32 * sizeof(float); batch -= 32 * sizeof(float)) { |
2999 |
| - const __m256 vx0 = _mm256_loadu_ps(input); |
3000 |
| - const __m256 vx1 = _mm256_loadu_ps(input + 8); |
3001 |
| - const __m256 vx2 = _mm256_loadu_ps(input + 16); |
3002 |
| - const __m256 vx3 = _mm256_loadu_ps(input + 24); |
3003 |
| - input += 32; |
3004 |
| - |
3005 |
| - __m256 vz0 = _mm256_or_ps(vx0, vsign_mask); |
3006 |
| - __m256 vz1 = _mm256_or_ps(vx1, vsign_mask); |
3007 |
| - __m256 vz2 = _mm256_or_ps(vx2, vsign_mask); |
3008 |
| - __m256 vz3 = _mm256_or_ps(vx3, vsign_mask); |
3009 |
| - |
3010 |
| - const __m256 vinvsignx0 = _mm256_xor_ps(vx0, vz0); |
3011 |
| - vz0 = _mm256_max_ps(vsat_cutoff, vz0); |
3012 |
| - const __m256 vinvsignx1 = _mm256_xor_ps(vx1, vz1); |
3013 |
| - vz1 = _mm256_max_ps(vsat_cutoff, vz1); |
3014 |
| - const __m256 vinvsignx2 = _mm256_xor_ps(vx2, vz2); |
3015 |
| - vz2 = _mm256_max_ps(vsat_cutoff, vz2); |
3016 |
| - const __m256 vinvsignx3 = _mm256_xor_ps(vx3, vz3); |
3017 |
| - vz3 = _mm256_max_ps(vsat_cutoff, vz3); |
3018 |
| - |
3019 |
| - __m256 vn0 = _mm256_fmadd_ps(vz0, vlog2e, vmagic_bias); |
3020 |
| - __m256 vn1 = _mm256_fmadd_ps(vz1, vlog2e, vmagic_bias); |
3021 |
| - __m256 vn2 = _mm256_fmadd_ps(vz2, vlog2e, vmagic_bias); |
3022 |
| - __m256 vn3 = _mm256_fmadd_ps(vz3, vlog2e, vmagic_bias); |
3023 |
| - |
3024 |
| - const __m256i ve0 = _mm256_slli_epi32(_mm256_castps_si256(vn0), 21); |
3025 |
| - const __m256i ve1 = _mm256_slli_epi32(_mm256_castps_si256(vn1), 21); |
3026 |
| - const __m256i ve2 = _mm256_slli_epi32(_mm256_castps_si256(vn2), 21); |
3027 |
| - const __m256i ve3 = _mm256_slli_epi32(_mm256_castps_si256(vn3), 21); |
3028 |
| - |
3029 |
| - const __m256i vl0 = _mm256_castps_si256(_mm256_permutevar_ps(vtable, _mm256_castps_si256(vn0))); |
3030 |
| - const __m256i vl1 = _mm256_castps_si256(_mm256_permutevar_ps(vtable, _mm256_castps_si256(vn1))); |
3031 |
| - const __m256i vl2 = _mm256_castps_si256(_mm256_permutevar_ps(vtable, _mm256_castps_si256(vn2))); |
3032 |
| - const __m256i vl3 = _mm256_castps_si256(_mm256_permutevar_ps(vtable, _mm256_castps_si256(vn3))); |
3033 |
| - |
3034 |
| - const __m256 vs0 = _mm256_castsi256_ps(_mm256_add_epi32(vl0, ve0)); |
3035 |
| - const __m256 vs1 = _mm256_castsi256_ps(_mm256_add_epi32(vl1, ve1)); |
3036 |
| - const __m256 vs2 = _mm256_castsi256_ps(_mm256_add_epi32(vl2, ve2)); |
3037 |
| - const __m256 vs3 = _mm256_castsi256_ps(_mm256_add_epi32(vl3, ve3)); |
3038 |
| - |
3039 |
| - vn0 = _mm256_sub_ps(vn0, vmagic_bias); |
3040 |
| - vn1 = _mm256_sub_ps(vn1, vmagic_bias); |
3041 |
| - vn2 = _mm256_sub_ps(vn2, vmagic_bias); |
3042 |
| - vn3 = _mm256_sub_ps(vn3, vmagic_bias); |
3043 |
| - |
3044 |
| - const __m256 vt0 = _mm256_fmadd_ps(vn0, vminus_ln2, vz0); |
3045 |
| - const __m256 vt1 = _mm256_fmadd_ps(vn1, vminus_ln2, vz1); |
3046 |
| - const __m256 vt2 = _mm256_fmadd_ps(vn2, vminus_ln2, vz2); |
3047 |
| - const __m256 vt3 = _mm256_fmadd_ps(vn3, vminus_ln2, vz3); |
3048 |
| - |
3049 |
| - __m256 vp0 = vc4; |
3050 |
| - __m256 vp1 = vc4; |
3051 |
| - __m256 vp2 = vc4; |
3052 |
| - __m256 vp3 = vc4; |
3053 |
| - vp0 = _mm256_fmadd_ps(vp0, vt0, vc3); |
3054 |
| - vp1 = _mm256_fmadd_ps(vp1, vt1, vc3); |
3055 |
| - vp2 = _mm256_fmadd_ps(vp2, vt2, vc3); |
3056 |
| - vp3 = _mm256_fmadd_ps(vp3, vt3, vc3); |
3057 |
| - vp0 = _mm256_fmadd_ps(vp0, vt0, vc2); |
3058 |
| - vp1 = _mm256_fmadd_ps(vp1, vt1, vc2); |
3059 |
| - vp2 = _mm256_fmadd_ps(vp2, vt2, vc2); |
3060 |
| - vp3 = _mm256_fmadd_ps(vp3, vt3, vc2); |
3061 |
| - vp0 = _mm256_fmadd_ps(vp0, vt0, vtwo); |
3062 |
| - vp1 = _mm256_fmadd_ps(vp1, vt1, vtwo); |
3063 |
| - vp2 = _mm256_fmadd_ps(vp2, vt2, vtwo); |
3064 |
| - vp3 = _mm256_fmadd_ps(vp3, vt3, vtwo); |
3065 |
| - |
3066 |
| - const __m256 vts0 = _mm256_mul_ps(vt0, vs0); |
3067 |
| - const __m256 vsmo0 = _mm256_add_ps(vs0, vminus_one); |
3068 |
| - const __m256 vts1 = _mm256_mul_ps(vt1, vs1); |
3069 |
| - const __m256 vsmo1 = _mm256_add_ps(vs1, vminus_one); |
3070 |
| - const __m256 vts2 = _mm256_mul_ps(vt2, vs2); |
3071 |
| - const __m256 vsmo2 = _mm256_add_ps(vs2, vminus_one); |
3072 |
| - const __m256 vts3 = _mm256_mul_ps(vt3, vs3); |
3073 |
| - const __m256 vsmo3 = _mm256_add_ps(vs3, vminus_one); |
3074 |
| - const __m256 vemo0 = _mm256_fmadd_ps(vp0, vts0, vsmo0); |
3075 |
| - const __m256 vemo1 = _mm256_fmadd_ps(vp1, vts1, vsmo1); |
3076 |
| - const __m256 vemo2 = _mm256_fmadd_ps(vp2, vts2, vsmo2); |
3077 |
| - const __m256 vemo3 = _mm256_fmadd_ps(vp3, vts3, vsmo3); |
3078 |
| - const __m256 vepo0 = _mm256_add_ps(vemo0, vtwo); |
3079 |
| - const __m256 vepo1 = _mm256_add_ps(vemo1, vtwo); |
3080 |
| - const __m256 vepo2 = _mm256_add_ps(vemo2, vtwo); |
3081 |
| - const __m256 vepo3 = _mm256_add_ps(vemo3, vtwo); |
3082 |
| - |
3083 |
| - __m256 vy0 = _mm256_div_ps(vemo0, vepo0); |
3084 |
| - __m256 vy1 = _mm256_div_ps(vemo1, vepo1); |
3085 |
| - __m256 vy2 = _mm256_div_ps(vemo2, vepo2); |
3086 |
| - __m256 vy3 = _mm256_div_ps(vemo3, vepo3); |
3087 |
| - |
3088 |
| - vy0 = _mm256_xor_ps(vy0, vinvsignx0); |
3089 |
| - vy1 = _mm256_xor_ps(vy1, vinvsignx1); |
3090 |
| - vy2 = _mm256_xor_ps(vy2, vinvsignx2); |
3091 |
| - vy3 = _mm256_xor_ps(vy3, vinvsignx3); |
3092 |
| - |
3093 |
| - _mm256_storeu_ps(output, vy0); |
3094 |
| - _mm256_storeu_ps(output + 8, vy1); |
3095 |
| - _mm256_storeu_ps(output + 16, vy2); |
3096 |
| - _mm256_storeu_ps(output + 24, vy3); |
3097 |
| - output += 32; |
3098 |
| - } |
3099 |
| - for (; batch >= 8 * sizeof(float); batch -= 8 * sizeof(float)) { |
3100 |
| - const __m256 vx = _mm256_loadu_ps(input); |
3101 |
| - input += 8; |
3102 |
| - |
3103 |
| - __m256 vz = _mm256_or_ps(vx, vsign_mask); |
3104 |
| - |
3105 |
| - const __m256 vinvsignx = _mm256_xor_ps(vx, vz); |
3106 |
| - vz = _mm256_max_ps(vsat_cutoff, vz); |
3107 |
| - |
3108 |
| - __m256 vn = _mm256_fmadd_ps(vz, vlog2e, vmagic_bias); |
3109 |
| - |
3110 |
| - const __m256i ve = _mm256_slli_epi32(_mm256_castps_si256(vn), 21); |
3111 |
| - |
3112 |
| - const __m256i vl = _mm256_castps_si256(_mm256_permutevar_ps(vtable, _mm256_castps_si256(vn))); |
3113 |
| - |
3114 |
| - const __m256 vs = _mm256_castsi256_ps(_mm256_add_epi32(vl, ve)); |
3115 |
| - |
3116 |
| - vn = _mm256_sub_ps(vn, vmagic_bias); |
3117 |
| - |
3118 |
| - const __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vz); |
3119 |
| - |
3120 |
| - __m256 vp = vc4; |
3121 |
| - vp = _mm256_fmadd_ps(vp, vt, vc3); |
3122 |
| - vp = _mm256_fmadd_ps(vp, vt, vc2); |
3123 |
| - vp = _mm256_fmadd_ps(vp, vt, vtwo); |
3124 |
| - |
3125 |
| - const __m256 vts = _mm256_mul_ps(vt, vs); |
3126 |
| - const __m256 vsmo = _mm256_add_ps(vs, vminus_one); |
3127 |
| - const __m256 vemo = _mm256_fmadd_ps(vp, vts, vsmo); |
3128 |
| - const __m256 vepo = _mm256_add_ps(vemo, vtwo); |
3129 |
| - |
3130 |
| - __m256 vy = _mm256_div_ps(vemo, vepo); |
3131 |
| - |
3132 |
| - vy = _mm256_xor_ps(vy, vinvsignx); |
3133 |
| - |
3134 |
| - _mm256_storeu_ps(output, vy); |
3135 |
| - output += 8; |
3136 |
| - } |
3137 |
| - if XNN_UNLIKELY(batch != 0) { |
3138 |
| - assert(batch >= 1 * sizeof(float)); |
3139 |
| - assert(batch <= 7 * sizeof(float)); |
3140 |
| - const __m256i vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) ¶ms->avx_expm1minus_rr1_lut4_p4h3_perm.mask_table[7] - batch)); |
3141 |
| - |
3142 |
| - const __m256 vx = _mm256_maskload_ps(input, vmask); |
3143 |
| - |
3144 |
| - __m256 vz = _mm256_or_ps(vx, vsign_mask); |
3145 |
| - |
3146 |
| - const __m256 vinvsignx = _mm256_xor_ps(vx, vz); |
3147 |
| - vz = _mm256_max_ps(vsat_cutoff, vz); |
3148 |
| - |
3149 |
| - __m256 vn = _mm256_fmadd_ps(vz, vlog2e, vmagic_bias); |
3150 |
| - |
3151 |
| - const __m256i ve = _mm256_slli_epi32(_mm256_castps_si256(vn), 21); |
3152 |
| - |
3153 |
| - const __m256i vl = _mm256_castps_si256(_mm256_permutevar_ps(vtable, _mm256_castps_si256(vn))); |
3154 |
| - |
3155 |
| - const __m256 vs = _mm256_castsi256_ps(_mm256_add_epi32(vl, ve)); |
3156 |
| - |
3157 |
| - vn = _mm256_sub_ps(vn, vmagic_bias); |
3158 |
| - |
3159 |
| - const __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vz); |
3160 |
| - |
3161 |
| - __m256 vp = vc4; |
3162 |
| - vp = _mm256_fmadd_ps(vp, vt, vc3); |
3163 |
| - vp = _mm256_fmadd_ps(vp, vt, vc2); |
3164 |
| - vp = _mm256_fmadd_ps(vp, vt, vtwo); |
3165 |
| - |
3166 |
| - const __m256 vts = _mm256_mul_ps(vt, vs); |
3167 |
| - const __m256 vsmo = _mm256_add_ps(vs, vminus_one); |
3168 |
| - const __m256 vemo = _mm256_fmadd_ps(vp, vts, vsmo); |
3169 |
| - const __m256 vepo = _mm256_add_ps(vemo, vtwo); |
3170 |
| - |
3171 |
| - __m256 vy = _mm256_div_ps(vemo, vepo); |
3172 |
| - |
3173 |
| - vy = _mm256_xor_ps(vy, vinvsignx); |
3174 |
| - |
3175 |
| - __m128 vy_lo = _mm256_castps256_ps128(vy); |
3176 |
| - if (batch & (4 * sizeof(float))) { |
3177 |
| - _mm_storeu_ps(output, vy_lo); |
3178 |
| - vy_lo = _mm256_extractf128_ps(vy, 1); |
3179 |
| - output += 4; |
3180 |
| - } |
3181 |
| - if (batch & (2 * sizeof(float))) { |
3182 |
| - _mm_storel_pi((__m64*) output, vy_lo); |
3183 |
| - vy_lo = _mm_movehl_ps(vy_lo, vy_lo); |
3184 |
| - output += 2; |
3185 |
| - } |
3186 |
| - if (batch & (1 * sizeof(float))) { |
3187 |
| - _mm_store_ss(output, vy_lo); |
3188 |
| - } |
3189 |
| - } |
3190 |
| -} |
3191 |
| - |
3192 | 2973 | void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx2(
|
3193 | 2974 | size_t mr,
|
3194 | 2975 | size_t nc,
|
|
0 commit comments