Skip to content

Commit

Permalink
FEAT: sparse-sparse add/sub support
Browse files Browse the repository at this point in the history
* CPU backend:
    - Use mkl sparse-sparse add/sub when available
    - Fallback cpu implementation: has support for mul and div too

* CUDA backend sparse-sparse add/sub
* OpenCL backend sparse-sparse add/sub/div/mul

The output of sub/div/mul is not guaranteed to have only
non-zero results of the arithmetic operation. The user has
to take care of pruning the zero results from the output.
  • Loading branch information
9prady9 committed Oct 17, 2018
1 parent ba73e93 commit cbf1c41
Show file tree
Hide file tree
Showing 12 changed files with 689 additions and 47 deletions.
84 changes: 51 additions & 33 deletions src/api/c/binary.cpp
Expand Up @@ -35,6 +35,14 @@ static inline af_array arithOp(const af_array lhs, const af_array rhs,
return res;
}

template<typename T, af_op_t op>
static inline
af_array sparseArithOp(const af_array lhs, const af_array rhs)
{
auto res = arithOp<T, op>(castSparse<T>(lhs), castSparse<T>(rhs));
return getHandle(res);
}

template<typename T, af_op_t op>
static inline af_array arithSparseDenseOp(const af_array lhs, const af_array rhs,
const bool reverse)
Expand Down Expand Up @@ -80,10 +88,11 @@ static af_err af_arith(af_array *out, const af_array lhs, const af_array rhs, co
}

