diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 97f925fd1..4fa06758d 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -25,7 +25,7 @@ jobs: - uses: actions/setup-python@v5 with: python-version: 3.8 - - run: pip install -U ruff black~=22.0 + - run: pip install -U ruff==0.4.1 black~=22.0 - uses: dtolnay/rust-toolchain@stable with: components: rustfmt diff --git a/rustworkx/rustworkx.pyi b/rustworkx/rustworkx.pyi index 991631460..2edcdde67 100644 --- a/rustworkx/rustworkx.pyi +++ b/rustworkx/rustworkx.pyi @@ -993,7 +993,7 @@ class _RustworkxCustomVecIter(Generic[_T_co], Sequence[_T_co], ABC): def __len__(self) -> int: ... def __ne__(self, other: object) -> bool: ... def __setstate__(self, state: Sequence[_T_co]) -> None: ... - def __array__(self, _dt: np.dtype | None = ...) -> np.ndarray: ... + def __array__(self, dtype: np.dtype | None = ..., copy: bool | None = ...) -> np.ndarray: ... def __iter__(self) -> Iterator[_T_co]: ... def __reversed__(self) -> Iterator[_T_co]: ... diff --git a/setup.py b/setup.py index 82a390d7e..40bf25ca8 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ def readme(): PKG_NAME = os.getenv("RUSTWORKX_PKG_NAME", "rustworkx") PKG_VERSION = "0.15.0" PKG_PACKAGES = ["rustworkx", "rustworkx.visualization"] -PKG_INSTALL_REQUIRES = ["numpy>=1.16.0,<2"] +PKG_INSTALL_REQUIRES = ["numpy>=1.16.0,<3"] RUST_EXTENSIONS = [RustExtension("rustworkx.rustworkx", "Cargo.toml", binding=Binding.PyO3, debug=rustworkx_debug)] RUST_OPTS ={"bdist_wheel": {"py_limited_api": "cp38"}} diff --git a/src/iterators.rs b/src/iterators.rs index 5ed425342..f766dfc3b 100644 --- a/src/iterators.rs +++ b/src/iterators.rs @@ -46,10 +46,11 @@ use num_bigint::BigUint; use rustworkx_core::dictmap::*; use ndarray::prelude::*; -use numpy::{IntoPyArray, PyArrayDescr}; -use pyo3::exceptions::{PyIndexError, PyKeyError, PyNotImplementedError}; +use numpy::IntoPyArray; +use pyo3::exceptions::{PyIndexError, PyKeyError, PyNotImplementedError, PyValueError}; use pyo3::gc::PyVisit; use pyo3::prelude::*; +use pyo3::types::IntoPyDict; use pyo3::types::PySlice; use pyo3::PyTraverseError; @@ -601,11 +602,26 @@ macro_rules! custom_vec_iter_impl { fn __array__( &self, py: Python, - _dt: Option<&Bound>, + dtype: Option, + copy: Option, ) -> PyResult { - // Note: we accept the dtype argument on the signature but - // effictively do nothing with it to let Numpy handle the conversion itself - self.$data.convert_to_pyarray(py) + if copy == Some(false) { + return Err(PyValueError::new_err( + "A copy is needed to return an array from this object.", + )); + } + let res = self.$data.convert_to_pyarray(py)?; + Ok(match dtype { + Some(dtype) => { + let numpy_mod = py.import_bound("numpy")?; + let args = (res,); + let kwargs = [("dtype", dtype)].into_py_dict_bound(py); + numpy_mod + .call_method("asarray", args, Some(&kwargs))? + .into() + } + None => res, + }) } fn __traverse__(&self, vis: PyVisit) -> Result<(), PyTraverseError> { diff --git a/tests/test_custom_return_types.py b/tests/test_custom_return_types.py index 504a9f734..725cf73ed 100644 --- a/tests/test_custom_return_types.py +++ b/tests/test_custom_return_types.py @@ -198,6 +198,16 @@ def test_numpy_conversion(self): res = self.dag.node_indexes() np.testing.assert_array_equal(np.asarray(res, dtype=np.uintp), np.array([0, 1])) + def test_numpy_conversion_copy_false(self): + res = self.dag.node_indices() + with self.assertRaises(ValueError): + res.__array__(copy=False) + + def test_numpy_conversion_dtype_complex(self): + res = self.dag.node_indices() + array = res.__array__(dtype=complex) + self.assertEqual(np.dtype(complex), array.dtype) + class TestNodesCountMapping(unittest.TestCase): def setUp(self):