@@ -32,36 +32,80 @@ namespace cublas {
32
32
*/
33
33
thread_local cublas_handle CublasScopedContextHandler::handle_helper = cublas_handle{};
34
34
35
- CublasScopedContextHandler::CublasScopedContextHandler (sycl::interop_handle& ih) : ih(ih) {}
35
+ CublasScopedContextHandler::CublasScopedContextHandler (sycl::interop_handle& ih) : ih(ih) {
36
+ // Initialize streamID member to a CUstream associated with the queue `ih`
37
+ // has been submitted to.
38
+ streamId = ih.get_native_queue <sycl::backend::ext_oneapi_cuda>();
36
39
37
- cublasHandle_t CublasScopedContextHandler::get_handle () {
40
+ // Initialize the ` cublasHandle_t` member `nativeHandle`
38
41
CUdevice device = ih.get_native_device <sycl::backend::ext_oneapi_cuda>();
39
- CUstream streamId = get_stream ();
40
- cublasStatus_t err;
41
-
42
42
auto it = handle_helper.cublas_handle_mapper_ .find (device);
43
43
if (it != handle_helper.cublas_handle_mapper_ .end ()) {
44
- cublasHandle_t nativeHandle = it->second ;
44
+ // Use existing handle if one already exists for the device, but update
45
+ // the native stream.
46
+ nativeHandle = it->second ;
45
47
cudaStream_t currentStreamId;
48
+ cublasStatus_t err;
46
49
CUBLAS_ERROR_FUNC (cublasGetStream, err, nativeHandle, ¤tStreamId);
47
50
if (currentStreamId != streamId) {
48
51
CUBLAS_ERROR_FUNC (cublasSetStream, err, nativeHandle, streamId);
49
52
}
50
- return nativeHandle;
51
53
}
52
-
53
- cublasHandle_t nativeHandle;
54
- CUBLAS_ERROR_FUNC (cublasCreate, err, &nativeHandle);
55
- CUBLAS_ERROR_FUNC (cublasSetStream, err, nativeHandle, streamId);
56
-
57
- auto insert_iter =
54
+ else {
55
+ // Create a new handle if one doesn't already exist for the device
56
+ cublasStatus_t err;
57
+ CUBLAS_ERROR_FUNC (cublasCreate, err, &nativeHandle);
58
+ CUBLAS_ERROR_FUNC (cublasSetStream, err, nativeHandle, streamId);
58
59
handle_helper.cublas_handle_mapper_ .insert (std::make_pair (device, nativeHandle));
60
+ }
61
+ }
59
62
60
- return nativeHandle;
63
+ void CublasScopedContextHandler::begin_recording_if_graph () {
64
+ // interop_handle graph methods only available from extension version 2
65
+ #if SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND >= 2
66
+ if (!ih.ext_codeplay_has_graph ()) {
67
+ return ;
68
+ }
69
+
70
+ CUresult err;
71
+ #if CUDA_VERSION >= 12030
72
+ // After CUDA 12.3 we can use cuStreamBeginCaptureToGraph to capture
73
+ // the stream directly in the native graph, rather than needing to
74
+ // instantiate the stream capture as a new graph.
75
+ auto graph = ih.ext_codeplay_get_native_graph <sycl::backend::ext_oneapi_cuda>();
76
+ CUDA_ERROR_FUNC (cuStreamBeginCaptureToGraph, err, streamId, graph, nullptr , nullptr , 0 ,
77
+ CU_STREAM_CAPTURE_MODE_GLOBAL);
78
+ #else
79
+ CUDA_ERROR_FUNC (cuStreamBeginCapture, err, streamId, CU_STREAM_CAPTURE_MODE_GLOBAL);
80
+ #endif // CUDA_VERSION
81
+ #endif // SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND >= 2
61
82
}
62
83
63
- CUstream CublasScopedContextHandler::get_stream () {
64
- return ih.get_native_queue <sycl::backend::ext_oneapi_cuda>();
84
+ void CublasScopedContextHandler::end_recording_if_graph () {
85
+ // interop_handle graph methods only available from extension version 2
86
+ #if SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND >= 2
87
+ if (!ih.ext_codeplay_has_graph ()) {
88
+ return ;
89
+ }
90
+
91
+ auto graph = ih.ext_codeplay_get_native_graph <sycl::backend::ext_oneapi_cuda>();
92
+ CUresult err;
93
+ #if CUDA_VERSION >= 12030
94
+ CUDA_ERROR_FUNC (cuStreamEndCapture, err, streamId, &graph);
95
+ #else
96
+ // cuStreamEndCapture returns a new graph, if we overwrite
97
+ // "graph" it won't be picked up by the SYCL runtime, as
98
+ // "ext_codeplay_get_native_graph" returns a passed-by-value pointer.
99
+ CUgraph recorded_graph;
100
+ CUDA_ERROR_FUNC (cuStreamEndCapture, err, streamId, &recorded_graph);
101
+
102
+ // Add graph to native graph as a child node
103
+ // Need to return a node object for the node to be created,
104
+ // can't be nullptr.
105
+ CUgraphNode node;
106
+ CUDA_ERROR_FUNC (cuGraphAddChildGraphNode, err, &node, graph, nullptr , 0 , recorded_graph);
107
+ #endif // CUDA_VERSION
108
+ #endif // SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND >= 2
65
109
}
66
110
} // namespace cublas
67
111
} // namespace blas
0 commit comments