Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions kernel/x86_64/sgemv_n_4.c
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,6 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
FLOAT * xbuffer_align = x;
FLOAT * ybuffer_align = y;

FLOAT * xbuffer = NULL;
FLOAT * ybuffer = NULL;

if (inc_x != 1) {
xbuffer_align = buffer;
for(BLASLONG i=0; i<n; i++) {
Expand Down
2 changes: 1 addition & 1 deletion kernel/x86_64/sgemv_t_4.c
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "sgemv_t_microk_haswell-4.c"
#elif defined (SKYLAKEX) || defined (COOPERLAKE)
#include "sgemv_t_microk_haswell-4.c"
/*#include "sgemv_t_microk_skylakex.c"*/
#include "sgemv_t_microk_skylakex.c"
#endif

#if defined(STEAMROLLER) || defined(EXCAVATOR)
Expand Down
27 changes: 14 additions & 13 deletions kernel/x86_64/sgemv_t_microk_skylakex_template.c
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ static int sgemv_kernel_t_1(BLASLONG m, float alpha, float *a, float *x, float *
}

if (tag_m_32x != m) {
for (BLASLONG idx_m = tag_m_64x; idx_m < tag_m_16x; idx_m+=32) {
for (BLASLONG idx_m = tag_m_32x; idx_m < tag_m_16x; idx_m+=16) {
matrixArray_0 = _mm512_loadu_ps(&a[idx_m + 0]);

_mm512_storeu_ps(&y[idx_m + 0], _mm512_fmadd_ps(matrixArray_0, ALPHAXVECTOR, _mm512_loadu_ps(&y[idx_m + 0])));
Expand Down Expand Up @@ -145,8 +145,8 @@ static int sgemv_kernel_t_2(BLASLONG m, float alpha, float *a, float *x, float *
}
if (tag_m_32x != m) {
for (BLASLONG idx_m = tag_m_32x; idx_m < tag_m_16x; idx_m+=16) {
m0 = _mm512_loadu_ps(&a[idx_m]);
m1 = _mm512_loadu_ps(&a[idx_m + 16]);
m0 = _mm512_loadu_ps(&a[idx_m*2]);
m1 = _mm512_loadu_ps(&a[idx_m*2 + 16]);
col1_1 = _mm512_permutex2var_ps(m0, idx_base_0, m1);
col1_2 = _mm512_permutex2var_ps(m0, idx_base_1, m1);
_mm512_storeu_ps(&y[idx_m], _mm512_add_ps(_mm512_fmadd_ps(x2Array, col1_2, _mm512_mul_ps(col1_1, x1Array)), _mm512_loadu_ps(&y[idx_m])));
Expand All @@ -157,7 +157,7 @@ static int sgemv_kernel_t_2(BLASLONG m, float alpha, float *a, float *x, float *
__mmask8 load_mask = *((__mmask8*) &load_mask_value);
x1Array = _mm512_broadcast_f32x2(_mm_maskz_loadu_ps(load_mask, x));
for (BLASLONG idx_m = tag_m_16x; idx_m < tag_m_8x; idx_m+=8) {
m0 = _mm512_loadu_ps(&a[idx_m]);
m0 = _mm512_loadu_ps(&a[idx_m*2]);
m1 = _mm512_mul_ps(_mm512_mul_ps(m0, x1Array), ALPHAVECTOR);
m2 = _mm512_permutexvar_ps(_mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0), m1);
__m256 ret = _mm256_add_ps(_mm512_extractf32x8_ps(m2, 1), _mm512_extractf32x8_ps(m2, 0));
Expand All @@ -166,12 +166,12 @@ static int sgemv_kernel_t_2(BLASLONG m, float alpha, float *a, float *x, float *
}

if (tag_m_8x != m) {
unsigned short tail_mask_value = (((unsigned int)0xffff) >> (16-((m-tag_m_8x)*2)&15));
unsigned short tail_mask_value = (((unsigned int)0xffff) >> (16-(((m-tag_m_8x)*2)&15)));
__mmask16 a_mask = *((__mmask16*) &tail_mask_value);
unsigned char y_mask_value = (((unsigned char)0xff) >> (8-(m-tag_m_8x)));
__mmask8 y_mask = *((__mmask8*) &y_mask_value);

m0 = _mm512_maskz_loadu_ps(a_mask, &a[tag_m_8x]);
m0 = _mm512_maskz_loadu_ps(a_mask, &a[tag_m_8x*2]);
m1 = _mm512_mul_ps(_mm512_mul_ps(m0, x1Array), ALPHAVECTOR);
m2 = _mm512_permutexvar_ps(_mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0), m1);
__m256 ret = _mm256_add_ps(_mm512_extractf32x8_ps(m2, 1), _mm512_extractf32x8_ps(m2, 0));
Expand Down Expand Up @@ -322,7 +322,7 @@ static int sgemv_kernel_t_4(BLASLONG m, float alpha, float *a, float *x, float *
{
BLASLONG tag_m_4x = m & (~3);
BLASLONG tag_m_2x = m & (~1);
__m512 m0, m1, m2;
__m512 m0, m1;
__m256 m256_0, m256_1, c256_1, c256_2;
__m128 c1, c2, c3, c4, ret;
__m128 xarray = _mm_maskz_loadu_ps(0x0f, x);
Expand All @@ -346,7 +346,7 @@ static int sgemv_kernel_t_4(BLASLONG m, float alpha, float *a, float *x, float *
c3 = _mm256_extractf32x4_ps(c256_2, 0);
c4 = _mm256_extractf32x4_ps(c256_2, 1);

ret = _mm_maskz_add_ps(0xff, _mm_maskz_add_ps(0xff, _mm_maskz_add_ps(0xff, c1, c2), _mm_maskz_add_ps(0xff, c3, c4)), _mm_maskz_loadu_ps(0xff, y));
ret = _mm_maskz_add_ps(0xff, _mm_maskz_add_ps(0xff, _mm_maskz_add_ps(0xff, c1, c2), _mm_maskz_add_ps(0xff, c3, c4)), _mm_maskz_loadu_ps(0xff, &y[idx_m]));
_mm_mask_storeu_ps(&y[idx_m], 0xff, ret);
}

Expand Down Expand Up @@ -958,6 +958,7 @@ static int sgemv_kernel_t_7(BLASLONG m, float alpha, float *a, float *x, float *
c256_1 = _mm512_extractf32x8_ps(tmp0, 1);

c256_0 = _mm256_add_ps(c256_0, c256_1);
c256_0 = _mm256_mul_ps(c256_0, alpha256);

__m128 c128_0 = _mm256_extractf32x4_ps(c256_0, 0);
__m128 c128_1 = _mm256_extractf32x4_ps(c256_0, 1);
Expand Down Expand Up @@ -1016,9 +1017,10 @@ static int sgemv_kernel_t_8(BLASLONG m, float alpha, float *a, float *x, float *
__m512 m0, m1, m2, m3;
__m256 r0, r1, r2, r3, r4, r5, r6, r7, tmp0, tmp1, tmp2, tmp3;
__m128 c128_0, c128_1, c128_2, c128_3;
__m128 alpha128 = _mm_set1_ps(alpha);
__m256 alpha256 = _mm256_set1_ps(alpha);

__m256 x256 = _mm256_loadu_ps(x);
x256 = _mm256_mul_ps(x256, alpha256);
__m512 x512 = _mm512_broadcast_f32x8(x256);

for(BLASLONG idx_m=0; idx_m<tag_m_8x; idx_m+=8) {
Expand Down Expand Up @@ -1053,8 +1055,8 @@ static int sgemv_kernel_t_8(BLASLONG m, float alpha, float *a, float *x, float *

c128_0 = _mm_add_ps(c128_0, c128_1);
c128_2 = _mm_add_ps(c128_2, c128_3);
_mm_storeu_ps(&y[idx_m], _mm_fmadd_ps(c128_0, alpha128, _mm_loadu_ps(&y[idx_m])));
_mm_storeu_ps(&y[idx_m+4], _mm_fmadd_ps(c128_2, alpha128, _mm_loadu_ps(&y[idx_m+4])));
_mm_storeu_ps(&y[idx_m], _mm_add_ps(c128_0, _mm_loadu_ps(&y[idx_m])));
_mm_storeu_ps(&y[idx_m+4], _mm_add_ps(c128_2, _mm_loadu_ps(&y[idx_m+4])));
}

if (tag_m_8x !=m ){
Expand All @@ -1078,7 +1080,7 @@ static int sgemv_kernel_t_8(BLASLONG m, float alpha, float *a, float *x, float *
c128_1 = _mm256_extractf32x4_ps(tmp1, 1);

c128_0 = _mm_add_ps(c128_0, c128_1);
_mm_storeu_ps(&y[idx_m], _mm_fmadd_ps(c128_0, alpha128, _mm_loadu_ps(&y[idx_m])));
_mm_storeu_ps(&y[idx_m], _mm_add_ps(c128_0, _mm_loadu_ps(&y[idx_m])));

}

Expand All @@ -1094,7 +1096,6 @@ static int sgemv_kernel_t_8(BLASLONG m, float alpha, float *a, float *x, float *
c128_1 = _mm256_extractf32x4_ps(tmp0, 1);

c128_0 = _mm_add_ps(c128_0, c128_1);
c128_0 = _mm_mul_ps(c128_0, alpha128);

_mm_storeu_ps(ret, c128_0);
y[idx_m] += (ret[0]+ret[1]);
Expand Down