Skip to content

Commit

Permalink
FEAT: sparse-sparse add/sub support
Browse files Browse the repository at this point in the history
* CPU backend:
    - Use mkl sparse-sparse add/sub when available
    - Fallback cpu implementation: has support for mul and div too

* CUDA backend sparse-sparse add/sub
* OpenCL backend sparse-sparse add/sub/div/mul

The output of sub/div/mul is not guaranteed to have only
non-zero results of the arithmetic operation. The user has
to take care of pruning the zero results from the output.

Also moved mkl interface defs in CPU backend into single source file
  • Loading branch information
9prady9 committed Oct 15, 2018
1 parent fcf1843 commit f515f4d
Show file tree
Hide file tree
Showing 19 changed files with 999 additions and 208 deletions.
84 changes: 51 additions & 33 deletions src/api/c/binary.cpp
Expand Up @@ -35,6 +35,14 @@ static inline af_array arithOp(const af_array lhs, const af_array rhs,
return res;
}

template<typename T, af_op_t op>
static inline
af_array sparseArithOp(const af_array lhs, const af_array rhs)
{
auto res = arithOp<T, op>(castSparse<T>(lhs), castSparse<T>(rhs));
return getHandle(res);
}

template<typename T, af_op_t op>
static inline af_array arithSparseDenseOp(const af_array lhs, const af_array rhs,
const bool reverse)
Expand Down Expand Up @@ -80,10 +88,11 @@ static af_err af_arith(af_array *out, const af_array lhs, const af_array rhs, co
}

