Skip to content

Commit f09d6cb

Browse files
authored
[BLAS] Remove queue from host_task_internal API (#678)
1 parent 5aa0635 commit f09d6cb

10 files changed

+235
-248
lines changed

src/blas/backends/cublas/cublas_batch.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,8 @@ inline void gemm_batch_impl(sycl::queue& queue, transpose transa, transpose tran
161161
auto a_acc = a.template get_access<sycl::access::mode::read>(cgh);
162162
auto b_acc = b.template get_access<sycl::access::mode::read>(cgh);
163163
auto c_acc = c.template get_access<sycl::access::mode::read_write>(cgh);
164-
onemath_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler& sc) {
165-
auto handle = sc.get_handle(queue);
164+
onemath_cublas_host_task(cgh, [=](CublasScopedContextHandler& sc) {
165+
auto handle = sc.get_handle();
166166
auto a_ = sc.get_mem<cuTypeA*>(a_acc);
167167
auto b_ = sc.get_mem<cuTypeB*>(b_acc);
168168
auto c_ = sc.get_mem<cuTypeC*>(c_acc);
@@ -514,8 +514,8 @@ inline sycl::event gemv_batch(const char* func_name, Func func, sycl::queue& que
514514
}
515515
auto done = queue.submit([&](sycl::handler& cgh) {
516516
cgh.depends_on(dependencies);
517-
onemath_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler& sc) {
518-
auto handle = sc.get_handle(queue);
517+
onemath_cublas_host_task(cgh, [=](CublasScopedContextHandler& sc) {
518+
auto handle = sc.get_handle();
519519
int64_t offset = 0;
520520
cublasStatus_t err;
521521
auto** a_ = reinterpret_cast<const cuDataType**>(a);
@@ -632,8 +632,8 @@ inline sycl::event gemm_batch_strided_usm_impl(sycl::queue& queue, transpose tra
632632
for (int64_t i = 0; i < num_events; i++) {
633633
cgh.depends_on(dependencies[i]);
634634
}
635-
onemath_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler& sc) {
636-
auto handle = sc.get_handle(queue);
635+
onemath_cublas_host_task(cgh, [=](CublasScopedContextHandler& sc) {
636+
auto handle = sc.get_handle();
637637
cublasStatus_t err;
638638
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
639639
CUBLAS_ERROR_FUNC_T("cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err,
@@ -718,8 +718,8 @@ inline sycl::event gemm_batch_usm_impl(sycl::queue& queue, transpose* transa, tr
718718
for (int64_t i = 0; i < num_events; i++) {
719719
cgh.depends_on(dependencies[i]);
720720
}
721-
onemath_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler& sc) {
722-
auto handle = sc.get_handle(queue);
721+
onemath_cublas_host_task(cgh, [=](CublasScopedContextHandler& sc) {
722+
auto handle = sc.get_handle();
723723
int64_t offset = 0;
724724
cublasStatus_t err;
725725
for (int64_t i = 0; i < group_count; i++) {
@@ -832,8 +832,8 @@ inline sycl::event trsm_batch(const char* func_name, Func func, sycl::queue& que
832832
for (int64_t i = 0; i < num_events; i++) {
833833
cgh.depends_on(dependencies[i]);
834834
}
835-
onemath_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler& sc) {
836-
auto handle = sc.get_handle(queue);
835+
onemath_cublas_host_task(cgh, [=](CublasScopedContextHandler& sc) {
836+
auto handle = sc.get_handle();
837837
int64_t offset = 0;
838838
cublasStatus_t err;
839839
for (int64_t i = 0; i < group_count; i++) {

src/blas/backends/cublas/cublas_extensions.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ void omatcopy(const char* func_name, Func func, sycl::queue& queue, transpose tr
9696
auto b_acc = b.template get_access<sycl::access::mode::read_write>(cgh);
9797
const int64_t logical_m = (trans == oneapi::math::transpose::nontrans ? m : n);
9898
const int64_t logical_n = (trans == oneapi::math::transpose::nontrans ? n : m);
99-
onemath_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler& sc) {
100-
auto handle = sc.get_handle(queue);
99+
onemath_cublas_host_task(cgh, [=](CublasScopedContextHandler& sc) {
100+
auto handle = sc.get_handle();
101101
auto a_ = sc.get_mem<cuDataType*>(a_acc);
102102
auto b_ = sc.get_mem<cuDataType*>(b_acc);
103103
cublasStatus_t err;
@@ -172,8 +172,8 @@ void omatadd(const char* func_name, Func func, sycl::queue& queue, transpose tra
172172
auto a_acc = a.template get_access<sycl::access::mode::read>(cgh);
173173
auto b_acc = b.template get_access<sycl::access::mode::read>(cgh);
174174
auto c_acc = c.template get_access<sycl::access::mode::read_write>(cgh);
175-
onemath_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler& sc) {
176-
auto handle = sc.get_handle(queue);
175+
onemath_cublas_host_task(cgh, [=](CublasScopedContextHandler& sc) {
176+
auto handle = sc.get_handle();
177177
auto a_ = sc.get_mem<cuDataType*>(a_acc);
178178
auto b_ = sc.get_mem<cuDataType*>(b_acc);
179179
auto c_ = sc.get_mem<cuDataType*>(c_acc);
@@ -274,8 +274,8 @@ sycl::event omatcopy(const char* func_name, Func func, sycl::queue& queue, trans
274274
cgh.depends_on(dependencies);
275275
const int64_t logical_m = (trans == oneapi::math::transpose::nontrans ? m : n);
276276
const int64_t logical_n = (trans == oneapi::math::transpose::nontrans ? n : m);
277-
onemath_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler& sc) {
278-
auto handle = sc.get_handle(queue);
277+
onemath_cublas_host_task(cgh, [=](CublasScopedContextHandler& sc) {
278+
auto handle = sc.get_handle();
279279
auto a_ = reinterpret_cast<const cuDataType*>(a);
280280
auto b_ = reinterpret_cast<cuDataType*>(b);
281281
cublasStatus_t err;
@@ -356,8 +356,8 @@ inline sycl::event omatadd(const char* func_name, Func func, sycl::queue& queue,
356356
overflow_check(m, n, lda, ldb, ldc);
357357
auto done = queue.submit([&](sycl::handler& cgh) {
358358
cgh.depends_on(dependencies);
359-
onemath_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler& sc) {
360-
auto handle = sc.get_handle(queue);
359+
onemath_cublas_host_task(cgh, [=](CublasScopedContextHandler& sc) {
360+
auto handle = sc.get_handle();
361361
auto a_ = reinterpret_cast<const cuDataType*>(a);
362362
auto b_ = reinterpret_cast<const cuDataType*>(b);
363363
auto c_ = reinterpret_cast<cuDataType*>(c);
@@ -459,8 +459,8 @@ void omatcopy(const char* func_name, Func func, sycl::queue& queue, transpose tr
459459
auto b_acc = b.template get_access<sycl::access::mode::read_write>(cgh);
460460
const int64_t logical_m = (trans == oneapi::math::transpose::nontrans ? n : m);
461461
const int64_t logical_n = (trans == oneapi::math::transpose::nontrans ? m : n);
462-
onemath_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler& sc) {
463-
auto handle = sc.get_handle(queue);
462+
onemath_cublas_host_task(cgh, [=](CublasScopedContextHandler& sc) {
463+
auto handle = sc.get_handle();
464464
auto a_ = sc.get_mem<cuDataType*>(a_acc);
465465
auto b_ = sc.get_mem<cuDataType*>(b_acc);
466466
cublasStatus_t err;
@@ -535,8 +535,8 @@ void omatadd(const char* func_name, Func func, sycl::queue& queue, transpose tra
535535
auto a_acc = a.template get_access<sycl::access::mode::read>(cgh);
536536
auto b_acc = b.template get_access<sycl::access::mode::read>(cgh);
537537
auto c_acc = c.template get_access<sycl::access::mode::read_write>(cgh);
538-
onemath_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler& sc) {
539-
auto handle = sc.get_handle(queue);
538+
onemath_cublas_host_task(cgh, [=](CublasScopedContextHandler& sc) {
539+
auto handle = sc.get_handle();
540540
auto a_ = sc.get_mem<cuDataType*>(a_acc);
541541
auto b_ = sc.get_mem<cuDataType*>(b_acc);
542542
auto c_ = sc.get_mem<cuDataType*>(c_acc);
@@ -637,8 +637,8 @@ sycl::event omatcopy(const char* func_name, Func func, sycl::queue& queue, trans
637637
cgh.depends_on(dependencies);
638638
const int64_t logical_m = (trans == oneapi::math::transpose::nontrans ? n : m);
639639
const int64_t logical_n = (trans == oneapi::math::transpose::nontrans ? m : n);
640-
onemath_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler& sc) {
641-
auto handle = sc.get_handle(queue);
640+
onemath_cublas_host_task(cgh, [=](CublasScopedContextHandler& sc) {
641+
auto handle = sc.get_handle();
642642
auto a_ = reinterpret_cast<const cuDataType*>(a);
643643
auto b_ = reinterpret_cast<cuDataType*>(b);
644644
cublasStatus_t err;
@@ -719,8 +719,8 @@ inline sycl::event omatadd(const char* func_name, Func func, sycl::queue& queue,
719719
overflow_check(m, n, lda, ldb, ldc);
720720
auto done = queue.submit([&](sycl::handler& cgh) {
721721
cgh.depends_on(dependencies);
722-
onemath_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler& sc) {
723-
auto handle = sc.get_handle(queue);
722+
onemath_cublas_host_task(cgh, [=](CublasScopedContextHandler& sc) {
723+
auto handle = sc.get_handle();
724724
auto a_ = reinterpret_cast<const cuDataType*>(a);
725725
auto b_ = reinterpret_cast<const cuDataType*>(b);
726726
auto c_ = reinterpret_cast<cuDataType*>(c);

0 commit comments

Comments
 (0)