Skip to content

Commit 043481b

Browse files
committed
change dlpack logic so shape and strides are null iff ndim == 0
1 parent 68c2167 commit 043481b

File tree

1 file changed

+56
-74
lines changed

1 file changed

+56
-74
lines changed

dpctl/tensor/_dlpack.pyx

Lines changed: 56 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ from .._backend cimport (
3636
DPCTLSyclDeviceRef,
3737
DPCTLSyclUSMRef,
3838
)
39-
from ._usmarray cimport USM_ARRAY_C_CONTIGUOUS, USM_ARRAY_WRITABLE, usm_ndarray
39+
from ._usmarray cimport USM_ARRAY_WRITABLE, usm_ndarray
4040

4141
import ctypes
4242

@@ -266,7 +266,6 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary):
266266
cdef int64_t *shape_strides_ptr = NULL
267267
cdef int i = 0
268268
cdef int device_id = -1
269-
cdef int flags = 0
270269
cdef Py_ssize_t element_offset = 0
271270
cdef Py_ssize_t byte_offset = 0
272271
cdef Py_ssize_t si = 1
@@ -281,22 +280,21 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary):
281280
raise MemoryError(
282281
"to_dlpack_capsule: Could not allocate memory for DLManagedTensor"
283282
)
284-
shape_strides_ptr = <int64_t *>stdlib.malloc((sizeof(int64_t) * 2) * nd)
285-
if shape_strides_ptr is NULL:
286-
stdlib.free(dlm_tensor)
287-
raise MemoryError(
288-
"to_dlpack_capsule: Could not allocate memory for shape/strides"
289-
)
290-
shape_ptr = usm_ary.get_shape()
291-
for i in range(nd):
292-
shape_strides_ptr[i] = shape_ptr[i]
293-
strides_ptr = usm_ary.get_strides()
294-
flags = usm_ary.flags_
295-
if strides_ptr:
283+
if nd > 0:
284+
shape_strides_ptr = <int64_t *>stdlib.malloc((sizeof(int64_t) * 2) * nd)
285+
if shape_strides_ptr is NULL:
286+
stdlib.free(dlm_tensor)
287+
raise MemoryError(
288+
"to_dlpack_capsule: Could not allocate memory for shape/strides"
289+
)
290+
shape_ptr = usm_ary.get_shape()
296291
for i in range(nd):
297-
shape_strides_ptr[nd + i] = strides_ptr[i]
298-
else:
299-
if not (flags & USM_ARRAY_C_CONTIGUOUS):
292+
shape_strides_ptr[i] = shape_ptr[i]
293+
strides_ptr = usm_ary.get_strides()
294+
if strides_ptr:
295+
for i in range(nd):
296+
shape_strides_ptr[nd + i] = strides_ptr[i]
297+
else:
300298
si = 1
301299
for i in range(0, nd):
302300
shape_strides_ptr[nd + i] = si
@@ -312,11 +310,8 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary):
312310
dl_tensor.data = <void*>(data_ptr - byte_offset)
313311
dl_tensor.ndim = nd
314312
dl_tensor.byte_offset = <uint64_t>byte_offset
315-
dl_tensor.shape = &shape_strides_ptr[0]
316-
if strides_ptr is NULL:
317-
dl_tensor.strides = NULL
318-
else:
319-
dl_tensor.strides = &shape_strides_ptr[nd]
313+
dl_tensor.shape = &shape_strides_ptr[0] if nd > 0 else NULL
314+
dl_tensor.strides = &shape_strides_ptr[nd] if nd > 0 else NULL
320315
dl_tensor.device.device_type = kDLOneAPI
321316
dl_tensor.device.device_id = device_id
322317
dl_tensor.dtype.lanes = <uint16_t>1
@@ -396,24 +391,24 @@ cpdef to_dlpack_versioned_capsule(usm_ndarray usm_ary, bint copied):
396391
"to_dlpack_versioned_capsule: Could not allocate memory "
397392
"for DLManagedTensorVersioned"
398393
)
399-
shape_strides_ptr = <int64_t *>stdlib.malloc((sizeof(int64_t) * 2) * nd)
400-
if shape_strides_ptr is NULL:
401-
stdlib.free(dlmv_tensor)
402-
raise MemoryError(
403-
"to_dlpack_versioned_capsule: Could not allocate memory "
404-
"for shape/strides"
405-
)
406-
# this can be a separate function for handling shapes and strides
407-
shape_ptr = usm_ary.get_shape()
408-
for i in range(nd):
409-
shape_strides_ptr[i] = shape_ptr[i]
410-
strides_ptr = usm_ary.get_strides()
411-
flags = usm_ary.flags_
412-
if strides_ptr:
394+
if nd > 0:
395+
shape_strides_ptr = <int64_t *>stdlib.malloc((sizeof(int64_t) * 2) * nd)
396+
if shape_strides_ptr is NULL:
397+
stdlib.free(dlmv_tensor)
398+
raise MemoryError(
399+
"to_dlpack_versioned_capsule: Could not allocate memory "
400+
"for shape/strides"
401+
)
402+
# this can be a separate function for handling shapes and strides
403+
shape_ptr = usm_ary.get_shape()
413404
for i in range(nd):
414-
shape_strides_ptr[nd + i] = strides_ptr[i]
415-
else:
416-
if not (flags & USM_ARRAY_C_CONTIGUOUS):
405+
shape_strides_ptr[i] = shape_ptr[i]
406+
strides_ptr = usm_ary.get_strides()
407+
flags = usm_ary.flags_
408+
if strides_ptr:
409+
for i in range(nd):
410+
shape_strides_ptr[nd + i] = strides_ptr[i]
411+
else:
417412
si = 1
418413
for i in range(0, nd):
419414
shape_strides_ptr[nd + i] = si
@@ -431,11 +426,8 @@ cpdef to_dlpack_versioned_capsule(usm_ndarray usm_ary, bint copied):
431426
dl_tensor.data = <void*>(data_ptr - byte_offset)
432427
dl_tensor.ndim = nd
433428
dl_tensor.byte_offset = <uint64_t>byte_offset
434-
dl_tensor.shape = &shape_strides_ptr[0]
435-
if strides_ptr is NULL:
436-
dl_tensor.strides = NULL
437-
else:
438-
dl_tensor.strides = &shape_strides_ptr[nd]
429+
dl_tensor.shape = &shape_strides_ptr[0] if nd > 0 else NULL
430+
dl_tensor.strides = &shape_strides_ptr[nd] if nd > 0 else NULL
439431
dl_tensor.device.device_type = kDLOneAPI
440432
dl_tensor.device.device_id = device_id
441433
dl_tensor.dtype.lanes = <uint16_t>1
@@ -515,10 +507,9 @@ cpdef numpy_to_dlpack_versioned_capsule(ndarray npy_ary, bint copied):
515507
"for DLManagedTensorVersioned"
516508
)
517509

