Skip to content

Commit 8efe85f

Browse files
authored
[BLAS] SYCL-Graph integration for native-command (#669)
1 parent beb4d6b commit 8efe85f

File tree

9 files changed

+753
-24
lines changed

9 files changed

+753
-24
lines changed

src/blas/backends/cublas/cublas_scope_handle.cpp

Lines changed: 60 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,36 +32,80 @@ namespace cublas {
3232
*/
3333
thread_local cublas_handle CublasScopedContextHandler::handle_helper = cublas_handle{};
3434

35-
CublasScopedContextHandler::CublasScopedContextHandler(sycl::interop_handle& ih) : ih(ih) {}
35+
CublasScopedContextHandler::CublasScopedContextHandler(sycl::interop_handle& ih) : ih(ih) {
36+
// Initialize streamID member to a CUstream associated with the queue `ih`
37+
// has been submitted to.
38+
streamId = ih.get_native_queue<sycl::backend::ext_oneapi_cuda>();
3639

37-
cublasHandle_t CublasScopedContextHandler::get_handle() {
40+
// Initialize the `cublasHandle_t` member `nativeHandle`
3841
CUdevice device = ih.get_native_device<sycl::backend::ext_oneapi_cuda>();
39-
CUstream streamId = get_stream();
40-
cublasStatus_t err;
41-
4242
auto it = handle_helper.cublas_handle_mapper_.find(device);
4343
if (it != handle_helper.cublas_handle_mapper_.end()) {
44-
cublasHandle_t nativeHandle = it->second;
44+
// Use existing handle if one already exists for the device, but update
45+
// the native stream.
46+
nativeHandle = it->second;
4547
cudaStream_t currentStreamId;
48+
cublasStatus_t err;
4649
CUBLAS_ERROR_FUNC(cublasGetStream, err, nativeHandle, &currentStreamId);
4750
if (currentStreamId != streamId) {
4851
CUBLAS_ERROR_FUNC(cublasSetStream, err, nativeHandle, streamId);
4952
}
50-
return nativeHandle;
5153
}
52-
53-
cublasHandle_t nativeHandle;
54-
CUBLAS_ERROR_FUNC(cublasCreate, err, &nativeHandle);
55-
CUBLAS_ERROR_FUNC(cublasSetStream, err, nativeHandle, streamId);
56-
57-
auto insert_iter =
54+
else {
55+
// Create a new handle if one doesn't already exist for the device
56+
cublasStatus_t err;
57+
CUBLAS_ERROR_FUNC(cublasCreate, err, &nativeHandle);
58+
CUBLAS_ERROR_FUNC(cublasSetStream, err, nativeHandle, streamId);
5859
handle_helper.cublas_handle_mapper_.insert(std::make_pair(device, nativeHandle));
60+
}
61+
}
5962

60-
return nativeHandle;
63+
void CublasScopedContextHandler::begin_recording_if_graph() {
64+
// interop_handle graph methods only available from extension version 2
65+
#if SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND >= 2
66+
if (!ih.ext_codeplay_has_graph()) {
67+
return;
68+
}
69+
70+
CUresult err;
71+
#if CUDA_VERSION >= 12030
72+
// After CUDA 12.3 we can use cuStreamBeginCaptureToGraph to capture
73+
// the stream directly in the native graph, rather than needing to
74+
// instantiate the stream capture as a new graph.
75+
auto graph = ih.ext_codeplay_get_native_graph<sycl::backend::ext_oneapi_cuda>();
76+
CUDA_ERROR_FUNC(cuStreamBeginCaptureToGraph, err, streamId, graph, nullptr, nullptr, 0,
77+
CU_STREAM_CAPTURE_MODE_GLOBAL);
78+
#else
79+
CUDA_ERROR_FUNC(cuStreamBeginCapture, err, streamId, CU_STREAM_CAPTURE_MODE_GLOBAL);
80+
#endif // CUDA_VERSION
81+
#endif // SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND >= 2
6182
}
6283

63-
CUstream CublasScopedContextHandler::get_stream() {
64-
return ih.get_native_queue<sycl::backend::ext_oneapi_cuda>();
84+
void CublasScopedContextHandler::end_recording_if_graph() {
85+
// interop_handle graph methods only available from extension version 2
86+
#if SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND >= 2
87+
if (!ih.ext_codeplay_has_graph()) {
88+
return;
89+
}
90+
91+
auto graph = ih.ext_codeplay_get_native_graph<sycl::backend::ext_oneapi_cuda>();
92+
CUresult err;
93+
#if CUDA_VERSION >= 12030
94+
CUDA_ERROR_FUNC(cuStreamEndCapture, err, streamId, &graph);
95+
#else
96+
// cuStreamEndCapture returns a new graph, if we overwrite
97+
// "graph" it won't be picked up by the SYCL runtime, as
98+
// "ext_codeplay_get_native_graph" returns a passed-by-value pointer.
99+
CUgraph recorded_graph;
100+
CUDA_ERROR_FUNC(cuStreamEndCapture, err, streamId, &recorded_graph);
101+
102+
// Add graph to native graph as a child node
103+
// Need to return a node object for the node to be created,
104+
// can't be nullptr.
105+
CUgraphNode node;
106+
CUDA_ERROR_FUNC(cuGraphAddChildGraphNode, err, &node, graph, nullptr, 0, recorded_graph);
107+
#endif // CUDA_VERSION
108+
#endif // SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND >= 2
65109
}
66110
} // namespace cublas
67111
} // namespace blas

src/blas/backends/cublas/cublas_scope_handle.hpp

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,18 +63,49 @@ the handle must be destroyed when the context goes out of scope. This will bind
6363
class CublasScopedContextHandler {
6464
sycl::interop_handle& ih;
6565
static thread_local cublas_handle handle_helper;
66-
CUstream get_stream();
66+
cublasHandle_t nativeHandle;
67+
// Cache the native CU stream when the `CublasScopedContextHandler`object
68+
// is constructed. This avoids calling `get_native_queue(ih)` multiple
69+
// times which isn't guaranteed to return the same CUstream handle each
70+
// time. A scenario that causes problems when trying to start/end cuda
71+
// stream recording to a graph.
72+
CUstream streamId;
6773

6874
public:
75+
/**
76+
* @brief Constructor
77+
* @detail Creates the cublasHandle_t by implicitly impose the advice
78+
* given by nvidia for creating a cublas_handle. (e.g. one cuStream per device
79+
* per thread).
80+
*/
6981
CublasScopedContextHandler(sycl::interop_handle& ih);
7082

7183
/**
72-
* @brief get_handle: creates the handle by implicitly impose the advice
73-
* given by nvidia for creating a cublas_handle. (e.g. one cuStream per device
74-
* per thread).
75-
* @return cublasHandle_t a handle to construct cublas routines
76-
*/
77-
cublasHandle_t get_handle();
84+
* @brief Start recording cuBlas calls to a graph.
85+
* @detail Checks if the command-group associated with \p ih is being added
86+
* to a graph, and if so, begin stream recording of the native CUDA stream
87+
* associated with \p queue to the native cuda-graph object.
88+
*/
89+
void begin_recording_if_graph();
90+
91+
/**
92+
* @brief End recording cuBlas calls to a graph.
93+
* @detail Checks if the command-group associated with \p ih is being added
94+
* to a graph, and if so, ends stream recording of the native CUDA stream
95+
* associated with \p queue to the native cuda-graph object. Doing any
96+
* extra work to ensure that stream recorded calls get added as nodes to
97+
* the native graph object associated with \p ih.
98+
* @param queue The sycl queue to end stream recording on native stream
99+
* backing the queue.
100+
*/
101+
void end_recording_if_graph();
102+
103+
/// @brief Query the cuBLAS handle created on construction
104+
/// @return cublasHandle_t a handle to construct cublas routines
105+
cublasHandle_t get_handle() const {
106+
return nativeHandle;
107+
}
108+
78109
// This is a work-around function for reinterpret_casting the memory. This
79110
// will be fixed when SYCL-2020 has been implemented for Pi backend.
80111
template <typename T, typename U>

src/blas/backends/cublas/cublas_task.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ static inline void host_task_internal(H& cgh, F f) {
6161
cgh.host_task([f](sycl::interop_handle ih) {
6262
#endif
6363
auto sc = CublasScopedContextHandler(ih);
64+
sc.begin_recording_if_graph();
6465
f(sc);
66+
sc.end_recording_if_graph();
6567
});
6668
}
6769
#endif

tests/unit_tests/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ set(blas_TEST_LIST
5454
blas_level2
5555
blas_level3
5656
blas_batch
57-
blas_extensions)
57+
blas_extensions
58+
blas_sycl_graph)
5859

5960
set(blas_TEST_LINK "")
6061

tests/unit_tests/blas/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,4 @@ add_subdirectory(level2)
2727
add_subdirectory(level3)
2828
add_subdirectory(batch)
2929
add_subdirectory(extensions)
30+
add_subdirectory(sycl-graph)
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#===============================================================================
2+
# Copyright 2025 Intel Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing,
11+
# software distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions
14+
# and limitations under the License.
15+
#
16+
#
17+
# SPDX-License-Identifier: Apache-2.0
18+
#===============================================================================
19+
20+
# Build object from all test sources
21+
set(SYCL_GRAPH_SOURCES)
22+
23+
set(SYCL_GRAPH_SOURCES_W_CBLAS "gemm_usm.cpp" "gemm_batch_usm.cpp")
24+
25+
if(CBLAS_FOUND)
26+
list(APPEND SYCL_GRAPH_SOURCES ${SYCL_GRAPH_SOURCES_W_CBLAS})
27+
endif()
28+
29+
if(BUILD_SHARED_LIBS)
30+
add_library(blas_sycl_graph_rt OBJECT ${SYCL_GRAPH_SOURCES})
31+
target_compile_options(blas_sycl_graph_rt PRIVATE -DCALL_RT_API -DNOMINMAX)
32+
target_include_directories(blas_sycl_graph_rt
33+
PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../include
34+
PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../include
35+
PUBLIC ${PROJECT_SOURCE_DIR}/include
36+
PUBLIC ${PROJECT_SOURCE_DIR}/deps/googletest/include
37+
PUBLIC ${CMAKE_BINARY_DIR}/bin
38+
$<$<BOOL:${CBLAS_FOUND}>:${CBLAS_INCLUDE}>
39+
)
40+
if (USE_ADD_SYCL_TO_TARGET_INTEGRATION)
41+
add_sycl_to_target(TARGET blas_sycl_graph_rt SOURCES ${SYCL_GRAPH_SOURCES})
42+
else()
43+
target_link_libraries(blas_sycl_graph_rt PUBLIC ONEMATH::SYCL::SYCL)
44+
endif()
45+
endif()
46+
47+
add_library(blas_sycl_graph_ct OBJECT ${SYCL_GRAPH_SOURCES})
48+
target_compile_options(blas_sycl_graph_ct PRIVATE -DNOMINMAX)
49+
target_include_directories(blas_sycl_graph_ct
50+
PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../include
51+
PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../include
52+
PUBLIC ${PROJECT_SOURCE_DIR}/include
53+
PUBLIC ${PROJECT_SOURCE_DIR}/deps/googletest/include
54+
PUBLIC ${CMAKE_BINARY_DIR}/bin
55+
$<$<BOOL:${CBLAS_FOUND}>:${CBLAS_INCLUDE}>
56+
)
57+
if (USE_ADD_SYCL_TO_TARGET_INTEGRATION)
58+
add_sycl_to_target(TARGET blas_sycl_graph_ct SOURCES ${SYCL_GRAPH_SOURCES})
59+
else()
60+
target_link_libraries(blas_sycl_graph_ct PUBLIC ONEMATH::SYCL::SYCL)
61+
endif()

0 commit comments

Comments
 (0)