template<af_op_t op>
static af_err af_arith_real(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode)
static
af_err af_arith_real(af_array *out, const af_array lhs, const af_array rhs,
const bool batchMode)
{
try {

const ArrayInfo& linfo = getInfo(lhs);
const ArrayInfo& rinfo = getInfo(rhs);

Expand Down Expand Up @@ -111,38 +120,39 @@ static af_err af_arith_real(af_array *out, const af_array lhs, const af_array rh
return AF_SUCCESS;
}

//template<af_op_t op>
//static af_err af_arith_sparse(af_array *out, const af_array lhs, const af_array rhs)
//{
// try {
// SparseArrayBase linfo = getSparseArrayBase(lhs);
// SparseArrayBase rinfo = getSparseArrayBase(rhs);
//
// dim4 odims = getOutDims(linfo.dims(), rinfo.dims(), batchMode);
//
// const af_dtype otype = implicit(linfo.getType(), rinfo.getType());
// af_array res;
// switch (otype) {
// case f32: res = arithOp<float , op>(lhs, rhs, odims); break;
// case f64: res = arithOp<double , op>(lhs, rhs, odims); break;
// case c32: res = arithOp<cfloat , op>(lhs, rhs, odims); break;
// case c64: res = arithOp<cdouble, op>(lhs, rhs, odims); break;
// default: TYPE_ERROR(0, otype);
// }
//
// std::swap(*out, res);
// }
// CATCHALL;
// return AF_SUCCESS;
//}
template<af_op_t op>
static af_err
af_arith_sparse(af_array *out, const af_array lhs, const af_array rhs)
{
try {
common::SparseArrayBase linfo = getSparseArrayBase(lhs);
common::SparseArrayBase rinfo = getSparseArrayBase(rhs);

ARG_ASSERT(1, (linfo.getStorage()==rinfo.getStorage()));

const af_dtype otype = implicit(linfo.getType(), rinfo.getType());
af_array res;
switch (otype) {
case f32: res = sparseArithOp<float , op>(lhs, rhs); break;
case f64: res = sparseArithOp<double , op>(lhs, rhs); break;
case c32: res = sparseArithOp<cfloat , op>(lhs, rhs); break;
case c64: res = sparseArithOp<cdouble, op>(lhs, rhs); break;
default: TYPE_ERROR(0, otype);
}

std::swap(*out, res);
}
CATCHALL;
return AF_SUCCESS;
}

template<af_op_t op>
static af_err af_arith_sparse_dense(af_array *out, const af_array lhs, const af_array rhs,
const bool reverse = false)
{
using namespace common;
try {
SparseArrayBase linfo = getSparseArrayBase(lhs);
common::SparseArrayBase linfo = getSparseArrayBase(lhs);
ArrayInfo rinfo = getInfo(rhs);

const af_dtype otype = implicit(linfo.getType(), rinfo.getType());
Expand All @@ -161,18 +171,20 @@ static af_err af_arith_sparse_dense(af_array *out, const af_array lhs, const af_
return AF_SUCCESS;
}

af_err af_add(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode)
af_err af_add(af_array *out, const af_array lhs, const af_array rhs,
const bool batchMode)
{
// Check if inputs are sparse
ArrayInfo linfo = getInfo(lhs, false, true);
ArrayInfo rinfo = getInfo(rhs, false, true);

if(linfo.isSparse() && rinfo.isSparse()) {
return AF_ERR_NOT_SUPPORTED; //af_arith_sparse<af_add_t>(out, lhs, rhs);
return af_arith_sparse<af_add_t>(out, lhs, rhs);
} else if(linfo.isSparse() && !rinfo.isSparse()) {
return af_arith_sparse_dense<af_add_t>(out, lhs, rhs);
} else if(!linfo.isSparse() && rinfo.isSparse()) {
return af_arith_sparse_dense<af_add_t>(out, rhs, lhs, true); // dense should be rhs
// second operand(Array) of af_arith call should be dense
return af_arith_sparse_dense<af_add_t>(out, rhs, lhs, true);
} else {
return af_arith<af_add_t>(out, lhs, rhs, batchMode);
}
Expand All @@ -185,7 +197,10 @@ af_err af_mul(af_array *out, const af_array lhs, const af_array rhs, const bool
ArrayInfo rinfo = getInfo(rhs, false, true);

if(linfo.isSparse() && rinfo.isSparse()) {
return AF_ERR_NOT_SUPPORTED; //af_arith_sparse<af_mul_t>(out, lhs, rhs);
//return af_arith_sparse<af_mul_t>(out, lhs, rhs);
//MKL doesn't have mul or div support yet, hence
//this is commented out although alternative cpu code exists
return AF_ERR_NOT_SUPPORTED;
} else if(linfo.isSparse() && !rinfo.isSparse()) {
return af_arith_sparse_dense<af_mul_t>(out, lhs, rhs);
} else if(!linfo.isSparse() && rinfo.isSparse()) {
Expand All @@ -202,7 +217,7 @@ af_err af_sub(af_array *out, const af_array lhs, const af_array rhs, const bool
ArrayInfo rinfo = getInfo(rhs, false, true);

if(linfo.isSparse() && rinfo.isSparse()) {
return AF_ERR_NOT_SUPPORTED; //af_arith_sparse<af_sub_t>(out, lhs, rhs);
return af_arith_sparse<af_sub_t>(out, lhs, rhs);
} else if(linfo.isSparse() && !rinfo.isSparse()) {
return af_arith_sparse_dense<af_sub_t>(out, lhs, rhs);
} else if(!linfo.isSparse() && rinfo.isSparse()) {
Expand All @@ -219,7 +234,10 @@ af_err af_div(af_array *out, const af_array lhs, const af_array rhs, const bool
ArrayInfo rinfo = getInfo(rhs, false, true);

if(linfo.isSparse() && rinfo.isSparse()) {
return AF_ERR_NOT_SUPPORTED; //af_arith_sparse<af_div_t>(out, lhs, rhs);
//return af_arith_sparse<af_div_t>(out, lhs, rhs);
//MKL doesn't have mul or div support yet, hence
//this is commented out although alternative cpu code exists
return AF_ERR_NOT_SUPPORTED;
} else if(linfo.isSparse() && !rinfo.isSparse()) {
return af_arith_sparse_dense<af_div_t>(out, lhs, rhs);
} else if(!linfo.isSparse() && rinfo.isSparse()) {
Expand Down
103 changes: 103 additions & 0 deletions src/backend/cpu/kernel/sparse_arith.hpp
Expand Up @@ -11,6 +11,9 @@
#include <Param.hpp>
#include <math.hpp>

#include <cmath>
#include <vector>

namespace cpu
{
namespace kernel
Expand Down Expand Up @@ -141,5 +144,105 @@ void sparseArithOpS(Param<T> values, Param<int> rowIdx, Param<int> colIdx,
}
}

// The following functions can handle CSR
// storage format only as of now.
static
void calcOutNNZ(Param<int> outRowIdx, unsigned &nnzC,
const uint M, const uint N,
CParam<int> lRowIdx, CParam<int> lColIdx,
CParam<int> rRowIdx, CParam<int> rColIdx)
{
int *orPtr = outRowIdx.get();
const int *lrPtr = lRowIdx.get();
const int *lcPtr = lColIdx.get();
const int *rrPtr = rRowIdx.get();
const int *rcPtr = rColIdx.get();

unsigned csrOutCount = 0;
for (uint row=0; row<M; ++row) {
const int lEnd = lrPtr[row+1];
const int rEnd = rrPtr[row+1];

uint rowNNZ = 0;
int l = lrPtr[row];
int r = rrPtr[row];
while (l < lEnd && r < rEnd) {
int lci = lcPtr[l];
int rci = rcPtr[r];

l += (lci <= rci);
r += (lci >= rci);
rowNNZ++;
}
// Elements from lhs or rhs are exhausted.
// Just count left over elements
rowNNZ += (lEnd-l);
rowNNZ += (rEnd-r);

orPtr[row] = csrOutCount;
csrOutCount += rowNNZ;
}
//Write out the Rows+1 entry
orPtr[M] = csrOutCount;
nnzC = csrOutCount;
}

template<typename T, af_op_t op>
void sparseArithOp(Param<T> oVals, Param<int> oColIdx,
CParam<int> oRowIdx, const uint Rows,
CParam<T> lvals, CParam<int> lRowIdx, CParam<int> lColIdx,
CParam<T> rvals, CParam<int> rRowIdx, CParam<int> rColIdx)
{
const int *orPtr = oRowIdx.get();
const T *lvPtr = lvals.get();
const int *lrPtr = lRowIdx.get();
const int *lcPtr = lColIdx.get();
const T *rvPtr = rvals.get();
const int *rrPtr = rRowIdx.get();
const int *rcPtr = rColIdx.get();

arith_op<T, op> binOp;

auto ZERO = scalar<T>(0);

for (uint row=0; row<Rows; ++row) {
const int lEnd = lrPtr[row+1];
const int rEnd = rrPtr[row+1];
const int offs = orPtr[row];

T *ovPtr = oVals.get() + offs;
int *ocPtr = oColIdx.get() + offs;

uint rowNNZ = 0;
int l = lrPtr[row];
int r = rrPtr[row];
while (l < lEnd && r < rEnd) {
int lci = lcPtr[l];
int rci = rcPtr[r];

T lhs = (lci <= rci ? lvPtr[l] : ZERO);
T rhs = (lci >= rci ? rvPtr[r] : ZERO);

ovPtr[ rowNNZ ] = binOp(lhs, rhs);
ocPtr[ rowNNZ ] = (lci <= rci) ? lci : rci;

l += (lci <= rci);
r += (lci >= rci);
rowNNZ++;
}
while (l < lEnd) {
ovPtr[ rowNNZ ] = binOp(lvPtr[l], ZERO);
ocPtr[ rowNNZ ] = lcPtr[l];
l++;
rowNNZ++;
}
while (r < rEnd) {
ovPtr[ rowNNZ ] = binOp(ZERO, rvPtr[r]);
ocPtr[ rowNNZ ] = rcPtr[r];
r++;
rowNNZ++;
}
}
}
}
}

0 comments on commit cbf1c41

Please sign in to comment.