518-
is_c_contiguous = npy_ary.flags["C"]
519510
shape = npy_ary.ctypes.shape_as(ctypes.c_int64)
520511
strides = npy_ary.ctypes.strides_as(ctypes.c_int64)
521-
if not is_c_contiguous:
512+
if nd > 0:
522513
if npy_ary.size != 1:
523514
for i in range(nd):
524515
if shape[i] != 1 and strides[i] % itemsize != 0:
@@ -529,18 +520,14 @@ cpdef numpy_to_dlpack_versioned_capsule(ndarray npy_ary, bint copied):
529520
"itemsize"
530521
)
531522
shape_strides_ptr = <int64_t *>stdlib.malloc((sizeof(int64_t) * 2) * nd)
532-
else:
533-
# no need to pass strides in this case
534-
shape_strides_ptr = <int64_t *>stdlib.malloc(sizeof(int64_t) * nd)
535-
if shape_strides_ptr is NULL:
536-
stdlib.free(dlmv_tensor)
537-
raise MemoryError(
538-
"numpy_to_dlpack_versioned_capsule: Could not allocate memory "
539-
"for shape/strides"
540-
)
541-
for i in range(nd):
542-
shape_strides_ptr[i] = shape[i]
543-
if not is_c_contiguous:
523+
if shape_strides_ptr is NULL:
524+
stdlib.free(dlmv_tensor)
525+
raise MemoryError(
526+
"numpy_to_dlpack_versioned_capsule: Could not allocate memory "
527+
"for shape/strides"
528+
)
529+
for i in range(nd):
530+
shape_strides_ptr[i] = shape[i]
544531
shape_strides_ptr[nd + i] = strides[i] // itemsize
545532

546533
writable_flag = npy_ary.flags["W"]
@@ -552,11 +539,8 @@ cpdef numpy_to_dlpack_versioned_capsule(ndarray npy_ary, bint copied):
552539
dl_tensor.data = <void *> npy_ary.data
553540
dl_tensor.ndim = nd
554541
dl_tensor.byte_offset = <uint64_t>byte_offset
555-
dl_tensor.shape = &shape_strides_ptr[0]
556-
if is_c_contiguous:
557-
dl_tensor.strides = NULL
558-
else:
559-
dl_tensor.strides = &shape_strides_ptr[nd]
542+
dl_tensor.shape = &shape_strides_ptr[0] if nd > 0 else NULL
543+
dl_tensor.strides = &shape_strides_ptr[nd] if nd > 0 else NULL
560544
dl_tensor.device.device_type = kDLCPU
561545
dl_tensor.device.device_id = 0
562546
dl_tensor.dtype.lanes = <uint16_t>1
@@ -828,12 +812,8 @@ cpdef object from_dlpack_capsule(object py_caps):
828812
raise BufferError(
829813
"Can not import DLPack tensor with lanes != 1"
830814
)
831-
offset_min = 0
832-
if dl_tensor.strides is NULL:
833-
for i in range(dl_tensor.ndim):
834-
sz = sz * dl_tensor.shape[i]
835-
offset_max = sz - 1
836-
else:
815+
if dl_tensor.ndim > 0:
816+
offset_min = 0
837817
offset_max = 0
838818
for i in range(dl_tensor.ndim):
839819
stride_i = dl_tensor.strides[i]
@@ -888,15 +868,17 @@ cpdef object from_dlpack_capsule(object py_caps):
888868
(<c_dpctl.SyclQueue>q).get_queue_ref(),
889869
memory_owner=tmp
890870
)
871+
891872
py_shape = list()
892-
for i in range(dl_tensor.ndim):
893-
py_shape.append(dl_tensor.shape[i])
894-
if (dl_tensor.strides is NULL):
895-
py_strides = None
896-
else:
873+
if (dl_tensor.shape is not NULL):
874+
for i in range(dl_tensor.ndim):
875+
py_shape.append(dl_tensor.shape[i])
876+
if (dl_tensor.strides is not NULL):
897877
py_strides = list()
898878
for i in range(dl_tensor.ndim):
899879
py_strides.append(dl_tensor.strides[i])
880+
else:
881+
py_strides = None
900882
if (dl_tensor.dtype.code == kDLUInt):
901883
ary_dt = np.dtype("u" + str(element_bytesize))
902884
elif (dl_tensor.dtype.code == kDLInt):

0 commit comments

Comments
 (0)