Skip to content

Commit 81fc629

Browse files
Add a static method for SyclEventRaw class
1 parent 2fe1384 commit 81fc629

File tree

4 files changed

+35
-4
lines changed

4 files changed

+35
-4
lines changed

dpctl/_backend.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ cdef extern from "dpctl_sycl_event_interface.h":
226226
cdef DPCTLSyclEventRef DPCTLEvent_Create()
227227
cdef DPCTLSyclEventRef DPCTLEvent_Copy(const DPCTLSyclEventRef ERef)
228228
cdef void DPCTLEvent_Wait(DPCTLSyclEventRef ERef)
229+
cdef void DPCTLEvent_WaitAndThrow(DPCTLSyclEventRef ERef)
229230
cdef void DPCTLEvent_Delete(DPCTLSyclEventRef ERef)
230231
cdef _event_status_type DPCTLEvent_GetCommandExecutionStatus(DPCTLSyclEventRef ERef)
231232
cdef _backend_type DPCTLEvent_GetBackend(DPCTLSyclEventRef ERef)

dpctl/_sycl_event.pxd

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,5 +46,5 @@ cdef public class SyclEventRaw(_SyclEventRaw) [object PySyclEventRawObject, type
4646
cdef int _init_event_from__SyclEventRaw(self, _SyclEventRaw other)
4747
cdef int _init_event_from_SyclEvent(self, SyclEvent event)
4848
cdef int _init_event_from_capsule(self, object caps)
49-
cdef DPCTLSyclEventRef get_event_ref (self)
50-
cpdef void wait (self)
49+
cdef DPCTLSyclEventRef get_event_ref (self)
50+
cdef void _wait (SyclEventRaw event)

dpctl/_sycl_event.pyx

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ from ._backend cimport ( # noqa: E211
3333
DPCTLEvent_GetCommandExecutionStatus,
3434
DPCTLEvent_GetWaitList,
3535
DPCTLEvent_Wait,
36+
DPCTLEvent_WaitAndThrow,
3637
DPCTLEventVector_Delete,
3738
DPCTLEventVector_GetAt,
3839
DPCTLEventVector_Size,
@@ -188,8 +189,20 @@ cdef class SyclEventRaw(_SyclEventRaw):
188189
"""
189190
return self._event_ref
190191

191-
cpdef void wait(self):
192-
DPCTLEvent_Wait(self._event_ref)
192+
@staticmethod
193+
cdef void _wait(SyclEventRaw event):
194+
DPCTLEvent_WaitAndThrow(event._event_ref)
195+
196+
@staticmethod
197+
def wait(event):
198+
if isinstance(event, list):
199+
for e in event:
200+
SyclEventRaw._wait(e)
201+
elif isinstance(event, SyclEventRaw):
202+
SyclEventRaw._wait(event)
203+
else:
204+
raise ValueError("The passed argument is not a list \
205+
or a SyclEventRaw type.")
193206

194207
def addressof_ref(self):
195208
""" Returns the address of the C API `DPCTLSyclEventRef` pointer as

dpctl/tests/test_sycl_event.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,23 @@ def test_create_event_raw_from_capsule():
7373
pytest.fail("Failed to create an event from capsule")
7474

7575

76+
def test_wait_with_event():
77+
event = dpctl.SyclEventRaw()
78+
try:
79+
dpctl.SyclEventRaw.wait(event)
80+
except ValueError:
81+
pytest.fail("Failed to wait for the event")
82+
83+
84+
def test_wait_with_list():
85+
event_1 = dpctl.SyclEventRaw()
86+
event_2 = dpctl.SyclEventRaw()
87+
try:
88+
dpctl.SyclEventRaw.wait([event_1, event_2])
89+
except ValueError:
90+
pytest.fail("Failed to wait for events from the list")
91+
92+
7693
def test_execution_status():
7794
event = dpctl.SyclEventRaw()
7895
try:

0 commit comments

Comments
 (0)