Skip to content

Commit

Permalink
properly handle the case of copy=False
Browse files Browse the repository at this point in the history
  • Loading branch information
jorisvandenbossche committed Apr 9, 2024
1 parent 2f1b8cc commit 686865e
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
14 changes: 13 additions & 1 deletion python/pyarrow/array.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -1544,7 +1544,19 @@ cdef class Array(_PandasConvertible):
return _array_like_to_pandas(self, options, types_mapper=types_mapper)

def __array__(self, dtype=None, copy=None):
# TODO honor the copy keyword
# TODO honor the copy=True case
if copy is False:
try:
values = self.to_numpy(zero_copy_only=True)
except ArrowInvalid as exc:
raise ArrowInvalid(
"Unable to avoid a copy while creating a numpy array as requested.\n"
"If using `np.array(obj, copy=False)` replace it with "
"`np.asarray(obj)` to allow a copy when needed"
)
# values is already a numpy array at this point, but calling np.array(..)
# again to handle the `dtype` keyword with a no-copy guarantee
return np.array(values, dtype=dtype, copy=False)
values = self.to_numpy(zero_copy_only=False)
if dtype is None:
return values
Expand Down
28 changes: 27 additions & 1 deletion python/pyarrow/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

import pyarrow as pa
import pyarrow.tests.strategies as past
from pyarrow.vendored.version import Version


def test_total_bytes_allocated():
Expand Down Expand Up @@ -3309,9 +3310,34 @@ def test_numpy_array_protocol():
np.testing.assert_array_equal(result, expected)

# this should not raise a deprecation warning with numpy 2.0+
result = np.asarray(arr, copy=False)
result = np.array(arr, copy=False)
np.testing.assert_array_equal(result, expected)

result = np.array(arr, dtype="int64", copy=False)
np.testing.assert_array_equal(result, expected)

# no zero-copy is possible
arr = pa.array([1, 2, None])
expected = np.array([1, 2, np.nan], dtype="float64")
result = np.asarray(arr)
np.testing.assert_array_equal(result, expected)

if Version(np.__version__) < Version("2.0"):
# copy keyword is not strict and not passed down to __array__
result = np.array(arr, copy=False)
np.testing.assert_array_equal(result, expected)

result = np.array(arr, dtype="float64", copy=False)
np.testing.assert_array_equal(result, expected)
else:
# starting with numpy 2.0, the copy=False keyword is assumed to be strict
with pytest.raises(ValueError, match="Unable to avoid a copy"):
np.array(arr, copy=False)

arr = pa.array([1, 2, 3])
with pytest.raises(ValueError):
np.array(arr, dtype="float64", copy=False)


def test_array_protocol():

Expand Down

0 comments on commit 686865e

Please sign in to comment.