Skip to content

Commit

Permalink
added GPU solve kernel for Cholesky factor
Browse files Browse the repository at this point in the history
  • Loading branch information
flopez committed May 11, 2020
1 parent a6334c7 commit eb8d734
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/StarPU/codelets_posdef.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,18 @@ namespace starpu {

// solve_block StarPU codelet
starpu_codelet_init(&cl_solve_block);
#if defined(SPLDLT_USE_GPU)
cl_solve_block.where = STARPU_CPU | STARPU_CUDA;
#else
cl_solve_block.where = STARPU_CPU;
#endif
cl_solve_block.nbuffers = STARPU_VARIABLE_NBUFFERS;
cl_solve_block.name = "SolveBlk";
cl_solve_block.cpu_funcs[0] = solve_block_cpu_func<T>;
#if defined(SPLDLT_USE_GPU)
cl_solve_block.cuda_funcs[0] = solve_block_cuda_func<T>;
cl_solve_block.cuda_flags[0] = STARPU_CUDA_ASYNC;
#endif

// solve_contrib_block StarPU codelet
starpu_codelet_init(&cl_solve_contrib_block);
Expand Down
26 changes: 26 additions & 0 deletions src/StarPU/cuda/kernels.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,37 @@
#if defined(SPLDLT_USE_GPU)
#include <starpu_cublas_v2.h>
#endif
#include "sylver/kernels/cuda/llt.hxx"

namespace sylver {
namespace spldlt {
namespace starpu {

////////////////////////////////////////////////////////////
// solve_block StarPU task

template<typename T>
void solve_block_cuda_func(void *buffers[], void *cl_arg) {

// Get diagonal block pointer and info
T *akk = (T *)STARPU_MATRIX_GET_PTR(buffers[0]);
unsigned ld_akk = STARPU_MATRIX_GET_LD(buffers[0]);

// Get sub diagonal block data pointer and info
T *aik = (T *)STARPU_MATRIX_GET_PTR(buffers[1]);
unsigned m = STARPU_MATRIX_GET_NX(buffers[1]);
unsigned n = STARPU_MATRIX_GET_NY(buffers[1]);
unsigned ld_aik = STARPU_MATRIX_GET_LD(buffers[1]);

// Retrieve cuBLAS handle associated with local stream
cublasHandle_t cuhandle = starpu_cublas_get_local_handle();

using FactorType = sylver::spldlt::cuda::Chol<T>;
FactorType::solve(
cuhandle, m, n, akk, ld_akk, aik, ld_aik);

}

////////////////////////////////////////////////////////////
// update_block StarPU task

Expand Down

0 comments on commit eb8d734

Please sign in to comment.