@@ -55,7 +55,8 @@ cdef class usm_ndarray:
5555 usm_ndarray(
5656 shape, dtype="|f8", strides=None, buffer='device',
5757 offset=0, order='C',
58- buffer_ctor_kwargs=dict()
58+ buffer_ctor_kwargs=dict(),
59+ array_namespace=None
5960 )
6061
6162 See :class:`dpctl.memory.MemoryUSMShared` for allowed
@@ -76,6 +77,7 @@ cdef class usm_ndarray:
7677 Initializes member fields
7778 """
7879 self .base_ = None
80+ self .array_namespace_ = None
7981 self .nd_ = - 1
8082 self .data_ = < char * > 0
8183 self .shape_ = < Py_ssize_t * > 0
@@ -106,13 +108,16 @@ cdef class usm_ndarray:
106108 order = (' C' if (self .flags_ & USM_ARRAY_C_CONTIGUOUS) else ' F' )
107109 )
108110 res.flags_ = self .flags_
111+ res.array_namespace_ = self .array_namespace_
109112 if (res.data_ != self .data_):
110113 raise InternalUSMArrayError(
111114 " Data pointers of cloned and original objects are different." )
112115 return res
113116
114117 def __cinit__ (self , shape , dtype = " |f8" , strides = None , buffer = ' device' ,
115- Py_ssize_t offset = 0 , order = ' C' , buffer_ctor_kwargs = dict ()):
118+ Py_ssize_t offset = 0 , order = ' C' ,
119+ buffer_ctor_kwargs = dict (),
120+ array_namespace = None ):
116121 """
117122 strides and offset must be given in units of array elements.
118123 buffer can be strings ('device'|'shared'|'host' to allocate new memory)
@@ -208,6 +213,7 @@ cdef class usm_ndarray:
208213 self .typenum_ = typenum
209214 self .flags_ = contig_flag
210215 self .nd_ = nd
216+ self .array_namespace_ = array_namespace
211217
212218 def __dealloc__ (self ):
213219 self ._cleanup()
@@ -489,8 +495,53 @@ cdef class usm_ndarray:
489495 offset = _meta[2 ]
490496 )
491497 res.flags_ |= (self .flags_ & USM_ARRAY_WRITEABLE)
498+ res.array_namespace_ = self .array_namespace_
492499 return res
493500
501+ def to_device (self , target_device ):
502+ """
503+ Transfer array to target device
504+ """
505+ d = Device.create_device(target_device)
506+ if (d.sycl_device == self .sycl_device):
507+ return self
508+ elif (d.sycl_context == self .sycl_context):
509+ res = usm_ndarray(
510+ self .shape,
511+ self .dtype,
512+ buffer = self .usm_data,
513+ strides = self .strides,
514+ offset = self .get_offset()
515+ )
516+ res.flags_ = self .flags
517+ return res
518+ else :
519+ nbytes = self .usm_data.nbytes
520+ new_buffer = type (self .usm_data)(
521+ nbytes, queue = d.sycl_queue
522+ )
523+ new_buffer.copy_from_device(self .usm_data)
524+ res = usm_ndarray(
525+ self .shape,
526+ self .dtype,
527+ buffer = new_buffer,
528+ strides = self .strides,
529+ offset = self .get_offset()
530+ )
531+ res.flags_ = self .flags
532+ return res
533+
534+ def _set_namespace (self , mod ):
535+ """ Sets array namespace to given module `mod`. """
536+ self .array_namespace_ = mod
537+
538+ def __array_namespace__ (self , api_version = None ):
539+ """
540+ Returns array namespace, member functions of which
541+ implement data API.
542+ """
543+ return self .array_namespace_
544+
494545 def __bool__ (self ):
495546 if self .size == 1 :
496547 mem_view = dpmem.as_usm_memory(self )
@@ -539,38 +590,89 @@ cdef class usm_ndarray:
539590
540591 raise IndexError (" only integer arrays are valid indices" )
541592
542- def to_device (self , target_device ):
543- """
544- Transfer array to target device
545- """
546- d = Device.create_device(target_device)
547- if (d.sycl_device == self .sycl_device):
548- return self
549- elif (d.sycl_context == self .sycl_context):
550- res = usm_ndarray(
551- self .shape,
552- self .dtype,
553- buffer = self .usm_data,
554- strides = self .strides,
555- offset = self .get_offset()
556- )
557- res.flags_ = self .flags
558- return res
593+ def __abs__ (self ):
594+ return NotImplemented
595+
596+ def __add__ (self , other ):
597+ return NotImplemented
598+
599+ def __and__ (self , other ):
600+ return NotImplemented
601+
602+ def __dlpack__ (self , stream = None ):
603+ return NotImplemented
604+
605+ def __dlpack_device__ (self ):
606+ return NotImplemented
607+
608+ def __eq__ (self , other ):
609+ return NotImplemented
610+
611+ def __floordiv__ (self , other ):
612+ return NotImplemented
613+
614+ def __ge__ (self , other ):
615+ return NotImplemented
616+
617+ def __gt__ (self , other ):
618+ return NotImplemented
619+
620+ def __invert__ (self ):
621+ return NotImplemented
622+
623+ def __le__ (self , other ):
624+ return NotImplemented
625+
626+ def __len__ (self ):
627+ if (self .nd_):
628+ return self .shape[0 ]
559629 else :
560- nbytes = self .usm_data.nbytes
561- new_buffer = type (self .usm_data)(
562- nbytes, queue = d.sycl_queue
563- )
564- new_buffer.copy_from_device(self .usm_data)
565- res = usm_ndarray(
566- self .shape,
567- self .dtype,
568- buffer = new_buffer,
569- strides = self .strides,
570- offset = self .get_offset()
571- )
572- res.flags_ = self .flags
573- return res
630+ raise TypeError (" len() of unsized object" )
631+
632+ def __lshift__ (self , other ):
633+ return NotImplemented
634+
635+ def __lt__ (self , other ):
636+ return NotImplemented
637+
638+ def __matmul__ (self , other ):
639+ return NotImplemented
640+
641+ def __mod__ (self , other ):
642+ return NotImplemented
643+
644+ def __mult__ (self , other ):
645+ return NotImplemented
646+
647+ def __ne__ (self , other ):
648+ return NotImplemented
649+
650+ def __neg__ (self ):
651+ return NotImplemented
652+
653+ def __or__ (self , other ):
654+ return NotImplemented
655+
656+ def __pos__ (self ):
657+ return NotImplemented
658+
659+ def __pow__ (self , other , mod ):
660+ return NotImplemented
661+
662+ def __rshift__ (self , other ):
663+ return NotImplemented
664+
665+ def __setitem__ (self , key , val ):
666+ raise NotImplementedError
667+
668+ def __sub__ (self , other ):
669+ return NotImplemented
670+
671+ def __truediv__ (self , other ):
672+ return NotImplemented
673+
674+ def __xor__ (self , other ):
675+ return NotImplemented
574676
575677
576678cdef usm_ndarray _real_view(usm_ndarray ary):
0 commit comments