Skip to content
105 changes: 25 additions & 80 deletions src/lapack/backends/cusolver/cusolver_batch.cpp
Original file line number Diff line number Diff line change
@@ -184,26 +184,25 @@ inline void getrf_batch(const char *func_name, Func func, sycl::queue &queue, st
// Create new buffer with 32-bit ints then copy over results
std::uint64_t ipiv_size = stride_ipiv * batch_size;
sycl::buffer<int> ipiv32(sycl::range<1>{ ipiv_size });
sycl::buffer<int> devInfo{ batch_size };

queue.submit([&](sycl::handler &cgh) {
auto a_acc = a.template get_access<sycl::access::mode::read_write>(cgh);
auto ipiv32_acc = ipiv32.template get_access<sycl::access::mode::write>(cgh);
auto devInfo_acc = devInfo.template get_access<sycl::access::mode::write>(cgh);
auto scratch_acc = scratchpad.template get_access<sycl::access::mode::write>(cgh);
onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) {
auto handle = sc.get_handle(queue);
auto a_ = sc.get_mem<cuDataType *>(a_acc);
auto ipiv_ = sc.get_mem<int *>(ipiv32_acc);
auto devInfo_ = sc.get_mem<int *>(devInfo_acc);
auto scratch_ = sc.get_mem<cuDataType *>(scratch_acc);
cusolverStatus_t err;
int *dev_info_d = create_dev_info(batch_size);

// Uses scratch so sync between each cuSolver call
for (std::int64_t i = 0; i < batch_size; ++i) {
CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_ + stride_a * i,
lda, scratch_, ipiv_ + stride_ipiv * i, devInfo_ + i);
lda, scratch_, ipiv_ + stride_ipiv * i, dev_info_d + i);
}
lapack_info_check_and_free(dev_info_d, __func__, func_name, batch_size);
});
});

@@ -215,7 +214,6 @@ inline void getrf_batch(const char *func_name, Func func, sycl::queue &queue, st
[=](sycl::id<1> index) { ipiv_acc[index] = ipiv32_acc[index]; });
});

lapack_info_check(queue, devInfo, __func__, func_name, batch_size);
}

