Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Gelu x86] Finish intrinsic with elempack merged(fast version) #4144

Merged
merged 6 commits into from Sep 18, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
42 changes: 42 additions & 0 deletions src/layer/x86/avx512_mathfun.h
Expand Up @@ -182,6 +182,48 @@ static NCNN_FORCEINLINE __m512 exp512_ps(__m512 x)
return y;
}

_PS512_CONST(tanh_hi, 9.0f);
_PS512_CONST(tanh_lo, -9.0f);

_PS512_CONST(cephes_tanh_p0, -2.76076847742355E-16f);
_PS512_CONST(cephes_tanh_p1, 2.00018790482477E-13f);
_PS512_CONST(cephes_tanh_p2, -8.60467152213735E-11f);
_PS512_CONST(cephes_tanh_p3, 5.12229709037114E-08f);
_PS512_CONST(cephes_tanh_p4, 1.48572235717979E-05f);
_PS512_CONST(cephes_tanh_p5, 6.37261928875436E-04f);
_PS512_CONST(cephes_tanh_p6, 4.89352455891786E-03f);

_PS512_CONST(cephes_tanh_p7, 1.19825839466702e-06f);
_PS512_CONST(cephes_tanh_p8, 1.18534705686654e-04f);
_PS512_CONST(cephes_tanh_p9, 2.26843463243900e-03f);

// an approximation of tanh
static inline __m512 tanh512_ps(const __m512 x)
{
__m512 value = x;
value = _mm512_max_ps(*(__m512*)_ps512_tanh_lo, value);
value = _mm512_min_ps(*(__m512*)_ps512_tanh_hi, value);

__m512 value_squared = _mm512_mul_ps(value, value);

__m512 p;
p = _mm512_fmadd_ps(value_squared, *(__m512*)_ps512_cephes_tanh_p0, *(__m512*)_ps512_cephes_tanh_p1);
p = _mm512_fmadd_ps(p, value_squared, *(__m512*)_ps512_cephes_tanh_p2);
p = _mm512_fmadd_ps(p, value_squared, *(__m512*)_ps512_cephes_tanh_p3);
p = _mm512_fmadd_ps(p, value_squared, *(__m512*)_ps512_cephes_tanh_p4);
p = _mm512_fmadd_ps(p, value_squared, *(__m512*)_ps512_cephes_tanh_p5);
p = _mm512_fmadd_ps(p, value_squared, *(__m512*)_ps512_cephes_tanh_p6);
p = _mm512_mul_ps(p, value);

__m512 q;
q = _mm512_fmadd_ps(value_squared, *(__m512*)_ps512_cephes_tanh_p7, *(__m512*)_ps512_cephes_tanh_p8);
q = _mm512_fmadd_ps(q, value_squared, *(__m512*)_ps512_cephes_tanh_p9);
q = _mm512_fmadd_ps(q, value_squared, *(__m512*)_ps512_cephes_tanh_p6);

__m512 dst = _mm512_div_ps(p, q);
return dst;
}

_PS512_CONST(minus_cephes_DP1, -0.78515625f);
_PS512_CONST(minus_cephes_DP2, -2.4187564849853515625e-4f);
_PS512_CONST(minus_cephes_DP3, -3.77489497744594108e-8f);
Expand Down
42 changes: 42 additions & 0 deletions src/layer/x86/avx_mathfun.h
Expand Up @@ -295,6 +295,48 @@ static NCNN_FORCEINLINE __m256 exp256_ps(__m256 x)
return y;
}

_PS256_CONST(tanh_hi, 9.0f);
_PS256_CONST(tanh_lo, -9.0f);

_PS256_CONST(cephes_tanh_p0, -2.76076847742355E-16f);
_PS256_CONST(cephes_tanh_p1, 2.00018790482477E-13f);
_PS256_CONST(cephes_tanh_p2, -8.60467152213735E-11f);
_PS256_CONST(cephes_tanh_p3, 5.12229709037114E-08f);
_PS256_CONST(cephes_tanh_p4, 1.48572235717979E-05f);
_PS256_CONST(cephes_tanh_p5, 6.37261928875436E-04f);
_PS256_CONST(cephes_tanh_p6, 4.89352455891786E-03f);

