@@ -36,7 +36,7 @@ from .._backend cimport (
3636 DPCTLSyclDeviceRef,
3737 DPCTLSyclUSMRef,
3838)
39- from ._usmarray cimport USM_ARRAY_C_CONTIGUOUS, USM_ARRAY_WRITABLE, usm_ndarray
39+ from ._usmarray cimport USM_ARRAY_WRITABLE, usm_ndarray
4040
4141import ctypes
4242
@@ -266,7 +266,6 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary):
266266 cdef int64_t * shape_strides_ptr = NULL
267267 cdef int i = 0
268268 cdef int device_id = - 1
269- cdef int flags = 0
270269 cdef Py_ssize_t element_offset = 0
271270 cdef Py_ssize_t byte_offset = 0
272271 cdef Py_ssize_t si = 1
@@ -281,22 +280,21 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary):
281280 raise MemoryError (
282281 " to_dlpack_capsule: Could not allocate memory for DLManagedTensor"
283282 )
284- shape_strides_ptr = < int64_t * > stdlib.malloc((sizeof(int64_t) * 2 ) * nd)
285- if shape_strides_ptr is NULL :
286- stdlib.free(dlm_tensor)
287- raise MemoryError (
288- " to_dlpack_capsule: Could not allocate memory for shape/strides"
289- )
290- shape_ptr = usm_ary.get_shape()
291- for i in range (nd):
292- shape_strides_ptr[i] = shape_ptr[i]
293- strides_ptr = usm_ary.get_strides()
294- flags = usm_ary.flags_
295- if strides_ptr:
283+ if nd > 0 :
284+ shape_strides_ptr = < int64_t * > stdlib.malloc((sizeof(int64_t) * 2 ) * nd)
285+ if shape_strides_ptr is NULL :
286+ stdlib.free(dlm_tensor)
287+ raise MemoryError (
288+ " to_dlpack_capsule: Could not allocate memory for shape/strides"
289+ )
290+ shape_ptr = usm_ary.get_shape()
296291 for i in range (nd):
297- shape_strides_ptr[nd + i] = strides_ptr[i]
298- else :
299- if not (flags & USM_ARRAY_C_CONTIGUOUS):
292+ shape_strides_ptr[i] = shape_ptr[i]
293+ strides_ptr = usm_ary.get_strides()
294+ if strides_ptr:
295+ for i in range (nd):
296+ shape_strides_ptr[nd + i] = strides_ptr[i]
297+ else :
300298 si = 1
301299 for i in range (0 , nd):
302300 shape_strides_ptr[nd + i] = si
@@ -312,11 +310,8 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary):
312310 dl_tensor.data = < void * > (data_ptr - byte_offset)
313311 dl_tensor.ndim = nd
314312 dl_tensor.byte_offset = < uint64_t> byte_offset
315- dl_tensor.shape = & shape_strides_ptr[0 ]
316- if strides_ptr is NULL :
317- dl_tensor.strides = NULL
318- else :
319- dl_tensor.strides = & shape_strides_ptr[nd]
313+ dl_tensor.shape = & shape_strides_ptr[0 ] if nd > 0 else NULL
314+ dl_tensor.strides = & shape_strides_ptr[nd] if nd > 0 else NULL
320315 dl_tensor.device.device_type = kDLOneAPI
321316 dl_tensor.device.device_id = device_id
322317 dl_tensor.dtype.lanes = < uint16_t> 1
@@ -396,24 +391,24 @@ cpdef to_dlpack_versioned_capsule(usm_ndarray usm_ary, bint copied):
396391 " to_dlpack_versioned_capsule: Could not allocate memory "
397392 " for DLManagedTensorVersioned"
398393 )
399- shape_strides_ptr = < int64_t * > stdlib.malloc((sizeof(int64_t) * 2 ) * nd)
400- if shape_strides_ptr is NULL :
401- stdlib.free(dlmv_tensor)
402- raise MemoryError (
403- " to_dlpack_versioned_capsule: Could not allocate memory "
404- " for shape/strides"
405- )
406- # this can be a separate function for handling shapes and strides
407- shape_ptr = usm_ary.get_shape()
408- for i in range (nd):
409- shape_strides_ptr[i] = shape_ptr[i]
410- strides_ptr = usm_ary.get_strides()
411- flags = usm_ary.flags_
412- if strides_ptr:
394+ if nd > 0 :
395+ shape_strides_ptr = < int64_t * > stdlib.malloc((sizeof(int64_t) * 2 ) * nd)
396+ if shape_strides_ptr is NULL :
397+ stdlib.free(dlmv_tensor)
398+ raise MemoryError (
399+ " to_dlpack_versioned_capsule: Could not allocate memory "
400+ " for shape/strides"
401+ )
402+ # this can be a separate function for handling shapes and strides
403+ shape_ptr = usm_ary.get_shape()
413404 for i in range (nd):
414- shape_strides_ptr[nd + i] = strides_ptr[i]
415- else :
416- if not (flags & USM_ARRAY_C_CONTIGUOUS):
405+ shape_strides_ptr[i] = shape_ptr[i]
406+ strides_ptr = usm_ary.get_strides()
407+ flags = usm_ary.flags_
408+ if strides_ptr:
409+ for i in range (nd):
410+ shape_strides_ptr[nd + i] = strides_ptr[i]
411+ else :
417412 si = 1
418413 for i in range (0 , nd):
419414 shape_strides_ptr[nd + i] = si
@@ -431,11 +426,8 @@ cpdef to_dlpack_versioned_capsule(usm_ndarray usm_ary, bint copied):
431426 dl_tensor.data = < void * > (data_ptr - byte_offset)
432427 dl_tensor.ndim = nd
433428 dl_tensor.byte_offset = < uint64_t> byte_offset
434- dl_tensor.shape = & shape_strides_ptr[0 ]
435- if strides_ptr is NULL :
436- dl_tensor.strides = NULL
437- else :
438- dl_tensor.strides = & shape_strides_ptr[nd]
429+ dl_tensor.shape = & shape_strides_ptr[0 ] if nd > 0 else NULL
430+ dl_tensor.strides = & shape_strides_ptr[nd] if nd > 0 else NULL
439431 dl_tensor.device.device_type = kDLOneAPI
440432 dl_tensor.device.device_id = device_id
441433 dl_tensor.dtype.lanes = < uint16_t> 1
@@ -515,10 +507,9 @@ cpdef numpy_to_dlpack_versioned_capsule(ndarray npy_ary, bint copied):
515507 " for DLManagedTensorVersioned"
516508 )
517509
518- is_c_contiguous = npy_ary.flags[" C" ]
519510 shape = npy_ary.ctypes.shape_as(ctypes.c_int64)
520511 strides = npy_ary.ctypes.strides_as(ctypes.c_int64)
521- if not is_c_contiguous :
512+ if nd > 0 :
522513 if npy_ary.size != 1 :
523514 for i in range (nd):
524515 if shape[i] != 1 and strides[i] % itemsize != 0 :
@@ -529,18 +520,14 @@ cpdef numpy_to_dlpack_versioned_capsule(ndarray npy_ary, bint copied):
529520 " itemsize"
530521 )
531522 shape_strides_ptr = < int64_t * > stdlib.malloc((sizeof(int64_t) * 2 ) * nd)
532- else :
533- # no need to pass strides in this case
534- shape_strides_ptr = < int64_t * > stdlib.malloc(sizeof(int64_t) * nd)
535- if shape_strides_ptr is NULL :
536- stdlib.free(dlmv_tensor)
537- raise MemoryError (
538- " numpy_to_dlpack_versioned_capsule: Could not allocate memory "
539- " for shape/strides"
540- )
541- for i in range (nd):
542- shape_strides_ptr[i] = shape[i]
543- if not is_c_contiguous:
523+ if shape_strides_ptr is NULL :
524+ stdlib.free(dlmv_tensor)
525+ raise MemoryError (
526+ " numpy_to_dlpack_versioned_capsule: Could not allocate memory "
527+ " for shape/strides"
528+ )
529+ for i in range (nd):
530+ shape_strides_ptr[i] = shape[i]
544531 shape_strides_ptr[nd + i] = strides[i] // itemsize
545532
546533 writable_flag = npy_ary.flags[" W" ]
@@ -552,11 +539,8 @@ cpdef numpy_to_dlpack_versioned_capsule(ndarray npy_ary, bint copied):
552539 dl_tensor.data = < void * > npy_ary.data
553540 dl_tensor.ndim = nd
554541 dl_tensor.byte_offset = < uint64_t> byte_offset
555- dl_tensor.shape = & shape_strides_ptr[0 ]
556- if is_c_contiguous:
557- dl_tensor.strides = NULL
558- else :
559- dl_tensor.strides = & shape_strides_ptr[nd]
542+ dl_tensor.shape = & shape_strides_ptr[0 ] if nd > 0 else NULL
543+ dl_tensor.strides = & shape_strides_ptr[nd] if nd > 0 else NULL
560544 dl_tensor.device.device_type = kDLCPU
561545 dl_tensor.device.device_id = 0
562546 dl_tensor.dtype.lanes = < uint16_t> 1
@@ -828,12 +812,8 @@ cpdef object from_dlpack_capsule(object py_caps):
828812 raise BufferError(
829813 " Can not import DLPack tensor with lanes != 1"
830814 )
831- offset_min = 0
832- if dl_tensor.strides is NULL :
833- for i in range (dl_tensor.ndim):
834- sz = sz * dl_tensor.shape[i]
835- offset_max = sz - 1
836- else :
815+ if dl_tensor.ndim > 0 :
816+ offset_min = 0
837817 offset_max = 0
838818 for i in range (dl_tensor.ndim):
839819 stride_i = dl_tensor.strides[i]
@@ -888,15 +868,17 @@ cpdef object from_dlpack_capsule(object py_caps):
888868 (< c_dpctl.SyclQueue> q).get_queue_ref(),
889869 memory_owner = tmp
890870 )
871+
891872 py_shape = list ()
892- for i in range (dl_tensor.ndim):
893- py_shape.append(dl_tensor.shape[i])
894- if (dl_tensor.strides is NULL ):
895- py_strides = None
896- else :
873+ if (dl_tensor.shape is not NULL ):
874+ for i in range (dl_tensor.ndim):
875+ py_shape.append(dl_tensor.shape[i])
876+ if (dl_tensor.strides is not NULL ):
897877 py_strides = list ()
898878 for i in range (dl_tensor.ndim):
899879 py_strides.append(dl_tensor.strides[i])
880+ else :
881+ py_strides = None
900882 if (dl_tensor.dtype.code == kDLUInt):
901883 ary_dt = np.dtype(" u" + str (element_bytesize))
902884 elif (dl_tensor.dtype.code == kDLInt):
0 commit comments