diff --git a/arrayfire/array.py b/arrayfire/array.py index 57f557485..d9674dbc6 100644 --- a/arrayfire/array.py +++ b/arrayfire/array.py @@ -54,6 +54,9 @@ def _create_strided_array(buf, numdims, idims, dtype, is_device, offset, strides def _create_empty_array(numdims, idims, dtype): out_arr = c_void_ptr_t(0) + + if numdims == 0: return out_arr + c_dims = dim4(idims[0], idims[1], idims[2], idims[3]) safe_call(backend.get().af_create_handle(c_pointer(out_arr), numdims, c_pointer(c_dims), dtype.value)) @@ -382,7 +385,7 @@ class Array(BaseArray): # arrayfire's __radd__() instead of numpy's __add__() __array_priority__ = 30 - def __init__(self, src=None, dims=(0,), dtype=None, is_device=False, offset=None, strides=None): + def __init__(self, src=None, dims=None, dtype=None, is_device=False, offset=None, strides=None): super(Array, self).__init__() @@ -449,10 +452,12 @@ def __init__(self, src=None, dims=(0,), dtype=None, is_device=False, offset=None if type_char is None: type_char = 'f' - numdims = len(dims) + numdims = len(dims) if dims else 0 + idims = [1] * 4 for n in range(numdims): idims[n] = dims[n] + self.arr = _create_empty_array(numdims, idims, to_dtype[type_char]) def as_type(self, ty):