Skip to content

Commit 8d153ca

Browse files
committed
[BLAS] SYCL-Graph integration for native-command
In order to support applications calling the library with a sycl queue recording to a SYCL-Graph, check if the `ext_codeplay_enqueue_native_command` command-group is being recorded to a graph object. If so use the native stream recording APIs to add the blas calls as nodes in the graph. In particular this fixes the llama.cpp unit test `MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0)` on CUDA with SYCL-Graph enabled. Previously this would throw an error: ```sh $ GGML_SYCL_DISABLE_GRAPH=0 ./bin/test-backend-ops -b SYCL0 -o MUL_MAT -p type_a=f16,type_b=f32,m=16,n=1,k=256,bs=\\[1,1\\],nr=\\[2 UR CUDA ERROR: Value: 700 Name: CUDA_ERROR_ILLEGAL_ADDRESS Description: an illegal memory access was encountered Function: operator() Source Location: $HOME/dpcpp/unified-runtime/source/adapters/cuda/queue.cpp:154 Native API failed. Native API returns: 2147483646 (UR_RESULT_ERROR_UNKNOWN) Exception caught at file:$HOME/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp, line:3598, func:operator() SYCL error: CHECK_TRY_ERROR((stream)->wait()): Meet error in this line code! in function ggml_backend_sycl_synchronize at $HOME/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp:3598 $HOME/llama.cpp/ggml/src/ggml-sycl/../ggml-sycl/common.hpp:118: SYCL error Could not attach to process. If your uid matches the uid of the target process, check the setting of /proc/sys/kernel/yama/ptrace_scope, or try again as the root user. For more details, see /etc/sysctl.d/10-ptrace.conf ptrace: Operation not permitted. No stack. The program is not being run. ```
1 parent f0b9b9a commit 8d153ca

File tree

5 files changed

+155
-2
lines changed

5 files changed

+155
-2
lines changed

src/blas/backends/cublas/cublas_scope_handle.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,56 @@ cublasHandle_t CublasScopedContextHandler::get_handle(const sycl::queue& queue)
6060
return nativeHandle;
6161
}
6262

63+
void CublasScopedContextHandler::begin_recording_if_graph(const sycl::queue& queue) {
64+
// interop_handle graph methods only available from extension version 2
65+
#if SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND >= 2
66+
if (!ih.ext_codeplay_has_graph()) {
67+
return;
68+
}
69+
70+
auto stream = get_stream(queue);
71+
CUresult err;
72+
#if CUDA_VERSION >= 12030
73+
// After CUDA 12.3 we can use cuStreamBeginCaptureToGraph to capture
74+
// the stream directly in the native graph, rather than needing to
75+
// instantiate the stream capture as a new graph.
76+
auto graph = ih.ext_codeplay_get_native_graph<sycl::backend::ext_oneapi_cuda>();
77+
CUDA_ERROR_FUNC(cuStreamBeginCaptureToGraph, err, stream, graph, nullptr, nullptr, 0,
78+
CU_STREAM_CAPTURE_MODE_GLOBAL);
79+
#else
80+
CUDA_ERROR_FUNC(cuStreamBeginCapture, err, stream, CU_STREAM_CAPTURE_MODE_GLOBAL);
81+
#endif // CUDA_VERSION
82+
#endif // SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND >= 2
83+
}
84+
85+
void CublasScopedContextHandler::end_recording_if_graph(const sycl::queue& queue) {
86+
// interop_handle graph methods only available from extension version 2
87+
#if SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND >= 2
88+
if (!ih.ext_codeplay_has_graph()) {
89+
return;
90+
}
91+
92+
auto graph = ih.ext_codeplay_get_native_graph<sycl::backend::ext_oneapi_cuda>();
93+
auto stream = get_stream(queue);
94+
CUresult err;
95+
#if CUDA_VERSION >= 12030
96+
CUDA_ERROR_FUNC(cuStreamEndCapture, err, stream, &graph);
97+
#else
98+
// cuStreamEndCapture returns a new graph, if we overwrite
99+
// "graph" it won't be picked up by the SYCL runtime, as
100+
// "ext_codeplay_get_native_graph" returns a passed-by-value pointer.
101+
CUgraph recorded_graph;
102+
CUDA_ERROR_FUNC(cuStreamEndCapture, err, stream, &recorded_graph);
103+
104+
// Add graph to native graph as a child node
105+
// Need to return a node object for the node to be created,
106+
// can't be nullptr.
107+
CUgraphNode node;
108+
CUDA_ERROR_FUNC(cuGraphAddChildGraphNode, err, &node, graph, nullptr, 0, recorded_graph);
109+
#endif // CUDA_VERSION
110+
#endif // SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND >= 2
111+
}
112+
63113
CUstream CublasScopedContextHandler::get_stream(const sycl::queue& queue) {
64114
return sycl::get_native<sycl::backend::ext_oneapi_cuda>(queue);
65115
}

