Skip to content

Commit aec2ab1

Browse files
committed
[BLAS] SYCL-Graph specific unit tests
Create SYCL-graph extension specific tests for blas in `tests/unit_tests/blas/sycl-graph`. Currently only covering `gemm_usm` and `gemm_batch_usm` These are stubbed out for the CT tests variants, and when the SYCL compiler doesn't support the `sycl_ext_oneapi_graph` extension.
1 parent 88c0bfd commit aec2ab1

File tree

8 files changed

+709
-140
lines changed

8 files changed

+709
-140
lines changed

tests/unit_tests/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ set(blas_TEST_LIST
5454
blas_level2
5555
blas_level3
5656
blas_batch
57-
blas_extensions)
57+
blas_extensions
58+
blas_sycl_graph)
5859

5960
set(blas_TEST_LINK "")
6061

tests/unit_tests/blas/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,4 @@ add_subdirectory(level2)
2727
add_subdirectory(level3)
2828
add_subdirectory(batch)
2929
add_subdirectory(extensions)
30+
add_subdirectory(sycl-graph)

tests/unit_tests/blas/batch/gemm_batch_usm.cpp

Lines changed: 20 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ extern std::vector<sycl::device*> devices;
4848
namespace {
4949

5050
template <typename Ta, typename Tb, typename Tc, typename Ts>
51-
int test(device* dev, oneapi::math::layout layout, int64_t group_count, bool graph_record = false) {
51+
int test(device* dev, oneapi::math::layout layout, int64_t group_count) {
5252
// Catch asynchronous exceptions.
5353
auto exception_handler = [](exception_list exceptions) {
5454
for (std::exception_ptr const& e : exceptions) {
@@ -247,15 +247,6 @@ int test(device* dev, oneapi::math::layout layout, int64_t group_count, bool gra
247247

248248
try {
249249
#ifdef CALL_RT_API
250-
#ifdef SYCL_EXT_ONEAPI_GRAPH
251-
namespace sycl_exp = sycl::ext::oneapi::experimental;
252-
using modifiable_graph = sycl_exp::command_graph<sycl_exp::graph_state::modifiable>;
253-
std::unique_ptr<modifiable_graph> graph;
254-
if (graph_record) {
255-
graph = std::make_unique<modifiable_graph>(main_queue);
256-
graph->begin_recording(main_queue);
257-
}
258-
#endif
259250
switch (layout) {
260251
case oneapi::math::layout::col_major:
261252
done = oneapi::math::blas::column_major::gemm_batch(
@@ -271,18 +262,7 @@ int test(device* dev, oneapi::math::layout layout, int64_t group_count, bool gra
271262
break;
272263
default: break;
273264
}
274-
275-
#ifdef SYCL_EXT_ONEAPI_GRAPH
276-
if (graph_record) {
277-
graph->end_recording(main_queue);
278-
auto exec_graph = graph->finalize();
279-
main_queue.ext_oneapi_graph(exec_graph).wait_and_throw();
280-
}
281-
else
282-
#endif
283-
{
284-
done.wait_and_throw();
285-
}
265+
done.wait_and_throw();
286266
#else
287267
switch (layout) {
288268
case oneapi::math::layout::col_major:
@@ -385,65 +365,58 @@ int test(device* dev, oneapi::math::layout layout, int64_t group_count, bool gra
385365
}
386366

387367
class GemmBatchUsmTests
388-
: public ::testing::TestWithParam<std::tuple<sycl::device*, oneapi::math::layout, bool>> {
389-
virtual void SetUp() override {
390-
// Skip test if graph recording variant and device doesn't support sycl_ext_oneapi_graph
391-
if (std::get<2>(GetParam())) {
392-
CHECK_GRAPH_ON_DEVICE(std::get<0>(GetParam()));
393-
}
394-
}
395-
};
368+
: public ::testing::TestWithParam<std::tuple<sycl::device*, oneapi::math::layout>> {};
396369

397370
TEST_P(GemmBatchUsmTests, RealHalfPrecision) {
398371
EXPECT_TRUEORSKIP((test<sycl::half, sycl::half, sycl::half, sycl::half>(
399-
std::get<0>(GetParam()), std::get<1>(GetParam()), 5, std::get<2>(GetParam()))));
372+
std::get<0>(GetParam()), std::get<1>(GetParam()), 5)));
400373
}
401374

402375
TEST_P(GemmBatchUsmTests, HalfHalfFloatPrecision) {
403-
EXPECT_TRUEORSKIP((test<sycl::half, sycl::half, float, float>(
404-
std::get<0>(GetParam()), std::get<1>(GetParam()), 5, std::get<2>(GetParam()))));
376+
EXPECT_TRUEORSKIP((test<sycl::half, sycl::half, float, float>(std::get<0>(GetParam()),
377+
std::get<1>(GetParam()), 5)));
405378
}
406379

407380
TEST_P(GemmBatchUsmTests, Int8Int8SinglePrecision) {
408-
EXPECT_TRUEORSKIP((test<std::int8_t, std::int8_t, float, float>(
409-
std::get<0>(GetParam()), std::get<1>(GetParam()), 5, std::get<2>(GetParam()))));
381+
EXPECT_TRUEORSKIP((test<std::int8_t, std::int8_t, float, float>(std::get<0>(GetParam()),
382+
std::get<1>(GetParam()), 5)));
410383
}
411384

412385
TEST_P(GemmBatchUsmTests, Int8Int8Int32Precision) {
413386
EXPECT_TRUEORSKIP((test<std::int8_t, std::int8_t, std::int32_t, float>(
414-
std::get<0>(GetParam()), std::get<1>(GetParam()), 5, std::get<2>(GetParam()))));
387+
std::get<0>(GetParam()), std::get<1>(GetParam()), 5)));
415388
}
416389

417390
TEST_P(GemmBatchUsmTests, RealSinglePrecision) {
418-
EXPECT_TRUEORSKIP((test<float, float, float, float>(
419-
std::get<0>(GetParam()), std::get<1>(GetParam()), 5, std::get<2>(GetParam()))));
391+
EXPECT_TRUEORSKIP(
392+
(test<float, float, float, float>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)));
420393
}
421394

422395
TEST_P(GemmBatchUsmTests, RealDoublePrecision) {
423396
CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam()));
424397

425-
EXPECT_TRUEORSKIP((test<double, double, double, double>(
426-
std::get<0>(GetParam()), std::get<1>(GetParam()), 5, std::get<2>(GetParam()))));
398+
EXPECT_TRUEORSKIP((
399+
test<double, double, double, double>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)));
427400
}
428401

