diff --git a/dpctl-capi/include/dpctl_sycl_event_interface.h b/dpctl-capi/include/dpctl_sycl_event_interface.h index 64505885bd..32ebd03226 100644 --- a/dpctl-capi/include/dpctl_sycl_event_interface.h +++ b/dpctl-capi/include/dpctl_sycl_event_interface.h @@ -31,6 +31,7 @@ #include "dpctl_data_types.h" #include "dpctl_sycl_enum_types.h" #include "dpctl_sycl_types.h" +#include "dpctl_vector.h" DPCTL_C_EXTERN_C_BEGIN @@ -38,6 +39,10 @@ DPCTL_C_EXTERN_C_BEGIN * @defgroup EventInterface Event class C wrapper */ +// Declares a set of types and functions to deal with vectors of +// DPCTLSyclEventRef. Refer dpctl_vector_macros.h +DPCTL_DECLARE_VECTOR(Event) + /*! * @brief A wrapper for ``sycl::event`` contructor to construct a new event. * @@ -150,4 +155,16 @@ uint64_t DPCTLEvent_GetProfilingInfoStart(__dpctl_keep DPCTLSyclEventRef ERef); DPCTL_API uint64_t DPCTLEvent_GetProfilingInfoEnd(__dpctl_keep DPCTLSyclEventRef ERef); +/*! + * @brief C-API wrapper for ``sycl::event::get_wait_list``. + * Returns a vector of events that this event still waits for. + * + * @param ERef Opaque pointer to a ``sycl::event``. + * @return A DPCTLEventVectorRef of DPCTLSyclEventRef objects. + * @ingroup EventInterface + */ +DPCTL_API +__dpctl_give DPCTLEventVectorRef +DPCTLEvent_GetWaitList(__dpctl_keep DPCTLSyclEventRef ERef); + DPCTL_C_EXTERN_C_END diff --git a/dpctl-capi/source/dpctl_sycl_event_interface.cpp b/dpctl-capi/source/dpctl_sycl_event_interface.cpp index c7800eeda7..7e09e60f48 100644 --- a/dpctl-capi/source/dpctl_sycl_event_interface.cpp +++ b/dpctl-capi/source/dpctl_sycl_event_interface.cpp @@ -37,6 +37,11 @@ namespace DEFINE_SIMPLE_CONVERSION_FUNCTIONS(event, DPCTLSyclEventRef) } /* end of anonymous namespace */ +#undef EL +#define EL Event +#include "dpctl_vector_templ.cpp" +#undef EL + __dpctl_give DPCTLSyclEventRef DPCTLEvent_Create() { DPCTLSyclEventRef ERef = nullptr; @@ -182,3 +187,39 @@ uint64_t DPCTLEvent_GetProfilingInfoEnd(__dpctl_keep DPCTLSyclEventRef ERef) } return profilingInfoEnd; } + +__dpctl_give DPCTLEventVectorRef +DPCTLEvent_GetWaitList(__dpctl_keep DPCTLSyclEventRef ERef) +{ + auto E = unwrap(ERef); + if (!E) { + std::cerr << "Cannot get wait list as input is a nullptr\n"; + return nullptr; + } + vector_class *EventsVectorPtr = nullptr; + try { + EventsVectorPtr = new vector_class(); + } catch (std::bad_alloc const &ba) { + // \todo log error + std::cerr << ba.what() << '\n'; + return nullptr; + } + try { + auto Events = E->get_wait_list(); + EventsVectorPtr->reserve(Events.size()); + for (const auto &Ev : Events) { + EventsVectorPtr->emplace_back(wrap(new event(Ev))); + } + return wrap(EventsVectorPtr); + } catch (std::bad_alloc const &ba) { + delete EventsVectorPtr; + // \todo log error + std::cerr << ba.what() << '\n'; + return nullptr; + } catch (const runtime_error &re) { + delete EventsVectorPtr; + // \todo log error + std::cerr << re.what() << '\n'; + return nullptr; + } +} diff --git a/dpctl-capi/tests/test_sycl_event_interface.cpp b/dpctl-capi/tests/test_sycl_event_interface.cpp index 7c40913fe1..a41f2dc31b 100644 --- a/dpctl-capi/tests/test_sycl_event_interface.cpp +++ b/dpctl-capi/tests/test_sycl_event_interface.cpp @@ -26,6 +26,7 @@ #include "Support/CBindingWrapping.h" #include "dpctl_sycl_event_interface.h" +#include "dpctl_sycl_types.h" #include #include @@ -33,7 +34,31 @@ using namespace cl::sycl; namespace { -DEFINE_SIMPLE_CONVERSION_FUNCTIONS(event, DPCTLSyclEventRef) +DEFINE_SIMPLE_CONVERSION_FUNCTIONS(event, DPCTLSyclEventRef); +DEFINE_SIMPLE_CONVERSION_FUNCTIONS(vector_class, + DPCTLEventVectorRef); + +#ifndef DPCTL_COVERAGE +sycl::event produce_event(sycl::queue &Q, sycl::buffer &data) +{ + int N = data.get_range()[0]; + + auto e1 = Q.submit([&](sycl::handler &h) { + sycl::accessor a{data, h, sycl::write_only, sycl::noinit}; + h.parallel_for(N, [=](sycl::id<1> i) { a[i] = 1; }); + }); + + auto e2 = Q.submit([&](sycl::handler &h) { + sycl::accessor a{data, h}; + h.single_task([=]() { + for (int i = 1; i < N; i++) + a[0] += a[i]; + }); + }); + + return e2; +} +#endif } // namespace struct TestDPCTLSyclEventInterface : public ::testing::Test @@ -163,3 +188,28 @@ TEST_F(TestDPCTLSyclEventInterface, CheckGetProfiling_Invalid) EXPECT_FALSE(eEnd); EXPECT_FALSE(eSubmit); } + +TEST_F(TestDPCTLSyclEventInterface, CheckGetWaitList) +{ + DPCTLEventVectorRef EVRef = nullptr; + EXPECT_NO_FATAL_FAILURE(EVRef = DPCTLEvent_GetWaitList(ERef)); + ASSERT_TRUE(EVRef); + EXPECT_NO_FATAL_FAILURE(DPCTLEventVector_Clear(EVRef)); + EXPECT_NO_FATAL_FAILURE(DPCTLEventVector_Delete(EVRef)); +} + +#ifndef DPCTL_COVERAGE +TEST_F(TestDPCTLSyclEventInterface, CheckGetWaitListSYCL) +{ + sycl::queue q; + sycl::buffer data{42}; + sycl::event eD; + DPCTLEventVectorRef EVRef = nullptr; + + EXPECT_NO_FATAL_FAILURE(eD = produce_event(q, data)); + DPCTLSyclEventRef ERef = reinterpret_cast(&eD); + EXPECT_NO_FATAL_FAILURE(EVRef = DPCTLEvent_GetWaitList(ERef)); + ASSERT_TRUE(DPCTLEventVector_Size(EVRef) > 0); + DPCTLEventVector_Delete(EVRef); +} +#endif