Skip to content

Commit

Permalink
FEAT: sparse-sparse add/sub support
Browse files Browse the repository at this point in the history
* CPU backend:
    - Use mkl sparse-sparse add/sub when available
    - Fallback cpu implementation: has support for mul and div too

* CUDA backend sparse-sparse add/sub
* OpenCL backend sparse-sparse add/sub/div/mul

The output of sub/div/mul is not guaranteed to have only
non-zero results of the arithmetic operation. The user has
to take care of pruning the zero results from the output.
  • Loading branch information
9prady9 committed Oct 29, 2018
1 parent e3f1466 commit 83b9101
Show file tree
Hide file tree
Showing 15 changed files with 706 additions and 65 deletions.
84 changes: 51 additions & 33 deletions src/api/c/binary.cpp
Expand Up @@ -35,6 +35,14 @@ static inline af_array arithOp(const af_array lhs, const af_array rhs,
return res;
}

template<typename T, af_op_t op>
static inline
af_array sparseArithOp(const af_array lhs, const af_array rhs)
{
auto res = arithOp<T, op>(castSparse<T>(lhs), castSparse<T>(rhs));
return getHandle(res);
}

template<typename T, af_op_t op>
static inline af_array arithSparseDenseOp(const af_array lhs, const af_array rhs,
const bool reverse)
Expand Down Expand Up @@ -80,10 +88,11 @@ static af_err af_arith(af_array *out, const af_array lhs, const af_array rhs, co
}

