Skip to content

Commit

Permalink
Merge 526e021 into e9edfdc
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanIsCoding committed Aug 9, 2022
2 parents e9edfdc + 526e021 commit d7cbbbc
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 0 deletions.
65 changes: 65 additions & 0 deletions src/iterators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ use std::hash::Hasher;
use num_bigint::BigUint;
use rustworkx_core::dictmap::*;

use ndarray::prelude::*;
use numpy::{IntoPyArray, PyArrayDescr};
use pyo3::class::iter::IterNextOutput;
use pyo3::exceptions::{PyIndexError, PyKeyError, PyNotImplementedError};
use pyo3::gc::PyVisit;
Expand Down Expand Up @@ -413,6 +415,63 @@ enum SliceOrInt<'a> {
Int(isize),
}

trait PyConvertToPyArray {
fn convert_to_pyarray(&self, py: Python) -> PyResult<PyObject>;
}

macro_rules! py_convert_to_py_array_impl {
($($t:ty)*) => ($(
impl PyConvertToPyArray for Vec<$t> {
fn convert_to_pyarray(&self, py: Python) -> PyResult<PyObject> {
Ok(self.clone().into_pyarray(py).into())
}
}
)*)
}

macro_rules! py_convert_to_py_array_obj_impl {
($t:ty) => {
impl PyConvertToPyArray for Vec<$t> {
fn convert_to_pyarray(&self, py: Python) -> PyResult<PyObject> {
let pyobj_vec: Vec<PyObject> = self.iter().map(|x| x.clone().into_py(py)).collect();
Ok(pyobj_vec.into_pyarray(py).into())
}
}
};
}

py_convert_to_py_array_impl! {usize u8 u16 u32 u64 isize i8 i16 i32 i64 f32 f64}

py_convert_to_py_array_obj_impl! {EdgeList}
py_convert_to_py_array_obj_impl! {(PyObject, Vec<PyObject>)}

impl PyConvertToPyArray for Vec<(usize, usize)> {
fn convert_to_pyarray(&self, py: Python) -> PyResult<PyObject> {
let mut mat = Array2::<usize>::from_elem((self.len(), 2), 0);

for (index, element) in self.iter().enumerate() {
mat[[index, 0]] = element.0;
mat[[index, 1]] = element.1;
}

Ok(mat.into_pyarray(py).into())
}
}

impl PyConvertToPyArray for Vec<(usize, usize, PyObject)> {
fn convert_to_pyarray(&self, py: Python) -> PyResult<PyObject> {
let mut mat = Array2::<PyObject>::from_elem((self.len(), 3), py.None());

for (index, element) in self.iter().enumerate() {
mat[[index, 0]] = element.0.into_py(py);
mat[[index, 1]] = element.1.into_py(py);
mat[[index, 2]] = element.2.clone();
}

Ok(mat.into_pyarray(py).into())
}
}

macro_rules! custom_vec_iter_impl {
($name:ident, $data:ident, $T:ty, $doc:literal) => {
#[doc = $doc]
Expand Down Expand Up @@ -521,6 +580,12 @@ macro_rules! custom_vec_iter_impl {
}
}

fn __array__(&self, py: Python, _dt: Option<&PyArrayDescr>) -> PyResult<PyObject> {
// Note: we accept the dtype argument on the signature but
// effictively do nothing with it to let Numpy handle the conversion itself
self.$data.convert_to_pyarray(py)
}

fn __traverse__(&self, vis: PyVisit) -> Result<(), PyTraverseError> {
PyGCProtocol::__traverse__(self, vis)
}
Expand Down
23 changes: 23 additions & 0 deletions tests/retworkx_backwards_compat/test_custom_return_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import unittest

import retworkx
import numpy as np


class TestBFSSuccessorsComparisons(unittest.TestCase):
Expand Down Expand Up @@ -173,6 +174,10 @@ def test_slices_negatives(self):
self.assertEqual([2, 3], slice_return)
self.assertEqual([], indices[-1:-2])

def test_numpy_conversion(self):
res = self.dag.node_indexes()
np.testing.assert_array_equal(np.asarray(res, dtype=np.uintp), np.array([0, 1]))


class TestNodesCountMapping(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -405,6 +410,15 @@ def test_slice(self):
slice_return = edges[0:3:2]
self.assertEqual([(0, 1), (0, 1)], slice_return)

@staticmethod
def test_numpy_conversion():
g = retworkx.generators.directed_star_graph(5)
res = g.edge_list()

np.testing.assert_array_equal(
np.asarray(res, dtype=np.uintp), np.array([[0, 1], [0, 2], [0, 3], [0, 4]])
)


class TestWeightedEdgeListComparisons(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -475,6 +489,11 @@ def test_slice(self):
slice_return = edges[0:3:2]
self.assertEqual([(0, 1, "Edgy"), (0, 1, None)], slice_return)

def test_numpy_conversion(self):
np.testing.assert_array_equal(
np.asarray(self.dag.weighted_edge_list()), np.array([(0, 1, "Edgy")], dtype=object)
)


class TestPathMapping(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -1352,6 +1371,10 @@ def test_hash(self):
# Assert hash is stable
self.assertEqual(hash_res, hash(self.chains))

def test_numpy_conversion(self):
# this test assumes the array is 1-dimensional which avoids issues with jagged arrays
self.assertTrue(np.asarray(self.chains).shape, (1,))


class TestProductNodeMap(unittest.TestCase):
def setUp(self):
Expand Down
23 changes: 23 additions & 0 deletions tests/rustworkx_tests/test_custom_return_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import unittest

import rustworkx
import numpy as np


class TestBFSSuccessorsComparisons(unittest.TestCase):
Expand Down Expand Up @@ -173,6 +174,10 @@ def test_slices_negatives(self):
self.assertEqual([2, 3], slice_return)
self.assertEqual([], indices[-1:-2])

def test_numpy_conversion(self):
res = self.dag.node_indexes()
np.testing.assert_array_equal(np.asarray(res, dtype=np.uintp), np.array([0, 1]))


class TestNodesCountMapping(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -405,6 +410,15 @@ def test_slice(self):
slice_return = edges[0:3:2]
self.assertEqual([(0, 1), (0, 1)], slice_return)

@staticmethod
def test_numpy_conversion():
g = rustworkx.generators.directed_star_graph(5)
res = g.edge_list()

np.testing.assert_array_equal(
np.asarray(res, dtype=np.uintp), np.array([[0, 1], [0, 2], [0, 3], [0, 4]])
)


class TestWeightedEdgeListComparisons(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -475,6 +489,11 @@ def test_slice(self):
slice_return = edges[0:3:2]
self.assertEqual([(0, 1, "Edgy"), (0, 1, None)], slice_return)

def test_numpy_conversion(self):
np.testing.assert_array_equal(
np.asarray(self.dag.weighted_edge_list()), np.array([(0, 1, "Edgy")], dtype=object)
)


class TestPathMapping(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -1358,6 +1377,10 @@ def test_hash(self):
# Assert hash is stable
self.assertEqual(hash_res, hash(self.chains))

def test_numpy_conversion(self):
# this test assumes the array is 1-dimensional which avoids issues with jagged arrays
self.assertTrue(np.asarray(self.chains).shape, (1,))


class TestProductNodeMap(unittest.TestCase):
def setUp(self):
Expand Down

0 comments on commit d7cbbbc

Please sign in to comment.