diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 7e46f356c07d50..26a2d2892e00e0 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -3480,14 +3480,18 @@ tf_kernel_library( tf_kernel_library( name = "cuda_sparse", - srcs = ["cuda_sparse.cc"], + srcs = if_cuda(["cuda_sparse.cc"]) + if_rocm(["rocm_sparse.cc"]), hdrs = ["cuda_sparse.h"], deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/kernels:cuda_solvers", + ] + if_cuda([ "//tensorflow/stream_executor/cuda:cusparse_lib", - ] + if_cuda(["@cub_archive//:cub"]), + "@cub_archive//:cub", + ]) + if_rocm([ + "@local_config_rocm//rocm:hipsparse", + ]), ) LINALG_DEPS = [ diff --git a/tensorflow/core/kernels/cuda_sparse.cc b/tensorflow/core/kernels/cuda_sparse.cc index 7825dc5969f0d6..7485bef45a2f74 100644 --- a/tensorflow/core/kernels/cuda_sparse.cc +++ b/tensorflow/core/kernels/cuda_sparse.cc @@ -69,7 +69,7 @@ inline typename CudaComplexT::type* AsCudaComplex(T* p) { } // A set of initialized handles to the underlying Cuda libraries used by -// CudaSparse. We maintain one such set of handles per unique stream. +// GpuSparse. We maintain one such set of handles per unique stream. class CudaSparseHandles { public: explicit CudaSparseHandles(cudaStream_t stream) @@ -96,8 +96,8 @@ class CudaSparseHandles { Status Initialize() { if (initialized_) return Status::OK(); - TF_RETURN_IF_CUSPARSE_ERROR(cusparseCreate(&cusparse_handle_)); - TF_RETURN_IF_CUSPARSE_ERROR(cusparseSetStream(cusparse_handle_, stream_)); + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreate(&cusparse_handle_)); + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSetStream(cusparse_handle_, stream_)); initialized_ = true; return Status::OK(); } @@ -149,7 +149,7 @@ HandleMap* GetHandleMapSingleton() { } // namespace -CudaSparse::CudaSparse(OpKernelContext* context) +GpuSparse::GpuSparse(OpKernelContext* context) : initialized_(false), context_(context) { auto cuda_stream_ptr = reinterpret_cast(context->op_device_context() @@ -157,25 +157,24 @@ CudaSparse::CudaSparse(OpKernelContext* context) ->implementation() ->GpuStreamMemberHack()); DCHECK(cuda_stream_ptr); - cuda_stream_ = *cuda_stream_ptr; + gpu_stream_ = *cuda_stream_ptr; } -Status CudaSparse::Initialize() { +Status GpuSparse::Initialize() { HandleMap* handle_map = GetHandleMapSingleton(); DCHECK(handle_map); mutex_lock lock(handle_map_mutex); - auto it = handle_map->find(cuda_stream_); + auto it = handle_map->find(gpu_stream_); if (it == handle_map->end()) { - LOG(INFO) << "Creating CudaSparse handles for stream " << cuda_stream_; + LOG(INFO) << "Creating CudaSparse handles for stream " << gpu_stream_; // Previously unseen Cuda stream. Initialize a set of Cuda sparse library // handles for it. - CudaSparseHandles new_handles(cuda_stream_); + CudaSparseHandles new_handles(gpu_stream_); TF_RETURN_IF_ERROR(new_handles.Initialize()); - it = - handle_map->insert(std::make_pair(cuda_stream_, std::move(new_handles))) - .first; + it = handle_map->insert(std::make_pair(gpu_stream_, std::move(new_handles))) + .first; } - cusparse_handle_ = &it->second.handle(); + gpusparse_handle_ = &it->second.handle(); initialized_ = true; return Status::OK(); } @@ -205,32 +204,32 @@ template static inline Status GtsvImpl(SparseFn op, cusparseHandle_t cusparse_handle, int m, int n, const Scalar* dl, const Scalar* d, const Scalar* du, Scalar* B, int ldb) { - TF_RETURN_IF_CUSPARSE_ERROR(op(cusparse_handle, m, n, AsCudaComplex(dl), - AsCudaComplex(d), AsCudaComplex(du), - AsCudaComplex(B), ldb)); + TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, AsCudaComplex(dl), + AsCudaComplex(d), AsCudaComplex(du), + AsCudaComplex(B), ldb)); return Status::OK(); } -#define GTSV_INSTANCE(Scalar, sparse_prefix) \ - template <> \ - Status CudaSparse::Gtsv(int m, int n, const Scalar* dl, \ - const Scalar* d, const Scalar* du, \ - Scalar* B, int ldb) const { \ - DCHECK(initialized_); \ - return GtsvImpl(SPARSE_FN(gtsv, sparse_prefix), *cusparse_handle_, m, n, \ - dl, d, du, B, ldb); \ +#define GTSV_INSTANCE(Scalar, sparse_prefix) \ + template <> \ + Status GpuSparse::Gtsv(int m, int n, const Scalar* dl, \ + const Scalar* d, const Scalar* du, Scalar* B, \ + int ldb) const { \ + DCHECK(initialized_); \ + return GtsvImpl(SPARSE_FN(gtsv, sparse_prefix), *gpusparse_handle_, m, n, \ + dl, d, du, B, ldb); \ } TF_CALL_LAPACK_TYPES(GTSV_INSTANCE); -#define GTSV_NO_PIVOT_INSTANCE(Scalar, sparse_prefix) \ - template <> \ - Status CudaSparse::GtsvNoPivot(int m, int n, const Scalar* dl, \ - const Scalar* d, const Scalar* du, \ - Scalar* B, int ldb) const { \ - DCHECK(initialized_); \ - return GtsvImpl(SPARSE_FN(gtsv_nopivot, sparse_prefix), *cusparse_handle_, \ - m, n, dl, d, du, B, ldb); \ +#define GTSV_NO_PIVOT_INSTANCE(Scalar, sparse_prefix) \ + template <> \ + Status GpuSparse::GtsvNoPivot(int m, int n, const Scalar* dl, \ + const Scalar* d, const Scalar* du, \ + Scalar* B, int ldb) const { \ + DCHECK(initialized_); \ + return GtsvImpl(SPARSE_FN(gtsv_nopivot, sparse_prefix), \ + *gpusparse_handle_, m, n, dl, d, du, B, ldb); \ } TF_CALL_LAPACK_TYPES(GTSV_NO_PIVOT_INSTANCE); @@ -242,20 +241,20 @@ static inline Status GtsvStridedBatchImpl(SparseFn op, const Scalar* d, const Scalar* du, Scalar* x, int batchCount, int batchStride) { - TF_RETURN_IF_CUSPARSE_ERROR(op(cusparse_handle, m, AsCudaComplex(dl), - AsCudaComplex(d), AsCudaComplex(du), - AsCudaComplex(x), batchCount, batchStride)); + TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, AsCudaComplex(dl), + AsCudaComplex(d), AsCudaComplex(du), + AsCudaComplex(x), batchCount, batchStride)); return Status::OK(); } #define GTSV_STRIDED_BATCH_INSTANCE(Scalar, sparse_prefix) \ template <> \ - Status CudaSparse::GtsvStridedBatch( \ + Status GpuSparse::GtsvStridedBatch( \ int m, const Scalar* dl, const Scalar* d, const Scalar* du, Scalar* x, \ int batchCount, int batchStride) const { \ DCHECK(initialized_); \ return GtsvStridedBatchImpl(SPARSE_FN(gtsvStridedBatch, sparse_prefix), \ - *cusparse_handle_, m, dl, d, du, x, \ + *gpusparse_handle_, m, dl, d, du, x, \ batchCount, batchStride); \ } @@ -266,32 +265,32 @@ static inline Status Gtsv2Impl(SparseFn op, cusparseHandle_t cusparse_handle, int m, int n, const Scalar* dl, const Scalar* d, const Scalar* du, Scalar* B, int ldb, void* pBuffer) { - TF_RETURN_IF_CUSPARSE_ERROR(op(cusparse_handle, m, n, AsCudaComplex(dl), - AsCudaComplex(d), AsCudaComplex(du), - AsCudaComplex(B), ldb, pBuffer)); + TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, AsCudaComplex(dl), + AsCudaComplex(d), AsCudaComplex(du), + AsCudaComplex(B), ldb, pBuffer)); return Status::OK(); } -#define GTSV2_INSTANCE(Scalar, sparse_prefix) \ - template <> \ - Status CudaSparse::Gtsv2(int m, int n, const Scalar* dl, \ - const Scalar* d, const Scalar* du, \ - Scalar* B, int ldb, void* pBuffer) const { \ - DCHECK(initialized_); \ - return Gtsv2Impl(SPARSE_FN(gtsv2, sparse_prefix), *cusparse_handle_, m, n, \ - dl, d, du, B, ldb, pBuffer); \ +#define GTSV2_INSTANCE(Scalar, sparse_prefix) \ + template <> \ + Status GpuSparse::Gtsv2(int m, int n, const Scalar* dl, \ + const Scalar* d, const Scalar* du, \ + Scalar* B, int ldb, void* pBuffer) const { \ + DCHECK(initialized_); \ + return Gtsv2Impl(SPARSE_FN(gtsv2, sparse_prefix), *gpusparse_handle_, m, \ + n, dl, d, du, B, ldb, pBuffer); \ } TF_CALL_LAPACK_TYPES(GTSV2_INSTANCE); -#define GTSV2_NO_PIVOT_INSTANCE(Scalar, sparse_prefix) \ - template <> \ - Status CudaSparse::Gtsv2NoPivot( \ - int m, int n, const Scalar* dl, const Scalar* d, const Scalar* du, \ - Scalar* B, int ldb, void* pBuffer) const { \ - DCHECK(initialized_); \ - return Gtsv2Impl(SPARSE_FN(gtsv2_nopivot, sparse_prefix), \ - *cusparse_handle_, m, n, dl, d, du, B, ldb, pBuffer); \ +#define GTSV2_NO_PIVOT_INSTANCE(Scalar, sparse_prefix) \ + template <> \ + Status GpuSparse::Gtsv2NoPivot( \ + int m, int n, const Scalar* dl, const Scalar* d, const Scalar* du, \ + Scalar* B, int ldb, void* pBuffer) const { \ + DCHECK(initialized_); \ + return Gtsv2Impl(SPARSE_FN(gtsv2_nopivot, sparse_prefix), \ + *gpusparse_handle_, m, n, dl, d, du, B, ldb, pBuffer); \ } TF_CALL_LAPACK_TYPES(GTSV2_NO_PIVOT_INSTANCE); @@ -303,34 +302,34 @@ static inline Status Gtsv2BufferSizeExtImpl(SparseFn op, const Scalar* d, const Scalar* du, const Scalar* B, int ldb, size_t* bufferSizeInBytes) { - TF_RETURN_IF_CUSPARSE_ERROR(op(cusparse_handle, m, n, AsCudaComplex(dl), - AsCudaComplex(d), AsCudaComplex(du), - AsCudaComplex(B), ldb, bufferSizeInBytes)); + TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, AsCudaComplex(dl), + AsCudaComplex(d), AsCudaComplex(du), + AsCudaComplex(B), ldb, bufferSizeInBytes)); return Status::OK(); } -#define GTSV2_BUFFER_SIZE_INSTANCE(Scalar, sparse_prefix) \ - template <> \ - Status CudaSparse::Gtsv2BufferSizeExt( \ - int m, int n, const Scalar* dl, const Scalar* d, const Scalar* du, \ - const Scalar* B, int ldb, size_t* bufferSizeInBytes) const { \ - DCHECK(initialized_); \ - return Gtsv2BufferSizeExtImpl( \ - SPARSE_FN(gtsv2_bufferSizeExt, sparse_prefix), *cusparse_handle_, m, \ - n, dl, d, du, B, ldb, bufferSizeInBytes); \ +#define GTSV2_BUFFER_SIZE_INSTANCE(Scalar, sparse_prefix) \ + template <> \ + Status GpuSparse::Gtsv2BufferSizeExt( \ + int m, int n, const Scalar* dl, const Scalar* d, const Scalar* du, \ + const Scalar* B, int ldb, size_t* bufferSizeInBytes) const { \ + DCHECK(initialized_); \ + return Gtsv2BufferSizeExtImpl( \ + SPARSE_FN(gtsv2_bufferSizeExt, sparse_prefix), *gpusparse_handle_, m, \ + n, dl, d, du, B, ldb, bufferSizeInBytes); \ } TF_CALL_LAPACK_TYPES(GTSV2_BUFFER_SIZE_INSTANCE); #define GTSV2_NO_PIVOT_BUFFER_SIZE_INSTANCE(Scalar, sparse_prefix) \ template <> \ - Status CudaSparse::Gtsv2NoPivotBufferSizeExt( \ + Status GpuSparse::Gtsv2NoPivotBufferSizeExt( \ int m, int n, const Scalar* dl, const Scalar* d, const Scalar* du, \ const Scalar* B, int ldb, size_t* bufferSizeInBytes) const { \ DCHECK(initialized_); \ return Gtsv2BufferSizeExtImpl( \ SPARSE_FN(gtsv2_nopivot_bufferSizeExt, sparse_prefix), \ - *cusparse_handle_, m, n, dl, d, du, B, ldb, bufferSizeInBytes); \ + *gpusparse_handle_, m, n, dl, d, du, B, ldb, bufferSizeInBytes); \ } TF_CALL_LAPACK_TYPES(GTSV2_NO_PIVOT_BUFFER_SIZE_INSTANCE); @@ -342,7 +341,7 @@ static inline Status Gtsv2StridedBatchImpl(SparseFn op, const Scalar* d, const Scalar* du, Scalar* x, int batchCount, int batchStride, void* pBuffer) { - TF_RETURN_IF_CUSPARSE_ERROR(op( + TF_RETURN_IF_GPUSPARSE_ERROR(op( cusparse_handle, m, AsCudaComplex(dl), AsCudaComplex(d), AsCudaComplex(du), AsCudaComplex(x), batchCount, batchStride, pBuffer)); return Status::OK(); @@ -350,12 +349,12 @@ static inline Status Gtsv2StridedBatchImpl(SparseFn op, #define GTSV2_STRIDED_BATCH_INSTANCE(Scalar, sparse_prefix) \ template <> \ - Status CudaSparse::Gtsv2StridedBatch( \ + Status GpuSparse::Gtsv2StridedBatch( \ int m, const Scalar* dl, const Scalar* d, const Scalar* du, Scalar* x, \ int batchCount, int batchStride, void* pBuffer) const { \ DCHECK(initialized_); \ return Gtsv2StridedBatchImpl(SPARSE_FN(gtsv2StridedBatch, sparse_prefix), \ - *cusparse_handle_, m, dl, d, du, x, \ + *gpusparse_handle_, m, dl, d, du, x, \ batchCount, batchStride, pBuffer); \ } @@ -366,30 +365,30 @@ static inline Status Gtsv2StridedBatchBufferSizeImpl( SparseFn op, cusparseHandle_t cusparse_handle, int m, const Scalar* dl, const Scalar* d, const Scalar* du, const Scalar* x, int batchCount, int batchStride, size_t* bufferSizeInBytes) { - TF_RETURN_IF_CUSPARSE_ERROR(op(cusparse_handle, m, AsCudaComplex(dl), - AsCudaComplex(d), AsCudaComplex(du), - AsCudaComplex(x), batchCount, batchStride, - bufferSizeInBytes)); + TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, AsCudaComplex(dl), + AsCudaComplex(d), AsCudaComplex(du), + AsCudaComplex(x), batchCount, batchStride, + bufferSizeInBytes)); return Status::OK(); } #define GTSV2_STRIDED_BATCH_BUFFER_SIZE_INSTANCE(Scalar, sparse_prefix) \ template <> \ - Status CudaSparse::Gtsv2StridedBatchBufferSizeExt( \ + Status GpuSparse::Gtsv2StridedBatchBufferSizeExt( \ int m, const Scalar* dl, const Scalar* d, const Scalar* du, \ const Scalar* x, int batchCount, int batchStride, \ size_t* bufferSizeInBytes) const { \ DCHECK(initialized_); \ return Gtsv2StridedBatchBufferSizeImpl( \ SPARSE_FN(gtsv2StridedBatch_bufferSizeExt, sparse_prefix), \ - *cusparse_handle_, m, dl, d, du, x, batchCount, batchStride, \ + *gpusparse_handle_, m, dl, d, du, x, batchCount, batchStride, \ bufferSizeInBytes); \ } TF_CALL_LAPACK_TYPES(GTSV2_STRIDED_BATCH_BUFFER_SIZE_INSTANCE); -Status CudaSparse::Coo2csr(const int* cooRowInd, int nnz, int m, - int* csrRowPtr) const { +Status GpuSparse::Coo2csr(const int* cooRowInd, int nnz, int m, + int* csrRowPtr) const { // cusparseStatus_t CUSPARSEAPI cusparseXcoo2csr(cusparseHandle_t handle, // const int *cooRowInd, // int nnz, @@ -398,14 +397,14 @@ Status CudaSparse::Coo2csr(const int* cooRowInd, int nnz, int m, // cusparseIndexBase_t // idxBase); DCHECK(initialized_); - TF_RETURN_IF_CUSPARSE_ERROR(cusparseXcoo2csr(*cusparse_handle_, cooRowInd, - nnz, m, csrRowPtr, - CUSPARSE_INDEX_BASE_ZERO)); + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcoo2csr(*gpusparse_handle_, cooRowInd, + nnz, m, csrRowPtr, + CUSPARSE_INDEX_BASE_ZERO)); return Status::OK(); } -Status CudaSparse::Csr2coo(const int* csrRowPtr, int nnz, int m, - int* cooRowInd) const { +Status GpuSparse::Csr2coo(const int* csrRowPtr, int nnz, int m, + int* cooRowInd) const { // cusparseStatus_t CUSPARSEAPI cusparseXcsr2coo(cusparseHandle_t handle, // const int *csrRowPtr, // int nnz, @@ -414,26 +413,26 @@ Status CudaSparse::Csr2coo(const int* csrRowPtr, int nnz, int m, // cusparseIndexBase_t // idxBase); DCHECK(initialized_); - TF_RETURN_IF_CUSPARSE_ERROR(cusparseXcsr2coo(*cusparse_handle_, csrRowPtr, - nnz, m, cooRowInd, - CUSPARSE_INDEX_BASE_ZERO)); + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsr2coo(*gpusparse_handle_, csrRowPtr, + nnz, m, cooRowInd, + CUSPARSE_INDEX_BASE_ZERO)); return Status::OK(); } -Status CudaSparse::CsrgeamNnz(int m, int n, const cusparseMatDescr_t descrA, - int nnzA, const int* csrSortedRowPtrA, - const int* csrSortedColIndA, - const cusparseMatDescr_t descrB, int nnzB, - const int* csrSortedRowPtrB, - const int* csrSortedColIndB, - const cusparseMatDescr_t descrC, - int* csrSortedRowPtrC, int* nnzTotalDevHostPtr) { +Status GpuSparse::CsrgeamNnz(int m, int n, const cusparseMatDescr_t descrA, + int nnzA, const int* csrSortedRowPtrA, + const int* csrSortedColIndA, + const cusparseMatDescr_t descrB, int nnzB, + const int* csrSortedRowPtrB, + const int* csrSortedColIndB, + const cusparseMatDescr_t descrC, + int* csrSortedRowPtrC, int* nnzTotalDevHostPtr) { DCHECK(initialized_); DCHECK(nnzTotalDevHostPtr != nullptr); - TF_RETURN_IF_CUSPARSE_ERROR(cusparseXcsrgeamNnz( - *cusparse_handle_, m, n, descrA, nnzA, csrSortedRowPtrA, csrSortedColIndA, - descrB, nnzB, csrSortedRowPtrB, csrSortedColIndB, descrC, - csrSortedRowPtrC, nnzTotalDevHostPtr)); + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsrgeamNnz( + *gpusparse_handle_, m, n, descrA, nnzA, csrSortedRowPtrA, + csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, csrSortedColIndB, + descrC, csrSortedRowPtrC, nnzTotalDevHostPtr)); return Status::OK(); } @@ -452,7 +451,7 @@ static inline Status CsrmmImpl( // const float* csrSortedValA, const int* csrSortedRowPtrA, // const int* csrSortedColIndA, const float* B, int ldb, const float* // beta, float* C, int ldc); - TF_RETURN_IF_CUSPARSE_ERROR(op( + TF_RETURN_IF_GPUSPARSE_ERROR(op( cusparse_handle, transA, transB, m, n, k, nnz, AsCudaComplex(alpha_host), descrA, AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA, AsCudaComplex(B), ldb, AsCudaComplex(beta_host), AsCudaComplex(C), ldc)); @@ -461,7 +460,7 @@ static inline Status CsrmmImpl( #define CSRMM_INSTANCE(Scalar, sparse_prefix) \ template <> \ - Status CudaSparse::Csrmm( \ + Status GpuSparse::Csrmm( \ cusparseOperation_t transA, cusparseOperation_t transB, int m, int n, \ int k, int nnz, const Scalar* alpha_host, \ const cusparseMatDescr_t descrA, const Scalar* csrSortedValA, \ @@ -470,7 +469,7 @@ static inline Status CsrmmImpl( const { \ DCHECK(initialized_); \ return CsrmmImpl(SPARSE_FN(csrmm2, sparse_prefix), context_, \ - *cusparse_handle_, transA, transB, m, n, k, nnz, \ + *gpusparse_handle_, transA, transB, m, n, k, nnz, \ alpha_host, descrA, csrSortedValA, csrSortedRowPtrA, \ csrSortedColIndA, B, ldb, beta_host, C, ldc); \ } @@ -484,7 +483,7 @@ static inline Status CsrmvImpl( const cusparseMatDescr_t descrA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA, const int* csrSortedColIndA, const Scalar* x, const Scalar* beta_host, Scalar* y) { - TF_RETURN_IF_CUSPARSE_ERROR( + TF_RETURN_IF_GPUSPARSE_ERROR( op(cusparse_handle, transA, m, n, nnz, AsCudaComplex(alpha_host), descrA, AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA, AsCudaComplex(x), AsCudaComplex(beta_host), AsCudaComplex(y))); @@ -494,7 +493,7 @@ static inline Status CsrmvImpl( // TODO(ebrevdo,rmlarsen): Use csrmv_mp for all cases when available in CUDA 9. #define CSRMV_INSTANCE(Scalar, sparse_prefix) \ template <> \ - Status CudaSparse::Csrmv( \ + Status GpuSparse::Csrmv( \ cusparseOperation_t transA, int m, int n, int nnz, \ const Scalar* alpha_host, const cusparseMatDescr_t descrA, \ const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \ @@ -503,12 +502,12 @@ static inline Status CsrmvImpl( DCHECK(initialized_); \ if (transA == CUSPARSE_OPERATION_NON_TRANSPOSE) { \ return CsrmvImpl(SPARSE_FN(csrmv_mp, sparse_prefix), context_, \ - *cusparse_handle_, transA, m, n, nnz, alpha_host, \ + *gpusparse_handle_, transA, m, n, nnz, alpha_host, \ descrA, csrSortedValA, csrSortedRowPtrA, \ csrSortedColIndA, x, beta_host, y); \ } else { \ return CsrmvImpl(SPARSE_FN(csrmv, sparse_prefix), context_, \ - *cusparse_handle_, transA, m, n, nnz, alpha_host, \ + *gpusparse_handle_, transA, m, n, nnz, alpha_host, \ descrA, csrSortedValA, csrSortedRowPtrA, \ csrSortedColIndA, x, beta_host, y); \ } \ @@ -526,7 +525,7 @@ static inline Status CsrgeamImpl( const int* csrSortedRowPtrB, const int* csrSortedColIndB, const cusparseMatDescr_t descrC, Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC) { - TF_RETURN_IF_CUSPARSE_ERROR( + TF_RETURN_IF_GPUSPARSE_ERROR( op(cusparse_handle, m, n, AsCudaComplex(alpha), descrA, nnzA, AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA, AsCudaComplex(beta), descrB, nnzB, AsCudaComplex(csrSortedValB), @@ -537,7 +536,7 @@ static inline Status CsrgeamImpl( #define CSRGEAM_INSTANCE(Scalar, sparse_prefix) \ template <> \ - Status CudaSparse::Csrgeam( \ + Status GpuSparse::Csrgeam( \ int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA, \ int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \ const int* csrSortedColIndA, const Scalar* beta, \ @@ -547,7 +546,7 @@ static inline Status CsrgeamImpl( int* csrSortedRowPtrC, int* csrSortedColIndC) { \ DCHECK(initialized_); \ return CsrgeamImpl(SPARSE_FN(csrgeam, sparse_prefix), context_, \ - *cusparse_handle_, m, n, alpha, descrA, nnzA, \ + *gpusparse_handle_, m, n, alpha, descrA, nnzA, \ csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, \ beta, descrB, nnzB, csrSortedValB, csrSortedRowPtrB, \ csrSortedColIndB, descrC, csrSortedValC, \ @@ -556,7 +555,7 @@ static inline Status CsrgeamImpl( TF_CALL_LAPACK_TYPES(CSRGEAM_INSTANCE); -Status CudaSparse::CsrgemmNnz( +Status GpuSparse::CsrgemmNnz( cusparseOperation_t transA, cusparseOperation_t transB, int m, int k, int n, const cusparseMatDescr_t descrA, int nnzA, const int* csrSortedRowPtrA, const int* csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB, @@ -565,8 +564,8 @@ Status CudaSparse::CsrgemmNnz( int* nnzTotalDevHostPtr) { DCHECK(initialized_); DCHECK(nnzTotalDevHostPtr != nullptr); - TF_RETURN_IF_CUSPARSE_ERROR(cusparseXcsrgemmNnz( - *cusparse_handle_, transA, transB, m, k, n, descrA, nnzA, + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsrgemmNnz( + *gpusparse_handle_, transA, transB, m, k, n, descrA, nnzA, csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, csrSortedColIndB, descrC, csrSortedRowPtrC, nnzTotalDevHostPtr)); return Status::OK(); @@ -582,7 +581,7 @@ static inline Status CsrgemmImpl( const int* csrSortedRowPtrB, const int* csrSortedColIndB, const cusparseMatDescr_t descrC, Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC) { - TF_RETURN_IF_CUSPARSE_ERROR( + TF_RETURN_IF_GPUSPARSE_ERROR( op(cusparse_handle, transA, transB, m, k, n, descrA, nnzA, AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, AsCudaComplex(csrSortedValB), csrSortedRowPtrB, @@ -593,7 +592,7 @@ static inline Status CsrgemmImpl( #define CSRGEMM_INSTANCE(Scalar, sparse_prefix) \ template <> \ - Status CudaSparse::Csrgemm( \ + Status GpuSparse::Csrgemm( \ cusparseOperation_t transA, cusparseOperation_t transB, int m, int k, \ int n, const cusparseMatDescr_t descrA, int nnzA, \ const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \ @@ -603,7 +602,7 @@ static inline Status CsrgemmImpl( Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC) { \ DCHECK(initialized_); \ return CsrgemmImpl(SPARSE_FN(csrgemm, sparse_prefix), context_, \ - *cusparse_handle_, transA, transB, m, k, n, descrA, \ + *gpusparse_handle_, transA, transB, m, k, n, descrA, \ nnzA, csrSortedValA, csrSortedRowPtrA, \ csrSortedColIndA, descrB, nnzB, csrSortedValB, \ csrSortedRowPtrB, csrSortedColIndB, descrC, \ @@ -620,12 +619,12 @@ static inline Status Csru2csrImpl(SparseFnT op, BufferSizeFnT buffer_size_op, const cusparseMatDescr_t descrA, Scalar* csrVal, const int* csrRowPtr, int* csrColInd) { - CudaSparseCsrSortingConversionInfo info; + GpuSparseCsrSortingConversionInfo info; TF_RETURN_IF_ERROR(info.Initialize()); size_t pBufferSizeInBytes = 0; - TF_RETURN_IF_CUSPARSE_ERROR( + TF_RETURN_IF_GPUSPARSE_ERROR( buffer_size_op(cusparse_handle, m, n, nnz, AsCudaComplex(csrVal), csrRowPtr, csrColInd, info.info(), &pBufferSizeInBytes)); @@ -636,22 +635,22 @@ static inline Status Csru2csrImpl(SparseFnT op, BufferSizeFnT buffer_size_op, auto pBuffer = pBuffer_t.flat(); DCHECK(pBuffer.data() != nullptr); - TF_RETURN_IF_CUSPARSE_ERROR(op(cusparse_handle, m, n, nnz, descrA, - AsCudaComplex(csrVal), csrRowPtr, csrColInd, - info.info(), pBuffer.data())); + TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, nnz, descrA, + AsCudaComplex(csrVal), csrRowPtr, csrColInd, + info.info(), pBuffer.data())); return Status::OK(); } #define CSRU2CSR_INSTANCE(Scalar, sparse_prefix) \ template <> \ - Status CudaSparse::Csru2csr( \ + Status GpuSparse::Csru2csr( \ int m, int n, int nnz, const cusparseMatDescr_t descrA, Scalar* csrVal, \ const int* csrRowPtr, int* csrColInd) { \ DCHECK(initialized_); \ return Csru2csrImpl(SPARSE_FN(csru2csr, sparse_prefix), \ BUFSIZE_FN(csru2csr, sparse_prefix), context_, \ - *cusparse_handle_, m, n, nnz, descrA, csrVal, \ + *gpusparse_handle_, m, n, nnz, descrA, csrVal, \ csrRowPtr, csrColInd); \ } @@ -664,22 +663,22 @@ static inline Status Csr2cscImpl(SparseFnT op, OpKernelContext* context, const int* csrRowPtr, const int* csrColInd, Scalar* cscVal, int* cscRowInd, int* cscColPtr, const cusparseAction_t copyValues) { - TF_RETURN_IF_CUSPARSE_ERROR(op(cusparse_handle, m, n, nnz, - AsCudaComplex(csrVal), csrRowPtr, csrColInd, - AsCudaComplex(cscVal), cscRowInd, cscColPtr, - copyValues, CUSPARSE_INDEX_BASE_ZERO)); + TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, nnz, + AsCudaComplex(csrVal), csrRowPtr, csrColInd, + AsCudaComplex(cscVal), cscRowInd, cscColPtr, + copyValues, CUSPARSE_INDEX_BASE_ZERO)); return Status::OK(); } #define CSR2CSC_INSTANCE(Scalar, sparse_prefix) \ template <> \ - Status CudaSparse::Csr2csc( \ + Status GpuSparse::Csr2csc( \ int m, int n, int nnz, const Scalar* csrVal, const int* csrRowPtr, \ const int* csrColInd, Scalar* cscVal, int* cscRowInd, int* cscColPtr, \ const cusparseAction_t copyValues) { \ DCHECK(initialized_); \ return Csr2cscImpl(SPARSE_FN(csr2csc, sparse_prefix), context_, \ - *cusparse_handle_, m, n, nnz, csrVal, csrRowPtr, \ + *gpusparse_handle_, m, n, nnz, csrVal, csrRowPtr, \ csrColInd, cscVal, cscRowInd, cscColPtr, copyValues); \ } diff --git a/tensorflow/core/kernels/cuda_sparse.h b/tensorflow/core/kernels/cuda_sparse.h index f2ef99c67e6f7b..6d042cf48c5a2a 100644 --- a/tensorflow/core/kernels/cuda_sparse.h +++ b/tensorflow/core/kernels/cuda_sparse.h @@ -16,15 +16,38 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_CUDA_SPARSE_H_ #define TENSORFLOW_CORE_KERNELS_CUDA_SPARSE_H_ -// This header declares the class CudaSparse, which contains wrappers of +// This header declares the class GpuSparse, which contains wrappers of // cuSparse libraries for use in TensorFlow kernels. -#ifdef GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include #include +#if GOOGLE_CUDA + #include "third_party/gpus/cuda/include/cusparse.h" + +using gpusparseStatus_t = cusparseStatus_t; +using gpusparseOperation_t = cusparseOperation_t; +using gpusparseMatDescr_t = cusparseMatDescr_t; +using gpusparseAction_t = cusparseAction_t; +using gpusparseHandle_t = cusparseHandle_t; +using gpuStream_t = cudaStream_t; + +#elif TENSORFLOW_USE_ROCM + +#include "rocm/include/hipsparse/hipsparse.h" + +using gpusparseStatus_t = hipsparseStatus_t; +using gpusparseOperation_t = hipsparseOperation_t; +using gpusparseMatDescr_t = hipsparseMatDescr_t; +using gpusparseAction_t = hipsparseAction_t; +using gpusparseHandle_t = hipsparseHandle_t; +using gpuStream_t = hipStream_t; + +#endif + #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" @@ -40,13 +63,15 @@ limitations under the License. namespace tensorflow { -inline string ConvertCUSparseErrorToString(const cusparseStatus_t status) { +inline string ConvertGPUSparseErrorToString(const gpusparseStatus_t status) { switch (status) { #define STRINGIZE(q) #q #define RETURN_IF_STATUS(err) \ case err: \ return STRINGIZE(err); +#if GOOGLE_CUDA + RETURN_IF_STATUS(CUSPARSE_STATUS_SUCCESS) RETURN_IF_STATUS(CUSPARSE_STATUS_NOT_INITIALIZED) RETURN_IF_STATUS(CUSPARSE_STATUS_ALLOC_FAILED) @@ -57,27 +82,62 @@ inline string ConvertCUSparseErrorToString(const cusparseStatus_t status) { RETURN_IF_STATUS(CUSPARSE_STATUS_INTERNAL_ERROR) RETURN_IF_STATUS(CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED) -#undef RETURN_IF_STATUS -#undef STRINGIZE default: return strings::StrCat("Unknown CUSPARSE error: ", static_cast(status)); +#elif TENSORFLOW_USE_ROCM + + RETURN_IF_STATUS(HIPSPARSE_STATUS_SUCCESS) + RETURN_IF_STATUS(HIPSPARSE_STATUS_NOT_INITIALIZED) + RETURN_IF_STATUS(HIPSPARSE_STATUS_ALLOC_FAILED) + RETURN_IF_STATUS(HIPSPARSE_STATUS_INVALID_VALUE) + RETURN_IF_STATUS(HIPSPARSE_STATUS_ARCH_MISMATCH) + RETURN_IF_STATUS(HIPSPARSE_STATUS_MAPPING_ERROR) + RETURN_IF_STATUS(HIPSPARSE_STATUS_EXECUTION_FAILED) + RETURN_IF_STATUS(HIPSPARSE_STATUS_INTERNAL_ERROR) + RETURN_IF_STATUS(HIPSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED) + RETURN_IF_STATUS(HIPSPARSE_STATUS_ZERO_PIVOT) + + default: + return strings::StrCat("Unknown hipSPARSE error: ", + static_cast(status)); +#endif + +#undef RETURN_IF_STATUS +#undef STRINGIZE } } -#define TF_RETURN_IF_CUSPARSE_ERROR(expr) \ +#if GOOGLE_CUDA + +#define TF_RETURN_IF_GPUSPARSE_ERROR(expr) \ do { \ auto status = (expr); \ if (TF_PREDICT_FALSE(status != CUSPARSE_STATUS_SUCCESS)) { \ return errors::Internal(__FILE__, ":", __LINE__, " (", TF_STR(expr), \ "): cuSparse call failed with status ", \ - ConvertCUSparseErrorToString(status)); \ + ConvertGPUSparseErrorToString(status)); \ } \ } while (0) -inline cusparseOperation_t TransposeAndConjugateToCuSparseOp(bool transpose, - bool conjugate, - Status* status) { +#elif TENSORFLOW_USE_ROCM + +#define TF_RETURN_IF_GPUSPARSE_ERROR(expr) \ + do { \ + auto status = (expr); \ + if (TF_PREDICT_FALSE(status != HIPSPARSE_STATUS_SUCCESS)) { \ + return errors::Internal(__FILE__, ":", __LINE__, " (", TF_STR(expr), \ + "): hipSPARSE call failed with status ", \ + ConvertGPUSparseErrorToString(status)); \ + } \ + } while (0) + +#endif + +inline gpusparseOperation_t TransposeAndConjugateToGpuSparseOp(bool transpose, + bool conjugate, + Status* status) { +#if GOOGLE_CUDA if (transpose) { return conjugate ? CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE : CUSPARSE_OPERATION_TRANSPOSE; @@ -89,25 +149,38 @@ inline cusparseOperation_t TransposeAndConjugateToCuSparseOp(bool transpose, } return CUSPARSE_OPERATION_NON_TRANSPOSE; } +#elif TENSORFLOW_USE_ROCM + if (transpose) { + return conjugate ? HIPSPARSE_OPERATION_CONJUGATE_TRANSPOSE + : HIPSPARSE_OPERATION_TRANSPOSE; + } else { + if (conjugate) { + DCHECK(status != nullptr); + *status = errors::InvalidArgument( + "Conjugate == True and transpose == False is not supported."); + } + return HIPSPARSE_OPERATION_NON_TRANSPOSE; + } +#endif } -// The CudaSparse class provides a simplified templated API for cuSparse +// The GpuSparse class provides a simplified templated API for cuSparse // (http://docs.nvidia.com/cuda/cusparse/index.html). // An object of this class wraps static cuSparse instances, // and will launch Cuda kernels on the stream wrapped by the GPU device // in the OpKernelContext provided to the constructor. // // Notice: All the computational member functions are asynchronous and simply -// launch one or more Cuda kernels on the Cuda stream wrapped by the CudaSparse +// launch one or more Cuda kernels on the Cuda stream wrapped by the GpuSparse // object. -class CudaSparse { +class GpuSparse { public: // This object stores a pointer to context, which must outlive it. - explicit CudaSparse(OpKernelContext* context); - virtual ~CudaSparse() {} + explicit GpuSparse(OpKernelContext* context); + virtual ~GpuSparse() {} - // This initializes the CudaSparse class if it hasn't + // This initializes the GpuSparse class if it hasn't // been initialized yet. All following public methods require the // class has been initialized. Can be run multiple times; all // subsequent calls after the first have no effect. @@ -218,9 +291,9 @@ class CudaSparse { // // **NOTE** This is an in-place operation for data in C. template - Status Csrmm(cusparseOperation_t transA, cusparseOperation_t transB, int m, + Status Csrmm(gpusparseOperation_t transA, gpusparseOperation_t transB, int m, int n, int k, int nnz, const Scalar* alpha_host, - const cusparseMatDescr_t descrA, const Scalar* csrSortedValA, + const gpusparseMatDescr_t descrA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA, const int* csrSortedColIndA, const Scalar* B, int ldb, const Scalar* beta_host, Scalar* C, int ldc) const; @@ -231,8 +304,8 @@ class CudaSparse { // // **NOTE** This is an in-place operation for data in y. template - Status Csrmv(cusparseOperation_t transA, int m, int n, int nnz, - const Scalar* alpha_host, const cusparseMatDescr_t descrA, + Status Csrmv(gpusparseOperation_t transA, int m, int n, int nnz, + const Scalar* alpha_host, const gpusparseMatDescr_t descrA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA, const int* csrSortedColIndA, const Scalar* x, const Scalar* beta_host, Scalar* y) const; @@ -242,11 +315,11 @@ class CudaSparse { // output. csrSortedRowPtrC must be preallocated on device with // m + 1 entries. See: // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgeam. - Status CsrgeamNnz(int m, int n, const cusparseMatDescr_t descrA, int nnzA, + Status CsrgeamNnz(int m, int n, const gpusparseMatDescr_t descrA, int nnzA, const int* csrSortedRowPtrA, const int* csrSortedColIndA, - const cusparseMatDescr_t descrB, int nnzB, + const gpusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB, const int* csrSortedColIndB, - const cusparseMatDescr_t descrC, int* csrSortedRowPtrC, + const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC, int* nnzTotalDevHostPtr); // Computes sparse - sparse matrix addition of matrices @@ -256,12 +329,12 @@ class CudaSparse { // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgeam. template Status Csrgeam(int m, int n, const Scalar* alpha, - const cusparseMatDescr_t descrA, int nnzA, + const gpusparseMatDescr_t descrA, int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA, const int* csrSortedColIndA, const Scalar* beta, - const cusparseMatDescr_t descrB, int nnzB, + const gpusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, const int* csrSortedRowPtrB, - const int* csrSortedColIndB, const cusparseMatDescr_t descrC, + const int* csrSortedColIndB, const gpusparseMatDescr_t descrC, Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC); @@ -270,13 +343,13 @@ class CudaSparse { // output. csrSortedRowPtrC must be preallocated on device with // m + 1 entries. See: // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm. - Status CsrgemmNnz(cusparseOperation_t transA, cusparseOperation_t transB, - int m, int k, int n, const cusparseMatDescr_t descrA, + Status CsrgemmNnz(gpusparseOperation_t transA, gpusparseOperation_t transB, + int m, int k, int n, const gpusparseMatDescr_t descrA, int nnzA, const int* csrSortedRowPtrA, const int* csrSortedColIndA, - const cusparseMatDescr_t descrB, int nnzB, + const gpusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB, const int* csrSortedColIndB, - const cusparseMatDescr_t descrC, int* csrSortedRowPtrC, + const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC, int* nnzTotalDevHostPtr); // Computes sparse - sparse matrix matmul of matrices @@ -285,19 +358,20 @@ class CudaSparse { // with nnzTotalDevHostPtr entries (as calculated by CsrgemmNnz). See: // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm. template - Status Csrgemm(cusparseOperation_t transA, cusparseOperation_t transB, int m, - int k, int n, const cusparseMatDescr_t descrA, int nnzA, - const Scalar* csrSortedValA, const int* csrSortedRowPtrA, - const int* csrSortedColIndA, const cusparseMatDescr_t descrB, - int nnzB, const Scalar* csrSortedValB, - const int* csrSortedRowPtrB, const int* csrSortedColIndB, - const cusparseMatDescr_t descrC, Scalar* csrSortedValC, - int* csrSortedRowPtrC, int* csrSortedColIndC); + Status Csrgemm(gpusparseOperation_t transA, gpusparseOperation_t transB, + int m, int k, int n, const gpusparseMatDescr_t descrA, + int nnzA, const Scalar* csrSortedValA, + const int* csrSortedRowPtrA, const int* csrSortedColIndA, + const gpusparseMatDescr_t descrB, int nnzB, + const Scalar* csrSortedValB, const int* csrSortedRowPtrB, + const int* csrSortedColIndB, const gpusparseMatDescr_t descrC, + Scalar* csrSortedValC, int* csrSortedRowPtrC, + int* csrSortedColIndC); // In-place reordering of unsorted CSR to sorted CSR. // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csru2csr template - Status Csru2csr(int m, int n, int nnz, const cusparseMatDescr_t descrA, + Status Csru2csr(int m, int n, int nnz, const gpusparseMatDescr_t descrA, Scalar* csrVal, const int* csrRowPtr, int* csrColInd); // Converts from CSR to CSC format (equivalently, transpose). @@ -306,30 +380,30 @@ class CudaSparse { Status Csr2csc(int m, int n, int nnz, const Scalar* csrVal, const int* csrRowPtr, const int* csrColInd, Scalar* cscVal, int* cscRowInd, int* cscColPtr, - const cusparseAction_t copyValues); + const gpusparseAction_t copyValues); private: bool initialized_; OpKernelContext *context_; // not owned. - cudaStream_t cuda_stream_; - cusparseHandle_t *cusparse_handle_; // not owned. + gpuStream_t gpu_stream_; + gpusparseHandle_t* gpusparse_handle_; // not owned. - TF_DISALLOW_COPY_AND_ASSIGN(CudaSparse); + TF_DISALLOW_COPY_AND_ASSIGN(GpuSparse); }; // A wrapper class to ensure that a CUDA sparse matrix descriptor is initialized -// only once. For more details on the descriptor (cusparseMatDescr_t), see: +// only once. For more details on the descriptor (gpusparseMatDescr_t), see: // https://docs.nvidia.com/cuda/cusparse/index.html#cusparsematdescrt -class CudaSparseMatrixDescriptor { +class GpuSparseMatrixDescriptor { public: - explicit CudaSparseMatrixDescriptor() : initialized_(false) {} + explicit GpuSparseMatrixDescriptor() : initialized_(false) {} - CudaSparseMatrixDescriptor(CudaSparseMatrixDescriptor&& rhs) + GpuSparseMatrixDescriptor(GpuSparseMatrixDescriptor&& rhs) : initialized_(rhs.initialized_), descr_(std::move(rhs.descr_)) { rhs.initialized_ = false; } - CudaSparseMatrixDescriptor& operator=(CudaSparseMatrixDescriptor&& rhs) { + GpuSparseMatrixDescriptor& operator=(GpuSparseMatrixDescriptor&& rhs) { if (this == &rhs) return *this; Release(); initialized_ = rhs.initialized_; @@ -338,23 +412,27 @@ class CudaSparseMatrixDescriptor { return *this; } - ~CudaSparseMatrixDescriptor() { Release(); } + ~GpuSparseMatrixDescriptor() { Release(); } // Initializes the underlying descriptor. Will fail on the second call if // called more than once. Status Initialize() { DCHECK(!initialized_); - TF_RETURN_IF_CUSPARSE_ERROR(cusparseCreateMatDescr(&descr_)); +#if GOOGLE_CUDA + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descr_)); +#elif TENSORFLOW_USE_ROCM + TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descr_)); +#endif initialized_ = true; return Status::OK(); } - cusparseMatDescr_t& descr() { + gpusparseMatDescr_t& descr() { DCHECK(initialized_); return descr_; } - const cusparseMatDescr_t& descr() const { + const gpusparseMatDescr_t& descr() const { DCHECK(initialized_); return descr_; } @@ -362,31 +440,37 @@ class CudaSparseMatrixDescriptor { private: void Release() { if (initialized_) { +#if GOOGLE_CUDA cusparseDestroyMatDescr(descr_); +#elif TENSORFLOW_USE_ROCM + hipsparseDestroyMatDescr(descr_); +#endif initialized_ = false; } } bool initialized_; - cusparseMatDescr_t descr_; + gpusparseMatDescr_t descr_; - TF_DISALLOW_COPY_AND_ASSIGN(CudaSparseMatrixDescriptor); + TF_DISALLOW_COPY_AND_ASSIGN(GpuSparseMatrixDescriptor); }; +#if GOOGLE_CUDA + // A wrapper class to ensure that an unsorted/sorted CSR conversion information // struct (csru2csrInfo_t) is initialized only once. See: // https://docs.nvidia.com/cuda/cusparse/index.html#csru2csr -class CudaSparseCsrSortingConversionInfo { +class GpuSparseCsrSortingConversionInfo { public: - explicit CudaSparseCsrSortingConversionInfo() : initialized_(false) {} + explicit GpuSparseCsrSortingConversionInfo() : initialized_(false) {} - CudaSparseCsrSortingConversionInfo(CudaSparseCsrSortingConversionInfo&& rhs) + GpuSparseCsrSortingConversionInfo(GpuSparseCsrSortingConversionInfo&& rhs) : initialized_(rhs.initialized_), info_(std::move(rhs.info_)) { rhs.initialized_ = false; } - CudaSparseCsrSortingConversionInfo& operator=( - CudaSparseCsrSortingConversionInfo&& rhs) { + GpuSparseCsrSortingConversionInfo& operator=( + GpuSparseCsrSortingConversionInfo&& rhs) { if (this == &rhs) return *this; Release(); initialized_ = rhs.initialized_; @@ -395,13 +479,13 @@ class CudaSparseCsrSortingConversionInfo { return *this; } - ~CudaSparseCsrSortingConversionInfo() { Release(); } + ~GpuSparseCsrSortingConversionInfo() { Release(); } // Initializes the underlying info. Will fail on the second call if called // more than once. Status Initialize() { DCHECK(!initialized_); - TF_RETURN_IF_CUSPARSE_ERROR(cusparseCreateCsru2csrInfo(&info_)); + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsru2csrInfo(&info_)); initialized_ = true; return Status::OK(); } @@ -427,11 +511,13 @@ class CudaSparseCsrSortingConversionInfo { bool initialized_; csru2csrInfo_t info_; - TF_DISALLOW_COPY_AND_ASSIGN(CudaSparseCsrSortingConversionInfo); + TF_DISALLOW_COPY_AND_ASSIGN(GpuSparseCsrSortingConversionInfo); }; +#endif // GOOGLE_CUDA + } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #endif // TENSORFLOW_CORE_KERNELS_CUDA_SPARSE_H_ diff --git a/tensorflow/core/kernels/rocm_sparse.cc b/tensorflow/core/kernels/rocm_sparse.cc new file mode 100644 index 00000000000000..97488692bc1253 --- /dev/null +++ b/tensorflow/core/kernels/rocm_sparse.cc @@ -0,0 +1,330 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if TENSORFLOW_USE_ROCM + +#include +#include +#include +#include +#include + +#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/cuda_solvers.h" +#include "tensorflow/core/kernels/cuda_sparse.h" +#include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { + +// A set of initialized handles to the underlying ROCm libraries used by +// GpuSparse. We maintain one such set of handles per unique stream. +class HipSparseHandles { + public: + explicit HipSparseHandles(hipStream_t stream) + : initialized_(false), stream_(stream) {} + + HipSparseHandles(HipSparseHandles&& rhs) + : initialized_(rhs.initialized_), + stream_(std::move(rhs.stream_)), + hipsparse_handle_(rhs.hipsparse_handle_) { + rhs.initialized_ = false; + } + + HipSparseHandles& operator=(HipSparseHandles&& rhs) { + if (this == &rhs) return *this; + Release(); + stream_ = std::move(rhs.stream_); + hipsparse_handle_ = std::move(rhs.hipsparse_handle_); + initialized_ = rhs.initialized_; + rhs.initialized_ = false; + return *this; + } + + ~HipSparseHandles() { Release(); } + + Status Initialize() { + if (initialized_) return Status::OK(); + TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreate(&hipsparse_handle_)); + TF_RETURN_IF_GPUSPARSE_ERROR( + hipsparseSetStream(hipsparse_handle_, stream_)); + initialized_ = true; + return Status::OK(); + } + + hipsparseHandle_t& handle() { + DCHECK(initialized_); + return hipsparse_handle_; + } + + const hipsparseHandle_t& handle() const { + DCHECK(initialized_); + return hipsparse_handle_; + } + + private: + void Release() { + if (initialized_) { + // This should never return anything other than success + auto err = hipsparseDestroy(hipsparse_handle_); + DCHECK(err == HIPSPARSE_STATUS_SUCCESS) + << "Failed to destroy hipSPARSE instance."; + initialized_ = false; + } + } + bool initialized_; + hipStream_t stream_; + hipsparseHandle_t hipsparse_handle_; + + TF_DISALLOW_COPY_AND_ASSIGN(HipSparseHandles); +}; + +// TODO(ebrevdo): Replace global mutex guarding CudaSparseHandles +// lookup with one of: +// 1. Adding the handle to the CudaStream structure; do the lookup there. +// 2. Add a thread-local cusparse, set it to the current stream +// upon each call. +// #1 seems like the cleanest option but will need to wait until this +// is moved into TF core. +static mutex handle_map_mutex(LINKER_INITIALIZED); + +using HandleMap = std::unordered_map; + +// Returns a singleton map used for storing initialized handles for each unique +// cuda stream. +HandleMap* GetHandleMapSingleton() { + static HandleMap* cm = new HandleMap; + return cm; +} + +} // namespace + +GpuSparse::GpuSparse(OpKernelContext* context) + : initialized_(false), context_(context) { + auto hip_stream_ptr = + reinterpret_cast(context->op_device_context() + ->stream() + ->implementation() + ->GpuStreamMemberHack()); + DCHECK(hip_stream_ptr); + gpu_stream_ = *hip_stream_ptr; +} + +Status GpuSparse::Initialize() { + HandleMap* handle_map = GetHandleMapSingleton(); + DCHECK(handle_map); + mutex_lock lock(handle_map_mutex); + auto it = handle_map->find(gpu_stream_); + if (it == handle_map->end()) { + LOG(INFO) << "Creating GpuSparse handles for stream " << gpu_stream_; + // Previously unseen ROCm stream. Initialize a set of ROCm sparse library + // handles for it. + HipSparseHandles new_handles(gpu_stream_); + TF_RETURN_IF_ERROR(new_handles.Initialize()); + it = handle_map->insert(std::make_pair(gpu_stream_, std::move(new_handles))) + .first; + } + gpusparse_handle_ = &it->second.handle(); + initialized_ = true; + return Status::OK(); +} + +// Macro that specializes a sparse method for all 4 standard +// numeric types. +#define TF_CALL_HIP_LAPACK_TYPES(m) m(float, S) m(double, D) + +// Macros to construct hipsparse method names. +#define SPARSE_FN(method, sparse_prefix) hipsparse##sparse_prefix##method + +Status GpuSparse::Coo2csr(const int* cooRowInd, int nnz, int m, + int* csrRowPtr) const { + DCHECK(initialized_); + TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseXcoo2csr(*gpusparse_handle_, cooRowInd, + nnz, m, csrRowPtr, + HIPSPARSE_INDEX_BASE_ZERO)); + return Status::OK(); +} + +Status GpuSparse::Csr2coo(const int* csrRowPtr, int nnz, int m, + int* cooRowInd) const { + DCHECK(initialized_); + TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseXcsr2coo(*gpusparse_handle_, csrRowPtr, + nnz, m, cooRowInd, + HIPSPARSE_INDEX_BASE_ZERO)); + return Status::OK(); +} + +template +static inline Status CsrmmImpl( + SparseFnT op, OpKernelContext* context, hipsparseHandle_t hipsparse_handle, + hipsparseOperation_t transA, hipsparseOperation_t transB, int m, int n, + int k, int nnz, const Scalar* alpha_host, const hipsparseMatDescr_t descrA, + const Scalar* csrSortedValA, const int* csrSortedRowPtrA, + const int* csrSortedColIndA, const Scalar* B, int ldb, + const Scalar* beta_host, Scalar* C, int ldc) { + TF_RETURN_IF_GPUSPARSE_ERROR(op(hipsparse_handle, transA, transB, m, n, k, + nnz, alpha_host, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, B, ldb, + beta_host, C, ldc)); + return Status::OK(); +} + +#define CSRMM_INSTANCE(Scalar, sparse_prefix) \ + template <> \ + Status GpuSparse::Csrmm( \ + hipsparseOperation_t transA, hipsparseOperation_t transB, int m, int n, \ + int k, int nnz, const Scalar* alpha_host, \ + const hipsparseMatDescr_t descrA, const Scalar* csrSortedValA, \ + const int* csrSortedRowPtrA, const int* csrSortedColIndA, \ + const Scalar* B, int ldb, const Scalar* beta_host, Scalar* C, int ldc) \ + const { \ + DCHECK(initialized_); \ + return CsrmmImpl(SPARSE_FN(csrmm2, sparse_prefix), context_, \ + *gpusparse_handle_, transA, transB, m, n, k, nnz, \ + alpha_host, descrA, csrSortedValA, csrSortedRowPtrA, \ + csrSortedColIndA, B, ldb, beta_host, C, ldc); \ + } + +TF_CALL_HIP_LAPACK_TYPES(CSRMM_INSTANCE); + +template +static inline Status CsrmvImpl(SparseFnT op, OpKernelContext* context, + hipsparseHandle_t hipsparse_handle, + hipsparseOperation_t transA, int m, int n, + int nnz, const Scalar* alpha_host, + const hipsparseMatDescr_t descrA, + const Scalar* csrSortedValA, + const int* csrSortedRowPtrA, + const int* csrSortedColIndA, const Scalar* x, + const Scalar* beta_host, Scalar* y) { + TF_RETURN_IF_GPUSPARSE_ERROR( + op(hipsparse_handle, transA, m, n, nnz, alpha_host, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, x, beta_host, y)); + return Status::OK(); +} + +// TODO(ebrevdo,rmlarsen): Use csrmv_mp for all cases when available in CUDA 9. +#define CSRMV_INSTANCE(Scalar, sparse_prefix) \ + template <> \ + Status GpuSparse::Csrmv( \ + hipsparseOperation_t transA, int m, int n, int nnz, \ + const Scalar* alpha_host, const hipsparseMatDescr_t descrA, \ + const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \ + const int* csrSortedColIndA, const Scalar* x, const Scalar* beta_host, \ + Scalar* y) const { \ + DCHECK(initialized_); \ + return CsrmvImpl(SPARSE_FN(csrmv, sparse_prefix), context_, \ + *gpusparse_handle_, transA, m, n, nnz, alpha_host, \ + descrA, csrSortedValA, csrSortedRowPtrA, \ + csrSortedColIndA, x, beta_host, y); \ + } + +TF_CALL_HIP_LAPACK_TYPES(CSRMV_INSTANCE); + +Status GpuSparse::CsrgemmNnz( + hipsparseOperation_t transA, hipsparseOperation_t transB, int m, int n, + int k, const hipsparseMatDescr_t descrA, int nnzA, + const int* csrSortedRowPtrA, const int* csrSortedColIndA, + const hipsparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB, + const int* csrSortedColIndB, const hipsparseMatDescr_t descrC, + int* csrSortedRowPtrC, int* nnzTotalDevHostPtr) { + DCHECK(initialized_); + DCHECK(nnzTotalDevHostPtr != nullptr); + TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseXcsrgemmNnz( + *gpusparse_handle_, transA, transB, m, n, k, descrA, nnzA, + csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, + csrSortedColIndB, descrC, csrSortedRowPtrC, nnzTotalDevHostPtr)); + return Status::OK(); +} + +template +static inline Status CsrgemmImpl( + SparseFnT op, OpKernelContext* context, hipsparseHandle_t hipsparse_handle, + hipsparseOperation_t transA, hipsparseOperation_t transB, int m, int n, + int k, const hipsparseMatDescr_t descrA, int nnzA, + const Scalar* csrSortedValA, const int* csrSortedRowPtrA, + const int* csrSortedColIndA, const hipsparseMatDescr_t descrB, int nnzB, + const Scalar* csrSortedValB, const int* csrSortedRowPtrB, + const int* csrSortedColIndB, const hipsparseMatDescr_t descrC, + Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC) { + TF_RETURN_IF_GPUSPARSE_ERROR( + op(hipsparse_handle, transA, transB, m, n, k, descrA, nnzA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, csrSortedValB, + csrSortedRowPtrB, csrSortedColIndB, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC)); + return Status::OK(); +} + +#define CSRGEMM_INSTANCE(Scalar, sparse_prefix) \ + template <> \ + Status GpuSparse::Csrgemm( \ + hipsparseOperation_t transA, hipsparseOperation_t transB, int m, int n, \ + int k, const hipsparseMatDescr_t descrA, int nnzA, \ + const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \ + const int* csrSortedColIndA, const hipsparseMatDescr_t descrB, int nnzB, \ + const Scalar* csrSortedValB, const int* csrSortedRowPtrB, \ + const int* csrSortedColIndB, const hipsparseMatDescr_t descrC, \ + Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC) { \ + DCHECK(initialized_); \ + return CsrgemmImpl(SPARSE_FN(csrgemm, sparse_prefix), context_, \ + *gpusparse_handle_, transA, transB, m, n, k, descrA, \ + nnzA, csrSortedValA, csrSortedRowPtrA, \ + csrSortedColIndA, descrB, nnzB, csrSortedValB, \ + csrSortedRowPtrB, csrSortedColIndB, descrC, \ + csrSortedValC, csrSortedRowPtrC, csrSortedColIndC); \ + } + +TF_CALL_HIP_LAPACK_TYPES(CSRGEMM_INSTANCE); + +template +static inline Status Csr2cscImpl(SparseFnT op, OpKernelContext* context, + hipsparseHandle_t hipsparse_handle, int m, + int n, int nnz, const Scalar* csrVal, + const int* csrRowPtr, const int* csrColInd, + Scalar* cscVal, int* cscRowInd, int* cscColPtr, + const hipsparseAction_t copyValues) { + TF_RETURN_IF_GPUSPARSE_ERROR( + op(hipsparse_handle, m, n, nnz, csrVal, csrRowPtr, csrColInd, cscVal, + cscRowInd, cscColPtr, copyValues, HIPSPARSE_INDEX_BASE_ZERO)); + return Status::OK(); +} + +#define CSR2CSC_INSTANCE(Scalar, sparse_prefix) \ + template <> \ + Status GpuSparse::Csr2csc( \ + int m, int n, int nnz, const Scalar* csrVal, const int* csrRowPtr, \ + const int* csrColInd, Scalar* cscVal, int* cscRowInd, int* cscColPtr, \ + const hipsparseAction_t copyValues) { \ + DCHECK(initialized_); \ + return Csr2cscImpl(SPARSE_FN(csr2csc, sparse_prefix), context_, \ + *gpusparse_handle_, m, n, nnz, csrVal, csrRowPtr, \ + csrColInd, cscVal, cscRowInd, cscColPtr, copyValues); \ + } + +TF_CALL_HIP_LAPACK_TYPES(CSR2CSC_INSTANCE); + +} // namespace tensorflow + +#endif // TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/sparse/BUILD b/tensorflow/core/kernels/sparse/BUILD index befe9c7c5ed64e..6b4dba69ff2f40 100644 --- a/tensorflow/core/kernels/sparse/BUILD +++ b/tensorflow/core/kernels/sparse/BUILD @@ -2,10 +2,10 @@ load( "//tensorflow:tensorflow.bzl", + "if_cuda_or_rocm", "tf_cc_test", "tf_kernel_library", ) -load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") package( default_visibility = ["//visibility:public"], @@ -77,7 +77,7 @@ tf_kernel_library( "//tensorflow/core/kernels:scatter_nd_op", "//tensorflow/core/kernels:slice_op", "//tensorflow/core/kernels:transpose_functor", - ] + if_cuda([ + ] + if_cuda_or_rocm([ "//tensorflow/core/kernels:cuda_solvers", "//tensorflow/core/kernels:cuda_sparse", ]), diff --git a/tensorflow/core/kernels/sparse/add_op.cc b/tensorflow/core/kernels/sparse/add_op.cc index 95d69410d45044..81bc7dfdb7de45 100644 --- a/tensorflow/core/kernels/sparse/add_op.cc +++ b/tensorflow/core/kernels/sparse/add_op.cc @@ -15,7 +15,7 @@ limitations under the License. #define EIGEN_USE_THREADS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #endif @@ -31,7 +31,7 @@ limitations under the License. #include "tensorflow/core/kernels/sparse/sparse_matrix.h" #include "tensorflow/core/kernels/fill_functor.h" -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/cuda_solvers.h" #include "tensorflow/core/kernels/cuda_sparse.h" #endif @@ -233,8 +233,10 @@ class CSRAddOp : public OpKernel { REGISTER_GPU(float) REGISTER_GPU(double) +#if GOOGLE_CUDA REGISTER_GPU(complex64) REGISTER_GPU(complex128) +#endif #undef REGISTER_GPU @@ -246,7 +248,7 @@ REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION( #undef REGISTER -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM namespace functor { template struct CSRSparseMatrixAdd @@ -324,10 +326,10 @@ struct CSRSparseMatrixAdd private: OpKernelContext* ctx_; - CudaSparse cuda_sparse_; - CudaSparseMatrixDescriptor descrA_; - CudaSparseMatrixDescriptor descrB_; - CudaSparseMatrixDescriptor descrC_; + GpuSparse cuda_sparse_; + GpuSparseMatrixDescriptor descrA_; + GpuSparseMatrixDescriptor descrB_; + GpuSparseMatrixDescriptor descrC_; const T alpha_; const T beta_; bool initialized_; @@ -337,6 +339,6 @@ struct CSRSparseMatrixAdd } // namespace functor -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace tensorflow diff --git a/tensorflow/core/kernels/sparse/conj_op.cc b/tensorflow/core/kernels/sparse/conj_op.cc index df1042ab8017df..7275262c1f041b 100644 --- a/tensorflow/core/kernels/sparse/conj_op.cc +++ b/tensorflow/core/kernels/sparse/conj_op.cc @@ -15,7 +15,7 @@ limitations under the License. #define EIGEN_USE_THREADS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #endif @@ -31,7 +31,7 @@ limitations under the License. #include "tensorflow/core/kernels/sparse/kernels.h" #include "tensorflow/core/kernels/sparse/sparse_matrix.h" -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/cuda_solvers.h" #include "tensorflow/core/kernels/cuda_sparse.h" #endif @@ -92,12 +92,12 @@ REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION( CONJ_VARIANT_UNARY_OP, DEVICE_CPU, CSRSparseMatrix, (CSRSparseMatrixUnaryHelper)); -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION( CONJ_VARIANT_UNARY_OP, DEVICE_GPU, CSRSparseMatrix, (CSRSparseMatrixUnaryHelper)); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace tensorflow diff --git a/tensorflow/core/kernels/sparse/csr_sparse_matrix_to_dense_op.cc b/tensorflow/core/kernels/sparse/csr_sparse_matrix_to_dense_op.cc index 92cb1080ca9c56..9e5a11c4aeb53e 100644 --- a/tensorflow/core/kernels/sparse/csr_sparse_matrix_to_dense_op.cc +++ b/tensorflow/core/kernels/sparse/csr_sparse_matrix_to_dense_op.cc @@ -15,7 +15,7 @@ limitations under the License. #define EIGEN_USE_THREADS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #endif @@ -33,7 +33,7 @@ limitations under the License. #include "tensorflow/core/kernels/sparse/sparse_matrix.h" #include "tensorflow/core/util/work_sharder.h" -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/cuda_solvers.h" #include "tensorflow/core/kernels/cuda_sparse.h" #endif @@ -220,19 +220,21 @@ REGISTER_CPU(double) REGISTER_CPU(complex64) REGISTER_CPU(complex128) -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM REGISTER_GPU(float) REGISTER_GPU(double) +#if GOOGLE_CUDA REGISTER_GPU(complex64) REGISTER_GPU(complex128) +#endif -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #undef REGISTER_CPU #undef REGISTER_GPU -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM namespace functor { template <> @@ -256,6 +258,6 @@ extern template struct CSRSparseMatrixToCOOSparseMatrix; } // namespace functor -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace tensorflow diff --git a/tensorflow/core/kernels/sparse/csr_sparse_matrix_to_sparse_tensor_op.cc b/tensorflow/core/kernels/sparse/csr_sparse_matrix_to_sparse_tensor_op.cc index 237401eaf4b437..55ebfa4fc10f71 100644 --- a/tensorflow/core/kernels/sparse/csr_sparse_matrix_to_sparse_tensor_op.cc +++ b/tensorflow/core/kernels/sparse/csr_sparse_matrix_to_sparse_tensor_op.cc @@ -15,7 +15,7 @@ limitations under the License. #define EIGEN_USE_THREADS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #endif @@ -31,7 +31,7 @@ limitations under the License. #include "tensorflow/core/kernels/sparse/sparse_matrix.h" #include "tensorflow/core/util/work_sharder.h" -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/cuda_solvers.h" #include "tensorflow/core/kernels/cuda_sparse.h" #endif @@ -205,18 +205,20 @@ class CSRSparseMatrixToSparseTensorGPUOp : public OpKernel { .HostMemory("dense_shape"), \ CSRSparseMatrixToSparseTensorGPUOp); -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM REGISTER_GPU(float) REGISTER_GPU(double) +#if GOOGLE_CUDA REGISTER_GPU(complex64) REGISTER_GPU(complex128) +#endif -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #undef REGISTER_GPU -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM namespace functor { template <> @@ -240,7 +242,7 @@ extern template struct CSRSparseMatrixToCOOSparseMatrix; } // namespace functor -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define REGISTER_CPU(T) \ REGISTER_KERNEL_BUILDER(Name("CSRSparseMatrixToSparseTensor") \ diff --git a/tensorflow/core/kernels/sparse/dense_to_csr_sparse_matrix_op.cc b/tensorflow/core/kernels/sparse/dense_to_csr_sparse_matrix_op.cc index 6e0397c8d27022..b42d315789b8b2 100644 --- a/tensorflow/core/kernels/sparse/dense_to_csr_sparse_matrix_op.cc +++ b/tensorflow/core/kernels/sparse/dense_to_csr_sparse_matrix_op.cc @@ -15,7 +15,7 @@ limitations under the License. #define EIGEN_USE_THREADS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #endif @@ -32,13 +32,18 @@ limitations under the License. #include "tensorflow/core/kernels/sparse/kernels.h" #include "tensorflow/core/kernels/sparse/sparse_matrix.h" -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/kernels/cuda_solvers.h" #include "tensorflow/core/kernels/cuda_sparse.h" -#include "tensorflow/stream_executor/cuda/cuda_activation.h" +#endif +#if GOOGLE_CUDA +#include "tensorflow/stream_executor/cuda/cuda_activation.h" using ::perftools::gputools::cuda::ScopedActivateExecutorContext; +#elif TENSORFLOW_USE_ROCM +#include "tensorflow/stream_executor/rocm/rocm_activation.h" +using ::perftools::gputools::rocm::ScopedActivateExecutorContext; #endif namespace tensorflow { @@ -138,7 +143,7 @@ REGISTER_CPU(complex128) #undef REGISTER_CPU -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM template class DenseToCSRSparseMatrixGPUOp : public AsyncOpKernel { @@ -356,8 +361,10 @@ class DenseToCSRSparseMatrixGPUOp : public AsyncOpKernel { REGISTER_GPU(GPU, float) REGISTER_GPU(GPU, double) +#if GOOGLE_CUDA REGISTER_GPU(GPU, complex64) REGISTER_GPU(GPU, complex128) +#endif namespace functor { @@ -380,7 +387,7 @@ struct COOSparseMatrixToCSRSparseMatrix { Status operator()(OpKernelContext* c, const int rows, const int cols, TTypes::UnalignedVec coo_row_ind, TTypes::UnalignedVec csr_row_ptr) { - CudaSparse cuda_sparse(c); + GpuSparse cuda_sparse(c); TF_RETURN_IF_ERROR(cuda_sparse.Initialize()); return cuda_sparse.Coo2csr(coo_row_ind.data(), /*nnz*/ coo_row_ind.size(), @@ -391,7 +398,7 @@ extern template struct COOSparseMatrixToCSRSparseMatrix; } // namespace functor -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #undef REGISTER_GPU diff --git a/tensorflow/core/kernels/sparse/kernels_gpu.cu.cc b/tensorflow/core/kernels/sparse/kernels_gpu.cu.cc index 2890a109b9f0bd..99c6d5b9259325 100644 --- a/tensorflow/core/kernels/sparse/kernels_gpu.cu.cc +++ b/tensorflow/core/kernels/sparse/kernels_gpu.cu.cc @@ -13,15 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#if GOOGLE_CUDA #include "third_party/cub/device/device_histogram.cuh" #include "third_party/cub/iterator/counting_input_iterator.cuh" #include "third_party/cub/iterator/transform_input_iterator.cuh" #include "third_party/gpus/cuda/include/cusparse.h" +#elif TENSORFLOW_USE_ROCM +#include "rocm/include/hipcub/hipcub.hpp" +#endif #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/cuda_sparse.h" @@ -32,6 +36,12 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/gpu_kernel_helper.h" +#if GOOGLE_CUDA +namespace gpuprim = ::cub; +#elif TENSORFLOW_USE_ROCM +namespace gpuprim = ::hipcub; +#endif + namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; @@ -65,9 +75,9 @@ Status CalculateNNZPerBatchMatrixFromIndices::operator()( DCHECK_EQ(indices.dimension(1), 3); // batch, row, col const int rank = indices.dimension(1); - cub::CountingInputIterator row_counter(0); - cub::TransformInputIterator> + gpuprim::CountingInputIterator row_counter(0); + gpuprim::TransformInputIterator> indices_first_column(row_counter, StridedDataReader(indices.data(), rank)); @@ -76,7 +86,7 @@ Status CalculateNNZPerBatchMatrixFromIndices::operator()( DCHECK_NE(indices.data(), nullptr); DCHECK_NE(nnz_per_batch.data(), nullptr); - auto first_success = cub::DeviceHistogram::HistogramEven( + auto first_success = gpuprim::DeviceHistogram::HistogramEven( /*d_temp_storage*/ nullptr, /*temp_storage_bytes&*/ temp_storage_bytes, /*d_samples*/ indices_first_column, @@ -87,12 +97,12 @@ Status CalculateNNZPerBatchMatrixFromIndices::operator()( /*num_samples*/ total_nnz, /*stream*/ cu_stream); - if (first_success != cudaSuccess) { + if (first_success != gpuSuccess) { return errors::Internal( "SparseTensorToCSRSparseMatrix: Could not launch " - "cub::DeviceHistogram::HistogramEven " + "gpuprim::DeviceHistogram::HistogramEven " "to calculate temp_storage_bytes, status: ", - cudaGetErrorString(first_success)); + GpuGetErrorString(first_success)); } Tensor temp_storage; @@ -100,7 +110,7 @@ Status CalculateNNZPerBatchMatrixFromIndices::operator()( DT_INT8, TensorShape({static_cast(temp_storage_bytes)}), &temp_storage)); DCHECK_NE(temp_storage.flat().data(), nullptr); - auto second_success = cub::DeviceHistogram::HistogramEven( + auto second_success = gpuprim::DeviceHistogram::HistogramEven( /*d_temp_storage*/ temp_storage.flat().data(), /*temp_storage_bytes&*/ temp_storage_bytes, /*d_samples*/ indices_first_column, @@ -111,12 +121,12 @@ Status CalculateNNZPerBatchMatrixFromIndices::operator()( /*num_samples*/ total_nnz, /*stream*/ cu_stream); - if (second_success != cudaSuccess) { + if (second_success != gpuSuccess) { return errors::Internal( "SparseTensorToCSRSparseMatrix: Could not launch " - "cub::DeviceHistogram::HistogramEven " + "gpuprim::DeviceHistogram::HistogramEven " "to count nnz entries per batch. temp_storage_bytes: ", - temp_storage_bytes, ", status: ", cudaGetErrorString(second_success)); + temp_storage_bytes, ", status: ", GpuGetErrorString(second_success)); } return Status::OK(); @@ -128,11 +138,11 @@ template <> Status CSRSparseMatrixToCOOSparseMatrix::operator()( OpKernelContext* c, TTypes::UnalignedVec csr_row_ptr, TTypes::UnalignedVec coo_row_ind) { - CudaSparse cuda_sparse(c); + GpuSparse gpu_sparse(c); const int nnz = coo_row_ind.size(); - TF_RETURN_IF_ERROR(cuda_sparse.Initialize()); + TF_RETURN_IF_ERROR(gpu_sparse.Initialize()); const int m = csr_row_ptr.size() - 1; // rows - return cuda_sparse.Csr2coo(csr_row_ptr.data(), nnz, m, coo_row_ind.data()); + return gpu_sparse.Csr2coo(csr_row_ptr.data(), nnz, m, coo_row_ind.data()); } template @@ -140,7 +150,7 @@ __global__ void SparseTensorToCOOMatrixKernel(const int64* indices, int* coo_rows_out, int* coo_cols_out, int size) { const int offset = (stride == 3) ? 1 : 0; - CUDA_1D_KERNEL_LOOP(i, size) { + GPU_1D_KERNEL_LOOP(i, size) { coo_rows_out[i] = static_cast(ldg(indices + i * stride + offset)); coo_cols_out[i] = static_cast(ldg(indices + i * stride + offset + 1)); } @@ -157,20 +167,22 @@ void SparseTensorToCOOSparseMatrix::operator()( const int size = coo_row_ind.dimension(0); GpuLaunchConfig config = GetGpuLaunchConfig(size, d); if (stride == 2) { - SparseTensorToCOOMatrixKernel<2> - <<>>( - indices.data(), coo_row_ind.data(), coo_col_ind.data(), size); + TF_CHECK_OK(GpuLaunchKernel(SparseTensorToCOOMatrixKernel<2>, + config.block_count, config.thread_per_block, 0, + d.stream(), indices.data(), coo_row_ind.data(), + coo_col_ind.data(), size)); } else { - SparseTensorToCOOMatrixKernel<3> - <<>>( - indices.data(), coo_row_ind.data(), coo_col_ind.data(), size); + TF_CHECK_OK(GpuLaunchKernel(SparseTensorToCOOMatrixKernel<3>, + config.block_count, config.thread_per_block, 0, + d.stream(), indices.data(), coo_row_ind.data(), + coo_col_ind.data(), size)); } } __global__ void COOMatrixToSparseTensorKernel2D(const int* coo_rows, const int* coo_cols, int64* indices_out, int size) { - CUDA_1D_KERNEL_LOOP(i, size) { + GPU_1D_KERNEL_LOOP(i, size) { indices_out[i * 2] = static_cast(ldg(coo_rows + i)); indices_out[i * 2 + 1] = static_cast(ldg(coo_cols + i)); } @@ -203,7 +215,7 @@ __global__ void COOMatrixToSparseTensorKernel3D( } __syncthreads(); - CUDA_1D_KERNEL_LOOP(i, size) { + GPU_1D_KERNEL_LOOP(i, size) { // TODO(ebrevdo): Consider special casing batch_size <= 3, // alternatively doing linear instead of binary search. Requires // some benchmarks. @@ -231,9 +243,10 @@ Status COOSparseMatrixToSparseTensor::operator()( DCHECK_EQ(size, indices.dimension(0)); if (ndims == 2) { GpuLaunchConfig config = GetGpuLaunchConfig(size, d); - COOMatrixToSparseTensorKernel2D<<>>( - coo_row_ind.data(), coo_col_ind.data(), indices.data(), size); + TF_CHECK_OK(GpuLaunchKernel(COOMatrixToSparseTensorKernel2D, + config.block_count, config.thread_per_block, 0, + d.stream(), coo_row_ind.data(), + coo_col_ind.data(), indices.data(), size)); return Status::OK(); } else { const int batch_size = host_dense_shape(0); @@ -246,11 +259,11 @@ Status COOSparseMatrixToSparseTensor::operator()( GpuLaunchConfig config = GetGpuLaunchConfig(size, d); // shared memory stores the batch pointers. const size_t shared_memory_size = sizeof(int) * (batch_size + 1); - COOMatrixToSparseTensorKernel3D<<>>( - coo_row_ind.data(), coo_col_ind.data(), indices.data(), - batch_ptr_copy.data(), batch_size, size); + TF_CHECK_OK( + GpuLaunchKernel(COOMatrixToSparseTensorKernel3D, config.block_count, + config.thread_per_block, shared_memory_size, d.stream(), + coo_row_ind.data(), coo_col_ind.data(), indices.data(), + batch_ptr_copy.data(), batch_size, size)); return Status::OK(); } } @@ -274,7 +287,7 @@ __global__ void CSRSparseMatrixBatchMulVecKernel3D( } __syncthreads(); - CUDA_1D_KERNEL_LOOP(i, total_nnz) { + GPU_1D_KERNEL_LOOP(i, total_nnz) { const int b = BinarySearchRange(local_batch_ptr, batch_size, i); c_values[i] = ldg(a_values + i) * local_batch_values[b]; } @@ -316,10 +329,10 @@ Status CSRSparseMatrixBatchMulVecImpl(OpKernelContext* ctx, const size_t shared_memory_size = (sizeof(int) * (batch_size + 1) // local batch_pointers. + sizeof(T) * batch_size); // local copy of b. - CSRSparseMatrixBatchMulVecKernel3D - <<>>(a_values.data(), b.data(), c_values.data(), - batch_ptr_copy.data(), batch_size, total_nnz); + TF_CHECK_OK(GpuLaunchKernel( + CSRSparseMatrixBatchMulVecKernel3D, config.block_count, + config.thread_per_block, shared_memory_size, d.stream(), a_values.data(), + b.data(), c_values.data(), batch_ptr_copy.data(), batch_size, total_nnz)); return Status::OK(); } @@ -374,7 +387,7 @@ __global__ void CSRSparseMatrixSoftmaxKernel2D(const int rows, // algorithm to distribute the work in case the row sizes are // uneven: // http://images.nvidia.com/events/sc15/pdfs/sc15-Merge-Based-Parallel-Sparse-Matrix-Vector-Multiplication-merrill.pdf - CUDA_1D_KERNEL_LOOP(row, rows) { + GPU_1D_KERNEL_LOOP(row, rows) { CalculateRowSoftmax(ldg(row_ptr + row), ldg(row_ptr + row + 1), logits, softmax); } @@ -382,7 +395,7 @@ __global__ void CSRSparseMatrixSoftmaxKernel2D(const int rows, EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void CopyFromGpuDeviceArrayToLocal( GpuDeviceArrayStruct cuda_ptr_s, int* local_ptr, int length) { -#ifdef __CUDA_ARCH__ +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) const int* cuda_ptr = GetGpuDeviceArrayOnDevice(&cuda_ptr_s); for (int i = threadIdx.x; i < length; i += blockDim.x) { local_ptr[i] = cuda_ptr[i]; @@ -404,7 +417,7 @@ __global__ void CSRSparseMatrixSoftmaxKernel3D( CopyFromGpuDeviceArrayToLocal(std::move(batch_ptr_s), local_batch_ptr, batch_size + 1); - CUDA_1D_KERNEL_LOOP(i, size) { + GPU_1D_KERNEL_LOOP(i, size) { const int batch = i / rows; const int row = i % rows; const int batch_offset = local_batch_ptr[batch]; @@ -431,10 +444,10 @@ Status CSRSparseMatrixSoftmaxGPUImpl(OpKernelContext* ctx, const int rows = host_dense_shape(0); DCHECK_EQ(rows, row_ptr.size() - 1); GpuLaunchConfig config = GetGpuLaunchConfig(rows /*size*/, d); - CSRSparseMatrixSoftmaxKernel2D - <<>>( - rows /*size*/, row_ptr.data(), logits_values.data(), - softmax_values.data()); + TF_CHECK_OK(GpuLaunchKernel(CSRSparseMatrixSoftmaxKernel2D, + config.block_count, config.thread_per_block, 0, + d.stream(), rows /*size*/, row_ptr.data(), + logits_values.data(), softmax_values.data())); } else { const int batch_size = host_dense_shape(0); const int rows = host_dense_shape(1); @@ -452,10 +465,11 @@ Status CSRSparseMatrixSoftmaxGPUImpl(OpKernelContext* ctx, GpuLaunchConfig config = GetGpuLaunchConfig(size, d); // shared memory stores the batch pointers. const size_t shared_memory_size = sizeof(int) * (batch_size + 1); - CSRSparseMatrixSoftmaxKernel3D - <<>>(size, rows, batch_ptr_copy.data(), row_ptr.data(), - logits_values.data(), softmax_values.data()); + TF_CHECK_OK(GpuLaunchKernel(CSRSparseMatrixSoftmaxKernel3D, + config.block_count, config.thread_per_block, + shared_memory_size, d.stream(), size, rows, + batch_ptr_copy.data(), row_ptr.data(), + logits_values.data(), softmax_values.data())); } return Status::OK(); @@ -549,7 +563,7 @@ __global__ void CSRSparseMatrixSoftmaxGradKernel2D( // algorithm to distribute the work in case the row sizes are // uneven: // http://images.nvidia.com/events/sc15/pdfs/sc15-Merge-Based-Parallel-Sparse-Matrix-Vector-Multiplication-merrill.pdf - CUDA_1D_KERNEL_LOOP(row, rows) { + GPU_1D_KERNEL_LOOP(row, rows) { CalculateRowSoftmaxGrad( ldg(softmax_row_ptr + row) /*softmax_begin*/, ldg(softmax_row_ptr + row + 1) /*softmax_end*/, softmax_col_ind, @@ -579,7 +593,7 @@ __global__ void CSRSparseMatrixSoftmaxGradKernel3D( #define SOFTMAX_BATCH_PTR(i) local_batch_ptr[i]; #define GRAD_SOFTMAX_BATCH_PTR(i) local_batch_ptr[batch_size + 1 + i]; - CUDA_1D_KERNEL_LOOP(i, size) { + GPU_1D_KERNEL_LOOP(i, size) { const int batch = i / rows; const int row = i % rows; const int softmax_batch_offset = SOFTMAX_BATCH_PTR(batch); @@ -625,12 +639,12 @@ Status CSRSparseMatrixSoftmaxGradGPUImpl( DCHECK_EQ(rows + 1, softmax_row_ptr.size()); DCHECK_EQ(rows + 1, grad_softmax_row_ptr.size()); GpuLaunchConfig config = GetGpuLaunchConfig(rows /*size*/, d); - CSRSparseMatrixSoftmaxGradKernel2D - <<>>( - rows /*size*/, softmax_row_ptr.data(), softmax_col_ind.data(), - softmax_values.data(), grad_softmax_row_ptr.data(), - grad_softmax_col_ind.data(), grad_softmax_values.data(), - gradient_values.data()); + TF_CHECK_OK(GpuLaunchKernel( + CSRSparseMatrixSoftmaxGradKernel2D, config.block_count, + config.thread_per_block, 0, d.stream(), rows /*size*/, + softmax_row_ptr.data(), softmax_col_ind.data(), softmax_values.data(), + grad_softmax_row_ptr.data(), grad_softmax_col_ind.data(), + grad_softmax_values.data(), gradient_values.data())); } else { const int batch_size = host_dense_shape(0); const int rows = host_dense_shape(1); @@ -656,13 +670,13 @@ Status CSRSparseMatrixSoftmaxGradGPUImpl( // shared memory stores two copies of batch pointers: one for the // softmax CSR matrix, one for the grad_softmax CSR matrix. const size_t shared_memory_size = 2 * sizeof(int) * (batch_size + 1); - CSRSparseMatrixSoftmaxGradKernel3D - <<>>(size, rows, softmax_and_grad_batch_ptr_copy.data(), - softmax_row_ptr.data(), softmax_col_ind.data(), - softmax_values.data(), grad_softmax_row_ptr.data(), - grad_softmax_col_ind.data(), - grad_softmax_values.data(), gradient_values.data()); + TF_CHECK_OK(GpuLaunchKernel( + CSRSparseMatrixSoftmaxGradKernel3D, config.block_count, + config.thread_per_block, shared_memory_size, d.stream(), size, rows, + softmax_and_grad_batch_ptr_copy.data(), softmax_row_ptr.data(), + softmax_col_ind.data(), softmax_values.data(), + grad_softmax_row_ptr.data(), grad_softmax_col_ind.data(), + grad_softmax_values.data(), gradient_values.data())); } return Status::OK(); @@ -687,4 +701,4 @@ DEFINE_SOFTMAX_GRAD_GPU(double); } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/sparse/mat_mul_op.cc b/tensorflow/core/kernels/sparse/mat_mul_op.cc index c279c9f0314efe..a57d97b7a730c6 100644 --- a/tensorflow/core/kernels/sparse/mat_mul_op.cc +++ b/tensorflow/core/kernels/sparse/mat_mul_op.cc @@ -15,7 +15,7 @@ limitations under the License. #define EIGEN_USE_THREADS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #endif @@ -36,7 +36,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/platform/threadpool.h" -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/cuda_solvers.h" #include "tensorflow/core/kernels/cuda_sparse.h" #endif @@ -694,7 +694,7 @@ REGISTER_CPU(complex128) #undef REGISTER_CPU -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define REGISTER_GPU(T) \ REGISTER_KERNEL_BUILDER( \ @@ -703,14 +703,16 @@ REGISTER_CPU(complex128) REGISTER_GPU(float) REGISTER_GPU(double) +#if GOOGLE_CUDA REGISTER_GPU(complex64) REGISTER_GPU(complex128) +#endif #undef REGISTER_GPU -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM namespace functor { @@ -723,7 +725,7 @@ class CSRSparseMatrixMatMul { Status Compute(OpKernelContext* ctx, const ConstCSRComponent& a, typename TTypes::UnalignedConstMatrix b, typename TTypes::UnalignedMatrix c) { - CudaSparse cuda_sparse(ctx); + GpuSparse cuda_sparse(ctx); TF_RETURN_IF_ERROR(cuda_sparse.Initialize()); { // Use Csrmm to calculate: @@ -741,19 +743,34 @@ class CSRSparseMatrixMatMul { // transA must be non-transpose if transB is transpose (cusparse // limitation). - const cusparseOperation_t transA = CUSPARSE_OPERATION_NON_TRANSPOSE; +#if GOOGLE_CUDA + const gpusparseOperation_t transA = CUSPARSE_OPERATION_NON_TRANSPOSE; +#elif TENSORFLOW_USE_ROCM + const gpusparseOperation_t transA = HIPSPARSE_OPERATION_NON_TRANSPOSE; +#endif // transB: b is row-major, and cusparse requires col-major b (or // equivalently transB == transpose). this version is actually more // efficient. - const cusparseOperation_t transB = CUSPARSE_OPERATION_TRANSPOSE; +#if GOOGLE_CUDA + const gpusparseOperation_t transB = CUSPARSE_OPERATION_TRANSPOSE; - cusparseMatDescr_t descrA; - TF_RETURN_IF_CUSPARSE_ERROR(cusparseCreateMatDescr(&descrA)); - TF_RETURN_IF_CUSPARSE_ERROR( + gpusparseMatDescr_t descrA; + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA)); + TF_RETURN_IF_GPUSPARSE_ERROR( cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL)); - TF_RETURN_IF_CUSPARSE_ERROR( + TF_RETURN_IF_GPUSPARSE_ERROR( cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO)); +#elif TENSORFLOW_USE_ROCM + const gpusparseOperation_t transB = HIPSPARSE_OPERATION_TRANSPOSE; + + gpusparseMatDescr_t descrA; + TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descrA)); + TF_RETURN_IF_GPUSPARSE_ERROR( + hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL)); + TF_RETURN_IF_GPUSPARSE_ERROR( + hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO)); +#endif // A is (m, k), Bt is (ldb, k) and Ct is (ldc, n) const int k = b.dimension(0); @@ -796,13 +813,13 @@ template class CSRSparseMatrixMatVec { public: CSRSparseMatrixMatVec(bool transpose_a, bool conjugate_a) - : transA_(TransposeAndConjugateToCuSparseOp(transpose_a, conjugate_a, - &status_)) {} + : transA_(TransposeAndConjugateToGpuSparseOp(transpose_a, conjugate_a, + &status_)) {} Status Compute(OpKernelContext* ctx, const ConstCSRComponent& a, const T* x, T* y) { TF_RETURN_IF_ERROR(status_); - CudaSparse cuda_sparse(ctx); + GpuSparse cuda_sparse(ctx); TF_RETURN_IF_ERROR(cuda_sparse.Initialize()); { // Use Csrmv to calculate: @@ -815,12 +832,20 @@ class CSRSparseMatrixMatVec { const T alpha = 1; const T beta = 0; - cusparseMatDescr_t descrA; - TF_RETURN_IF_CUSPARSE_ERROR(cusparseCreateMatDescr(&descrA)); - TF_RETURN_IF_CUSPARSE_ERROR( + gpusparseMatDescr_t descrA; +#if GOOGLE_CUDA + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA)); + TF_RETURN_IF_GPUSPARSE_ERROR( cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL)); - TF_RETURN_IF_CUSPARSE_ERROR( + TF_RETURN_IF_GPUSPARSE_ERROR( cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO)); +#elif TENSORFLOW_USE_ROCM + TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descrA)); + TF_RETURN_IF_GPUSPARSE_ERROR( + hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL)); + TF_RETURN_IF_GPUSPARSE_ERROR( + hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO)); +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM const int m = a.dense_shape_host(0); const int n = a.dense_shape_host(1); @@ -836,11 +861,11 @@ class CSRSparseMatrixMatVec { private: Status status_; - const cusparseOperation_t transA_; + const gpusparseOperation_t transA_; }; } // namespace functor -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace tensorflow diff --git a/tensorflow/core/kernels/sparse/mul_op.cc b/tensorflow/core/kernels/sparse/mul_op.cc index d63512252f765e..f6cf369626c853 100644 --- a/tensorflow/core/kernels/sparse/mul_op.cc +++ b/tensorflow/core/kernels/sparse/mul_op.cc @@ -15,7 +15,7 @@ limitations under the License. #define EIGEN_USE_THREADS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #endif @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/core/kernels/sparse/kernels.h" #include "tensorflow/core/kernels/sparse/sparse_matrix.h" -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/cuda_sparse.h" #endif @@ -101,22 +101,24 @@ class CSRMulOp : public OpKernel { Name("SparseMatrixMul").Device(DEVICE_##DEV).TypeConstraint("T"), \ CSRMulOp); -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define REGISTER_GPU(T) REGISTER(GPU, T) REGISTER_GPU(float) REGISTER_GPU(double) +#if GOOGLE_CUDA REGISTER_GPU(complex64) REGISTER_GPU(complex128) +#endif #undef REGISTER_GPU -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #undef REGISTER -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM namespace functor { @@ -159,13 +161,15 @@ class CSRSparseMatrixMulScalar { DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); +#if GOOGLE_CUDA DECLARE_GPU_SPEC(std::complex); DECLARE_GPU_SPEC(std::complex); +#endif #undef DECLARE_GPU_SPEC } // namespace functor -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace tensorflow diff --git a/tensorflow/core/kernels/sparse/nnz_op.cc b/tensorflow/core/kernels/sparse/nnz_op.cc index e38b39916c3002..ebc48c3e9a48d0 100644 --- a/tensorflow/core/kernels/sparse/nnz_op.cc +++ b/tensorflow/core/kernels/sparse/nnz_op.cc @@ -15,7 +15,7 @@ limitations under the License. #define EIGEN_USE_THREADS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #endif @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/core/kernels/sparse/kernels.h" #include "tensorflow/core/kernels/sparse/sparse_matrix.h" -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/cuda_solvers.h" #include "tensorflow/core/kernels/cuda_sparse.h" #endif @@ -67,11 +67,11 @@ class CSRNNZOp : public OpKernel { REGISTER(CPU) -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM REGISTER(GPU) -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #undef REGISTER diff --git a/tensorflow/core/kernels/sparse/softmax_op.cc b/tensorflow/core/kernels/sparse/softmax_op.cc index 0195eb474e95e8..25025bfe2a62fc 100644 --- a/tensorflow/core/kernels/sparse/softmax_op.cc +++ b/tensorflow/core/kernels/sparse/softmax_op.cc @@ -19,7 +19,7 @@ limitations under the License. #define EIGEN_USE_THREADS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/cuda_sparse.h" #define EIGEN_USE_GPU #endif @@ -84,7 +84,7 @@ class CSRSoftmaxOp : public OpKernel { } }; -#ifdef GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define REGISTER(DEV, T) \ REGISTER_KERNEL_BUILDER(Name("SparseMatrixSoftmax") \ .Device(DEVICE_##DEV) \ @@ -110,7 +110,7 @@ DECLARE_GPU_SPEC(double); #undef DECLARE_GPU_SPEC } // namespace functor -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM template class CSRSoftmaxGradOp : public OpKernel { @@ -193,7 +193,7 @@ class CSRSoftmaxGradOp : public OpKernel { } }; -#ifdef GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define REGISTER(DEV, T) \ REGISTER_KERNEL_BUILDER(Name("SparseMatrixSoftmaxGrad") \ .Device(DEVICE_##DEV) \ @@ -220,6 +220,6 @@ DECLARE_GPU_SPEC(double); #undef DECLARE_GPU_SPEC } // namespace functor -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace tensorflow diff --git a/tensorflow/core/kernels/sparse/sparse_mat_mul_op.cc b/tensorflow/core/kernels/sparse/sparse_mat_mul_op.cc index a03d60ed155daa..e06dbcb0242543 100644 --- a/tensorflow/core/kernels/sparse/sparse_mat_mul_op.cc +++ b/tensorflow/core/kernels/sparse/sparse_mat_mul_op.cc @@ -15,7 +15,7 @@ limitations under the License. #define EIGEN_USE_THREADS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #endif @@ -35,7 +35,7 @@ limitations under the License. #include "tensorflow/core/kernels/sparse/sparse_matrix.h" #include "tensorflow/core/util/work_sharder.h" -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/cuda_solvers.h" #include "tensorflow/core/kernels/cuda_sparse.h" #endif @@ -498,22 +498,24 @@ REGISTER_CPU(complex128) .TypeConstraint("type"), \ CSRSparseMatMulGPUOp); -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define REGISTER_GPU(T) REGISTER(GPU, T) REGISTER_GPU(float) REGISTER_GPU(double) +#if GOOGLE_CUDA REGISTER_GPU(complex64) REGISTER_GPU(complex128) +#endif // GOOGLE_CUDA #undef REGISTER_GPU -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #undef REGISTER -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM namespace functor { template struct CSRSparseSparseMatrixMatMul @@ -527,11 +529,20 @@ struct CSRSparseSparseMatrixMatMul adjoint_a_(adjoint_a), transpose_b_(transpose_b) { // TODO(ebrevdo): Figure out why transposed implementations crash cuSparse. +#if GOOGLE_CUDA transA_ = transpose_a ? (adjoint_a ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE) : CUSPARSE_OPERATION_NON_TRANSPOSE; transB_ = transpose_b ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE; +#elif TENSORFLOW_USE_ROCM + transA_ = transpose_a + ? (adjoint_a ? HIPSPARSE_OPERATION_TRANSPOSE + : HIPSPARSE_OPERATION_CONJUGATE_TRANSPOSE) + : HIPSPARSE_OPERATION_NON_TRANSPOSE; + transB_ = transpose_b ? HIPSPARSE_OPERATION_TRANSPOSE + : HIPSPARSE_OPERATION_NON_TRANSPOSE; +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } Status Initialize() { @@ -630,20 +641,20 @@ struct CSRSparseSparseMatrixMatMul private: OpKernelContext* ctx_; - CudaSparse cuda_sparse_; + GpuSparse cuda_sparse_; bool initialized_; bool transpose_a_; bool adjoint_a_; bool transpose_b_; - CudaSparseMatrixDescriptor descrA_; - CudaSparseMatrixDescriptor descrB_; - CudaSparseMatrixDescriptor descrC_; - cusparseOperation_t transA_; - cusparseOperation_t transB_; + GpuSparseMatrixDescriptor descrA_; + GpuSparseMatrixDescriptor descrB_; + GpuSparseMatrixDescriptor descrC_; + gpusparseOperation_t transA_; + gpusparseOperation_t transB_; }; } // namespace functor -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace tensorflow diff --git a/tensorflow/core/kernels/sparse/sparse_matrix.cc b/tensorflow/core/kernels/sparse/sparse_matrix.cc index 0871ba2b121b19..98ee8458c65c4b 100644 --- a/tensorflow/core/kernels/sparse/sparse_matrix.cc +++ b/tensorflow/core/kernels/sparse/sparse_matrix.cc @@ -15,7 +15,7 @@ limitations under the License. #define EIGEN_USE_THREADS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #endif diff --git a/tensorflow/core/kernels/sparse/sparse_matrix.h b/tensorflow/core/kernels/sparse/sparse_matrix.h index 482e5978c9e9b0..8fec9f42fbde85 100644 --- a/tensorflow/core/kernels/sparse/sparse_matrix.h +++ b/tensorflow/core/kernels/sparse/sparse_matrix.h @@ -18,7 +18,7 @@ limitations under the License. #define EIGEN_USE_THREADS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #endif diff --git a/tensorflow/core/kernels/sparse/sparse_matrix_components_op.cc b/tensorflow/core/kernels/sparse/sparse_matrix_components_op.cc index e72c85184d1a32..9cbe88bde6c84a 100644 --- a/tensorflow/core/kernels/sparse/sparse_matrix_components_op.cc +++ b/tensorflow/core/kernels/sparse/sparse_matrix_components_op.cc @@ -15,7 +15,7 @@ limitations under the License. #define EIGEN_USE_THREADS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #endif @@ -29,7 +29,7 @@ limitations under the License. #include "tensorflow/core/kernels/sparse/kernels.h" #include "tensorflow/core/kernels/sparse/sparse_matrix.h" -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/cuda_solvers.h" #include "tensorflow/core/kernels/cuda_sparse.h" #endif @@ -116,12 +116,14 @@ REGISTER(CPU, double) REGISTER(CPU, complex64) REGISTER(CPU, complex128) -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM REGISTER(GPU, float) REGISTER(GPU, double) +#if GOOGLE_CUDA REGISTER(GPU, complex64) REGISTER(GPU, complex128) +#endif #undef REGISTER @@ -139,12 +141,14 @@ namespace functor { DECLARE_GPU_SPEC(int32); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); +#if GOOGLE_CUDA DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex128); +#endif #undef DECLARE_GPU_SPEC } // namespace functor -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace tensorflow diff --git a/tensorflow/core/kernels/sparse/sparse_tensor_to_csr_sparse_matrix_op.cc b/tensorflow/core/kernels/sparse/sparse_tensor_to_csr_sparse_matrix_op.cc index 3ecebfe0ac7fbe..47efd24f83a338 100644 --- a/tensorflow/core/kernels/sparse/sparse_tensor_to_csr_sparse_matrix_op.cc +++ b/tensorflow/core/kernels/sparse/sparse_tensor_to_csr_sparse_matrix_op.cc @@ -15,7 +15,7 @@ limitations under the License. #define EIGEN_USE_THREADS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #endif @@ -30,13 +30,18 @@ limitations under the License. #include "tensorflow/core/kernels/sparse/kernels.h" #include "tensorflow/core/kernels/sparse/sparse_matrix.h" -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/kernels/cuda_solvers.h" #include "tensorflow/core/kernels/cuda_sparse.h" -#include "tensorflow/stream_executor/cuda/cuda_activation.h" +#endif +#if GOOGLE_CUDA +#include "tensorflow/stream_executor/cuda/cuda_activation.h" using ::perftools::gputools::cuda::ScopedActivateExecutorContext; +#elif TENSORFLOW_USE_ROCM +#include "tensorflow/stream_executor/rocm/rocm_activation.h" +using ::perftools::gputools::rocm::ScopedActivateExecutorContext; #endif namespace tensorflow { @@ -104,7 +109,7 @@ class SparseTensorToCSRSparseMatrixCPUOp : public OpKernel { } }; -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM template class SparseTensorToCSRSparseMatrixGPUOp : public AsyncOpKernel { @@ -302,7 +307,7 @@ struct COOSparseMatrixToCSRSparseMatrix { Status operator()(OpKernelContext* c, const int rows, const int cols, TTypes::UnalignedVec coo_row_ind, TTypes::UnalignedVec csr_row_ptr) { - CudaSparse cuda_sparse(c); + GpuSparse cuda_sparse(c); TF_RETURN_IF_ERROR(cuda_sparse.Initialize()); return cuda_sparse.Coo2csr(coo_row_ind.data(), /*nnz*/ coo_row_ind.size(), @@ -322,12 +327,14 @@ extern template struct COOSparseMatrixToCSRSparseMatrix; REGISTER_GPU(float) REGISTER_GPU(double) +#if GOOGLE_CUDA REGISTER_GPU(complex64) REGISTER_GPU(complex128) +#endif #undef REGISTER_GPU -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define REGISTER_CPU(T) \ REGISTER_KERNEL_BUILDER(Name("SparseTensorToCSRSparseMatrix") \ diff --git a/tensorflow/core/kernels/sparse/transpose_op.cc b/tensorflow/core/kernels/sparse/transpose_op.cc index 137e285ec067d2..f9ddb1d8d97f80 100644 --- a/tensorflow/core/kernels/sparse/transpose_op.cc +++ b/tensorflow/core/kernels/sparse/transpose_op.cc @@ -19,7 +19,7 @@ limitations under the License. #define EIGEN_USE_THREADS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/cuda_sparse.h" #define EIGEN_USE_GPU #endif @@ -132,9 +132,12 @@ REGISTER_TRANSPOSE(CPU, double) REGISTER_TRANSPOSE(CPU, complex64) REGISTER_TRANSPOSE(CPU, complex128) -#ifdef GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM REGISTER_TRANSPOSE(GPU, float) REGISTER_TRANSPOSE(GPU, double) +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#if GOOGLE_CUDA REGISTER_TRANSPOSE(GPU, complex64) REGISTER_TRANSPOSE(GPU, complex128) #endif // GOOGLE_CUDA @@ -250,16 +253,20 @@ struct CSRSparseMatrixTransposeComponent { } }; -#ifdef GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM template struct CSRSparseMatrixTransposeComponent { Status operator()(OpKernelContext* ctx, const ConstCSRComponent& x, CSRComponent* y) { TF_RETURN_IF_ERROR(ValidateTransposeInputs(x, *y)); - CudaSparse cuda_sparse(ctx); + GpuSparse cuda_sparse(ctx); TF_RETURN_IF_ERROR(cuda_sparse.Initialize()); - const cusparseAction_t copyValues = CUSPARSE_ACTION_NUMERIC; +#if GOOGLE_CUDA + const gpusparseAction_t copyValues = CUSPARSE_ACTION_NUMERIC; +#elif TENSORFLOW_USE_ROCM + const gpusparseAction_t copyValues = HIPSPARSE_ACTION_NUMERIC; +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM const int rank = x.dense_shape_host.size(); const int m = x.row_ptr.size() - 1; const int n = x.dense_shape_host(rank - 1); @@ -279,7 +286,7 @@ struct CSRSparseMatrixTransposeComponent { return Status::OK(); } }; -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/sparse/zeros_op.cc b/tensorflow/core/kernels/sparse/zeros_op.cc index 2eb1a768364043..924221b66e51da 100644 --- a/tensorflow/core/kernels/sparse/zeros_op.cc +++ b/tensorflow/core/kernels/sparse/zeros_op.cc @@ -15,7 +15,7 @@ limitations under the License. #define EIGEN_USE_THREADS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #endif @@ -74,7 +74,7 @@ Status CSRSparseMatrixZerosLikeHelper(OpKernelContext* ctx, } // namespace -#ifdef GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define REGISTER(DEV) \ REGISTER_KERNEL_BUILDER(Name("SparseMatrixZeros") \ .Device(DEVICE_##DEV) \ @@ -88,6 +88,6 @@ REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION( CSRSparseMatrixZerosLikeHelper); #undef REGISTER -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace tensorflow diff --git a/tensorflow/core/kernels/sparse/zeros_op.h b/tensorflow/core/kernels/sparse/zeros_op.h index 66cba071c94c14..85ea9c0c448a66 100644 --- a/tensorflow/core/kernels/sparse/zeros_op.h +++ b/tensorflow/core/kernels/sparse/zeros_op.h @@ -18,7 +18,7 @@ limitations under the License. #define EIGEN_USE_THREADS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #endif diff --git a/tensorflow/core/kernels/tridiagonal_solve_op_gpu.cu.cc b/tensorflow/core/kernels/tridiagonal_solve_op_gpu.cu.cc index 4899cd8642f250..3825e29189a6c9 100644 --- a/tensorflow/core/kernels/tridiagonal_solve_op_gpu.cu.cc +++ b/tensorflow/core/kernels/tridiagonal_solve_op_gpu.cu.cc @@ -156,7 +156,7 @@ class TridiagonalSolveOpGpuLinalg : public LinearAlgebraOp { k); return; } - std::unique_ptr cusparse_solver(new CudaSparse(context)); + std::unique_ptr cusparse_solver(new GpuSparse(context)); OP_REQUIRES_OK(context, cusparse_solver->Initialize()); if (k == 1) { // rhs is copied into x, then gtsv replaces x with solution. @@ -196,20 +196,20 @@ class TridiagonalSolveOpGpuLinalg : public LinearAlgebraOp { } void SolveWithGtsv(OpKernelContext* context, - std::unique_ptr& cusparse_solver, + std::unique_ptr& cusparse_solver, const Scalar* superdiag, const Scalar* diag, const Scalar* subdiag, Scalar* rhs, const int num_eqs, const int num_rhs) const { #if CUDA_VERSION < 9000 - auto function = pivoting_ ? &CudaSparse::Gtsv - : &CudaSparse::GtsvNoPivot; + auto function = + pivoting_ ? &GpuSparse::Gtsv : &GpuSparse::GtsvNoPivot; OP_REQUIRES_OK( context, (cusparse_solver.get()->*function)( num_eqs, num_rhs, subdiag, diag, superdiag, rhs, num_eqs)); #else auto buffer_function = pivoting_ - ? &CudaSparse::Gtsv2BufferSizeExt - : &CudaSparse::Gtsv2NoPivotBufferSizeExt; + ? &GpuSparse::Gtsv2BufferSizeExt + : &GpuSparse::Gtsv2NoPivotBufferSizeExt; size_t buffer_size; OP_REQUIRES_OK(context, (cusparse_solver.get()->*buffer_function)( num_eqs, num_rhs, subdiag, diag, superdiag, rhs, @@ -220,8 +220,8 @@ class TridiagonalSolveOpGpuLinalg : public LinearAlgebraOp { context->allocate_temp(DT_UINT8, temp_shape, &temp_tensor)); void* buffer = temp_tensor.flat().data(); - auto solver_function = pivoting_ ? &CudaSparse::Gtsv2 - : &CudaSparse::Gtsv2NoPivot; + auto solver_function = pivoting_ ? &GpuSparse::Gtsv2 + : &GpuSparse::Gtsv2NoPivot; OP_REQUIRES_OK(context, (cusparse_solver.get()->*solver_function)( num_eqs, num_rhs, subdiag, diag, superdiag, rhs, num_eqs, buffer)); @@ -315,7 +315,7 @@ class TridiagonalSolveOpGpu : public OpKernel { rhs.flat().size()); Scalar* x = output->flat().data(); - std::unique_ptr cusparse_solver(new CudaSparse(context)); + std::unique_ptr cusparse_solver(new GpuSparse(context)); OP_REQUIRES_OK(context, cusparse_solver->Initialize()); #if CUDA_VERSION < 9000 diff --git a/tensorflow/python/kernel_tests/linalg/sparse/BUILD b/tensorflow/python/kernel_tests/linalg/sparse/BUILD index e5a8a93fbf7c6d..af9113f02d64c8 100644 --- a/tensorflow/python/kernel_tests/linalg/sparse/BUILD +++ b/tensorflow/python/kernel_tests/linalg/sparse/BUILD @@ -28,7 +28,6 @@ cuda_py_test( size = "medium", srcs = ["csr_sparse_matrix_test.py"], main = "csr_sparse_matrix_test.py", - tags = ["no_rocm"], deps = [ "//tensorflow/python/ops/linalg/sparse", ], @@ -40,7 +39,6 @@ cuda_py_test( srcs = ["csr_sparse_matrix_ops_test.py"], main = "csr_sparse_matrix_ops_test.py", shard_count = 10, - tags = ["no_rocm"], deps = [ "//tensorflow/python/ops/linalg/sparse", "//tensorflow/python/ops/linalg/sparse:gen_sparse_csr_matrix_ops", @@ -53,7 +51,6 @@ cuda_py_test( srcs = ["csr_sparse_matrix_grad_test.py"], main = "csr_sparse_matrix_grad_test.py", shard_count = 50, - tags = ["no_rocm"], deps = [ "//tensorflow/python/ops/linalg/sparse", ], @@ -65,7 +62,6 @@ cuda_py_test( srcs = ["csr_sparse_matrix_dense_mat_mul_grad_test.py"], main = "csr_sparse_matrix_dense_mat_mul_grad_test.py", shard_count = 50, - tags = ["no_rocm"], deps = [ "//tensorflow/python/ops/linalg/sparse", ], @@ -77,7 +73,6 @@ cuda_py_test( srcs = ["csr_sparse_matrix_sparse_mat_mul_grad_test.py"], main = "csr_sparse_matrix_sparse_mat_mul_grad_test.py", shard_count = 50, - tags = ["no_rocm"], deps = [ "//tensorflow/python/ops/linalg/sparse", ], diff --git a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_dense_mat_mul_grad_test.py b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_dense_mat_mul_grad_test.py index c56ac88249f375..5cd206ccbc1f6d 100644 --- a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_dense_mat_mul_grad_test.py +++ b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_dense_mat_mul_grad_test.py @@ -106,7 +106,11 @@ def _testLargeBatchSparseMatrixMatMulGrad(self, datatype, transpose_a, # These tests are refactored from sparse_csr_matrix_grad_test to keep its size # "medium". -for dtype in (np.float32, np.complex64): +dtypes_to_test = [np.float32] +if not test.is_built_with_rocm: + # complex type is not supported on the ROCm platform + dtypes_to_test += [np.complex64] +for dtype in dtypes_to_test: for (t_a, t_b, adj_a, adj_b, t_out, conj_out) in itertools.product(*(([False, True],) * 6)): diff --git a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_grad_test.py b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_grad_test.py index e6425fcdc9416a..0cda66a63ad248 100644 --- a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_grad_test.py +++ b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_grad_test.py @@ -84,6 +84,9 @@ def testLargeBatchSparseMatrixAddGrad(self): if not self._gpu_available: return + if test.is_built_with_rocm(): + self.skipTest("sparse-matrix-add op not supported on ROCm") + sparsify = lambda m: m * (m > 0) for dense_shape in ([53, 65, 127], [127, 65]): a_mats_val = sparsify(np.random.randn(*dense_shape)) diff --git a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py index c05e50664b26fa..517578029688b9 100644 --- a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py +++ b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py @@ -432,6 +432,9 @@ def testSparseMatrixAdd(self): if not self._gpu_available: return + if test.is_built_with_rocm(): + self.skipTest("sparse-matrix-add op not supported on ROCm") + a_indices = np.array([[0, 0], [2, 3]]) a_values = np.array([1.0, 5.0]).astype(np.float32) a_dense_shape = [5, 6] @@ -469,6 +472,9 @@ def testLargeBatchSparseMatrixAdd(self): if not self._gpu_available: return + if test.is_built_with_rocm(): + self.skipTest("sparse-matrix-add op not supported on ROCm") + sparsify = lambda m: m * (m > 0) dense_shape = [53, 65, 127] a_mats = sparsify(np.random.randn(*dense_shape)).astype(np.float32) @@ -511,6 +517,9 @@ def testSparseMatrixMatMul(self): @test_util.run_in_graph_and_eager_modes def testSparseMatrixMatMulConjugateOutput(self): + if test.is_built_with_rocm(): + self.skipTest("complex type not supported on ROCm") + for shapes in [[(5, 6), (6, 1)], [(5, 6), (6, 2)]]: a_indices = np.array([[0, 0], [2, 3]]) a_values = np.array([1.0 + 1.j, 5.0 - 2.j]).astype(np.complex64) @@ -533,8 +542,19 @@ def testSparseMatrixMatMulConjugateOutput(self): @test_util.run_in_graph_and_eager_modes def testLargeBatchSparseMatrixMatMul(self): + dtypes_to_test = [np.float32] + if not test.is_built_with_rocm(): + # complex types is not supported on the ROCm platform + dtypes_to_test += [np.complex64] + + if test.is_built_with_rocm(): + # TODO(rocm): fix this + # This test is currently failing on the ROCm platform + # Ren-enable it once the fix is available + self.skipTest("hipSPARSE all failure on the ROCm platform") + sparsify = lambda m: m * (m > 0) - for dtype in np.float32, np.complex64: + for dtype in dtypes_to_test: for (transpose_a, transpose_b) in ((False, False), (False, True), (True, False), (True, True)): for (adjoint_a, adjoint_b) in ((False, False), (False, True), @@ -584,8 +604,19 @@ def testLargeBatchSparseMatrixMatMul(self): @test_util.run_in_graph_and_eager_modes def testLargeBatchSparseMatrixMatMulTransposed(self): + dtypes_to_test = [np.float32] + if not test.is_built_with_rocm(): + # complex types is not supported on the ROCm platform + dtypes_to_test += [np.complex64] + + if test.is_built_with_rocm(): + # TODO(rocm): fix this + # This test is currently failing on the ROCm platform + # Ren-enable it once the fix is available + self.skipTest("hipSPARSE all failure on the ROCm platform") + sparsify = lambda m: m * (m > 0) - for dtype in np.float32, np.complex64: + for dtype in dtypes_to_test: for (transpose_a, transpose_b) in ((False, False), (False, True), (True, False), (True, True)): for (adjoint_a, adjoint_b) in ((False, False), (False, True), @@ -636,6 +667,10 @@ def testLargeBatchSparseMatrixMatMulTransposed(self): @test_util.run_in_graph_and_eager_modes def testLargeBatchSparseMatrixMatMulConjugate(self): + if test.is_built_with_rocm(): + # complex types are not yet supported on the ROCm platform + self.skipTest("complex type not supported on ROCm") + sparsify = lambda m: m * (m > 0) a_dense_shape = [53, 65, 127] b_dense_shape = [53, 127, 67] @@ -767,6 +802,10 @@ def testLargeBatchRegisteredAddN(self): if not self._gpu_available: return + if test.is_built_with_rocm(): + # sparse-matrix-add op is not yet supported on the ROCm platform + self.skipTest("sparse-matrix-add op not supported on ROCm") + sparsify = lambda m: m * (m > 0) dense_shape = [53, 65, 127] matrices = [ @@ -1154,9 +1193,10 @@ def testBatchSparseCholesky(self): ] # ]).astype(np.complex128) - data_types = [ - dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128 - ] + data_types = [dtypes.float32, dtypes.float64] + if not test.is_built_with_rocm(): + # complex type is not supported on the ROCm platform + data_types += [dtypes.complex64, dtypes.complex128] for dtype in data_types: sparse_matrix = dense_to_csr_sparse_matrix( math_ops.cast(dense_mat, dtype)) diff --git a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_test.py b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_test.py index 74456229b49759..66077f5b2d20d1 100644 --- a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_test.py +++ b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_test.py @@ -154,7 +154,11 @@ def _testSparseSparse(self, transpose_a, transpose_b, adjoint_a, adjoint_b): sparsify = lambda m: m * (m > 0) dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13] dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15] - for dtype in np.float32, np.complex64: + dtypes_to_test = [np.float32] + if not test.is_built_with_rocm(): + # complex type is not supported on the ROCm platform + dtypes_to_test += [np.complex64] + for dtype in dtypes_to_test: a_mats = sparsify((np.random.randn(*dense_shape_a) + 1.j * np.random.randn(*dense_shape_a))).astype(dtype) b_mats = sparsify((np.random.randn(*dense_shape_b) + @@ -194,7 +198,11 @@ def _testSparseDense(self, transpose_a, transpose_b, adjoint_a, adjoint_b): sparsify = lambda m: m * (m > 0) dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13] dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15] - for dtype in np.float32, np.complex64: + dtypes_to_test = [np.float32] + if not test.is_built_with_rocm(): + # complex type is not supported on the ROCm platform + dtypes_to_test += [np.complex64] + for dtype in dtypes_to_test: a_mats = sparsify((np.random.randn(*dense_shape_a) + 1.j * np.random.randn(*dense_shape_a))).astype(dtype) b_mats = (np.random.randn(*dense_shape_b) + @@ -231,7 +239,11 @@ def _testDenseSparse(self, transpose_a, transpose_b, adjoint_a, adjoint_b): sparsify = lambda m: m * (m > 0) dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13] dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15] - for dtype in np.float32, np.complex64: + dtypes_to_test = [np.float32] + if not test.is_built_with_rocm(): + # complex type is not supported on the ROCm platform + dtypes_to_test += [np.complex64] + for dtype in dtypes_to_test: a_mats = (np.random.randn(*dense_shape_a) + 1.j * np.random.randn(*dense_shape_a)).astype(dtype) b_mats = sparsify((np.random.randn(*dense_shape_b) + diff --git a/third_party/gpus/rocm/BUILD.tpl b/third_party/gpus/rocm/BUILD.tpl index 5a225af1d1568b..cf8950b5bc7c9c 100644 --- a/third_party/gpus/rocm/BUILD.tpl +++ b/third_party/gpus/rocm/BUILD.tpl @@ -137,4 +137,11 @@ cc_library( ], ) +cc_import( + name = "hipsparse", + hdrs = glob(["rocm/include/hipsparse/**",]), + shared_library = "rocm/lib/%{hipsparse_lib}", + visibility = ["//visibility:public"], +) + %{copy_rules} diff --git a/third_party/gpus/rocm/rocm_config.h.tpl b/third_party/gpus/rocm/rocm_config.h.tpl index c5f25a845cae13..957413b9acd734 100644 --- a/third_party/gpus/rocm/rocm_config.h.tpl +++ b/third_party/gpus/rocm/rocm_config.h.tpl @@ -16,6 +16,6 @@ limitations under the License. #ifndef ROCM_ROCM_CONFIG_H_ #define ROCM_ROCM_CONFIG_H_ -#define TF_ROCM_TOOLKIT_PATH "/opt/rocm" +#define TF_ROCM_TOOLKIT_PATH "%{rocm_toolkit_path}" #endif // ROCM_ROCM_CONFIG_H_ diff --git a/third_party/gpus/rocm_configure.bzl b/third_party/gpus/rocm_configure.bzl index 9d6331df6b10b7..c7474273fd8f45 100644 --- a/third_party/gpus/rocm_configure.bzl +++ b/third_party/gpus/rocm_configure.bzl @@ -191,50 +191,50 @@ def _rocm_include_path(repository_ctx, rocm_config): inc_dirs.append(rocm_config.rocm_toolkit_path + "/include") # Add HSA headers - inc_dirs.append("/opt/rocm/hsa/include") + inc_dirs.append(rocm_config.rocm_toolkit_path + "/hsa/include") # Add HIP headers - inc_dirs.append("/opt/rocm/include/hip") - inc_dirs.append("/opt/rocm/include/hip/hcc_detail") - inc_dirs.append("/opt/rocm/hip/include") + inc_dirs.append(rocm_config.rocm_toolkit_path + "/include/hip") + inc_dirs.append(rocm_config.rocm_toolkit_path + "/include/hip/hcc_detail") + inc_dirs.append(rocm_config.rocm_toolkit_path + "/hip/include") # Add HIP-Clang headers - inc_dirs.append("/opt/rocm/llvm/lib/clang/8.0/include") - inc_dirs.append("/opt/rocm/llvm/lib/clang/9.0.0/include") - inc_dirs.append("/opt/rocm/llvm/lib/clang/10.0.0/include") + inc_dirs.append(rocm_config.rocm_toolkit_path + "/llvm/lib/clang/8.0/include") + inc_dirs.append(rocm_config.rocm_toolkit_path + "/llvm/lib/clang/9.0.0/include") + inc_dirs.append(rocm_config.rocm_toolkit_path + "/llvm/lib/clang/10.0.0/include") # Add rocrand and hiprand headers - inc_dirs.append("/opt/rocm/rocrand/include") - inc_dirs.append("/opt/rocm/hiprand/include") + inc_dirs.append(rocm_config.rocm_toolkit_path + "/rocrand/include") + inc_dirs.append(rocm_config.rocm_toolkit_path + "/hiprand/include") # Add rocfft headers - inc_dirs.append("/opt/rocm/rocfft/include") + inc_dirs.append(rocm_config.rocm_toolkit_path + "/rocfft/include") # Add rocBLAS headers - inc_dirs.append("/opt/rocm/rocblas/include") + inc_dirs.append(rocm_config.rocm_toolkit_path + "/rocblas/include") # Add MIOpen headers - inc_dirs.append("/opt/rocm/miopen/include") + inc_dirs.append(rocm_config.rocm_toolkit_path + "/miopen/include") # Add RCCL headers - inc_dirs.append("/opt/rocm/rccl/include") + inc_dirs.append(rocm_config.rocm_toolkit_path + "/rccl/include") # Add hcc headers - inc_dirs.append("/opt/rocm/hcc/include") - inc_dirs.append("/opt/rocm/hcc/compiler/lib/clang/7.0.0/include/") - inc_dirs.append("/opt/rocm/hcc/lib/clang/7.0.0/include") + inc_dirs.append(rocm_config.rocm_toolkit_path + "/hcc/include") + inc_dirs.append(rocm_config.rocm_toolkit_path + "/hcc/compiler/lib/clang/7.0.0/include/") + inc_dirs.append(rocm_config.rocm_toolkit_path + "/hcc/lib/clang/7.0.0/include") # Newer hcc builds use/are based off of clang 8.0.0. - inc_dirs.append("/opt/rocm/hcc/compiler/lib/clang/8.0.0/include/") - inc_dirs.append("/opt/rocm/hcc/lib/clang/8.0.0/include") + inc_dirs.append(rocm_config.rocm_toolkit_path + "/hcc/compiler/lib/clang/8.0.0/include/") + inc_dirs.append(rocm_config.rocm_toolkit_path + "/hcc/lib/clang/8.0.0/include") # Support hcc based off clang 9.0.0, included in ROCm2.2 - inc_dirs.append("/opt/rocm/hcc/compiler/lib/clang/9.0.0/include/") - inc_dirs.append("/opt/rocm/hcc/lib/clang/9.0.0/include") + inc_dirs.append(rocm_config.rocm_toolkit_path + "/hcc/compiler/lib/clang/9.0.0/include/") + inc_dirs.append(rocm_config.rocm_toolkit_path + "/hcc/lib/clang/9.0.0/include") # Support hcc based off clang 10.0.0, included in ROCm2.8 - inc_dirs.append("/opt/rocm/hcc/compiler/lib/clang/10.0.0/include/") - inc_dirs.append("/opt/rocm/hcc/lib/clang/10.0.0/include") + inc_dirs.append(rocm_config.rocm_toolkit_path + "/hcc/compiler/lib/clang/10.0.0/include/") + inc_dirs.append(rocm_config.rocm_toolkit_path + "/hcc/lib/clang/10.0.0/include") return inc_dirs @@ -300,11 +300,12 @@ def _hipcc_env(repository_ctx): repository_ctx.os.environ[name].strip() + "\";") return hipcc_env.strip() -def _hipcc_is_hipclang(repository_ctx): +def _hipcc_is_hipclang(repository_ctx, rocm_config): """Returns if hipcc is based on hip-clang toolchain. Args: repository_ctx: The repository context. + rocm_config: The path to the hip compiler. Returns: A string "True" if hipcc is based on hip-clang toolchain. @@ -319,7 +320,7 @@ def _hipcc_is_hipclang(repository_ctx): # grep for "HIP_COMPILER=clang" in /opt/rocm/hip/lib/.hipInfo grep_result = _execute( repository_ctx, - ["grep", "HIP_COMPILER=clang", "/opt/rocm/hip/lib/.hipInfo"], + ["grep", "HIP_COMPILER=clang", rocm_config.rocm_toolkit_path + "/hip/lib/.hipInfo"], empty_stdout_fine = True, ) result = grep_result.stdout.strip() @@ -327,13 +328,14 @@ def _hipcc_is_hipclang(repository_ctx): return "True" return "False" -def _if_hipcc_is_hipclang(repository_ctx, if_true, if_false = []): +def _if_hipcc_is_hipclang(repository_ctx, rocm_config, if_true, if_false = []): """ Returns either the if_true or if_false arg based on whether hipcc is based on the hip-clang toolchain Args : repository_ctx: The repository context. + rocm_config: The path to the hip compiler. if_true : value to return if hipcc is hip-clang based if_false : value to return if hipcc is not hip-clang based (optional, defaults to empty list) @@ -341,7 +343,7 @@ def _if_hipcc_is_hipclang(repository_ctx, if_true, if_false = []): Returns : either the if_true arg or the of_False arg """ - if _hipcc_is_hipclang(repository_ctx) == "True": + if _hipcc_is_hipclang(repository_ctx, rocm_config) == "True": return if_true return if_false @@ -478,6 +480,11 @@ def _find_libs(repository_ctx, rocm_config): repository_ctx, rocm_config.rocm_toolkit_path + "/rccl", ), + "hipsparse": _find_rocm_lib( + "hipsparse", + repository_ctx, + rocm_config.rocm_toolkit_path + "/hipsparse", + ), } def _get_rocm_config(repository_ctx): @@ -558,6 +565,7 @@ def _create_dummy_repository(repository_ctx): "%{rccl_lib}": _lib_name("rccl"), "%{rocfft_lib}": _lib_name("rocfft"), "%{hiprand_lib}": _lib_name("hiprand"), + "%{hipsparse_lib}": _lib_name("hipsparse"), "%{copy_rules}": "", "%{rocm_headers}": "", }, @@ -703,6 +711,12 @@ def _create_local_rocm_repository(repository_ctx): src_dir = rocm_toolkit_path + "/rccl/include", out_dir = "rocm/include/rccl", ), + make_copy_dir_rule( + repository_ctx, + name = "hipsparse-include", + src_dir = rocm_toolkit_path + "/hipsparse/include", + out_dir = "rocm/include/hipsparse", + ), ] rocm_libs = _find_libs(repository_ctx, rocm_config) @@ -740,16 +754,19 @@ def _create_local_rocm_repository(repository_ctx): "%{hiprand_lib}": rocm_libs["hiprand"].file_name, "%{miopen_lib}": rocm_libs["miopen"].file_name, "%{rccl_lib}": rocm_libs["rccl"].file_name, + "%{hipsparse_lib}": rocm_libs["hipsparse"].file_name, "%{copy_rules}": "\n".join(copy_rules), "%{rocm_headers}": ('":rocm-include",\n' + '":rocfft-include",\n' + '":rocblas-include",\n' + '":miopen-include",\n' + - '":rccl-include",'), + '":rccl-include",\n' + + '":hipsparse-include",'), }, ) # Set up crosstool/ + cc = find_cc(repository_ctx) host_compiler_includes = get_cxx_inc_directories(repository_ctx, cc) @@ -762,7 +779,7 @@ def _create_local_rocm_repository(repository_ctx): rocm_defines["%{host_compiler_prefix}"] = host_compiler_prefix - rocm_defines["%{linker_bin_path}"] = "/opt/rocm/hcc/compiler/bin" + rocm_defines["%{linker_bin_path}"] = rocm_config.rocm_toolkit_path + "/hcc/compiler/bin" # For gcc, do not canonicalize system header paths; some versions of gcc # pick the shortest possible path for system includes when creating the @@ -775,7 +792,7 @@ def _create_local_rocm_repository(repository_ctx): "-DTENSORFLOW_USE_ROCM=1", "-D__HIP_PLATFORM_HCC__", "-DEIGEN_USE_HIP", - ] + _if_hipcc_is_hipclang(repository_ctx, [ + ] + _if_hipcc_is_hipclang(repository_ctx, rocm_config, [ # # define "TENSORFLOW_COMPILER_IS_HIP_CLANG" when we are using clang # based hipcc to compile/build tensorflow @@ -815,14 +832,14 @@ def _create_local_rocm_repository(repository_ctx): "crosstool:clang/bin/crosstool_wrapper_driver_rocm", { "%{cpu_compiler}": str(cc), - "%{hipcc_path}": "/opt/rocm/bin/hipcc", + "%{hipcc_path}": rocm_config.rocm_toolkit_path + "/bin/hipcc", "%{hipcc_env}": _hipcc_env(repository_ctx), - "%{hipcc_is_hipclang}": _hipcc_is_hipclang(repository_ctx), - "%{rocr_runtime_path}": "/opt/rocm/lib", + "%{hipcc_is_hipclang}": _hipcc_is_hipclang(repository_ctx, rocm_config), + "%{rocr_runtime_path}": rocm_config.rocm_toolkit_path + "/lib", "%{rocr_runtime_library}": "hsa-runtime64", - "%{hip_runtime_path}": "/opt/rocm/hip/lib", + "%{hip_runtime_path}": rocm_config.rocm_toolkit_path + "/hip/lib", "%{hip_runtime_library}": "hip_hcc", - "%{hcc_runtime_path}": "/opt/rocm/hcc/lib", + "%{hcc_runtime_path}": rocm_config.rocm_toolkit_path + "/hcc/lib", "%{hcc_runtime_library}": "mcwamp", "%{crosstool_verbose}": _crosstool_verbose(repository_ctx), "%{gcc_host_compiler_path}": str(cc), diff --git a/third_party/toolchains/preconfig/generate/containers.bzl b/third_party/toolchains/preconfig/generate/containers.bzl index 7c77399240bd9f..6a3a481bd6bdc3 100644 --- a/third_party/toolchains/preconfig/generate/containers.bzl +++ b/third_party/toolchains/preconfig/generate/containers.bzl @@ -9,5 +9,5 @@ container_digests = { "cuda10.1-cudnn7-centos6": "sha256:454b899657e87893ee5e68dc0f87df59b6a0a7418ae09cafcc3dd65ac71feca9", "cuda10.0-cudnn7-ubuntu16.04-manylinux2010": "sha256:5812d9d0ef0a3276fc5faaf4cd01f3d6e03d635893a6e2d2e04f6f01d626c432", "cuda10.1-cudnn7-ubuntu16.04-manylinux2010": "sha256:f8e15f08cb501e5f2de3dc450f614609fd3ed19bde74b153fa66d14b2307610c", - "rocm-ubuntu16.04": "sha256:d5cd4120cff3d2a452378aad03746ff5f24699d86cf695c20ee96f366e42975f", + "rocm-ubuntu16.04": "sha256:e645447dd6127325f3e97b8bf23424f637a8579d963b34fcc6772cf7cfaa0ebe", } diff --git a/third_party/toolchains/preconfig/generate/generate.bzl b/third_party/toolchains/preconfig/generate/generate.bzl index 08b4ef3d44ff46..1c8a4dfb052f7d 100644 --- a/third_party/toolchains/preconfig/generate/generate.bzl +++ b/third_party/toolchains/preconfig/generate/generate.bzl @@ -72,7 +72,7 @@ def _tensorflow_rbe_config(name, compiler, python_version, os, rocm_version = No docker_toolchain_autoconfig( name = name, base = base, - bazel_version = "0.29.1", + bazel_version = "1.2.1", build_bazel_src = build_bazel_src, config_repos = config_repos, env = env, diff --git a/third_party/toolchains/preconfig/ubuntu16.04/rocm/rocm/BUILD b/third_party/toolchains/preconfig/ubuntu16.04/rocm/rocm/BUILD index a1ecadf2e2932a..a8217711803e92 100755 --- a/third_party/toolchains/preconfig/ubuntu16.04/rocm/rocm/BUILD +++ b/third_party/toolchains/preconfig/ubuntu16.04/rocm/rocm/BUILD @@ -15,6 +15,7 @@ cc_library( name = "rocm_headers", hdrs = [ "rocm/rocm_config.h", + ":hipsparse-include", ":miopen-include", ":rccl-include", ":rocblas-include", @@ -141,6 +142,13 @@ cc_library( ], ) +cc_import( + name = "hipsparse", + hdrs = glob(["rocm/include/hipsparse/**"]), + shared_library = "rocm/lib/libhipsparse.so", + visibility = ["//visibility:public"], +) + genrule( name = "rocm-include", outs = [ @@ -175,6 +183,7 @@ genrule( "rocm/include/hcc/clang-c/CXErrorCode.h", "rocm/include/hcc/clang-c/CXString.h", "rocm/include/hcc/clang-c/Documentation.h", + "rocm/include/hcc/clang-c/FatalErrorHandler.h", "rocm/include/hcc/clang-c/Index.h", "rocm/include/hcc/clang-c/Platform.h", "rocm/include/hcc/coordinate", @@ -275,12 +284,14 @@ genrule( "rocm/include/hip/hcc_detail/hip_prof_str.h", "rocm/include/hip/hcc_detail/hip_runtime.h", "rocm/include/hip/hcc_detail/hip_runtime_api.h", + "rocm/include/hip/hcc_detail/hip_runtime_prof.h", "rocm/include/hip/hcc_detail/hip_surface_types.h", "rocm/include/hip/hcc_detail/hip_texture_types.h", "rocm/include/hip/hcc_detail/hip_vector_types.h", "rocm/include/hip/hcc_detail/hiprtc.h", "rocm/include/hip/hcc_detail/host_defines.h", "rocm/include/hip/hcc_detail/hsa_helpers.hpp", + "rocm/include/hip/hcc_detail/library_types.h", "rocm/include/hip/hcc_detail/llvm_intrinsics.h", "rocm/include/hip/hcc_detail/macro_based_grid_launch.hpp", "rocm/include/hip/hcc_detail/math_functions.h", @@ -292,6 +303,7 @@ genrule( "rocm/include/hip/hip_common.h", "rocm/include/hip/hip_complex.h", "rocm/include/hip/hip_cooperative_groups.h", + "rocm/include/hip/hip_ext.h", "rocm/include/hip/hip_fp16.h", "rocm/include/hip/hip_hcc.h", "rocm/include/hip/hip_profile.h", @@ -300,6 +312,7 @@ genrule( "rocm/include/hip/hip_texture_types.h", "rocm/include/hip/hip_vector_types.h", "rocm/include/hip/hiprtc.h", + "rocm/include/hip/library_types.h", "rocm/include/hip/math_functions.h", "rocm/include/hip/nvcc_detail/channel_descriptor.h", "rocm/include/hip/nvcc_detail/hip_complex.h", @@ -441,7 +454,6 @@ genrule( "rocm/include/ocml.h", "rocm/include/opencl1.2-c.pch", "rocm/include/opencl2.0-c.pch", - "rocm/include/profiler/CXLActivityLogger/CXLActivityLogger.h", "rocm/include/rccl.h", "rocm/include/rocalution.hpp", "rocm/include/rocblas-auxiliary.h", @@ -583,6 +595,7 @@ genrule( "rocm/include/rocrand/rocrand_xorwow.h", "rocm/include/rocrand/rocrand_xorwow_precomputed.h", "rocm/include/rocsparse-auxiliary.h", + "rocm/include/rocsparse-complex-types.h", "rocm/include/rocsparse-export.h", "rocm/include/rocsparse-functions.h", "rocm/include/rocsparse-types.h", @@ -1468,6 +1481,16 @@ genrule( cmd = """cp -rLf "/opt/rocm/rccl/include/." "$(@D)/" """, ) +genrule( + name = "hipsparse-include", + outs = [ + "rocm/include/hipsparse/hipsparse-export.h", + "rocm/include/hipsparse/hipsparse-version.h", + "rocm/include/hipsparse/hipsparse.h", + ], + cmd = """cp -rLf "/opt/rocm/hipsparse/include/." "$(@D)/rocm/include/hipsparse/" """, +) + genrule( name = "rocm-lib", outs = [ @@ -1477,11 +1500,13 @@ genrule( "rocm/lib/libhiprand.so", "rocm/lib/libMIOpen.so", "rocm/lib/librccl.so", + "rocm/lib/libhipsparse.so", ], cmd = """cp -f "/opt/rocm/hip/lib/libhip_hcc.so" "$(location rocm/lib/libhip_hcc.so)" && \ cp -f "/opt/rocm/rocblas/lib/librocblas.so.0.1" "$(location rocm/lib/librocblas.so)" && \ cp -f "/opt/rocm/rocfft/lib/librocfft.so.0.1" "$(location rocm/lib/librocfft.so)" && \ cp -f "/opt/rocm/hiprand/lib/libhiprand.so.1.1" "$(location rocm/lib/libhiprand.so)" && \ cp -f "/opt/rocm/miopen/lib/libMIOpen.so.1" "$(location rocm/lib/libMIOpen.so)" && \ -cp -f "/opt/rocm/rccl/lib/librccl.so" "$(location rocm/lib/librccl.so)" """, +cp -f "/opt/rocm/rccl/lib/librccl.so" "$(location rocm/lib/librccl.so)" && \ +cp -f "/opt/rocm/hipsparse/lib/libhipsparse.so.0.1" "$(location rocm/lib/libhipsparse.so)" """, )