src/blas/backends/cublas/cublas_scope_handle.hpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,28 @@ class CublasScopedContextHandler {
6969
public:
7070
CublasScopedContextHandler(sycl::interop_handle& ih);
7171

72+
/**
73+
* @brief Start recording cuBlas calls to a graph.
74+
* @detail Checks if the command-group associated with \p ih is being added
75+
* to a graph, and if so, begin stream recording of the native CUDA stream
76+
* associated with \p queue to the native cuda-graph object.
77+
* @param queue The sycl queue to start stream recording on native stream
78+
* backing the queue.
79+
*/
80+
void begin_recording_if_graph(const sycl::queue& queue);
81+
82+
/**
83+
* @brief End recording cuBlas calls to a graph.
84+
* @detail Checks if the command-group associated with \p ih is being added
85+
* to a graph, and if so, ends stream recording of the native CUDA stream
86+
* associated with \p queue to the native cuda-graph object. Doing any
87+
* extra work to ensure that stream recorded calls get added as nodes to
88+
* the native graph object associated with \p ih.
89+
* @param queue The sycl queue to end stream recording on native stream
90+
* backing the queue.
91+
*/
92+
void end_recording_if_graph(const sycl::queue& queue);
93+
7294
/**
7395
* @brief get_handle: creates the handle by implicitly impose the advice
7496
* given by nvidia for creating a cublas_handle. (e.g. one cuStream per device

src/blas/backends/cublas/cublas_task.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ static inline void host_task_internal(H& cgh, sycl::queue queue, F f) {
6161
cgh.host_task([f, queue](sycl::interop_handle ih) {
6262
#endif
6363
auto sc = CublasScopedContextHandler(ih);
64+
sc.begin_recording_if_graph(queue);
6465
f(sc);
66+
sc.end_recording_if_graph(queue);
6567
});
6668
}
6769
#endif

tests/unit_tests/blas/batch/gemm_batch_usm.cpp

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ extern std::vector<sycl::device*> devices;
4848
namespace {
4949

5050
template <typename Ta, typename Tb, typename Tc, typename Ts>
51-
int test(device* dev, oneapi::math::layout layout, int64_t group_count) {
51+
int test(device* dev, oneapi::math::layout layout, int64_t group_count, bool graph_record = false) {
5252
// Catch asynchronous exceptions.
5353
auto exception_handler = [](exception_list exceptions) {
5454
for (std::exception_ptr const& e : exceptions) {
@@ -247,6 +247,13 @@ int test(device* dev, oneapi::math::layout layout, int64_t group_count) {
247247

248248
try {
249249
#ifdef CALL_RT_API
250+
namespace sycl_exp = sycl::ext::oneapi::experimental;
251+
using modifiable_graph = sycl_exp::command_graph<sycl_exp::graph_state::modifiable>;
252+
std::unique_ptr<modifiable_graph> graph;
253+
if (graph_record) {
254+
graph = std::make_unique<modifiable_graph>(main_queue);
255+
graph->begin_recording(main_queue);
256+
}
250257
switch (layout) {
251258
case oneapi::math::layout::col_major:
252259
done = oneapi::math::blas::column_major::gemm_batch(
@@ -262,7 +269,15 @@ int test(device* dev, oneapi::math::layout layout, int64_t group_count) {
262269
break;
263270
default: break;
264271
}
265-
done.wait_and_throw();
272+
273+
if (graph_record) {
274+
graph->end_recording(main_queue);
275+
auto exec_graph = graph->finalize();
276+
main_queue.ext_oneapi_graph(exec_graph).wait_and_throw();
277+
}
278+
else {
279+
done.wait_and_throw();
280+
}
266281
#else
267282
switch (layout) {
268283
case oneapi::math::layout::col_major:
@@ -419,4 +434,64 @@ INSTANTIATE_TEST_SUITE_P(GemmBatchUsmTestSuite, GemmBatchUsmTests,
419434
oneapi::math::layout::row_major)),
420435
::LayoutDeviceNamePrint());
421436

437+
// Test using sycl_ext_oneapi_graph to record the operations from a sycl::queue
438+
// to a graph, then execute the graph.
439+
class GraphRecordGemmBatchUsmTests
440+
: public ::testing::TestWithParam<std::tuple<sycl::device*, oneapi::math::layout>> {
441+
virtual void SetUp() override {
442+
CHECK_GRAPH_ON_DEVICE(std::get<0>(GetParam()));
443+
}
444+
};
445+
446+
TEST_P(GraphRecordGemmBatchUsmTests, RealHalfPrecision) {
447+
EXPECT_TRUEORSKIP((test<sycl::half, sycl::half, sycl::half, sycl::half>(
448+
std::get<0>(GetParam()), std::get<1>(GetParam()), 5, true)));
449+
}
450+
451+
TEST_P(GraphRecordGemmBatchUsmTests, HalfHalfFloatPrecision) {
452+
EXPECT_TRUEORSKIP((test<sycl::half, sycl::half, float, float>(
453+
std::get<0>(GetParam()), std::get<1>(GetParam()), 5, true)));
454+
}
455+
456+
TEST_P(GraphRecordGemmBatchUsmTests, Int8Int8SinglePrecision) {
457+
EXPECT_TRUEORSKIP((test<std::int8_t, std::int8_t, float, float>(
458+
std::get<0>(GetParam()), std::get<1>(GetParam()), 5, true)));
459+
}
460+
461+
TEST_P(GraphRecordGemmBatchUsmTests, Int8Int8Int32Precision) {
462+
EXPECT_TRUEORSKIP((test<std::int8_t, std::int8_t, std::int32_t, float>(
463+
std::get<0>(GetParam()), std::get<1>(GetParam()), 5, true)));
464+
}
465+
466+
TEST_P(GraphRecordGemmBatchUsmTests, RealSinglePrecision) {
467+
EXPECT_TRUEORSKIP((test<float, float, float, float>(std::get<0>(GetParam()),
468+
std::get<1>(GetParam()), 5, true)));
469+
}
470+
471+
TEST_P(GraphRecordGemmBatchUsmTests, RealDoublePrecision) {
472+
CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam()));
473+
474+
EXPECT_TRUEORSKIP((test<double, double, double, double>(std::get<0>(GetParam()),
475+
std::get<1>(GetParam()), 5, true)));
476+
}
477+
478+
TEST_P(GraphRecordGemmBatchUsmTests, ComplexSinglePrecision) {
479+
EXPECT_TRUEORSKIP(
480+
(test<std::complex<float>, std::complex<float>, std::complex<float>, std::complex<float>>(
481+
std::get<0>(GetParam()), std::get<1>(GetParam()), 5, true)));
482+
}
483+
484+
TEST_P(GraphRecordGemmBatchUsmTests, ComplexDoublePrecision) {
485+
CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam()));
486+
487+
EXPECT_TRUEORSKIP(
488+
(test<std::complex<double>, std::complex<double>, std::complex<double>,
489+
std::complex<double>>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5, true)));
490+
}
491+
492+
INSTANTIATE_TEST_SUITE_P(GraphRecordGemmBatchUsmTestSuite, GraphRecordGemmBatchUsmTests,
493+
::testing::Combine(testing::ValuesIn(devices),
494+
testing::Values(oneapi::math::layout::col_major,
495+
oneapi::math::layout::row_major)),
496+
::LayoutDeviceNamePrint());
422497
} // anonymous namespace

tests/unit_tests/include/test_helper.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@
7373
if (d->get_info<sycl::info::device::double_fp_config>().size() == 0) \
7474
GTEST_SKIP() << "Double precision is not supported on the device"
7575

76+
#define CHECK_GRAPH_ON_DEVICE(d) \
77+
if (!d->has(aspect::ext_oneapi_limited_graph)) \
78+
GTEST_SKIP() << "SYCL-Graph is not supported on the device"
79+
7680
#if defined(ONEMATH_ENABLE_MKLCPU_BACKEND) || defined(ONEMATH_ENABLE_NETLIB_BACKEND) || \
7781
defined(ONEMATH_ENABLE_ARMPL_BACKEND)
7882
#ifdef ONEMATH_ENABLE_MKLCPU_BACKEND

0 commit comments

Comments
 (0)