@@ -48,7 +48,7 @@ extern std::vector<sycl::device*> devices;
48
48
namespace {
49
49
50
50
template <typename Ta, typename Tb, typename Tc, typename Ts>
51
- int test (device* dev, oneapi::math::layout layout, int64_t group_count, bool graph_record = false ) {
51
+ int test (device* dev, oneapi::math::layout layout, int64_t group_count) {
52
52
// Catch asynchronous exceptions.
53
53
auto exception_handler = [](exception_list exceptions) {
54
54
for (std::exception_ptr const & e : exceptions) {
@@ -247,15 +247,6 @@ int test(device* dev, oneapi::math::layout layout, int64_t group_count, bool gra
247
247
248
248
try {
249
249
#ifdef CALL_RT_API
250
- #ifdef SYCL_EXT_ONEAPI_GRAPH
251
- namespace sycl_exp = sycl::ext::oneapi::experimental;
252
- using modifiable_graph = sycl_exp::command_graph<sycl_exp::graph_state::modifiable>;
253
- std::unique_ptr<modifiable_graph> graph;
254
- if (graph_record) {
255
- graph = std::make_unique<modifiable_graph>(main_queue);
256
- graph->begin_recording (main_queue);
257
- }
258
- #endif
259
250
switch (layout) {
260
251
case oneapi::math::layout::col_major:
261
252
done = oneapi::math::blas::column_major::gemm_batch (
@@ -271,18 +262,7 @@ int test(device* dev, oneapi::math::layout layout, int64_t group_count, bool gra
271
262
break ;
272
263
default : break ;
273
264
}
274
-
275
- #ifdef SYCL_EXT_ONEAPI_GRAPH
276
- if (graph_record) {
277
- graph->end_recording (main_queue);
278
- auto exec_graph = graph->finalize ();
279
- main_queue.ext_oneapi_graph (exec_graph).wait_and_throw ();
280
- }
281
- else
282
- #endif
283
- {
284
- done.wait_and_throw ();
285
- }
265
+ done.wait_and_throw ();
286
266
#else
287
267
switch (layout) {
288
268
case oneapi::math::layout::col_major:
@@ -385,65 +365,58 @@ int test(device* dev, oneapi::math::layout layout, int64_t group_count, bool gra
385
365
}
386
366
387
367
class GemmBatchUsmTests
388
- : public ::testing::TestWithParam<std::tuple<sycl::device*, oneapi::math::layout, bool >> {
389
- virtual void SetUp () override {
390
- // Skip test if graph recording variant and device doesn't support sycl_ext_oneapi_graph
391
- if (std::get<2 >(GetParam ())) {
392
- CHECK_GRAPH_ON_DEVICE (std::get<0 >(GetParam ()));
393
- }
394
- }
395
- };
368
+ : public ::testing::TestWithParam<std::tuple<sycl::device*, oneapi::math::layout>> {};
396
369
397
370
TEST_P (GemmBatchUsmTests, RealHalfPrecision) {
398
371
EXPECT_TRUEORSKIP ((test<sycl::half, sycl::half, sycl::half, sycl::half>(
399
- std::get<0 >(GetParam ()), std::get<1 >(GetParam ()), 5 , std::get< 2 >( GetParam ()) )));
372
+ std::get<0 >(GetParam ()), std::get<1 >(GetParam ()), 5 )));
400
373
}
401
374
402
375
TEST_P (GemmBatchUsmTests, HalfHalfFloatPrecision) {
403
- EXPECT_TRUEORSKIP ((test<sycl::half, sycl::half, float , float >(
404
- std::get< 0 >( GetParam ()), std::get<1 >(GetParam ()), 5 , std::get< 2 >( GetParam ()) )));
376
+ EXPECT_TRUEORSKIP ((test<sycl::half, sycl::half, float , float >(std::get< 0 >( GetParam ()),
377
+ std::get<1 >(GetParam ()), 5 )));
405
378
}
406
379
407
380
TEST_P (GemmBatchUsmTests, Int8Int8SinglePrecision) {
408
- EXPECT_TRUEORSKIP ((test<std::int8_t , std::int8_t , float , float >(
409
- std::get< 0 >( GetParam ()), std::get<1 >(GetParam ()), 5 , std::get< 2 >( GetParam ()) )));
381
+ EXPECT_TRUEORSKIP ((test<std::int8_t , std::int8_t , float , float >(std::get< 0 >( GetParam ()),
382
+ std::get<1 >(GetParam ()), 5 )));
410
383
}
411
384
412
385
TEST_P (GemmBatchUsmTests, Int8Int8Int32Precision) {
413
386
EXPECT_TRUEORSKIP ((test<std::int8_t , std::int8_t , std::int32_t , float >(
414
- std::get<0 >(GetParam ()), std::get<1 >(GetParam ()), 5 , std::get< 2 >( GetParam ()) )));
387
+ std::get<0 >(GetParam ()), std::get<1 >(GetParam ()), 5 )));
415
388
}
416
389
417
390
TEST_P (GemmBatchUsmTests, RealSinglePrecision) {
418
- EXPECT_TRUEORSKIP ((test< float , float , float , float >(
419
- std::get<0 >(GetParam ()), std::get<1 >(GetParam ()), 5 , std::get< 2 >( GetParam ()) )));
391
+ EXPECT_TRUEORSKIP (
392
+ (test< float , float , float , float >( std::get<0 >(GetParam ()), std::get<1 >(GetParam ()), 5 )));
420
393
}
421
394
422
395
TEST_P (GemmBatchUsmTests, RealDoublePrecision) {
423
396
CHECK_DOUBLE_ON_DEVICE (std::get<0 >(GetParam ()));
424
397
425
- EXPECT_TRUEORSKIP ((test< double , double , double , double >(
426
- std::get<0 >(GetParam ()), std::get<1 >(GetParam ()), 5 , std::get< 2 >( GetParam ()) )));
398
+ EXPECT_TRUEORSKIP ((
399
+ test< double , double , double , double >( std::get<0 >(GetParam ()), std::get<1 >(GetParam ()), 5 )));
427
400
}
428
401
429
402
TEST_P (GemmBatchUsmTests, ComplexSinglePrecision) {
430
403
EXPECT_TRUEORSKIP (
431
404
(test<std::complex<float >, std::complex<float >, std::complex<float >, std::complex<float >>(
432
- std::get<0 >(GetParam ()), std::get<1 >(GetParam ()), 5 , std::get< 2 >( GetParam ()) )));
405
+ std::get<0 >(GetParam ()), std::get<1 >(GetParam ()), 5 )));
433
406
}
434
407
435
408
TEST_P (GemmBatchUsmTests, ComplexDoublePrecision) {
436
409
CHECK_DOUBLE_ON_DEVICE (std::get<0 >(GetParam ()));
437
410
438
- EXPECT_TRUEORSKIP ((test<std::complex< double >, std::complex< double >, std::complex< double >,
439
- std::complex<double >>( std::get< 0 >( GetParam ()) , std::get< 1 >( GetParam ()) ,
440
- 5 , std::get<2 >(GetParam ()))));
411
+ EXPECT_TRUEORSKIP (
412
+ (test< std::complex<double >, std::complex< double > , std::complex< double > ,
413
+ std::complex< double >>(std::get< 0 >( GetParam ()) , std::get<1 >(GetParam ()), 5 )));
441
414
}
442
415
443
416
INSTANTIATE_TEST_SUITE_P (GemmBatchUsmTestSuite, GemmBatchUsmTests,
444
417
::testing::Combine (testing::ValuesIn(devices),
445
418
testing::Values(oneapi::math::layout::col_major,
446
- oneapi::math::layout::row_major),
447
- testing::Values( true , false )),
448
- ::LayoutGraphDeviceNamePrint());
419
+ oneapi::math::layout::row_major)) ,
420
+ ::LayoutDeviceNamePrint());
421
+
449
422
} // anonymous namespace
0 commit comments