Skip to content

Commit 4a2daa8

Browse files
Added special methods to usm_ndarray implemented to raise NotImplementedError
1 parent ff33cb6 commit 4a2daa8

File tree

2 files changed

+136
-33
lines changed

2 files changed

+136
-33
lines changed

dpctl/tensor/_usmarray.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ cdef api class usm_ndarray [object PyUSMArrayObject, type PyUSMArrayType]:
3636
cdef int typenum_
3737
cdef int flags_
3838
cdef object base_
39+
cdef object array_namespace_
3940
# make usm_ndarray weak-referenceable
4041
cdef object __weakref__
4142

dpctl/tensor/_usmarray.pyx

Lines changed: 135 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ cdef class usm_ndarray:
5555
usm_ndarray(
5656
shape, dtype="|f8", strides=None, buffer='device',
5757
offset=0, order='C',
58-
buffer_ctor_kwargs=dict()
58+
buffer_ctor_kwargs=dict(),
59+
array_namespace=None
5960
)
6061
6162
See :class:`dpctl.memory.MemoryUSMShared` for allowed
@@ -76,6 +77,7 @@ cdef class usm_ndarray:
7677
Initializes member fields
7778
"""
7879
self.base_ = None
80+
self.array_namespace_ = None
7981
self.nd_ = -1
8082
self.data_ = <char *>0
8183
self.shape_ = <Py_ssize_t *>0
@@ -106,13 +108,16 @@ cdef class usm_ndarray:
106108
order=('C' if (self.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'F')
107109
)
108110
res.flags_ = self.flags_
111+
res.array_namespace_ = self.array_namespace_
109112
if (res.data_ != self.data_):
110113
raise InternalUSMArrayError(
111114
"Data pointers of cloned and original objects are different.")
112115
return res
113116

114117
def __cinit__(self, shape, dtype="|f8", strides=None, buffer='device',
115-
Py_ssize_t offset=0, order='C', buffer_ctor_kwargs=dict()):
118+
Py_ssize_t offset=0, order='C',
119+
buffer_ctor_kwargs=dict(),
120+
array_namespace=None):
116121
"""
117122
strides and offset must be given in units of array elements.
118123
buffer can be strings ('device'|'shared'|'host' to allocate new memory)
@@ -208,6 +213,7 @@ cdef class usm_ndarray:
208213
self.typenum_ = typenum
209214
self.flags_ = contig_flag
210215
self.nd_ = nd
216+
self.array_namespace_ = array_namespace
211217

212218
def __dealloc__(self):
213219
self._cleanup()
@@ -489,8 +495,53 @@ cdef class usm_ndarray:
489495
offset=_meta[2]
490496
)
491497
res.flags_ |= (self.flags_ & USM_ARRAY_WRITEABLE)
498+
res.array_namespace_ = self.array_namespace_
492499
return res
493500

501+
def to_device(self, target_device):
502+
"""
503+
Transfer array to target device
504+
"""
505+
d = Device.create_device(target_device)
506+
if (d.sycl_device == self.sycl_device):
507+
return self
508+
elif (d.sycl_context == self.sycl_context):
509+
res = usm_ndarray(
510+
self.shape,
511+
self.dtype,
512+
buffer=self.usm_data,
513+
strides=self.strides,
514+
offset=self.get_offset()
515+
)
516+
res.flags_ = self.flags
517+
return res
518+
else:
519+
nbytes = self.usm_data.nbytes
520+
new_buffer = type(self.usm_data)(
521+
nbytes, queue=d.sycl_queue
522+
)
523+
new_buffer.copy_from_device(self.usm_data)
524+
res = usm_ndarray(
525+
self.shape,
526+
self.dtype,
527+
buffer=new_buffer,
528+
strides=self.strides,
529+
offset=self.get_offset()
530+
)
531+
res.flags_ = self.flags
532+
return res
533+
534+
def _set_namespace(self, mod):
535+
""" Sets array namespace to given module `mod`. """
536+
self.array_namespace_ = mod
537+
538+
def __array_namespace__(self, api_version=None):
539+
"""
540+
Returns array namespace, member functions of which
541+
implement data API.
542+
"""
543+
return self.array_namespace_
544+
494545
def __bool__(self):
495546
if self.size == 1:
496547
mem_view = dpmem.as_usm_memory(self)
@@ -539,38 +590,89 @@ cdef class usm_ndarray:
539590

540591
raise IndexError("only integer arrays are valid indices")
541592

542-
def to_device(self, target_device):
543-
"""
544-
Transfer array to target device
545-
"""
546-
d = Device.create_device(target_device)
547-
if (d.sycl_device == self.sycl_device):
548-
return self
549-
elif (d.sycl_context == self.sycl_context):
550-
res = usm_ndarray(
551-
self.shape,
552-
self.dtype,
553-
buffer=self.usm_data,
554-
strides=self.strides,
555-
offset=self.get_offset()
556-
)
557-
res.flags_ = self.flags
558-
return res
593+
def __abs__(self):
594+
return NotImplemented
595+
596+
def __add__(self, other):
597+
return NotImplemented
598+
599+
def __and__(self, other):
600+
return NotImplemented
601+
602+
def __dlpack__(self, stream=None):
603+
return NotImplemented
604+
605+
def __dlpack_device__(self):
606+
return NotImplemented
607+
608+
def __eq__(self, other):
609+
return NotImplemented
610+
611+
def __floordiv__(self, other):
612+
return NotImplemented
613+
614+
def __ge__(self, other):
615+
return NotImplemented
616+
617+
def __gt__(self, other):
618+
return NotImplemented
619+
620+
def __invert__(self):
621+
return NotImplemented
622+
623+
def __le__(self, other):
624+
return NotImplemented
625+
626+
def __len__(self):
627+
if (self.nd_):
628+
return self.shape[0]
559629
else:
560-
nbytes = self.usm_data.nbytes
561-
new_buffer = type(self.usm_data)(
562-
nbytes, queue=d.sycl_queue
563-
)
564-
new_buffer.copy_from_device(self.usm_data)
565-
res = usm_ndarray(
566-
self.shape,
567-
self.dtype,
568-
buffer=new_buffer,
569-
strides=self.strides,
570-
offset=self.get_offset()
571-
)
572-
res.flags_ = self.flags
573-
return res
630+
raise TypeError("len() of unsized object")
631+
632+
def __lshift__(self, other):
633+
return NotImplemented
634+
635+
def __lt__(self, other):
636+
return NotImplemented
637+
638+
def __matmul__(self, other):
639+
return NotImplemented
640+
641+
def __mod__(self, other):
642+
return NotImplemented
643+
644+
def __mult__(self, other):
645+
return NotImplemented
646+
647+
def __ne__(self, other):
648+
return NotImplemented
649+
650+
def __neg__(self):
651+
return NotImplemented
652+
653+
def __or__(self, other):
654+
return NotImplemented
655+
656+
def __pos__(self):
657+
return NotImplemented
658+
659+
def __pow__(self, other, mod):
660+
return NotImplemented
661+
662+
def __rshift__(self, other):
663+
return NotImplemented
664+
665+
def __setitem__(self, key, val):
666+
raise NotImplementedError
667+
668+
def __sub__(self, other):
669+
return NotImplemented
670+
671+
def __truediv__(self, other):
672+
return NotImplemented
673+
674+
def __xor__(self, other):
675+
return NotImplemented
574676

575677

576678
cdef usm_ndarray _real_view(usm_ndarray ary):

0 commit comments

Comments
 (0)