_PS256_CONST(cephes_tanh_p7, 1.19825839466702e-06f);
_PS256_CONST(cephes_tanh_p8, 1.18534705686654e-04f);
_PS256_CONST(cephes_tanh_p9, 2.26843463243900e-03f);

// an approximation of tanh
static inline __m256 tanh256_ps(const __m256 x)
{
__m256 value = x;
value = _mm256_max_ps(*(__m256*)_ps256_tanh_lo, value);
value = _mm256_min_ps(*(__m256*)_ps256_tanh_hi, value);

__m256 value_squared = _mm256_mul_ps(value, value);

__m256 p;
p = _mm256_comp_fmadd_ps(value_squared, *(__m256*)_ps256_cephes_tanh_p0, *(__m256*)_ps256_cephes_tanh_p1);
p = _mm256_comp_fmadd_ps(p, value_squared, *(__m256*)_ps256_cephes_tanh_p2);
p = _mm256_comp_fmadd_ps(p, value_squared, *(__m256*)_ps256_cephes_tanh_p3);
p = _mm256_comp_fmadd_ps(p, value_squared, *(__m256*)_ps256_cephes_tanh_p4);
p = _mm256_comp_fmadd_ps(p, value_squared, *(__m256*)_ps256_cephes_tanh_p5);
p = _mm256_comp_fmadd_ps(p, value_squared, *(__m256*)_ps256_cephes_tanh_p6);
p = _mm256_mul_ps(p, value);

__m256 q;
q = _mm256_comp_fmadd_ps(value_squared, *(__m256*)_ps256_cephes_tanh_p7, *(__m256*)_ps256_cephes_tanh_p8);
q = _mm256_comp_fmadd_ps(q, value_squared, *(__m256*)_ps256_cephes_tanh_p9);
q = _mm256_comp_fmadd_ps(q, value_squared, *(__m256*)_ps256_cephes_tanh_p6);

__m256 dst = _mm256_div_ps(p, q);
return dst;
}

_PS256_CONST(minus_cephes_DP1, -0.78515625f);
_PS256_CONST(minus_cephes_DP2, -2.4187564849853515625e-4f);
_PS256_CONST(minus_cephes_DP3, -3.77489497744594108e-8f);
Expand Down
141 changes: 141 additions & 0 deletions src/layer/x86/gelu_x86.cpp
@@ -0,0 +1,141 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
LRY89757 marked this conversation as resolved.
Show resolved Hide resolved
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "gelu_x86.h"

#if __SSE2__
#include <emmintrin.h>
#include "sse_mathfun.h"
#if __AVX__
#include <immintrin.h>
#include "avx_mathfun.h"
#if __AVX512F__
#include "avx512_mathfun.h"
#endif // __AVX512F__
#endif // __AVX__
#endif // __SSE2__

