From cbf1c41cef5defc132848c6b7085dabb9c4aca51 Mon Sep 17 00:00:00 2001 From: pradeep Date: Wed, 17 Oct 2018 21:17:39 +0530 Subject: [PATCH] FEAT: sparse-sparse add/sub support * CPU backend: - Use mkl sparse-sparse add/sub when available - Fallback cpu implementation: has support for mul and div too * CUDA backend sparse-sparse add/sub * OpenCL backend sparse-sparse add/sub/div/mul The output of sub/div/mul is not guaranteed to have only non-zero results of the arithmetic operation. The user has to take care of pruning the zero results from the output. --- src/api/c/binary.cpp | 84 ++++++----- src/backend/cpu/kernel/sparse_arith.hpp | 103 ++++++++++++++ src/backend/cpu/sparse_arith.cpp | 130 +++++++++++++++++- src/backend/cpu/sparse_arith.hpp | 6 +- src/backend/cuda/sparse_arith.cu | 94 +++++++++++++ src/backend/cuda/sparse_arith.hpp | 6 +- src/backend/opencl/kernel/sp_sp_arith_csr.cl | 57 ++++++++ src/backend/opencl/kernel/sparse_arith.hpp | 105 ++++++++++++++ .../opencl/kernel/ssarith_calc_out_nnz.cl | 36 +++++ src/backend/opencl/sparse_arith.cpp | 52 ++++++- src/backend/opencl/sparse_arith.hpp | 7 +- test/sparse_arith.cpp | 56 ++++++++ 12 files changed, 689 insertions(+), 47 deletions(-) create mode 100644 src/backend/opencl/kernel/sp_sp_arith_csr.cl create mode 100644 src/backend/opencl/kernel/ssarith_calc_out_nnz.cl diff --git a/src/api/c/binary.cpp b/src/api/c/binary.cpp index f0efc1ba51..041369f3e0 100644 --- a/src/api/c/binary.cpp +++ b/src/api/c/binary.cpp @@ -35,6 +35,14 @@ static inline af_array arithOp(const af_array lhs, const af_array rhs, return res; } +template +static inline +af_array sparseArithOp(const af_array lhs, const af_array rhs) +{ + auto res = arithOp(castSparse(lhs), castSparse(rhs)); + return getHandle(res); +} + template static inline af_array arithSparseDenseOp(const af_array lhs, const af_array rhs, const bool reverse) @@ -80,10 +88,11 @@ static af_err af_arith(af_array *out, const af_array lhs, const af_array rhs, co } template -static af_err af_arith_real(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode) +static +af_err af_arith_real(af_array *out, const af_array lhs, const af_array rhs, + const bool batchMode) { try { - const ArrayInfo& linfo = getInfo(lhs); const ArrayInfo& rinfo = getInfo(rhs); @@ -111,30 +120,31 @@ static af_err af_arith_real(af_array *out, const af_array lhs, const af_array rh return AF_SUCCESS; } -//template -//static af_err af_arith_sparse(af_array *out, const af_array lhs, const af_array rhs) -//{ -// try { -// SparseArrayBase linfo = getSparseArrayBase(lhs); -// SparseArrayBase rinfo = getSparseArrayBase(rhs); -// -// dim4 odims = getOutDims(linfo.dims(), rinfo.dims(), batchMode); -// -// const af_dtype otype = implicit(linfo.getType(), rinfo.getType()); -// af_array res; -// switch (otype) { -// case f32: res = arithOp(lhs, rhs, odims); break; -// case f64: res = arithOp(lhs, rhs, odims); break; -// case c32: res = arithOp(lhs, rhs, odims); break; -// case c64: res = arithOp(lhs, rhs, odims); break; -// default: TYPE_ERROR(0, otype); -// } -// -// std::swap(*out, res); -// } -// CATCHALL; -// return AF_SUCCESS; -//} +template +static af_err +af_arith_sparse(af_array *out, const af_array lhs, const af_array rhs) +{ + try { + common::SparseArrayBase linfo = getSparseArrayBase(lhs); + common::SparseArrayBase rinfo = getSparseArrayBase(rhs); + + ARG_ASSERT(1, (linfo.getStorage()==rinfo.getStorage())); + + const af_dtype otype = implicit(linfo.getType(), rinfo.getType()); + af_array res; + switch (otype) { + case f32: res = sparseArithOp(lhs, rhs); break; + case f64: res = sparseArithOp(lhs, rhs); break; + case c32: res = sparseArithOp(lhs, rhs); break; + case c64: res = sparseArithOp(lhs, rhs); break; + default: TYPE_ERROR(0, otype); + } + + std::swap(*out, res); + } + CATCHALL; + return AF_SUCCESS; +} template static af_err af_arith_sparse_dense(af_array *out, const af_array lhs, const af_array rhs, @@ -142,7 +152,7 @@ static af_err af_arith_sparse_dense(af_array *out, const af_array lhs, const af_ { using namespace common; try { - SparseArrayBase linfo = getSparseArrayBase(lhs); + common::SparseArrayBase linfo = getSparseArrayBase(lhs); ArrayInfo rinfo = getInfo(rhs); const af_dtype otype = implicit(linfo.getType(), rinfo.getType()); @@ -161,18 +171,20 @@ static af_err af_arith_sparse_dense(af_array *out, const af_array lhs, const af_ return AF_SUCCESS; } -af_err af_add(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode) +af_err af_add(af_array *out, const af_array lhs, const af_array rhs, + const bool batchMode) { // Check if inputs are sparse ArrayInfo linfo = getInfo(lhs, false, true); ArrayInfo rinfo = getInfo(rhs, false, true); if(linfo.isSparse() && rinfo.isSparse()) { - return AF_ERR_NOT_SUPPORTED; //af_arith_sparse(out, lhs, rhs); + return af_arith_sparse(out, lhs, rhs); } else if(linfo.isSparse() && !rinfo.isSparse()) { return af_arith_sparse_dense(out, lhs, rhs); } else if(!linfo.isSparse() && rinfo.isSparse()) { - return af_arith_sparse_dense(out, rhs, lhs, true); // dense should be rhs + // second operand(Array) of af_arith call should be dense + return af_arith_sparse_dense(out, rhs, lhs, true); } else { return af_arith(out, lhs, rhs, batchMode); } @@ -185,7 +197,10 @@ af_err af_mul(af_array *out, const af_array lhs, const af_array rhs, const bool ArrayInfo rinfo = getInfo(rhs, false, true); if(linfo.isSparse() && rinfo.isSparse()) { - return AF_ERR_NOT_SUPPORTED; //af_arith_sparse(out, lhs, rhs); + //return af_arith_sparse(out, lhs, rhs); + //MKL doesn't have mul or div support yet, hence + //this is commented out although alternative cpu code exists + return AF_ERR_NOT_SUPPORTED; } else if(linfo.isSparse() && !rinfo.isSparse()) { return af_arith_sparse_dense(out, lhs, rhs); } else if(!linfo.isSparse() && rinfo.isSparse()) { @@ -202,7 +217,7 @@ af_err af_sub(af_array *out, const af_array lhs, const af_array rhs, const bool ArrayInfo rinfo = getInfo(rhs, false, true); if(linfo.isSparse() && rinfo.isSparse()) { - return AF_ERR_NOT_SUPPORTED; //af_arith_sparse(out, lhs, rhs); + return af_arith_sparse(out, lhs, rhs); } else if(linfo.isSparse() && !rinfo.isSparse()) { return af_arith_sparse_dense(out, lhs, rhs); } else if(!linfo.isSparse() && rinfo.isSparse()) { @@ -219,7 +234,10 @@ af_err af_div(af_array *out, const af_array lhs, const af_array rhs, const bool ArrayInfo rinfo = getInfo(rhs, false, true); if(linfo.isSparse() && rinfo.isSparse()) { - return AF_ERR_NOT_SUPPORTED; //af_arith_sparse(out, lhs, rhs); + //return af_arith_sparse(out, lhs, rhs); + //MKL doesn't have mul or div support yet, hence + //this is commented out although alternative cpu code exists + return AF_ERR_NOT_SUPPORTED; } else if(linfo.isSparse() && !rinfo.isSparse()) { return af_arith_sparse_dense(out, lhs, rhs); } else if(!linfo.isSparse() && rinfo.isSparse()) { diff --git a/src/backend/cpu/kernel/sparse_arith.hpp b/src/backend/cpu/kernel/sparse_arith.hpp index 9b5a7cc4f1..f8209d455f 100644 --- a/src/backend/cpu/kernel/sparse_arith.hpp +++ b/src/backend/cpu/kernel/sparse_arith.hpp @@ -11,6 +11,9 @@ #include #include +#include +#include + namespace cpu { namespace kernel @@ -141,5 +144,105 @@ void sparseArithOpS(Param values, Param rowIdx, Param colIdx, } } +// The following functions can handle CSR +// storage format only as of now. +static +void calcOutNNZ(Param outRowIdx, unsigned &nnzC, + const uint M, const uint N, + CParam lRowIdx, CParam lColIdx, + CParam rRowIdx, CParam rColIdx) +{ + int *orPtr = outRowIdx.get(); + const int *lrPtr = lRowIdx.get(); + const int *lcPtr = lColIdx.get(); + const int *rrPtr = rRowIdx.get(); + const int *rcPtr = rColIdx.get(); + + unsigned csrOutCount = 0; + for (uint row=0; row= rci); + rowNNZ++; + } + // Elements from lhs or rhs are exhausted. + // Just count left over elements + rowNNZ += (lEnd-l); + rowNNZ += (rEnd-r); + + orPtr[row] = csrOutCount; + csrOutCount += rowNNZ; + } + //Write out the Rows+1 entry + orPtr[M] = csrOutCount; + nnzC = csrOutCount; +} + +template +void sparseArithOp(Param oVals, Param oColIdx, + CParam oRowIdx, const uint Rows, + CParam lvals, CParam lRowIdx, CParam lColIdx, + CParam rvals, CParam rRowIdx, CParam rColIdx) +{ + const int *orPtr = oRowIdx.get(); + const T *lvPtr = lvals.get(); + const int *lrPtr = lRowIdx.get(); + const int *lcPtr = lColIdx.get(); + const T *rvPtr = rvals.get(); + const int *rrPtr = rRowIdx.get(); + const int *rcPtr = rColIdx.get(); + + arith_op binOp; + + auto ZERO = scalar(0); + + for (uint row=0; row= rci ? rvPtr[r] : ZERO); + + ovPtr[ rowNNZ ] = binOp(lhs, rhs); + ocPtr[ rowNNZ ] = (lci <= rci) ? lci : rci; + + l += (lci <= rci); + r += (lci >= rci); + rowNNZ++; + } + while (l < lEnd) { + ovPtr[ rowNNZ ] = binOp(lvPtr[l], ZERO); + ocPtr[ rowNNZ ] = lcPtr[l]; + l++; + rowNNZ++; + } + while (r < rEnd) { + ovPtr[ rowNNZ ] = binOp(ZERO, rvPtr[r]); + ocPtr[ rowNNZ ] = rcPtr[r]; + r++; + rowNNZ++; + } + } +} } } diff --git a/src/backend/cpu/sparse_arith.cpp b/src/backend/cpu/sparse_arith.cpp index 09ede431b0..2afa37b915 100644 --- a/src/backend/cpu/sparse_arith.cpp +++ b/src/backend/cpu/sparse_arith.cpp @@ -7,16 +7,11 @@ * http://arrayfire.com/licenses/BSD-3-Clause ********************************************************/ +#include #include #include -#include #include - -#include - -#include -#include - +#include #include #include #include @@ -26,6 +21,13 @@ #include #include +#include + +#include +#include +#include +#include + namespace cpu { @@ -115,6 +117,112 @@ SparseArray arithOpS(const SparseArray &lhs, const Array &rhs, const bo return out; } +#ifdef USE_MKL +template +sparse_matrix_t createMKLSparseMat(const SparseArray in) +{ + sparse_index_base_t sib = SPARSE_INDEX_BASE_ZERO; + const dim4 idims = in.dims(); + int rows = idims[0]; + int cols = idims[1]; + int nnz = in.getNNZ(); + T* vPtr = const_cast< T*>(in.getValues().get()); + int* rPtr = const_cast(in.getRowIdx().get()); + int* cPtr = const_cast(in.getColIdx().get()); + + sparse_matrix_t res; + sparse_status_t err = SPARSE_STATUS_SUCCESS; + + switch(in.getStorage()) { + case AF_STORAGE_CSR: + err = create_csr_func()(&res, sib, rows, cols, + rPtr, rPtr+1, cPtr, reinterpret_cast>(vPtr)); + break; + default: + AF_ERROR("Storage not supported", AF_ERR_NOT_SUPPORTED); + } + if (err!=SPARSE_STATUS_SUCCESS) + AF_ERROR("Create sparse matrix using MKL failed", AF_ERR_INTERNAL); + return res; +} + +template +void mklSparseOp(sparse_matrix_t* out, + const sparse_matrix_t &l, const sparse_matrix_t &r) +{ + if (op == af_add_t) { + auto alpha = getScaleValue< scale_type, T>(scalar(1.0)); + csr_add()(SPARSE_OPERATION_NON_TRANSPOSE, l, alpha, r, out); + } else if (op == af_sub_t) { + auto alpha = getScaleValue< scale_type, T>(scalar(-1.0)); + csr_add()(SPARSE_OPERATION_NON_TRANSPOSE, r, alpha, l, out); + } else { + //TODO MKL doesn't have sparse / sparse - elementwise div/mul + AF_ERROR("Elementwise division of sparse matrices not supported", + AF_ERR_NOT_SUPPORTED); + } +} +#endif + +template +SparseArray arithOp(const SparseArray &lhs, const SparseArray &rhs) +{ + lhs.eval(); + rhs.eval(); + af::storage sfmt = lhs.getStorage(); + if (sfmt != AF_STORAGE_CSR) + AF_ERROR("Only CSR format supported currently.", AF_ERR_NOT_SUPPORTED); + +#ifdef USE_MKL + const sparse_matrix_t l = createMKLSparseMat(lhs); + const sparse_matrix_t r = createMKLSparseMat(rhs); + + int rows, cols; + int *rstart, *rend, *colIdx; + ptr_type values; + sparse_index_base_t indexing; + + sparse_matrix_t res; + mklSparseOp(&res, l, r); + + export_csr_func()(res, &indexing, &rows, &cols, + &rstart, &rend, &colIdx, &values); + uint outNnz = rend[rows-1]; + std::vector rIdx(rows+1, 0); + std::copy(rstart, rstart+rows, rIdx.begin()); + rIdx[rows] = outNnz; + + auto out = createHostDataSparseArray(dim4(rows, cols), outNnz, + reinterpret_cast(values), + rIdx.data(), colIdx, sfmt); + mkl_sparse_destroy(res); + mkl_sparse_destroy(l); + mkl_sparse_destroy(r); + return out; +#else + const dim4 ldims = lhs.dims(); + + const uint M = ldims[0]; + const uint N = ldims[1]; + + auto outRowIdx = createValueArray(dim4(M+1), scalar(0)); + outRowIdx.eval(); + + unsigned nnzC = 0; + kernel::calcOutNNZ(outRowIdx, nnzC, M, N, + lhs.getRowIdx(), lhs.getColIdx(), + rhs.getRowIdx(), rhs.getColIdx()); + auto outColIdx = createEmptyArray(dim4(nnzC)); + auto outValues = createEmptyArray(dim4(nnzC)); + + kernel::sparseArithOp(outValues, outColIdx, outRowIdx, M, + lhs.getValues(), lhs.getRowIdx(), lhs.getColIdx(), + rhs.getValues(), rhs.getRowIdx(), rhs.getColIdx()); + + return createArrayDataSparseArray(ldims, outValues, outRowIdx, outColIdx, sfmt); +#endif +} + #define INSTANTIATE(T) \ template Array arithOpD(const SparseArray &lhs, const Array &rhs, \ const bool reverse); \ @@ -132,6 +240,14 @@ SparseArray arithOpS(const SparseArray &lhs, const Array &rhs, const bo const bool reverse); \ template SparseArray arithOpS(const SparseArray &lhs, const Array &rhs, \ const bool reverse); \ + template SparseArray arithOp(const common::SparseArray &lhs, \ + const common::SparseArray &rhs); \ + template SparseArray arithOp(const common::SparseArray &lhs, \ + const common::SparseArray &rhs); \ + template SparseArray arithOp(const common::SparseArray &lhs, \ + const common::SparseArray &rhs); \ + template SparseArray arithOp(const common::SparseArray &lhs, \ + const common::SparseArray &rhs); INSTANTIATE(float ) INSTANTIATE(double ) diff --git a/src/backend/cpu/sparse_arith.hpp b/src/backend/cpu/sparse_arith.hpp index db55154814..1cd1a6911c 100644 --- a/src/backend/cpu/sparse_arith.hpp +++ b/src/backend/cpu/sparse_arith.hpp @@ -7,6 +7,8 @@ * http://arrayfire.com/licenses/BSD-3-Clause ********************************************************/ +#pragma once + #include #include #include @@ -14,7 +16,6 @@ namespace cpu { - // These two functions cannot be overloaded by return type. // So have to give them separate names. template @@ -25,4 +26,7 @@ template common::SparseArray arithOpS(const common::SparseArray &lhs, const Array &rhs, const bool reverse = false); +template +common::SparseArray arithOp(const common::SparseArray &lhs, + const common::SparseArray &rhs); } diff --git a/src/backend/cuda/sparse_arith.cu b/src/backend/cuda/sparse_arith.cu index 2126234f66..dc37cb5a20 100644 --- a/src/backend/cuda/sparse_arith.cu +++ b/src/backend/cuda/sparse_arith.cu @@ -103,6 +103,92 @@ SparseArray arithOpS(const SparseArray &lhs, const Array &rhs, const bo return out; } +template +using csrgeam_def = cusparseStatus_t (*)(cusparseHandle_t, int, int, + const T*, const cusparseMatDescr_t, int, const T*, const int*, const int*, + const T*, const cusparseMatDescr_t, int, const T*, const int*, const int*, + const cusparseMatDescr_t, T*, int*, int*); + +#define SPARSE_ARITH_OP_FUNC_DEF( FUNC ) \ +template FUNC##_def FUNC##_func(); + +SPARSE_ARITH_OP_FUNC_DEF( csrgeam ); + +#define SPARSE_ARITH_OP_FUNC( FUNC, TYPE, INFIX ) \ +template<> FUNC##_def FUNC##_func() \ +{ return cusparse##INFIX##FUNC; } + +SPARSE_ARITH_OP_FUNC(csrgeam, float , S); +SPARSE_ARITH_OP_FUNC(csrgeam, double , D); +SPARSE_ARITH_OP_FUNC(csrgeam, cfloat , C); +SPARSE_ARITH_OP_FUNC(csrgeam, cdouble, Z); + +template +SparseArray arithOp(const SparseArray &lhs, const SparseArray &rhs) +{ + lhs.eval(); + rhs.eval(); + af::storage sfmt = lhs.getStorage(); + if (sfmt != AF_STORAGE_CSR) + AF_ERROR("Only CSR format supported currently.", + AF_ERR_NOT_SUPPORTED); + + cusparseMatDescr_t desc; + cusparseCreateMatDescr(&desc); + + const dim4 ldims = lhs.dims(); + + const int M = ldims[0]; + const int N = ldims[1]; + + const dim_t nnzA = lhs.getNNZ(); + const dim_t nnzB = rhs.getNNZ(); + + const int* csrRowPtrA = lhs.getRowIdx().get(); + const int* csrColPtrA = lhs.getColIdx().get(); + const int* csrRowPtrB = rhs.getRowIdx().get(); + const int* csrColPtrB = rhs.getColIdx().get(); + + auto outRowIdx = createEmptyArray(dim4(M+1)); + + int* csrRowPtrC = outRowIdx.get(); + int baseC, nnzC; + int* nnzcDevHostPtr = &nnzC; + + cusparseXcsrgeamNnz(sparseHandle(), M, N, + desc, nnzA, csrRowPtrA, csrColPtrA, + desc, nnzB, csrRowPtrB, csrColPtrB, + desc, csrRowPtrC, nnzcDevHostPtr); + if (NULL != nnzcDevHostPtr) { + nnzC = *nnzcDevHostPtr; + } else { + cudaMemcpyAsync(&nnzC, csrRowPtrC+M, sizeof(int), + cudaMemcpyDeviceToHost, cuda::getActiveStream()); + cudaMemcpyAsync(&baseC, csrRowPtrC, sizeof(int), + cudaMemcpyDeviceToHost, cuda::getActiveStream()); + CUDA_CHECK(cudaStreamSynchronize(cuda::getActiveStream())); + nnzC -= baseC; + } + + auto outColIdx = createEmptyArray(dim4(nnzC)); + auto outValues = createEmptyArray(dim4(nnzC)); + + T alpha = scalar(1); + T beta = op==af_sub_t ? scalar(-1) : alpha; + + csrgeam_func()(sparseHandle(), M, N, + &alpha, desc, nnzA, + lhs.getValues().get(), csrRowPtrA, csrColPtrA, + &beta, desc, nnzB, + rhs.getValues().get(), csrRowPtrB, csrColPtrB, + desc, outValues.get(), csrRowPtrC, outColIdx.get()); + + SparseArray retVal = createArrayDataSparseArray(ldims, + outValues, outRowIdx, outColIdx, + sfmt); + return retVal; +} + #define INSTANTIATE(T) \ template Array arithOpD(const SparseArray &lhs, const Array &rhs, \ const bool reverse); \ @@ -120,6 +206,14 @@ SparseArray arithOpS(const SparseArray &lhs, const Array &rhs, const bo const bool reverse); \ template SparseArray arithOpS(const SparseArray &lhs, const Array &rhs, \ const bool reverse); \ + template SparseArray arithOp(const common::SparseArray &lhs, \ + const common::SparseArray &rhs); \ + template SparseArray arithOp(const common::SparseArray &lhs, \ + const common::SparseArray &rhs); \ + template SparseArray arithOp(const common::SparseArray &lhs, \ + const common::SparseArray &rhs); \ + template SparseArray arithOp(const common::SparseArray &lhs, \ + const common::SparseArray &rhs); INSTANTIATE(float ) INSTANTIATE(double ) diff --git a/src/backend/cuda/sparse_arith.hpp b/src/backend/cuda/sparse_arith.hpp index 5ea1e68059..f9ee528ae5 100644 --- a/src/backend/cuda/sparse_arith.hpp +++ b/src/backend/cuda/sparse_arith.hpp @@ -25,5 +25,7 @@ template common::SparseArray arithOpS(const common::SparseArray &lhs, const Array &rhs, const bool reverse = false); -} - +template +common::SparseArray arithOp(const common::SparseArray &lhs, + const common::SparseArray &rhs); +} \ No newline at end of file diff --git a/src/backend/opencl/kernel/sp_sp_arith_csr.cl b/src/backend/opencl/kernel/sp_sp_arith_csr.cl new file mode 100644 index 0000000000..5a668c6bd8 --- /dev/null +++ b/src/backend/opencl/kernel/sp_sp_arith_csr.cl @@ -0,0 +1,57 @@ +/******************************************************* + * Copyright (c) 2018, ArrayFire + * All rights reserved. + * + * This file is distributed under 3-clause BSD license. + * The complete license agreement can be obtained at: + * http://arrayfire.com/licenses/BSD-3-Clause + ********************************************************/ + +kernel +void ssarith_csr_kernel(global T* oVals, global int* oColIdx, + global const int* oRowIdx, + uint M, uint N, + uint nnza, global const T *lVals, + global const int *lRowIdx, global const int *lColIdx, + uint nnzb, global const T *rVals, + global const int *rRowIdx, global const int *rColIdx) +{ + const uint row = get_global_id(0); + const uint lEnd = lRowIdx[row+1]; + const uint rEnd = rRowIdx[row+1]; + const uint offset = oRowIdx[row]; + + global T *ovPtr = oVals + offset; + global int *ocPtr = oColIdx + offset; + + uint l = lRowIdx[row]; + uint r = rRowIdx[row]; + + uint nnz = 0; + while (l < lEnd && r < rEnd) { + uint lci = lColIdx[l]; + uint rci = rColIdx[r]; + + T lhs = (lci <= rci ? lVals[l] : ZERO); + T rhs = (lci >= rci ? rVals[r] : ZERO); + + ovPtr[ nnz ] = OP(lhs, rhs); + ocPtr[ nnz ] = (lci <= rci) ? lci : rci; + + l += (lci <= rci); + r += (lci >= rci); + nnz++; + } + while (l < lEnd) { + ovPtr[nnz] = OP(lVals[l], ZERO); + ocPtr[nnz] = lColIdx[l]; + l++; + nnz++; + } + while (r < rEnd) { + ovPtr[nnz] = OP(ZERO, rVals[r]); + ocPtr[nnz] = rColIdx[r]; + r++; + nnz++; + } +} diff --git a/src/backend/opencl/kernel/sparse_arith.hpp b/src/backend/opencl/kernel/sparse_arith.hpp index 3cef0fdcab..48e50fd233 100644 --- a/src/backend/opencl/kernel/sparse_arith.hpp +++ b/src/backend/opencl/kernel/sparse_arith.hpp @@ -11,6 +11,8 @@ #include #include #include +#include +#include #include #include #include @@ -21,6 +23,8 @@ #include #include #include +#include +#include using cl::Buffer; using cl::Program; @@ -29,6 +33,7 @@ using cl::KernelFunctor; using cl::EnqueueArgs; using cl::NDRange; using std::string; +using af::scalar_to_option; namespace opencl { @@ -272,5 +277,105 @@ namespace opencl CL_DEBUG_FINISH(getQueue()); } + + static + void csrCalcOutNNZ(Param outRowIdx, unsigned &nnzC, + const uint M, const uint N, + uint nnzA, const Param lrowIdx, const Param lcolIdx, + uint nnzB, const Param rrowIdx, const Param rcolIdx) + { + std::string refName = std::string("csr_calc_output_NNZ"); + int device = getActiveDeviceId(); + kc_entry_t entry = kernelCache(device, refName); + + if (entry.prog==0 && entry.ker==0) { + const char *kerStrs[] = { ssarith_calc_out_nnz_cl }; + const int kerLens[] = { ssarith_calc_out_nnz_cl_len }; + + Program prog; + buildProgram(prog, 1, kerStrs, kerLens, std::string("")); + entry.prog = new Program(prog); + entry.ker = new Kernel(*entry.prog, "csr_calc_out_nnz"); + + addKernelToCache(device, refName, entry); + } + auto calcNNZop = KernelFunctor(*entry.ker); + + NDRange local(1, 1); + NDRange global(M, 1, 1); + + nnzC = 0; + cl::Buffer* out = bufferAlloc(sizeof(unsigned)); + getQueue().enqueueWriteBuffer(*out, CL_TRUE, 0, sizeof(unsigned), &nnzC); + + calcNNZop(EnqueueArgs(getQueue(), global, local), + *out, *outRowIdx.data, + *lrowIdx.data, *lcolIdx.data, + *rrowIdx.data, *rcolIdx.data); + getQueue().enqueueReadBuffer(*out, CL_TRUE, 0, sizeof(unsigned), &nnzC); + + CL_DEBUG_FINISH(getQueue()); + } + + template + void ssArithCSR(Param oVals, Param oColIdx, + const Param oRowIdx, const uint M, const uint N, + unsigned nnzA, const Param lVals, const Param lRowIdx, const Param lColIdx, + unsigned nnzB, const Param rVals, const Param rRowIdx, const Param rColIdx) + { + std::string refName = std::string("ss_arith_csr_") + + getOpString() + "_" + + std::string(dtype_traits::getName()); + int device = getActiveDeviceId(); + kc_entry_t entry = kernelCache(device, refName); + + if (entry.prog==0 && entry.ker==0) { + ToNumStr toNumStr; + std::ostringstream options; + options << " -D T=" << dtype_traits::getName() + << " -D OP=" << getOpString() + << " -D ZERO=(T)(" << scalar_to_option(scalar(0)) << ")"; + + if((af_dtype) dtype_traits::af_type == c32 || + (af_dtype) dtype_traits::af_type == c64) { + options << " -D IS_CPLX=1"; + } else { + options << " -D IS_CPLX=0"; + } + if (std::is_same::value || + std::is_same::value) { + options << " -D USE_DOUBLE"; + } + + const char *kerStrs[] = { sparse_arith_common_cl, sp_sp_arith_csr_cl }; + const int kerLens[] = { sparse_arith_common_cl_len, sp_sp_arith_csr_cl_len }; + + Program prog; + buildProgram(prog, 2, kerStrs, kerLens, options.str()); + entry.prog = new Program(prog); + entry.ker = new Kernel(*entry.prog, "ssarith_csr_kernel"); + + addKernelToCache(device, refName, entry); + } + auto arithOp = KernelFunctor(*entry.ker); + + NDRange local(1, 1); + NDRange global(M, 1, 1); + + arithOp(EnqueueArgs(getQueue(), global, local), + *oVals.data, *oColIdx.data, + *oRowIdx.data, M, N, + nnzA, *lVals.data, *lRowIdx.data, *lColIdx.data, + nnzB, *rVals.data, *rRowIdx.data, *rColIdx.data); + + CL_DEBUG_FINISH(getQueue()); + } } } diff --git a/src/backend/opencl/kernel/ssarith_calc_out_nnz.cl b/src/backend/opencl/kernel/ssarith_calc_out_nnz.cl new file mode 100644 index 0000000000..ee6ab0762e --- /dev/null +++ b/src/backend/opencl/kernel/ssarith_calc_out_nnz.cl @@ -0,0 +1,36 @@ +/******************************************************* + * Copyright (c) 2018, ArrayFire + * All rights reserved. + * + * This file is distributed under 3-clause BSD license. + * The complete license agreement can be obtained at: + * http://arrayfire.com/licenses/BSD-3-Clause + ********************************************************/ + +kernel +void csr_calc_out_nnz(global unsigned* nnzc, + global int* oRowIdx, + global const int *lRowIdx, global const int *lColIdx, + global const int *rRowIdx, global const int *rColIdx) +{ + const uint row = get_global_id(0); + const uint lEnd = lRowIdx[row+1]; + const uint rEnd = rRowIdx[row+1]; + + uint nnz = 0; + uint l = lRowIdx[row]; + uint r = rRowIdx[row]; + while (l < lEnd && r < rEnd) { + uint lci = lColIdx[l]; + uint rci = rColIdx[r]; + l += (lci <= rci); + r += (lci >= rci); + nnz++; + } + nnz += (lEnd-l); + nnz += (rEnd-r); + + oRowIdx[row+1] = nnz; + + atomic_add(nnzc, nnz); +} diff --git a/src/backend/opencl/sparse_arith.cpp b/src/backend/opencl/sparse_arith.cpp index a5e269ea2a..6579371709 100644 --- a/src/backend/opencl/sparse_arith.cpp +++ b/src/backend/opencl/sparse_arith.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include namespace opencl @@ -103,6 +104,48 @@ SparseArray arithOpS(const SparseArray &lhs, const Array &rhs, const bo return out; } +template +SparseArray arithOp(const SparseArray &lhs, const SparseArray &rhs) +{ + lhs.eval(); + rhs.eval(); + af::storage sfmt = lhs.getStorage(); + if (sfmt != AF_STORAGE_CSR) + AF_ERROR("Only CSR format supported currently.", + AF_ERR_NOT_SUPPORTED); + + const dim4 ldims = lhs.dims(); + + const uint M = ldims[0]; + const uint N = ldims[1]; + + const dim_t nnzA = lhs.getNNZ(); + const dim_t nnzB = rhs.getNNZ(); + + auto temp = createValueArray(dim4(M+1), scalar(0)); + temp.eval(); + + unsigned nnzC = 0; + kernel::csrCalcOutNNZ(temp, nnzC, M, N, + nnzA, lhs.getRowIdx(), lhs.getColIdx(), + nnzB, rhs.getRowIdx(), rhs.getColIdx()); + + auto outRowIdx = scan(temp, 0); + + auto outColIdx = createEmptyArray(dim4(nnzC)); + auto outValues = createEmptyArray(dim4(nnzC)); + + kernel::ssArithCSR(outValues, outColIdx, + outRowIdx, M, N, + nnzA, lhs.getValues(), lhs.getRowIdx(), lhs.getColIdx(), + nnzB, rhs.getValues(), rhs.getRowIdx(), rhs.getColIdx()); + + SparseArray retVal = createArrayDataSparseArray(ldims, + outValues, outRowIdx, outColIdx, + sfmt); + return retVal; +} + #define INSTANTIATE(T) \ template Array arithOpD(const SparseArray &lhs, const Array &rhs, \ const bool reverse); \ @@ -120,6 +163,14 @@ SparseArray arithOpS(const SparseArray &lhs, const Array &rhs, const bo const bool reverse); \ template SparseArray arithOpS(const SparseArray &lhs, const Array &rhs, \ const bool reverse); \ + template SparseArray arithOp(const common::SparseArray &lhs, \ + const common::SparseArray &rhs); \ + template SparseArray arithOp(const common::SparseArray &lhs, \ + const common::SparseArray &rhs); \ + template SparseArray arithOp(const common::SparseArray &lhs, \ + const common::SparseArray &rhs); \ + template SparseArray arithOp(const common::SparseArray &lhs, \ + const common::SparseArray &rhs); INSTANTIATE(float ) INSTANTIATE(double ) @@ -127,4 +178,3 @@ INSTANTIATE(cfloat ) INSTANTIATE(cdouble) } - diff --git a/src/backend/opencl/sparse_arith.hpp b/src/backend/opencl/sparse_arith.hpp index 4afc799cad..3a54a674d6 100644 --- a/src/backend/opencl/sparse_arith.hpp +++ b/src/backend/opencl/sparse_arith.hpp @@ -25,6 +25,7 @@ template common::SparseArray arithOpS(const common::SparseArray &lhs, const Array &rhs, const bool reverse = false); -} - - +template +common::SparseArray arithOp(const common::SparseArray &lhs, + const common::SparseArray &rhs); +} \ No newline at end of file diff --git a/test/sparse_arith.cpp b/test/sparse_arith.cpp index cd8e98d857..8b7c35b8e7 100644 --- a/test/sparse_arith.cpp +++ b/test/sparse_arith.cpp @@ -328,3 +328,59 @@ ARITH_TESTS(float , 1e-6) ARITH_TESTS(double , 1e-6) ARITH_TESTS(cfloat , 1e-4) // This is mostly for complex division in OpenCL ARITH_TESTS(cdouble, 1e-6) + +// Sparse-Sparse Arithmetic testing function +template +void ssArithmetic(const int m, const int n, int factor, const double eps) +{ + deviceGC(); + + if (noDoubleTests()) return; + +#if 1 + array A = cpu_randu(dim4(m, n)); + array B = cpu_randu(dim4(m, n)); +#else + array A = randu(m, n, (dtype)dtype_traits::af_type); + array B = randu(m, n, (dtype)dtype_traits::af_type); +#endif + + A = makeSparse(A, factor); + B = makeSparse(B, factor); + + array spA = sparse(A, AF_STORAGE_CSR); + array spB = sparse(B, AF_STORAGE_CSR); + + arith_op binOp; + + // Arith Op + array resS = binOp(spA, spB); + array resD = binOp(A, B); + array revS = binOp(spB, spA); + array revD = binOp(B, A); + + ASSERT_ARRAYS_NEAR(resD, dense(resS), eps); + ASSERT_ARRAYS_NEAR(revD, dense(revS), eps); +} + +#define SP_SP_ARITH_TEST(type, m, n, factor, eps) \ +TEST(SparseSparseArith, type##_Addition_##m##_##n) \ +{ \ + ssArithmetic(m, n, factor, eps); \ +} \ +TEST(SparseSparseArith, type##_Subtraction_##m##_##n) \ +{ \ + ssArithmetic(m, n, factor, eps); \ +} + +#define SP_SP_ARITH_TESTS(T, eps) \ + SP_SP_ARITH_TEST(T, 10 , 10 , 5, eps) \ + SP_SP_ARITH_TEST(T, 1024, 1024, 5, eps) \ + SP_SP_ARITH_TEST(T, 100 , 100 , 1, eps) \ + SP_SP_ARITH_TEST(T, 2048, 1000, 6, eps) \ + SP_SP_ARITH_TEST(T, 123 , 278 , 5, eps) \ + +SP_SP_ARITH_TESTS(float , 1e-6) +SP_SP_ARITH_TESTS(double , 1e-6) +SP_SP_ARITH_TESTS(cfloat , 1e-4) // This is mostly for complex division in OpenCL +SP_SP_ARITH_TESTS(cdouble, 1e-6)