Skip to content

Commit

Permalink
fix sve dtrsm kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
binebrank committed Jan 15, 2022
1 parent 8071e17 commit aaa2b1a
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 80 deletions.
20 changes: 11 additions & 9 deletions kernel/arm64/trsm_kernel_LN_sve.c
Expand Up @@ -182,8 +182,8 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT dummy1,

i = m % sve_size;
if (i) {
aa = a + ((m & ~(i - 1)) - i) * k * COMPSIZE;
cc = c + ((m & ~(i - 1)) - i) * COMPSIZE;
aa = a + (m - i) * k * COMPSIZE;
cc = c + (m - i) * COMPSIZE;

if (k - kk > 0) {
GEMM_KERNEL(i, GEMM_UNROLL_N, k - kk, dm1,
Expand All @@ -205,10 +205,11 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT dummy1,

}

int mod = i;
i = sve_size;
if (i <= m) {
aa = a + ((m & ~(sve_size - 1)) - sve_size) * k * COMPSIZE;
cc = c + ((m & ~(sve_size - 1)) - sve_size) * COMPSIZE;
aa = a + (m - mod - sve_size) * k * COMPSIZE;
cc = c + (m - mod - sve_size) * COMPSIZE;

do {
if (k - kk > 0) {
Expand All @@ -217,7 +218,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT dummy1,
ZERO,
#endif
aa + sve_size * kk * COMPSIZE,
b + sve_size * kk * COMPSIZE,
b + GEMM_UNROLL_N * kk * COMPSIZE,
cc,
ldc);
}
Expand Down Expand Up @@ -251,8 +252,8 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT dummy1,

i = m % sve_size;
if (i) {
aa = a + ((m & ~(i - 1)) - i) * k * COMPSIZE;
cc = c + ((m & ~(i - 1)) - i) * COMPSIZE;
aa = a + (m - i) * k * COMPSIZE;
cc = c + (m - i) * COMPSIZE;

if (k - kk > 0) {
GEMM_KERNEL(i, j, k - kk, dm1,
Expand All @@ -273,10 +274,11 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT dummy1,

}

int mod = i;
i = sve_size;
if (i <= m) {
aa = a + ((m & ~(sve_size - 1)) - sve_size) * k * COMPSIZE;
cc = c + ((m & ~(sve_size - 1)) - sve_size) * COMPSIZE;
aa = a + (m - mod - sve_size) * k * COMPSIZE;
cc = c + (m - mod - sve_size) * COMPSIZE;

do {
if (k - kk > 0) {
Expand Down
2 changes: 1 addition & 1 deletion kernel/arm64/trsm_kernel_LT_sve.c
Expand Up @@ -257,7 +257,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT dummy1,
i += sve_size;
}

i = sve_size % m;
i = m % sve_size;
if (i) {
if (kk > 0) {
GEMM_KERNEL(i, j, kk, dm1,
Expand Down
12 changes: 6 additions & 6 deletions kernel/arm64/trsm_kernel_RT_sve.c
Expand Up @@ -258,23 +258,23 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT dummy1,
if (i <= m) {
do {
if (k - kk > 0) {
GEMM_KERNEL(GEMM_UNROLL_M, GEMM_UNROLL_N, k - kk, dm1,
GEMM_KERNEL(sve_size, GEMM_UNROLL_N, k - kk, dm1,
#ifdef COMPLEX
ZERO,
#endif
aa + GEMM_UNROLL_M * kk * COMPSIZE,
aa + sve_size * kk * COMPSIZE,
b + GEMM_UNROLL_N * kk * COMPSIZE,
cc,
ldc);
}

solve(GEMM_UNROLL_M, GEMM_UNROLL_N,
aa + (kk - GEMM_UNROLL_N) * GEMM_UNROLL_M * COMPSIZE,
solve(sve_size, GEMM_UNROLL_N,
aa + (kk - GEMM_UNROLL_N) * sve_size * COMPSIZE,
b + (kk - GEMM_UNROLL_N) * GEMM_UNROLL_N * COMPSIZE,
cc, ldc);

aa += GEMM_UNROLL_M * k * COMPSIZE;
cc += GEMM_UNROLL_M * COMPSIZE;
aa += sve_size * k * COMPSIZE;
cc += sve_size * COMPSIZE;
i += sve_size;
} while (i <= m);
}
Expand Down
30 changes: 15 additions & 15 deletions kernel/arm64/trsm_lncopy_sve.c
Expand Up @@ -48,17 +48,18 @@

int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT *b){

BLASLONG i, ii, j, jj;
BLASLONG i, ii, jj;

FLOAT *ao;

jj = offset;
int js = 0;
#ifdef DOUBLE
int64_t js = 0;
svint64_t index = svindex_s64(0LL, lda);
svbool_t pn = svwhilelt_b64(js, n);
int n_active = svcntp_b64(svptrue_b64(), pn);
#else
int32_t js = 0;
svint32_t index = svindex_s32(0, lda);
svbool_t pn = svwhilelt_b32(js, n);
int n_active = svcntp_b32(svptrue_b32(), pn);
Expand All @@ -74,25 +75,24 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT
if (ii == jj) {
for (int j = 0; j < n_active; j++) {
for (int k = 0; k < j; k++) {
*(b + j * n_active + k) = *(a + k * lda + j);
*(b + j * n_active + k) = *(ao + k * lda + j);
}
*(b + j * n_active + j) = INV(*(a + j * lda + j));
*(b + j * n_active + j) = INV(*(ao + j * lda + j));
}
}

if (ii > jj) {
for (int j = 0; j < n_active; j++) {
ao += n_active;
b += n_active * n_active;
i += n_active;
ii += n_active;
} else {
if (ii > jj) {
svfloat64_t aj_vec = svld1_gather_index(pn, ao, index);
svst1(pn, b, aj_vec);
ao++;
}

ao++;
b += n_active;
i++;
ii++;
}

b += n_active * n_active;

i += n_active;
ii += n_active;
} while (i < m);


Expand Down
32 changes: 15 additions & 17 deletions kernel/arm64/trsm_ltcopy_sve.c
Expand Up @@ -48,18 +48,17 @@

int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT *b){

BLASLONG i, ii, j, jj;
BLASLONG i, ii, jj;

FLOAT *ao;

jj = offset;
int js = 0;
#ifdef DOUBLE
svint64_t index = svindex_s64(0LL, lda);
int64_t js = 0;
svbool_t pn = svwhilelt_b64(js, n);
int n_active = svcntp_b64(svptrue_b64(), pn);
#else
svint32_t index = svindex_s32(0, lda);
int32_t js = 0;
svbool_t pn = svwhilelt_b32(js, n);
int n_active = svcntp_b32(svptrue_b32(), pn);
#endif
Expand All @@ -73,26 +72,25 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT

if (ii == jj) {
for (int j = 0; j < n_active; j++) {
*(b + j * n_active + j) = INV(*(a + j * lda + j));
*(b + j * n_active + j) = INV(*(ao + j * lda + j));
for (int k = j+1; k < n_active; k++) {
*(b + j * n_active + k) = *(a + j * lda + k);
*(b + j * n_active + k) = *(ao + j * lda + k);
}
}
}

if (ii < jj) {
for (int j = 0; j < n_active; j++) {
b += n_active * n_active;
ao += lda * n_active;
i += n_active;
ii += n_active;
} else {
if (ii < jj) {
svfloat64_t aj_vec = svld1(pn, ao);
svst1(pn, b, aj_vec);
ao += lda;
}

ao += lda;
b += n_active;
i ++;
ii ++;
}

b += n_active * n_active;

i += n_active;
ii += n_active;
} while (i < m);


Expand Down
29 changes: 15 additions & 14 deletions kernel/arm64/trsm_uncopy_sve.c
Expand Up @@ -48,17 +48,18 @@

int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT *b){

BLASLONG i, ii, j, jj;
BLASLONG i, ii, jj;

FLOAT *ao;

jj = offset;
int js = 0;
#ifdef DOUBLE
int64_t js = 0;
svint64_t index = svindex_s64(0LL, lda);
svbool_t pn = svwhilelt_b64(js, n);
int n_active = svcntp_b64(svptrue_b64(), pn);
#else
int32_t js = 0;
svint32_t index = svindex_s32(0, lda);
svbool_t pn = svwhilelt_b32(js, n);
int n_active = svcntp_b32(svptrue_b32(), pn);
Expand All @@ -73,25 +74,25 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT

if (ii == jj) {
for (int j = 0; j < n_active; j++) {
*(b + j * n_active + j) = INV(*(a + j * lda + j));
*(b + j * n_active + j) = INV(*(ao + j * lda + j));
for (int k = j+1; k < n_active; k++) {
*(b + j * n_active + k) = *(a + k * lda + j);
*(b + j * n_active + k) = *(ao + k * lda + j);
}
}
}

if (ii < jj) {
for (int j = 0; j < n_active; j++) {
ao += n_active;
b += n_active * n_active;
i += n_active;
ii += n_active;
} else {
if (ii < jj) {
svfloat64_t aj_vec = svld1_gather_index(pn, ao, index);
svst1(pn, b, aj_vec);
ao++;
}
ao++;
b += n_active;
i++;
ii++;
}

b += n_active * n_active;

i += n_active;
ii += n_active;
} while (i < m);


Expand Down
34 changes: 16 additions & 18 deletions kernel/arm64/trsm_utcopy_sve.c
Expand Up @@ -48,18 +48,17 @@

int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT *b){

BLASLONG i, ii, j, jj;
BLASLONG i, ii, jj;

FLOAT *ao;

jj = offset;
int js = 0;
#ifdef DOUBLE
svint64_t index = svindex_s64(0LL, lda);
int64_t js = 0;
svbool_t pn = svwhilelt_b64(js, n);
int n_active = svcntp_b64(svptrue_b64(), pn);
#else
svint32_t index = svindex_s32(0, lda);
int32_t js = 0;
svbool_t pn = svwhilelt_b32(js, n);
int n_active = svcntp_b32(svptrue_b32(), pn);
#endif
Expand All @@ -74,25 +73,24 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT
if (ii == jj) {
for (int j = 0; j < n_active; j++) {
for (int k = 0; k < j; k++) {
*(b + j * n_active + k) = *(a + j * lda + k);
*(b + j * n_active + k) = *(ao + j * lda + k);
}
*(b + j * n_active + j) = INV(*(a + j * lda + j));
*(b + j * n_active + j) = INV(*(ao + j * lda + j));
}
}

if (ii > jj) {
for (int j = 0; j < n_active; j++) {
ao += lda * n_active;
b += n_active * n_active;
i += n_active;
ii += n_active;
} else {
if (ii > jj) {
svfloat64_t aj_vec = svld1(pn, ao);
svst1(pn, b, aj_vec);
ao += lda;
}

}

b += n_active * n_active;

i += n_active;
ii += n_active;
ao += lda;
b += n_active;
i ++;
ii ++;
}
} while (i < m);


Expand Down

0 comments on commit aaa2b1a

Please sign in to comment.