Skip to content

Commit

Permalink
[SYSTEMDS-3033] Fix native BLAS tsmm right (nnz compute, lda param)
Browse files Browse the repository at this point in the history
This patch fixes issues with the rather uncommon native BLAS tsmm-right
(e.g., in dist()), where invalid parameters led to incorrect results and
corrupted nnz (and thus index out of bounds exceptions on dense to
sparse conversion). Furthermore, the nnz computation ran over the size
of the allocated output array, causes occasional segmentation faults.
This patch fixes the issues, which also makes the common tsmm-left
sightly faster (less nnz computation).
  • Loading branch information
mboehm7 committed Jun 18, 2021
1 parent 87fd786 commit effb11b
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 21 deletions.
Binary file modified src/main/cpp/lib/libsystemds_mkl-Linux-x86_64.so
Binary file not shown.
Binary file modified src/main/cpp/lib/libsystemds_openblas-Linux-x86_64.so
Binary file not shown.
2 changes: 1 addition & 1 deletion src/main/cpp/libmatrixmult.cpp
Expand Up @@ -54,6 +54,6 @@ void tsmm(double *m1Ptr, double *retPtr, int m1rlen, int m1clen, bool leftTrans,
int n = leftTrans ? m1clen : m1rlen;
int k = leftTrans ? m1rlen : m1clen;
cblas_dsyrk(CblasRowMajor, CblasUpper,
leftTrans ? CblasTrans : CblasNoTrans, n, k, 1, m1Ptr, n, 0, retPtr, n);
leftTrans ? CblasTrans : CblasNoTrans, n, k, 1, m1Ptr, m1clen, 0, retPtr, n);
}
}
42 changes: 22 additions & 20 deletions src/main/cpp/systemds.cpp
Expand Up @@ -95,10 +95,12 @@ JNIEXPORT jlong JNICALL Java_org_apache_sysds_utils_NativeHelper_tsmm
double* m1Ptr = GET_DOUBLE_ARRAY(env, m1, numThreads);
double* retPtr = GET_DOUBLE_ARRAY(env, ret, numThreads);
if(m1Ptr == NULL || retPtr == NULL)
return -1;
return -1;

tsmm(m1Ptr, retPtr, (int)m1rlen, (int)m1clen, (bool)leftTrans, (int)numThreads);
size_t nnz = computeNNZ<double>(retPtr, m1rlen * m1clen);

int n = leftTrans ? m1clen : m1rlen;
size_t nnz = computeNNZ<double>(retPtr, n * n);

RELEASE_INPUT_ARRAY(env, m1, m1Ptr, numThreads);
RELEASE_ARRAY(env, ret, retPtr, numThreads);
Expand Down Expand Up @@ -201,47 +203,47 @@ JNIEXPORT jlong JNICALL Java_org_apache_sysds_utils_NativeHelper_sconv2dBiasAddD
return -1;

size_t nnz = sconv2dBiasAddDense(inputPtr, biasPtr, filterPtr, retPtr, (int) N, (int) C, (int) H, (int) W, (int) K,
(int) R, (int) S, (int) stride_h, (int) stride_w, (int) pad_h, (int) pad_w, (int) P,
(int) R, (int) S, (int) stride_h, (int) stride_w, (int) pad_h, (int) pad_w, (int) P,
(int) Q, true, (int) numThreads);

return static_cast<jlong>(nnz);
}

JNIEXPORT jlong JNICALL Java_org_apache_sysds_utils_NativeHelper_conv2dBackwardDataDense(
JNIEnv* env, jclass, jdoubleArray filter, jdoubleArray dout,
jdoubleArray ret, jint N, jint C, jint H, jint W, jint K, jint R, jint S,
jint stride_h, jint stride_w, jint pad_h, jint pad_w, jint P, jint Q, jint numThreads) {
JNIEnv* env, jclass, jdoubleArray filter, jdoubleArray dout,
jdoubleArray ret, jint N, jint C, jint H, jint W, jint K, jint R, jint S,
jint stride_h, jint stride_w, jint pad_h, jint pad_w, jint P, jint Q, jint numThreads) {

double* filterPtr = GET_DOUBLE_ARRAY(env, filter, numThreads);
double* doutPtr = GET_DOUBLE_ARRAY(env, dout, numThreads);
double* retPtr = GET_DOUBLE_ARRAY(env, ret, numThreads);
if(doutPtr == NULL || filterPtr == NULL || retPtr == NULL)
return -1;
return -1;

size_t nnz = conv2dBackwardDataDense(filterPtr, doutPtr, retPtr, (int) N, (int) C, (int) H, (int) W, (int) K,
(int) R, (int) S, (int) stride_h, (int) stride_w, (int) pad_h, (int) pad_w,
(int) P, (int) Q, (int) numThreads);
(int) R, (int) S, (int) stride_h, (int) stride_w, (int) pad_h, (int) pad_w,
(int) P, (int) Q, (int) numThreads);

RELEASE_INPUT_ARRAY(env, filter, filterPtr, numThreads);
RELEASE_INPUT_ARRAY(env, dout, doutPtr, numThreads);
RELEASE_ARRAY(env, ret, retPtr, numThreads);
return static_cast<jlong>(nnz);
}

JNIEXPORT jlong JNICALL Java_org_apache_sysds_utils_NativeHelper_conv2dBackwardFilterDense(
JNIEnv* env, jclass, jdoubleArray input, jdoubleArray dout,
jdoubleArray ret, jint N, jint C, jint H, jint W, jint K, jint R, jint S,
jint stride_h, jint stride_w, jint pad_h, jint pad_w, jint P, jint Q, jint numThreads) {
JNIEnv* env, jclass, jdoubleArray input, jdoubleArray dout,
jdoubleArray ret, jint N, jint C, jint H, jint W, jint K, jint R, jint S,
jint stride_h, jint stride_w, jint pad_h, jint pad_w, jint P, jint Q, jint numThreads) {
double* inputPtr = GET_DOUBLE_ARRAY(env, input, numThreads);
double* doutPtr = GET_DOUBLE_ARRAY(env, dout, numThreads);
double* retPtr = GET_DOUBLE_ARRAY(env, ret, numThreads);
if(doutPtr == NULL || inputPtr == NULL || retPtr == NULL)
return -1;
return -1;

size_t nnz = conv2dBackwardFilterDense(inputPtr, doutPtr, retPtr, (int)N, (int) C, (int) H, (int) W, (int) K, (int) R,
(int) S, (int) stride_h, (int) stride_w, (int) pad_h, (int) pad_w, (int) P,
(int) Q, (int) numThreads);
(int) S, (int) stride_h, (int) stride_w, (int) pad_h, (int) pad_w, (int) P,
(int) Q, (int) numThreads);

RELEASE_INPUT_ARRAY(env, input, inputPtr, numThreads);
RELEASE_INPUT_ARRAY(env, dout, doutPtr, numThreads);
RELEASE_ARRAY(env, ret, retPtr, numThreads);
Expand Down

0 comments on commit effb11b

Please sign in to comment.