#define GETRF_STRIDED_BATCH_LAUNCHER(TYPE, CUSOLVER_ROUTINE) \
@@ -459,10 +457,7 @@ inline sycl::event geqrf_batch(const char *func_name, Func func, sycl::queue &qu
overflow_check(m, n, lda, stride_a, stride_tau, batch_size, scratchpad_size);

auto done = queue.submit([&](sycl::handler &cgh) {
int64_t num_events = dependencies.size();
for (int64_t i = 0; i < num_events; i++) {
cgh.depends_on(dependencies[i]);
}
depends_on_events(cgh, dependencies);
onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) {
auto handle = sc.get_handle(queue);
auto a_ = reinterpret_cast<cuDataType *>(a);
@@ -513,10 +508,7 @@ inline sycl::event geqrf_batch(const char *func_name, Func func, sycl::queue &qu
overflow_check(m[i], n[i], lda[i], group_sizes[i]);

auto done = queue.submit([&](sycl::handler &cgh) {
int64_t num_events = dependencies.size();
for (int64_t i = 0; i < num_events; i++) {
cgh.depends_on(dependencies[i]);
}
depends_on_events(cgh, dependencies);
onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) {
auto handle = sc.get_handle(queue);
auto a_ = reinterpret_cast<cuDataType **>(a);
@@ -571,26 +563,22 @@ inline sycl::event getrf_batch(const char *func_name, Func func, sycl::queue &qu
// Allocate memory with 32-bit ints then copy over results
std::uint64_t ipiv_size = stride_ipiv * batch_size;
int *ipiv32 = (int *)malloc_device(sizeof(int) * ipiv_size, queue);
int *devInfo = (int *)malloc_device(sizeof(int) * batch_size, queue);

auto done = queue.submit([&](sycl::handler &cgh) {
int64_t num_events = dependencies.size();
for (int64_t i = 0; i < num_events; i++) {
cgh.depends_on(dependencies[i]);
}
depends_on_events(cgh, dependencies);
onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) {
auto handle = sc.get_handle(queue);
auto a_ = reinterpret_cast<cuDataType *>(a);
auto devInfo_ = reinterpret_cast<int *>(devInfo);
auto scratchpad_ = reinterpret_cast<cuDataType *>(scratchpad);
auto ipiv_ = reinterpret_cast<int *>(ipiv32);
cusolverStatus_t err;
int *dev_info_d = create_dev_info(batch_size);

// Uses scratch so sync between each cuSolver call
for (int64_t i = 0; i < batch_size; ++i) {
CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_ + stride_a * i,
lda, scratchpad_, ipiv_ + stride_ipiv * i, devInfo_ + i);
lda, scratchpad_, ipiv32 + stride_ipiv * i, dev_info_d + i);
}
lapack_info_check_and_free(dev_info_d, __func__, func_name, batch_size);
});
});

@@ -607,10 +595,6 @@ inline sycl::event getrf_batch(const char *func_name, Func func, sycl::queue &qu
cgh.host_task([=](sycl::interop_handle ih) { sycl::free(ipiv32, queue); });
});

// lapack_info_check calls queue.wait()
lapack_info_check(queue, devInfo, __func__, func_name, batch_size);
sycl::free(devInfo, queue);

return done_casting;
}

@@ -656,29 +640,27 @@ inline sycl::event getrf_batch(const char *func_name, Func func, sycl::queue &qu
for (int64_t group_id = 0; group_id < group_count; ++group_id)
for (int64_t local_id = 0; local_id < group_sizes[group_id]; ++local_id, ++global_id)
ipiv32[global_id] = (int *)malloc_device(sizeof(int) * n[group_id], queue);
int *devInfo = (int *)malloc_device(sizeof(int) * batch_size, queue);

auto done = queue.submit([&](sycl::handler &cgh) {
int64_t num_events = dependencies.size();
for (int64_t i = 0; i < num_events; i++) {
cgh.depends_on(dependencies[i]);
}
depends_on_events(cgh, dependencies);
onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) {
auto handle = sc.get_handle(queue);
auto a_ = reinterpret_cast<cuDataType **>(a);
auto scratch_ = reinterpret_cast<cuDataType *>(scratchpad);
int64_t global_id = 0;
cusolverStatus_t err;
int *dev_info_d = create_dev_info(batch_size);

// Uses scratch so sync between each cuSolver call
for (int64_t group_id = 0; group_id < group_count; ++group_id) {
for (int64_t local_id = 0; local_id < group_sizes[group_id];
++local_id, ++global_id) {
CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m[group_id],
n[group_id], a_[global_id], lda[group_id], scratch_,
ipiv32[global_id], devInfo + global_id);
ipiv32[global_id], dev_info_d + global_id);
}
}
lapack_info_check_and_free(dev_info_d, __func__, func_name, batch_size);
});
});

@@ -712,10 +694,6 @@ inline sycl::event getrf_batch(const char *func_name, Func func, sycl::queue &qu
});
});

// lapack_info_check calls queue.wait()
lapack_info_check(queue, devInfo, __func__, func_name, batch_size);
sycl::free(devInfo, queue);

return done_freeing;
}

@@ -814,22 +792,18 @@ inline sycl::event getrs_batch(const char *func_name, Func func, sycl::queue &qu
});