namespace ncnn {

GELU_x86::GELU_x86()
{
#if __SSE2__
support_packing = true;
#endif // __SSE2__
}

int GELU_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
{
LRY89757 marked this conversation as resolved.
Show resolved Hide resolved
int w = bottom_top_blob.w;
int h = bottom_top_blob.h;
int elempack = bottom_top_blob.elempack;
int channels = bottom_top_blob.c;
int size = w * h * elempack;

#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < channels; q++)
{
float* ptr = bottom_top_blob.channel(q);

int i = 0;

#if __SSE2__
__m128 _half128 = _mm_set1_ps(0.5f);
__m128 _one128 = _mm_set1_ps(1.f);
__m128 _fast1c128 = _mm_set1_ps(0.79788452f);
__m128 _fast2c128 = _mm_set1_ps(0.044715f);
#if __AVX__
__m256 _half256 = _mm256_set1_ps(0.5f);
__m256 _one256 = _mm256_set1_ps(1.f);
__m256 _fast1c256 = _mm256_set1_ps(0.79788452f);
__m256 _fast2c256 = _mm256_set1_ps(0.044715f);
LRY89757 marked this conversation as resolved.
Show resolved Hide resolved
#if __AVX512F__
__m512 _half512 = _mm512_set1_ps(0.5f);
__m512 _one512 = _mm512_set1_ps(1.f);
__m512 _fast1c512 = _mm512_set1_ps(0.79788452f);
__m512 _fast2c512 = _mm512_set1_ps(0.044715f);

for (; i + 15 < size; i += 16)
{
__m512 _pLoad = _mm512_loadu_ps(ptr);

__m512 _cube = _mm512_mul_ps(_pLoad, _pLoad);
_cube = _mm512_mul_ps(_pLoad, _cube);

__m512 _blob = _mm512_mul_ps(_fast2c512, _cube);
_blob = _mm512_add_ps(_pLoad, _blob);
_blob = _mm512_mul_ps(_fast1c512, _blob);
_blob = tanh512_ps(_blob);
_blob = _mm512_add_ps(_one512, _blob);

_blob = _mm512_mul_ps(_half512, _mm512_mul_ps(_blob, _pLoad));

_mm512_storeu_ps(ptr, _blob);

ptr += 16;
}
#endif // __AVX512F__
for (; i + 7 < size; i += 8)
{
__m256 _pLoad = _mm256_loadu_ps(ptr);

__m256 _cube = _mm256_mul_ps(_pLoad, _pLoad);
_cube = _mm256_mul_ps(_pLoad, _cube);

__m256 _blob = _mm256_mul_ps(_fast2c256, _cube);
_blob = _mm256_add_ps(_pLoad, _blob);
_blob = _mm256_mul_ps(_fast1c256, _blob);
_blob = tanh256_ps(_blob);
_blob = _mm256_add_ps(_one256, _blob);

_blob = _mm256_mul_ps(_half256, _mm256_mul_ps(_blob, _pLoad));

_mm256_storeu_ps(ptr, _blob);

ptr += 8;
}
#endif // __AVX__
for (; i + 3 < size; i += 4)
{
__m128 _pLoad = _mm_loadu_ps(ptr);

__m128 _cube = _mm_mul_ps(_pLoad, _pLoad);
_cube = _mm_mul_ps(_pLoad, _cube);

__m128 _blob = _mm_mul_ps(_fast2c128, _cube);
_blob = _mm_add_ps(_pLoad, _blob);
_blob = _mm_mul_ps(_fast1c128, _blob);
_blob = tanh_ps(_blob);
_blob = _mm_add_ps(_one128, _blob);

_blob = _mm_mul_ps(_half128, _mm_mul_ps(_blob, _pLoad));

_mm_storeu_ps(ptr, _blob);

ptr += 4;
}
#endif // __SSE2__
for (; i < size; i++)
{
// y = 0.5x * (1 + tanh(sqrt(2/Pi) * (x + 0.044715x^3)))
*ptr = 0.5f * *ptr * (1.0f + tanhf(0.79788452f * (*ptr + 0.044715f * *ptr * *ptr * *ptr)));

ptr++;
}
}

return 0;
}

} // namespace ncnn
32 changes: 32 additions & 0 deletions src/layer/x86/gelu_x86.h
@@ -0,0 +1,32 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
LRY89757 marked this conversation as resolved.
Show resolved Hide resolved
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#ifndef LAYER_GELU_X86_H
#define LAYER_GELU_X86_H

#include "gelu.h"

namespace ncnn {

class GELU_x86 : virtual public GELU
{
public:
GELU_x86();

virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const;
};

} // namespace ncnn

#endif // LAYER_GELU_X86_H
41 changes: 41 additions & 0 deletions src/layer/x86/sse_mathfun.h
Expand Up @@ -286,6 +286,47 @@ static NCNN_FORCEINLINE v4sf exp_ps(v4sf x)
return y;
}

_PS_CONST(tanh_hi, 9.0f);
_PS_CONST(tanh_lo, -9.0f);