template<af_op_t op>
static af_err af_arith_real(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode)
static
af_err af_arith_real(af_array *out, const af_array lhs, const af_array rhs,
const bool batchMode)
{
try {

const ArrayInfo& linfo = getInfo(lhs);
const ArrayInfo& rinfo = getInfo(rhs);

Expand Down Expand Up @@ -111,38 +120,39 @@ static af_err af_arith_real(af_array *out, const af_array lhs, const af_array rh
return AF_SUCCESS;
}

//template<af_op_t op>
//static af_err af_arith_sparse(af_array *out, const af_array lhs, const af_array rhs)
//{
// try {
// SparseArrayBase linfo = getSparseArrayBase(lhs);
// SparseArrayBase rinfo = getSparseArrayBase(rhs);
//
// dim4 odims = getOutDims(linfo.dims(), rinfo.dims(), batchMode);
//
// const af_dtype otype = implicit(linfo.getType(), rinfo.getType());
// af_array res;
// switch (otype) {
// case f32: res = arithOp<float , op>(lhs, rhs, odims); break;
// case f64: res = arithOp<double , op>(lhs, rhs, odims); break;
// case c32: res = arithOp<cfloat , op>(lhs, rhs, odims); break;
// case c64: res = arithOp<cdouble, op>(lhs, rhs, odims); break;
// default: TYPE_ERROR(0, otype);
// }
//
// std::swap(*out, res);
// }
// CATCHALL;
// return AF_SUCCESS;
//}
template<af_op_t op>
static af_err
af_arith_sparse(af_array *out, const af_array lhs, const af_array rhs)
{
try {
common::SparseArrayBase linfo = getSparseArrayBase(lhs);
common::SparseArrayBase rinfo = getSparseArrayBase(rhs);

ARG_ASSERT(1, (linfo.getStorage()==rinfo.getStorage()));

const af_dtype otype = implicit(linfo.getType(), rinfo.getType());
af_array res;
switch (otype) {
case f32: res = sparseArithOp<float , op>(lhs, rhs); break;
case f64: res = sparseArithOp<double , op>(lhs, rhs); break;
case c32: res = sparseArithOp<cfloat , op>(lhs, rhs); break;
case c64: res = sparseArithOp<cdouble, op>(lhs, rhs); break;
default: TYPE_ERROR(0, otype);
}

std::swap(*out, res);
}
CATCHALL;
return AF_SUCCESS;
}

template<af_op_t op>
static af_err af_arith_sparse_dense(af_array *out, const af_array lhs, const af_array rhs,
const bool reverse = false)
{
using namespace common;
try {
SparseArrayBase linfo = getSparseArrayBase(lhs);
common::SparseArrayBase linfo = getSparseArrayBase(lhs);
ArrayInfo rinfo = getInfo(rhs);

const af_dtype otype = implicit(linfo.getType(), rinfo.getType());
Expand All @@ -161,18 +171,20 @@ static af_err af_arith_sparse_dense(af_array *out, const af_array lhs, const af_
return AF_SUCCESS;
}

af_err af_add(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode)
af_err af_add(af_array *out, const af_array lhs, const af_array rhs,
const bool batchMode)
{
// Check if inputs are sparse
ArrayInfo linfo = getInfo(lhs, false, true);
ArrayInfo rinfo = getInfo(rhs, false, true);

if(linfo.isSparse() && rinfo.isSparse()) {
return AF_ERR_NOT_SUPPORTED; //af_arith_sparse<af_add_t>(out, lhs, rhs);
return af_arith_sparse<af_add_t>(out, lhs, rhs);
} else if(linfo.isSparse() && !rinfo.isSparse()) {
return af_arith_sparse_dense<af_add_t>(out, lhs, rhs);
} else if(!linfo.isSparse() && rinfo.isSparse()) {
return af_arith_sparse_dense<af_add_t>(out, rhs, lhs, true); // dense should be rhs
// second operand(Array) of af_arith call should be dense
return af_arith_sparse_dense<af_add_t>(out, rhs, lhs, true);
} else {
return af_arith<af_add_t>(out, lhs, rhs, batchMode);
}
Expand All @@ -185,7 +197,10 @@ af_err af_mul(af_array *out, const af_array lhs, const af_array rhs, const bool
ArrayInfo rinfo = getInfo(rhs, false, true);

if(linfo.isSparse() && rinfo.isSparse()) {
return AF_ERR_NOT_SUPPORTED; //af_arith_sparse<af_mul_t>(out, lhs, rhs);
//return af_arith_sparse<af_mul_t>(out, lhs, rhs);
//MKL doesn't have mul or div support yet, hence
//this is commented out although alternative cpu code exists
return AF_ERR_NOT_SUPPORTED;
} else if(linfo.isSparse() && !rinfo.isSparse()) {
return af_arith_sparse_dense<af_mul_t>(out, lhs, rhs);
} else if(!linfo.isSparse() && rinfo.isSparse()) {
Expand All @@ -202,7 +217,7 @@ af_err af_sub(af_array *out, const af_array lhs, const af_array rhs, const bool
ArrayInfo rinfo = getInfo(rhs, false, true);

if(linfo.isSparse() && rinfo.isSparse()) {
return AF_ERR_NOT_SUPPORTED; //af_arith_sparse<af_sub_t>(out, lhs, rhs);
return af_arith_sparse<af_sub_t>(out, lhs, rhs);
} else if(linfo.isSparse() && !rinfo.isSparse()) {
return af_arith_sparse_dense<af_sub_t>(out, lhs, rhs);
} else if(!linfo.isSparse() && rinfo.isSparse()) {
Expand All @@ -219,7 +234,10 @@ af_err af_div(af_array *out, const af_array lhs, const af_array rhs, const bool
ArrayInfo rinfo = getInfo(rhs, false, true);

if(linfo.isSparse() && rinfo.isSparse()) {
return AF_ERR_NOT_SUPPORTED; //af_arith_sparse<af_div_t>(out, lhs, rhs);
//return af_arith_sparse<af_div_t>(out, lhs, rhs);
//MKL doesn't have mul or div support yet, hence
//this is commented out although alternative cpu code exists
return AF_ERR_NOT_SUPPORTED;
} else if(linfo.isSparse() && !rinfo.isSparse()) {
return af_arith_sparse_dense<af_div_t>(out, lhs, rhs);
} else if(!linfo.isSparse() && rinfo.isSparse()) {
Expand Down
3 changes: 3 additions & 0 deletions src/backend/cpu/CMakeLists.txt
Expand Up @@ -295,6 +295,9 @@ target_include_directories(afcpu
if(USE_CPU_MKL)
dependency_check(MKL_FOUND "MKL not found")
target_compile_definitions(afcpu PRIVATE USE_MKL)
target_sources(afcpu
PRIVATE
mkl_sparse_interface.cpp)

if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" OR
(CMAKE_CXX_COMPILER_ID STREQUAL "Intel" AND (UNIX AND NOT APPLE)))
Expand Down
129 changes: 129 additions & 0 deletions src/backend/cpu/kernel/sparse_arith.hpp
Expand Up @@ -11,6 +11,12 @@
#include <Param.hpp>
#include <math.hpp>

#include <cmath>
#include <vector>

#include <iterator>
#include <iostream>

namespace cpu
{
namespace kernel
Expand Down Expand Up @@ -141,5 +147,128 @@ void sparseArithOpS(Param<T> values, Param<int> rowIdx, Param<int> colIdx,
}
}

// The following functions can handle CSR
// storage format only as of now.
static
void calcOutNNZ(Param<int> outRowIdx, unsigned &nnzC,
const uint M, const uint N,
CParam<int> lRowIdx, CParam<int> lColIdx,
CParam<int> rRowIdx, CParam<int> rColIdx)
{
int *orPtr = outRowIdx.get();
const int *lrPtr = lRowIdx.get();
const int *lcPtr = lColIdx.get();
const int *rrPtr = rRowIdx.get();
const int *rcPtr = rColIdx.get();

unsigned csrOutCount = 0;
for (uint row=0; row<M; ++row) {
const int lEnd = lrPtr[row+1];
const int rEnd = rrPtr[row+1];

uint rowNNZ = 0;
int l = lrPtr[row];
int r = rrPtr[row];
while (l < lEnd && r < rEnd) {
int lci = lcPtr[l];
int rci = rcPtr[r];

if (lci < rci) {
//rhs element has higher column index than
//lhs => lhs is zero
l++;
} else if (lci > rci) {
//lhs element has higher column index than
//rhs => rhs is zero
r++;
} else {
// lhs and rhs are from same column
l++;
r++;
}
rowNNZ++;
}
// At this point, elements from one of lhs or rhs are exhausted.
// Therefore, run through the elements of left over Array.
while (l < lEnd) {
l++;
rowNNZ++;
}
while (r < rEnd) {
r++;
rowNNZ++;
}
orPtr[row] = csrOutCount;
csrOutCount += rowNNZ;
}
//Write out the Rows+1 entry
orPtr[M] = csrOutCount;
nnzC = csrOutCount;
}

template<typename T, af_op_t op>
void sparseArithOp(Param<T> oVals, Param<int> oColIdx,
CParam<int> oRowIdx, const uint Rows,
CParam<T> lvals, CParam<int> lRowIdx, CParam<int> lColIdx,
CParam<T> rvals, CParam<int> rRowIdx, CParam<int> rColIdx)
{
const int *orPtr = oRowIdx.get();
const T *lvPtr = lvals.get();
const int *lrPtr = lRowIdx.get();
const int *lcPtr = lColIdx.get();
const T *rvPtr = rvals.get();
const int *rrPtr = rRowIdx.get();
const int *rcPtr = rColIdx.get();

arith_op<T, op> binOp;

auto ZERO = scalar<T>(0);

for (uint row=0; row<Rows; ++row) {
const int lEnd = lrPtr[row+1];
const int rEnd = rrPtr[row+1];
const int offs = orPtr[row];


T *ovPtr = oVals.get() + offs;
int *ocPtr = oColIdx.get() + offs;

uint rowNNZ = 0;
int l = lrPtr[row];
int r = rrPtr[row];
while (l < lEnd && r < rEnd) {
int lci = lcPtr[l];
int rci = rcPtr[r];

if (lci < rci) {
ovPtr[ rowNNZ ] = binOp(lvPtr[l], ZERO);
ocPtr[ rowNNZ ] = lci;
l++;
} else if (lci > rci) {
ovPtr[ rowNNZ ] = binOp(ZERO, rvPtr[r]);
ocPtr[ rowNNZ ] = rci;
r++;
} else {
ovPtr[ rowNNZ ] = binOp(lvPtr[l], rvPtr[r]);
ocPtr[ rowNNZ ] = lci; // or rci
l++;
r++;
}
rowNNZ++;
}
while (l < lEnd) {
ovPtr[ rowNNZ ] = binOp(lvPtr[l], ZERO);
ocPtr[ rowNNZ ] = lcPtr[l];
l++;
rowNNZ++;
}
while (r < rEnd) {
ovPtr[ rowNNZ ] = binOp(ZERO, rvPtr[r]);
ocPtr[ rowNNZ ] = rcPtr[r];
r++;
rowNNZ++;
}
}
}
}
}
55 changes: 55 additions & 0 deletions src/backend/cpu/mkl_interface_types.hpp
@@ -0,0 +1,55 @@
/*******************************************************
* Copyright (c) 2018, ArrayFire
* All rights reserved.
*
* This file is distributed under 3-clause BSD license.
* The complete license agreement can be obtained at:
* http://arrayfire.com/licenses/BSD-3-Clause
********************************************************/

#pragma once

#include <type_traits>

#ifdef USE_MKL
#include <mkl_spblas.h>
#endif

#include <types.hpp>

namespace cpu
{

#ifdef USE_MKL
typedef MKL_Complex8 sp_cfloat;
typedef MKL_Complex16 sp_cdouble;
#else
typedef cfloat sp_cfloat;
typedef cdouble sp_cdouble;
#endif

template<typename T, class Enable = void>
struct blas_base {
using type = T;
};

template<typename T>
struct blas_base <T, typename std::enable_if< is_complex<T>::value>::type> {
using type = typename std::conditional<std::is_same<T, cdouble>::value,
sp_cdouble, sp_cfloat>
::type;
};

template<typename T>
using cptr_type = typename std::conditional< is_complex<T>::value,
const typename blas_base<T>::type *,
const T*>::type;
template<typename T>
using ptr_type = typename std::conditional< is_complex<T>::value,
typename blas_base<T>::type *,
T*>::type;
template<typename T>
using scale_type = typename std::conditional< is_complex<T>::value,
const typename blas_base<T>::type,
const T>::type;
}

0 comments on commit f515f4d

Please sign in to comment.