auto done = queue.submit([&](sycl::handler &cgh) {
int64_t num_events = dependencies.size();
for (int64_t i = 0; i < num_events; i++) {
cgh.depends_on(dependencies[i]);
}
depends_on_events(cgh, dependencies);
cgh.depends_on(done_casting);
onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) {
auto handle = sc.get_handle(queue);
auto a_ = reinterpret_cast<cuDataType *>(a);
auto ipiv_ = reinterpret_cast<int *>(ipiv32);
auto b_ = reinterpret_cast<cuDataType *>(b);
cusolverStatus_t err;

// Does not use scratch so call cuSolver asynchronously and sync at end
for (int64_t i = 0; i < batch_size; ++i) {
CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_operation(trans), n,
nrhs, a_ + stride_a * i, lda, ipiv_ + stride_ipiv * i,
nrhs, a_ + stride_a * i, lda, ipiv32 + stride_ipiv * i,
b_ + stride_b * i, ldb, nullptr);
}
CUSOLVER_SYNC(err, handle)
@@ -902,13 +876,8 @@ inline sycl::event getrs_batch(const char *func_name, Func func, sycl::queue &qu
}

auto done = queue.submit([&](sycl::handler &cgh) {
int64_t num_events = dependencies.size();
for (int64_t i = 0; i < num_events; i++) {
cgh.depends_on(dependencies[i]);
}
for (int64_t i = 0; i < batch_size; i++) {
cgh.depends_on(casting_dependencies[i]);
}
depends_on_events(cgh, dependencies);
depends_on_events(cgh, casting_dependencies);

onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) {
auto handle = sc.get_handle(queue);
@@ -967,10 +936,7 @@ inline sycl::event orgqr_batch(const char *func_name, Func func, sycl::queue &qu
overflow_check(m, n, k, lda, stride_a, stride_tau, batch_size, scratchpad_size);

auto done = queue.submit([&](sycl::handler &cgh) {
int64_t num_events = dependencies.size();
for (int64_t i = 0; i < num_events; i++) {
cgh.depends_on(dependencies[i]);
}
depends_on_events(cgh, dependencies);
onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) {
auto handle = sc.get_handle(queue);
auto a_ = reinterpret_cast<cuDataType *>(a);
@@ -1020,10 +986,7 @@ inline sycl::event orgqr_batch(const char *func_name, Func func, sycl::queue &qu
overflow_check(m[i], n[i], k[i], lda[i], group_sizes[i]);

auto done = queue.submit([&](sycl::handler &cgh) {
int64_t num_events = dependencies.size();
for (int64_t i = 0; i < num_events; i++) {
cgh.depends_on(dependencies[i]);
}
depends_on_events(cgh, dependencies);
onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) {
auto handle = sc.get_handle(queue);
auto a_ = reinterpret_cast<cuDataType **>(a);
@@ -1074,10 +1037,7 @@ inline sycl::event potrf_batch(const char *func_name, Func func, sycl::queue &qu
overflow_check(n, lda, stride_a, batch_size, scratchpad_size);

auto done = queue.submit([&](sycl::handler &cgh) {
int64_t num_events = dependencies.size();
for (int64_t i = 0; i < num_events; i++) {
cgh.depends_on(dependencies[i]);
}
depends_on_events(cgh, dependencies);
onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) {
auto handle = sc.get_handle(queue);
CUdeviceptr a_dev;
@@ -1135,10 +1095,7 @@ inline sycl::event potrf_batch(const char *func_name, Func func, sycl::queue &qu
}

auto done = queue.submit([&](sycl::handler &cgh) {
int64_t num_events = dependencies.size();
for (int64_t i = 0; i < num_events; i++) {
cgh.depends_on(dependencies[i]);
}
depends_on_events(cgh, dependencies);
onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) {
auto handle = sc.get_handle(queue);
int64_t offset = 0;
@@ -1199,10 +1156,7 @@ inline sycl::event potrs_batch(const char *func_name, Func func, sycl::queue &qu
throw unimplemented("lapack", "potrs_batch", "cusolver potrs_batch only supports nrhs = 1");

auto done = queue.submit([&](sycl::handler &cgh) {
int64_t num_events = dependencies.size();
for (int64_t i = 0; i < num_events; i++) {
cgh.depends_on(dependencies[i]);
}
depends_on_events(cgh, dependencies);
onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) {
auto handle = sc.get_handle(queue);
CUresult cuda_result;
@@ -1283,10 +1237,7 @@ inline sycl::event potrs_batch(const char *func_name, Func func, sycl::queue &qu
queue.submit([&](sycl::handler &h) { h.memcpy(b_dev, b, batch_size * sizeof(T *)); });

auto done = queue.submit([&](sycl::handler &cgh) {
int64_t num_events = dependencies.size();
for (int64_t i = 0; i < num_events; i++) {
cgh.depends_on(dependencies[i]);
}
depends_on_events(cgh, dependencies);
cgh.depends_on(done_cpy_a);
cgh.depends_on(done_cpy_b);
onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) {
@@ -1340,10 +1291,7 @@ inline sycl::event ungqr_batch(const char *func_name, Func func, sycl::queue &qu
overflow_check(m, n, k, lda, stride_a, stride_tau, batch_size, scratchpad_size);

auto done = queue.submit([&](sycl::handler &cgh) {
int64_t num_events = dependencies.size();
for (int64_t i = 0; i < num_events; i++) {
cgh.depends_on(dependencies[i]);
}
depends_on_events(cgh, dependencies);
onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) {
auto handle = sc.get_handle(queue);
auto a_ = reinterpret_cast<cuDataType *>(a);
@@ -1393,10 +1341,7 @@ inline sycl::event ungqr_batch(const char *func_name, Func func, sycl::queue &qu
overflow_check(m[i], n[i], k[i], lda[i], group_sizes[i]);

auto done = queue.submit([&](sycl::handler &cgh) {
int64_t num_events = dependencies.size();
for (int64_t i = 0; i < num_events; i++) {
cgh.depends_on(dependencies[i]);
}
depends_on_events(cgh, dependencies);
onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) {
auto handle = sc.get_handle(queue);
auto a_ = reinterpret_cast<cuDataType **>(a);
56 changes: 36 additions & 20 deletions src/lapack/backends/cusolver/cusolver_helper.hpp
Original file line number Diff line number Diff line change
@@ -280,30 +280,46 @@ struct CudaEquivalentType<std::complex<double>> {

/* devinfo */

inline void get_cusolver_devinfo(sycl::queue &queue, sycl::buffer<int> &devInfo,
std::vector<int> &dev_info_) {
sycl::host_accessor<int, 1, sycl::access::mode::read> dev_info_acc{ devInfo };
for (unsigned int i = 0; i < dev_info_.size(); ++i)
dev_info_[i] = dev_info_acc[i];
// Accepts a int*, copies the memory from device to host,
// checks value does not indicate an error, frees the device memory
inline void lapack_info_check_and_free(int *dev_info_d, const char *func_name,
const char *cufunc_name, int num_elements = 1) {
int *dev_info_h = (int *)malloc(sizeof(int) * num_elements);
cuMemcpyDtoH(dev_info_h, reinterpret_cast<CUdeviceptr>(dev_info_d), sizeof(int) * num_elements);
for (uint32_t i = 0; i < num_elements; ++i) {
if (dev_info_h[i] > 0)
throw oneapi::mkl::lapack::computation_error(
func_name,
std::string(cufunc_name) + " failed with info = " + std::to_string(dev_info_h[i]),
dev_info_h[i]);
}
cuMemFree(reinterpret_cast<CUdeviceptr>(dev_info_d));
}

inline void get_cusolver_devinfo(sycl::queue &queue, const int *devInfo,
std::vector<int> &dev_info_) {
queue.wait();
queue.memcpy(dev_info_.data(), devInfo, sizeof(int));
// Allocates and returns a CUDA device pointer for cuSolver dev_info
inline int *create_dev_info(int num_elements = 1) {
CUdeviceptr dev_info_d;
cuMemAlloc(&dev_info_d, sizeof(int) * num_elements);
return reinterpret_cast<int *>(dev_info_d);
}

template <typename DEVINFO_T>
inline void lapack_info_check(sycl::queue &queue, DEVINFO_T devinfo, const char *func_name,
const char *cufunc_name, int dev_info_size = 1) {
std::vector<int> dev_info_(dev_info_size);
get_cusolver_devinfo(queue, devinfo, dev_info_);
for (const auto &val : dev_info_) {
if (val > 0)
throw oneapi::mkl::lapack::computation_error(
func_name, std::string(cufunc_name) + " failed with info = " + std::to_string(val),
val);
}
// Helper function for waiting on a vector of sycl events
inline void depends_on_events(sycl::handler &cgh,
const std::vector<sycl::event> &dependencies = {}) {
for (auto &e : dependencies)
cgh.depends_on(e);
}

// Asynchronously frees sycl USM `ptr` after waiting on events `dependencies`
template <typename T>
inline sycl::event free_async(sycl::queue &queue, T *ptr,
const std::vector<sycl::event> &dependencies = {}) {
sycl::event done = queue.submit([&](sycl::handler &cgh) {
depends_on_events(cgh, dependencies);

cgh.host_task([=](sycl::interop_handle ih) { sycl::free(ptr, queue); });
});
return done;
}

/* batched helpers */
Loading
Oops, something went wrong.