template<af_op_t op>
static af_err af_arith_real(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode)
static
af_err af_arith_real(af_array *out, const af_array lhs, const af_array rhs,
const bool batchMode)
{
try {

const ArrayInfo& linfo = getInfo(lhs);
const ArrayInfo& rinfo = getInfo(rhs);

Expand Down Expand Up @@ -111,38 +120,39 @@ static af_err af_arith_real(af_array *out, const af_array lhs, const af_array rh
return AF_SUCCESS;
}

//template<af_op_t op>
//static af_err af_arith_sparse(af_array *out, const af_array lhs, const af_array rhs)
//{
// try {
// SparseArrayBase linfo = getSparseArrayBase(lhs);
// SparseArrayBase rinfo = getSparseArrayBase(rhs);
//
// dim4 odims = getOutDims(linfo.dims(), rinfo.dims(), batchMode);
//
// const af_dtype otype = implicit(linfo.getType(), rinfo.getType());
// af_array res;
// switch (otype) {
// case f32: res = arithOp<float , op>(lhs, rhs, odims); break;
// case f64: res = arithOp<double , op>(lhs, rhs, odims); break;
// case c32: res = arithOp<cfloat , op>(lhs, rhs, odims); break;
// case c64: res = arithOp<cdouble, op>(lhs, rhs, odims); break;
// default: TYPE_ERROR(0, otype);
// }
//
// std::swap(*out, res);
// }
// CATCHALL;
// return AF_SUCCESS;
//}
template<af_op_t op>
static af_err
af_arith_sparse(af_array *out, const af_array lhs, const af_array rhs)
{
try {
common::SparseArrayBase linfo = getSparseArrayBase(lhs);
common::SparseArrayBase rinfo = getSparseArrayBase(rhs);

ARG_ASSERT(1, (linfo.getStorage()==rinfo.getStorage()));

const af_dtype otype = implicit(linfo.getType(), rinfo.getType());
af_array res;
switch (otype) {
case f32: res = sparseArithOp<float , op>(lhs, rhs); break;
case f64: res = sparseArithOp<double , op>(lhs, rhs); break;
case c32: res = sparseArithOp<cfloat , op>(lhs, rhs); break;
case c64: res = sparseArithOp<cdouble, op>(lhs, rhs); break;
default: TYPE_ERROR(0, otype);
}

std::swap(*out, res);
}
CATCHALL;
return AF_SUCCESS;
}

template<af_op_t op>
static af_err af_arith_sparse_dense(af_array *out, const af_array lhs, const af_array rhs,
const bool reverse = false)
{
using namespace common;
try {
SparseArrayBase linfo = getSparseArrayBase(lhs);
common::SparseArrayBase linfo = getSparseArrayBase(lhs);
ArrayInfo rinfo = getInfo(rhs);

const af_dtype otype = implicit(linfo.getType(), rinfo.getType());
Expand All @@ -161,18 +171,20 @@ static af_err af_arith_sparse_dense(af_array *out, const af_array lhs, const af_
return AF_SUCCESS;
}

af_err af_add(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode)
af_err af_add(af_array *out, const af_array lhs, const af_array rhs,
const bool batchMode)
{
// Check if inputs are sparse
ArrayInfo linfo = getInfo(lhs, false, true);
ArrayInfo rinfo = getInfo(rhs, false, true);

if(linfo.isSparse() && rinfo.isSparse()) {
return AF_ERR_NOT_SUPPORTED; //af_arith_sparse<af_add_t>(out, lhs, rhs);
return af_arith_sparse<af_add_t>(out, lhs, rhs);
} else if(linfo.isSparse() && !rinfo.isSparse()) {
return af_arith_sparse_dense<af_add_t>(out, lhs, rhs);
} else if(!linfo.isSparse() && rinfo.isSparse()) {
return af_arith_sparse_dense<af_add_t>(out, rhs, lhs, true); // dense should be rhs
// second operand(Array) of af_arith call should be dense
return af_arith_sparse_dense<af_add_t>(out, rhs, lhs, true);
} else {
return af_arith<af_add_t>(out, lhs, rhs, batchMode);
}
Expand All @@ -185,7 +197,10 @@ af_err af_mul(af_array *out, const af_array lhs, const af_array rhs, const bool
ArrayInfo rinfo = getInfo(rhs, false, true);

if(linfo.isSparse() && rinfo.isSparse()) {
return AF_ERR_NOT_SUPPORTED; //af_arith_sparse<af_mul_t>(out, lhs, rhs);
//return af_arith_sparse<af_mul_t>(out, lhs, rhs);
//MKL doesn't have mul or div support yet, hence
//this is commented out although alternative cpu code exists
return AF_ERR_NOT_SUPPORTED;
} else if(linfo.isSparse() && !rinfo.isSparse()) {
return af_arith_sparse_dense<af_mul_t>(out, lhs, rhs);
} else if(!linfo.isSparse() && rinfo.isSparse()) {
Expand All @@ -202,7 +217,7 @@ af_err af_sub(af_array *out, const af_array lhs, const af_array rhs, const bool
ArrayInfo rinfo = getInfo(rhs, false, true);

if(linfo.isSparse() && rinfo.isSparse()) {
return AF_ERR_NOT_SUPPORTED; //af_arith_sparse<af_sub_t>(out, lhs, rhs);
return af_arith_sparse<af_sub_t>(out, lhs, rhs);
} else if(linfo.isSparse() && !rinfo.isSparse()) {
return af_arith_sparse_dense<af_sub_t>(out, lhs, rhs);
} else if(!linfo.isSparse() && rinfo.isSparse()) {
Expand All @@ -219,7 +234,10 @@ af_err af_div(af_array *out, const af_array lhs, const af_array rhs, const bool
ArrayInfo rinfo = getInfo(rhs, false, true);

if(linfo.isSparse() && rinfo.isSparse()) {
return AF_ERR_NOT_SUPPORTED; //af_arith_sparse<af_div_t>(out, lhs, rhs);
//return af_arith_sparse<af_div_t>(out, lhs, rhs);
//MKL doesn't have mul or div support yet, hence
//this is commented out although alternative cpu code exists
return AF_ERR_NOT_SUPPORTED;
} else if(linfo.isSparse() && !rinfo.isSparse()) {
return af_arith_sparse_dense<af_div_t>(out, lhs, rhs);
} else if(!linfo.isSparse() && rinfo.isSparse()) {
Expand Down
102 changes: 102 additions & 0 deletions src/backend/cpu/kernel/sparse_arith.hpp
Expand Up @@ -11,6 +11,9 @@
#include <Param.hpp>
#include <math.hpp>

#include <cmath>
#include <vector>

namespace cpu
{
namespace kernel
Expand Down Expand Up @@ -141,5 +144,104 @@ void sparseArithOpS(Param<T> values, Param<int> rowIdx, Param<int> colIdx,
}
}

// The following functions can handle CSR
// storage format only as of now.
static
void calcOutNNZ(Param<int> outRowIdx,
const uint M, const uint N,
CParam<int> lRowIdx, CParam<int> lColIdx,
CParam<int> rRowIdx, CParam<int> rColIdx)
{
int *orPtr = outRowIdx.get();
const int *lrPtr = lRowIdx.get();
const int *lcPtr = lColIdx.get();
const int *rrPtr = rRowIdx.get();
const int *rcPtr = rColIdx.get();

unsigned csrOutCount = 0;
for (uint row=0; row<M; ++row) {
const int lEnd = lrPtr[row+1];
const int rEnd = rrPtr[row+1];

uint rowNNZ = 0;
int l = lrPtr[row];
int r = rrPtr[row];
while (l < lEnd && r < rEnd) {
int lci = lcPtr[l];
int rci = rcPtr[r];

l += (lci <= rci);
r += (lci >= rci);
rowNNZ++;
}
// Elements from lhs or rhs are exhausted.
// Just count left over elements
rowNNZ += (lEnd-l);
rowNNZ += (rEnd-r);

orPtr[row] = csrOutCount;
csrOutCount += rowNNZ;
}
//Write out the Rows+1 entry
orPtr[M] = csrOutCount;
}

template<typename T, af_op_t op>
void sparseArithOp(Param<T> oVals, Param<int> oColIdx,
CParam<int> oRowIdx, const uint Rows,
CParam<T> lvals, CParam<int> lRowIdx, CParam<int> lColIdx,
CParam<T> rvals, CParam<int> rRowIdx, CParam<int> rColIdx)
{
const int *orPtr = oRowIdx.get();
const T *lvPtr = lvals.get();
const int *lrPtr = lRowIdx.get();
const int *lcPtr = lColIdx.get();
const T *rvPtr = rvals.get();
const int *rrPtr = rRowIdx.get();
const int *rcPtr = rColIdx.get();

arith_op<T, op> binOp;

auto ZERO = scalar<T>(0);

for (uint row=0; row<Rows; ++row) {
const int lEnd = lrPtr[row+1];
const int rEnd = rrPtr[row+1];
const int offs = orPtr[row];

T *ovPtr = oVals.get() + offs;
int *ocPtr = oColIdx.get() + offs;

uint rowNNZ = 0;
int l = lrPtr[row];
int r = rrPtr[row];
while (l < lEnd && r < rEnd) {
int lci = lcPtr[l];
int rci = rcPtr[r];

T lhs = (lci <= rci ? lvPtr[l] : ZERO);
T rhs = (lci >= rci ? rvPtr[r] : ZERO);

ovPtr[ rowNNZ ] = binOp(lhs, rhs);
ocPtr[ rowNNZ ] = (lci <= rci) ? lci : rci;

l += (lci <= rci);
r += (lci >= rci);
rowNNZ++;
}
while (l < lEnd) {
ovPtr[ rowNNZ ] = binOp(lvPtr[l], ZERO);
ocPtr[ rowNNZ ] = lcPtr[l];
l++;
rowNNZ++;
}
while (r < rEnd) {
ovPtr[ rowNNZ ] = binOp(ZERO, rvPtr[r]);
ocPtr[ rowNNZ ] = rcPtr[r];
r++;
rowNNZ++;
}
}
}
}
}
26 changes: 22 additions & 4 deletions src/backend/cpu/mkl_sparse_interface.cpp
Expand Up @@ -32,13 +32,9 @@ template<> const double getScaleValue<const double, double>(double val) { return
{ return &mkl_sparse_##PREFIX##_##FUNC; }

SPARSE_FUNC(create_csr , float , s)
SPARSE_FUNC(create_coo , float , s)
SPARSE_FUNC(create_csr , double , d)
SPARSE_FUNC(create_coo , double , d)
SPARSE_FUNC(create_csr , cfloat , c)
SPARSE_FUNC(create_coo , cfloat , c)
SPARSE_FUNC(create_csr , cdouble , z)
SPARSE_FUNC(create_coo , cdouble , z)

SPARSE_FUNC(export_csr , float , s)
SPARSE_FUNC(export_csr , double , d)
Expand Down Expand Up @@ -76,5 +72,27 @@ const sp_cdouble getScaleValue<const sp_cdouble, cdouble>(cdouble val)
return ret;
}

const char* errorString(sparse_status_t code)
{
switch (code) {
case SPARSE_STATUS_NOT_INITIALIZED:
return "The routine encountered an empty handle or matrix array.";
case SPARSE_STATUS_ALLOC_FAILED:
return "Internal memory allocation failed.";
case SPARSE_STATUS_INVALID_VALUE:
return "The input parameters contain an invalid value.";
case SPARSE_STATUS_EXECUTION_FAILED:
return "Execution failed.";
case SPARSE_STATUS_INTERNAL_ERROR:
return "An error in algorithm implementation occurred.";
case SPARSE_STATUS_NOT_SUPPORTED:
return "The requested operation is not supported.";
case SPARSE_STATUS_SUCCESS:
return "The operation was successful.";
default:
return "Unkown error";
}
}

#endif
}
25 changes: 17 additions & 8 deletions src/backend/cpu/mkl_sparse_interface.hpp
Expand Up @@ -51,13 +51,6 @@ using create_csr_func_def = sparse_status_t (*)
int *, int *, int*,
ptr_type<T>);
template<typename T>
using create_coo_func_def = sparse_status_t (*)
(sparse_matrix_t *,
sparse_index_base_t,
int, int, int,
int *, int *,
ptr_type<T>);
template<typename T>
using export_csr_func_def = sparse_status_t (*)
(sparse_matrix_t,
sparse_index_base_t*,
Expand All @@ -69,7 +62,6 @@ using export_csr_func_def = sparse_status_t (*)
template<typename T> FUNC##_func_def<T> FUNC##_func();

SPARSE_FUNC_DEF( create_csr )
SPARSE_FUNC_DEF( create_coo )
SPARSE_FUNC_DEF( export_csr )

#undef SPARSE_FUNC_DEF
Expand All @@ -90,6 +82,23 @@ template<typename T> FMT##_add_def<T> FMT##_add();

SPARSE_OP_FUNC_DEF(csr)

const char* errorString(sparse_status_t code);

#define MKL_SPARSE_CHECK(fn) \
do { \
sparse_status_t _error = fn; \
if (_error != SPARSE_STATUS_SUCCESS) { \
char _err_msg[1024]; \
snprintf(_err_msg, \
sizeof(_err_msg), \
"MKL Sparse Error (%d): %s\n", \
(int)(_error), \
errorString(_error)); \
AF_ERROR(_err_msg, \
AF_ERR_INTERNAL); \
} \
} while(0);

#else // USE_MKL

// From mkl_spblas.h
Expand Down

0 comments on commit 83b9101

Please sign in to comment.