429402
TEST_P(GemmBatchUsmTests, ComplexSinglePrecision) {
430403
EXPECT_TRUEORSKIP(
431404
(test<std::complex<float>, std::complex<float>, std::complex<float>, std::complex<float>>(
432-
std::get<0>(GetParam()), std::get<1>(GetParam()), 5, std::get<2>(GetParam()))));
405+
std::get<0>(GetParam()), std::get<1>(GetParam()), 5)));
433406
}
434407

435408
TEST_P(GemmBatchUsmTests, ComplexDoublePrecision) {
436409
CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam()));
437410

438-
EXPECT_TRUEORSKIP((test<std::complex<double>, std::complex<double>, std::complex<double>,
439-
std::complex<double>>(std::get<0>(GetParam()), std::get<1>(GetParam()),
440-
5, std::get<2>(GetParam()))));
411+
EXPECT_TRUEORSKIP(
412+
(test<std::complex<double>, std::complex<double>, std::complex<double>,
413+
std::complex<double>>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)));
441414
}
442415

443416
INSTANTIATE_TEST_SUITE_P(GemmBatchUsmTestSuite, GemmBatchUsmTests,
444417
::testing::Combine(testing::ValuesIn(devices),
445418
testing::Values(oneapi::math::layout::col_major,
446-
oneapi::math::layout::row_major),
447-
testing::Values(true, false)),
448-
::LayoutGraphDeviceNamePrint());
419+
oneapi::math::layout::row_major)),
420+
::LayoutDeviceNamePrint());
421+
449422
} // anonymous namespace

0 commit comments

Comments
 (0)