Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: sparse-sparse add/sub support #2312

Merged
merged 4 commits into from Dec 19, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
86 changes: 53 additions & 33 deletions src/api/c/binary.cpp
Expand Up @@ -35,6 +35,14 @@ static inline af_array arithOp(const af_array lhs, const af_array rhs,
return res;
}

template<typename T, af_op_t op>
static inline
af_array sparseArithOp(const af_array lhs, const af_array rhs)
{
auto res = arithOp<T, op>(getSparseArray<T>(lhs), getSparseArray<T>(rhs));
return getHandle(res);
}

template<typename T, af_op_t op>
static inline af_array arithSparseDenseOp(const af_array lhs, const af_array rhs,
const bool reverse)
Expand Down Expand Up @@ -80,10 +88,11 @@ static af_err af_arith(af_array *out, const af_array lhs, const af_array rhs, co
}

template<af_op_t op>
static af_err af_arith_real(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode)
static
af_err af_arith_real(af_array *out, const af_array lhs, const af_array rhs,
const bool batchMode)
{
try {

const ArrayInfo& linfo = getInfo(lhs);
const ArrayInfo& rinfo = getInfo(rhs);

Expand Down Expand Up @@ -111,38 +120,41 @@ static af_err af_arith_real(af_array *out, const af_array lhs, const af_array rh
return AF_SUCCESS;
}

//template<af_op_t op>
//static af_err af_arith_sparse(af_array *out, const af_array lhs, const af_array rhs)
//{
// try {
// SparseArrayBase linfo = getSparseArrayBase(lhs);
// SparseArrayBase rinfo = getSparseArrayBase(rhs);
//
// dim4 odims = getOutDims(linfo.dims(), rinfo.dims(), batchMode);
//
// const af_dtype otype = implicit(linfo.getType(), rinfo.getType());
// af_array res;
// switch (otype) {
// case f32: res = arithOp<float , op>(lhs, rhs, odims); break;
// case f64: res = arithOp<double , op>(lhs, rhs, odims); break;
// case c32: res = arithOp<cfloat , op>(lhs, rhs, odims); break;
// case c64: res = arithOp<cdouble, op>(lhs, rhs, odims); break;
// default: TYPE_ERROR(0, otype);
// }
//
// std::swap(*out, res);
// }
// CATCHALL;
// return AF_SUCCESS;
//}
template<af_op_t op>
static af_err
af_arith_sparse(af_array *out, const af_array lhs, const af_array rhs)
9prady9 marked this conversation as resolved.
Show resolved Hide resolved
{
try {
common::SparseArrayBase linfo = getSparseArrayBase(lhs);
common::SparseArrayBase rinfo = getSparseArrayBase(rhs);

ARG_ASSERT(1, (linfo.getStorage()==rinfo.getStorage()));
9prady9 marked this conversation as resolved.
Show resolved Hide resolved
ARG_ASSERT(1, (linfo.dims()==rinfo.dims()));
ARG_ASSERT(1, (linfo.getStorage()==AF_STORAGE_CSR));

const af_dtype otype = implicit(linfo.getType(), rinfo.getType());
af_array res;
switch (otype) {
case f32: res = sparseArithOp<float , op>(lhs, rhs); break;
case f64: res = sparseArithOp<double , op>(lhs, rhs); break;
case c32: res = sparseArithOp<cfloat , op>(lhs, rhs); break;
case c64: res = sparseArithOp<cdouble, op>(lhs, rhs); break;
default: TYPE_ERROR(0, otype);
}

std::swap(*out, res);
}
CATCHALL;
return AF_SUCCESS;
}

template<af_op_t op>
static af_err af_arith_sparse_dense(af_array *out, const af_array lhs, const af_array rhs,
const bool reverse = false)
{
using namespace common;
try {
SparseArrayBase linfo = getSparseArrayBase(lhs);
common::SparseArrayBase linfo = getSparseArrayBase(lhs);
ArrayInfo rinfo = getInfo(rhs);

const af_dtype otype = implicit(linfo.getType(), rinfo.getType());
Expand All @@ -161,18 +173,20 @@ static af_err af_arith_sparse_dense(af_array *out, const af_array lhs, const af_
return AF_SUCCESS;
}

af_err af_add(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode)
af_err af_add(af_array *out, const af_array lhs, const af_array rhs,
const bool batchMode)
{
// Check if inputs are sparse
ArrayInfo linfo = getInfo(lhs, false, true);
ArrayInfo rinfo = getInfo(rhs, false, true);

if(linfo.isSparse() && rinfo.isSparse()) {
return AF_ERR_NOT_SUPPORTED; //af_arith_sparse<af_add_t>(out, lhs, rhs);
return af_arith_sparse<af_add_t>(out, lhs, rhs);
} else if(linfo.isSparse() && !rinfo.isSparse()) {
return af_arith_sparse_dense<af_add_t>(out, lhs, rhs);
} else if(!linfo.isSparse() && rinfo.isSparse()) {
return af_arith_sparse_dense<af_add_t>(out, rhs, lhs, true); // dense should be rhs
// second operand(Array) of af_arith call should be dense
return af_arith_sparse_dense<af_add_t>(out, rhs, lhs, true);
} else {
return af_arith<af_add_t>(out, lhs, rhs, batchMode);
}
Expand All @@ -185,7 +199,10 @@ af_err af_mul(af_array *out, const af_array lhs, const af_array rhs, const bool
ArrayInfo rinfo = getInfo(rhs, false, true);

if(linfo.isSparse() && rinfo.isSparse()) {
return AF_ERR_NOT_SUPPORTED; //af_arith_sparse<af_mul_t>(out, lhs, rhs);
//return af_arith_sparse<af_mul_t>(out, lhs, rhs);
//MKL doesn't have mul or div support yet, hence
//this is commented out although alternative cpu code exists
return AF_ERR_NOT_SUPPORTED;
} else if(linfo.isSparse() && !rinfo.isSparse()) {
return af_arith_sparse_dense<af_mul_t>(out, lhs, rhs);
} else if(!linfo.isSparse() && rinfo.isSparse()) {
Expand All @@ -202,7 +219,7 @@ af_err af_sub(af_array *out, const af_array lhs, const af_array rhs, const bool
ArrayInfo rinfo = getInfo(rhs, false, true);

if(linfo.isSparse() && rinfo.isSparse()) {
return AF_ERR_NOT_SUPPORTED; //af_arith_sparse<af_sub_t>(out, lhs, rhs);
return af_arith_sparse<af_sub_t>(out, lhs, rhs);
} else if(linfo.isSparse() && !rinfo.isSparse()) {
return af_arith_sparse_dense<af_sub_t>(out, lhs, rhs);
} else if(!linfo.isSparse() && rinfo.isSparse()) {
Expand All @@ -219,7 +236,10 @@ af_err af_div(af_array *out, const af_array lhs, const af_array rhs, const bool
ArrayInfo rinfo = getInfo(rhs, false, true);

if(linfo.isSparse() && rinfo.isSparse()) {
return AF_ERR_NOT_SUPPORTED; //af_arith_sparse<af_div_t>(out, lhs, rhs);
//return af_arith_sparse<af_div_t>(out, lhs, rhs);
//MKL doesn't have mul or div support yet, hence
//this is commented out although alternative cpu code exists
return AF_ERR_NOT_SUPPORTED;
} else if(linfo.isSparse() && !rinfo.isSparse()) {
return af_arith_sparse_dense<af_div_t>(out, lhs, rhs);
} else if(!linfo.isSparse() && rinfo.isSparse()) {
Expand Down
10 changes: 5 additions & 5 deletions src/backend/cpu/kernel/sparse.hpp
Expand Up @@ -45,7 +45,7 @@ void coo2dense(Param<T> output,
}

template<typename T>
void dense_csr(Param<T> values, Param<int> rowIdx, Param<int> colIdx,
void dense2csr(Param<T> values, Param<int> rowIdx, Param<int> colIdx,
CParam<T> in)
{
const T * iPtr = in.get();
Expand All @@ -70,8 +70,8 @@ void dense_csr(Param<T> values, Param<int> rowIdx, Param<int> colIdx,
}

template<typename T>
void csr_dense(Param<T> out,
CParam<T> values, CParam<int> rowIdx, CParam<int> colIdx)
void csr2dense(Param<T> out,
CParam<T> values, CParam<int> rowIdx, CParam<int> colIdx)
{
T *oPtr = out.get();
const T *vPtr = values.get();
Expand Down Expand Up @@ -107,7 +107,7 @@ struct SpKIPCompareK
};

template<typename T>
void csr_coo(Param<T> ovalues, Param<int> orowIdx, Param<int> ocolIdx,
void csr2coo(Param<T> ovalues, Param<int> orowIdx, Param<int> ocolIdx,
CParam<T> ivalues, CParam<int> irowIdx, CParam<int> icolIdx)
{
// First calculate the linear index
Expand Down Expand Up @@ -143,7 +143,7 @@ void csr_coo(Param<T> ovalues, Param<int> orowIdx, Param<int> ocolIdx,
}

template<typename T>
void coo_csr(Param<T> ovalues, Param<int> orowIdx, Param<int> ocolIdx,
void coo2csr(Param<T> ovalues, Param<int> orowIdx, Param<int> ocolIdx,
CParam<T> ivalues, CParam<int> irowIdx, CParam<int> icolIdx)
{
T * ovPtr = ovalues.get();
Expand Down
101 changes: 101 additions & 0 deletions src/backend/cpu/kernel/sparse_arith.hpp
Expand Up @@ -11,6 +11,8 @@
#include <Param.hpp>
#include <math.hpp>

#include <cmath>

namespace cpu
{
namespace kernel
Expand Down Expand Up @@ -143,5 +145,104 @@ void sparseArithOpS(Param<T> values, Param<int> rowIdx, Param<int> colIdx,
}
}

// The following functions can handle CSR
// storage format only as of now.
static
void calcOutNNZ(Param<int> outRowIdx,
const uint M, const uint N,
CParam<int> lRowIdx, CParam<int> lColIdx,
CParam<int> rRowIdx, CParam<int> rColIdx)
{
int *orPtr = outRowIdx.get();
const int *lrPtr = lRowIdx.get();
const int *lcPtr = lColIdx.get();
const int *rrPtr = rRowIdx.get();
const int *rcPtr = rColIdx.get();

unsigned csrOutCount = 0;
for (uint row=0; row<M; ++row) {
const int lEnd = lrPtr[row+1];
const int rEnd = rrPtr[row+1];

uint rowNNZ = 0;
int l = lrPtr[row];
umar456 marked this conversation as resolved.
Show resolved Hide resolved
int r = rrPtr[row];
while (l < lEnd && r < rEnd) {
int lci = lcPtr[l];
int rci = rcPtr[r];

l += (lci <= rci);
r += (lci >= rci);
rowNNZ++;
}
// Elements from lhs or rhs are exhausted.
// Just count left over elements
rowNNZ += (lEnd-l);
rowNNZ += (rEnd-r);

orPtr[row] = csrOutCount;
csrOutCount += rowNNZ;
}
//Write out the Rows+1 entry
orPtr[M] = csrOutCount;
}

template<typename T, af_op_t op>
void sparseArithOp(Param<T> oVals, Param<int> oColIdx,
CParam<int> oRowIdx, const uint Rows,
CParam<T> lvals, CParam<int> lRowIdx, CParam<int> lColIdx,
CParam<T> rvals, CParam<int> rRowIdx, CParam<int> rColIdx)
{
const int *orPtr = oRowIdx.get();
const T *lvPtr = lvals.get();
const int *lrPtr = lRowIdx.get();
const int *lcPtr = lColIdx.get();
const T *rvPtr = rvals.get();
const int *rrPtr = rRowIdx.get();
const int *rcPtr = rColIdx.get();

arith_op<T, op> binOp;

auto ZERO = scalar<T>(0);

for (uint row=0; row<Rows; ++row) {
const int lEnd = lrPtr[row+1];
const int rEnd = rrPtr[row+1];
const int offs = orPtr[row];

T *ovPtr = oVals.get() + offs;
int *ocPtr = oColIdx.get() + offs;

uint rowNNZ = 0;
int l = lrPtr[row];
int r = rrPtr[row];
while (l < lEnd && r < rEnd) {
int lci = lcPtr[l];
int rci = rcPtr[r];

T lhs = (lci <= rci ? lvPtr[l] : ZERO);
T rhs = (lci >= rci ? rvPtr[r] : ZERO);

ovPtr[ rowNNZ ] = binOp(lhs, rhs);
ocPtr[ rowNNZ ] = (lci <= rci) ? lci : rci;

l += (lci <= rci);
r += (lci >= rci);
rowNNZ++;
}
while (l < lEnd) {
ovPtr[ rowNNZ ] = binOp(lvPtr[l], ZERO);
ocPtr[ rowNNZ ] = lcPtr[l];
l++;
rowNNZ++;
}
while (r < rEnd) {
ovPtr[ rowNNZ ] = binOp(ZERO, rvPtr[r]);
ocPtr[ rowNNZ ] = rcPtr[r];
r++;
rowNNZ++;
}
}
}
}
}