_PS_CONST(cephes_tanh_p0, -2.76076847742355E-16f);
_PS_CONST(cephes_tanh_p1, 2.00018790482477E-13f);
_PS_CONST(cephes_tanh_p2, -8.60467152213735E-11f);
_PS_CONST(cephes_tanh_p3, 5.12229709037114E-08f);
_PS_CONST(cephes_tanh_p4, 1.48572235717979E-05f);
_PS_CONST(cephes_tanh_p5, 6.37261928875436E-04f);
_PS_CONST(cephes_tanh_p6, 4.89352455891786E-03f);
_PS_CONST(cephes_tanh_p7, 1.19825839466702e-06f);
_PS_CONST(cephes_tanh_p8, 1.18534705686654e-04f);
_PS_CONST(cephes_tanh_p9, 2.26843463243900e-03f);

// an approximation of tanh
static inline v4sf tanh_ps(const v4sf x)
{
v4sf value = x;
value = _mm_max_ps(*(v4sf*)_ps_tanh_lo, value);
value = _mm_min_ps(*(v4sf*)_ps_tanh_hi, value);

v4sf value_squared = _mm_mul_ps(value, value);

v4sf p;
p = _mm_comp_fmadd_ps(value_squared, *(v4sf*)_ps_cephes_tanh_p0, *(v4sf*)_ps_cephes_tanh_p1);
p = _mm_comp_fmadd_ps(p, value_squared, *(v4sf*)_ps_cephes_tanh_p2);
p = _mm_comp_fmadd_ps(p, value_squared, *(v4sf*)_ps_cephes_tanh_p3);
p = _mm_comp_fmadd_ps(p, value_squared, *(v4sf*)_ps_cephes_tanh_p4);
p = _mm_comp_fmadd_ps(p, value_squared, *(v4sf*)_ps_cephes_tanh_p5);
p = _mm_comp_fmadd_ps(p, value_squared, *(v4sf*)_ps_cephes_tanh_p6);
p = _mm_mul_ps(p, value);

v4sf q;
q = _mm_comp_fmadd_ps(value_squared, *(v4sf*)_ps_cephes_tanh_p7, *(v4sf*)_ps_cephes_tanh_p8);
q = _mm_comp_fmadd_ps(q, value_squared, *(v4sf*)_ps_cephes_tanh_p9);
q = _mm_comp_fmadd_ps(q, value_squared, *(v4sf*)_ps_cephes_tanh_p6);

v4sf dst = _mm_div_ps(p, q);
return dst;
}

_PS_CONST(minus_cephes_DP1, -0.78515625f);
_PS_CONST(minus_cephes_DP2, -2.4187564849853515625e-4f);
_PS_CONST(minus_cephes_DP3, -3.77489497744594108e-8f);
Expand Down
8 changes: 7 additions & 1 deletion tests/test_gelu.cpp
Expand Up @@ -34,6 +34,8 @@ static int test_gelu(const ncnn::Mat& a, bool fast_gelu)
static int test_gelu_0()
{
return 0
|| test_gelu(RandomMat(9, 7, 32), false)
|| test_gelu(RandomMat(9, 7, 32), true)
|| test_gelu(RandomMat(5, 7, 24), false)
|| test_gelu(RandomMat(5, 7, 24), true)
|| test_gelu(RandomMat(7, 9, 12), false)
Expand All @@ -45,6 +47,8 @@ static int test_gelu_0()
static int test_gelu_1()
{
return 0
|| test_gelu(RandomMat(13, 32), false)
|| test_gelu(RandomMat(13, 32), true)
|| test_gelu(RandomMat(15, 24), false)
|| test_gelu(RandomMat(15, 24), true)
|| test_gelu(RandomMat(17, 12), false)
Expand All @@ -61,7 +65,9 @@ static int test_gelu_2()
|| test_gelu(RandomMat(124), false)
|| test_gelu(RandomMat(124), true)
|| test_gelu(RandomMat(127), false)
|| test_gelu(RandomMat(127), true);
|| test_gelu(RandomMat(127), true)
|| test_gelu(RandomMat(120), false)
|| test_gelu(RandomMat(120), true);
}

int main()
Expand Down