diff --git a/releasenotes/notes/with-capacity-6f35d43e256e268e.yaml b/releasenotes/notes/with-capacity-6f35d43e256e268e.yaml new file mode 100644 index 000000000..7b261bb4f --- /dev/null +++ b/releasenotes/notes/with-capacity-6f35d43e256e268e.yaml @@ -0,0 +1,6 @@ +--- +features: + - | + :class:`.PyGraph` and :class:`.PyDiGraph` now each take keyword arguments ``initial_node_count`` + and ``initial_edge_count`` in their constructors, which can be used to pre-allocate space for + the given number of nodes and edges. diff --git a/rustworkx/rustworkx.pyi b/rustworkx/rustworkx.pyi index c1964c35f..c6fe85b7e 100644 --- a/rustworkx/rustworkx.pyi +++ b/rustworkx/rustworkx.pyi @@ -1245,6 +1245,7 @@ class PyGraph(Generic[_S, _T]): ) -> None: ... def __delitem__(self, idx: int, /) -> None: ... def __getitem__(self, idx: int, /) -> _S: ... + def __getnewargs_ex__(self) -> tuple[tuple[Any, ...], dict[str, Any]]: ... def __getstate__(self) -> Any: ... def __len__(self) -> int: ... def __setitem__(self, idx: int, value: _S, /) -> None: ... @@ -1261,6 +1262,10 @@ class PyDiGraph(Generic[_S, _T]): /, check_cycle: bool = ..., multigraph: bool = ..., + attrs: object = ..., + *, + initial_node_count: int = ..., + initial_edge_count: int = ..., ) -> None: ... def add_child(self, parent: int, obj: _S, edge: _T, /) -> int: ... def add_edge(self, parent: int, child: int, edge: _T, /) -> int: ... @@ -1440,6 +1445,7 @@ class PyDiGraph(Generic[_S, _T]): def reverse(self) -> None: ... def __delitem__(self, idx: int, /) -> None: ... def __getitem__(self, idx: int, /) -> _S: ... + def __getnewargs_ex__(self) -> tuple[tuple[Any, ...], dict[str, Any]]: ... def __getstate__(self) -> Any: ... def __len__(self) -> int: ... def __setitem__(self, idx: int, value: _S, /) -> None: ... diff --git a/src/digraph.rs b/src/digraph.rs index c31638188..7baef0268 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -31,7 +31,7 @@ use smallvec::SmallVec; use pyo3::exceptions::PyIndexError; use pyo3::gc::PyVisit; use pyo3::prelude::*; -use pyo3::types::{PyBool, PyDict, PyList, PyString, PyTuple}; +use pyo3::types::{IntoPyDict, PyBool, PyDict, PyList, PyString, PyTuple}; use pyo3::PyTraverseError; use pyo3::Python; @@ -175,6 +175,14 @@ use super::dag_algo::is_directed_acyclic_graph; /// :param attrs: An optional attributes payload to assign to the /// :attr:`~.PyDiGraph.attrs` attribute. This can be any Python object. If /// it is not specified :attr:`~.PyDiGraph.attrs` will be set to ``None``. +/// :param int initial_node_count: The graph will be allocated with enough capacity to store this +/// many nodes before needing to grow. This does not prepopulate any nodes with data, it is +/// only a potential performance optimization if the complete size of the graph is known in +/// advance (default 0: no overallocation). +/// :param int initial_edge_count: The graph will be allocated with enough capacity to store this +/// many edges before needing to grow. This does not prepopulate any edges with data, it is +/// only a potential performance optimization if the complete size of the graph is known in +/// advance (default 0: no overallocation). #[pyclass(mapping, module = "rustworkx", subclass)] #[derive(Clone)] pub struct PyDiGraph { @@ -287,10 +295,17 @@ impl PyDiGraph { #[pymethods] impl PyDiGraph { #[new] - #[pyo3(signature=(check_cycle=false, multigraph=true, attrs=None), text_signature="(/, check_cycle=False, multigraph=True, attrs=None)")] - fn new(py: Python, check_cycle: bool, multigraph: bool, attrs: Option) -> Self { + #[pyo3(signature=(/, check_cycle=false, multigraph=true, attrs=None, *, initial_node_count=0, initial_edge_count=0))] + fn new( + py: Python, + check_cycle: bool, + multigraph: bool, + attrs: Option, + initial_node_count: usize, + initial_edge_count: usize, + ) -> Self { PyDiGraph { - graph: StablePyGraph::::new(), + graph: StablePyGraph::::with_capacity(initial_node_count, initial_edge_count), cycle_state: algo::DfsSpace::default(), check_cycle, node_removed: false, @@ -299,8 +314,19 @@ impl PyDiGraph { } } + fn __getnewargs_ex__<'py>(&self, py: Python<'py>) -> (Py, Bound<'py, PyDict>) { + ( + (self.check_cycle, self.multigraph, self.attrs.clone_ref(py)).into_py(py), + [ + ("initial_node_count", self.graph.node_bound()), + ("initial_edge_count", self.graph.edge_bound()), + ] + .into_py_dict_bound(py), + ) + } + fn __getstate__(&self, py: Python) -> PyResult { - let mut nodes: Vec = Vec::with_capacity(self.graph.node_count()); + let mut nodes: Vec = Vec::with_capacity(self.graph.node_bound()); let mut edges: Vec = Vec::with_capacity(self.graph.edge_bound()); // save nodes to a list along with its index @@ -327,9 +353,6 @@ impl PyDiGraph { out_dict.set_item("nodes", nodes_lst)?; out_dict.set_item("edges", edges_lst)?; out_dict.set_item("nodes_removed", self.node_removed)?; - out_dict.set_item("multigraph", self.multigraph)?; - out_dict.set_item("attrs", self.attrs.clone_ref(py))?; - out_dict.set_item("check_cycle", self.check_cycle)?; Ok(out_dict.into()) } @@ -341,26 +364,11 @@ impl PyDiGraph { let edges_lst = binding.downcast::()?; self.graph = StablePyGraph::::new(); let dict_state = state.downcast_bound::(py)?; - self.multigraph = dict_state - .get_item("multigraph")? - .unwrap() - .downcast::()? - .extract()?; self.node_removed = dict_state .get_item("nodes_removed")? .unwrap() .downcast::()? .extract()?; - let attrs = match dict_state.get_item("attrs")? { - Some(attr) => attr.into(), - None => py.None(), - }; - self.attrs = attrs; - self.check_cycle = dict_state - .get_item("check_cycle")? - .unwrap() - .downcast::()? - .extract()?; // graph is empty, stop early if nodes_lst.is_empty() { diff --git a/src/graph.rs b/src/graph.rs index 06ba8cc4a..981a4a576 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -26,7 +26,7 @@ use rustworkx_core::graph_ext::*; use pyo3::exceptions::PyIndexError; use pyo3::gc::PyVisit; use pyo3::prelude::*; -use pyo3::types::{PyBool, PyDict, PyList, PyString, PyTuple}; +use pyo3::types::{IntoPyDict, PyBool, PyDict, PyList, PyString, PyTuple}; use pyo3::PyTraverseError; use pyo3::Python; @@ -138,6 +138,14 @@ use petgraph::visit::{ /// :param attrs: An optional attributes payload to assign to the /// :attr:`~.PyGraph.attrs` attribute. This can be any Python object. If /// it is not specified :attr:`~.PyGraph.attrs` will be set to ``None``. +/// :param int initial_node_count: The graph will be allocated with enough capacity to store this +/// many nodes before needing to grow. This does not prepopulate any nodes with data, it is +/// only a potential performance optimization if the complete size of the graph is known in +/// advance (default 0: no overallocation). +/// :param int initial_edge_count: The graph will be allocated with enough capacity to store this +/// many edges before needing to grow. This does not prepopulate any edges with data, it is +/// only a potential performance optimization if the complete size of the graph is known in +/// advance (default 0: no overallocation). #[pyclass(mapping, module = "rustworkx", subclass)] #[derive(Clone)] pub struct PyGraph { @@ -183,18 +191,38 @@ impl PyGraph { #[pymethods] impl PyGraph { #[new] - #[pyo3(signature=(multigraph=true, attrs=None), text_signature = "(/, multigraph=True, attrs=None)")] - fn new(py: Python, multigraph: bool, attrs: Option) -> Self { + #[pyo3(signature=(multigraph=true, attrs=None, *, initial_node_count=0, initial_edge_count=0))] + fn new( + py: Python, + multigraph: bool, + attrs: Option, + initial_node_count: usize, + initial_edge_count: usize, + ) -> Self { PyGraph { - graph: StablePyGraph::::default(), + graph: StablePyGraph::::with_capacity( + initial_node_count, + initial_edge_count, + ), node_removed: false, multigraph, attrs: attrs.unwrap_or_else(|| py.None()), } } + fn __getnewargs_ex__<'py>(&self, py: Python<'py>) -> (Py, Bound<'py, PyDict>) { + ( + (self.multigraph, self.attrs.clone_ref(py)).into_py(py), + [ + ("initial_node_count", self.graph.node_bound()), + ("initial_edge_count", self.graph.edge_bound()), + ] + .into_py_dict_bound(py), + ) + } + fn __getstate__(&self, py: Python) -> PyResult { - let mut nodes: Vec = Vec::with_capacity(self.graph.node_count()); + let mut nodes: Vec = Vec::with_capacity(self.graph.node_bound()); let mut edges: Vec = Vec::with_capacity(self.graph.edge_bound()); // save nodes to a list along with its index @@ -222,8 +250,6 @@ impl PyGraph { out_dict.set_item("nodes", nodes_lst)?; out_dict.set_item("edges", edges_lst)?; out_dict.set_item("nodes_removed", self.node_removed)?; - out_dict.set_item("multigraph", self.multigraph)?; - out_dict.set_item("attrs", self.attrs.clone_ref(py))?; Ok(out_dict.into()) } @@ -234,21 +260,11 @@ impl PyGraph { let binding = dict_state.get_item("edges")?.unwrap(); let edges_lst = binding.downcast::()?; - self.graph = StablePyGraph::::default(); - self.multigraph = dict_state - .get_item("multigraph")? - .unwrap() - .downcast::()? - .extract()?; self.node_removed = dict_state .get_item("nodes_removed")? .unwrap() .downcast::()? .extract()?; - self.attrs = match dict_state.get_item("attrs")? { - Some(attr) => attr.into(), - None => py.None(), - }; // graph is empty, stop early if nodes_lst.is_empty() { return Ok(()); diff --git a/tests/digraph/test_pickle.py b/tests/digraph/test_pickle.py index 306fd119c..2912a5b96 100644 --- a/tests/digraph/test_pickle.py +++ b/tests/digraph/test_pickle.py @@ -30,7 +30,7 @@ def test_noweight_graph(self): self.assertEqual({1: (1, 2, None), 3: (3, 1, None)}, dict(gprime.edge_index_map())) def test_weight_graph(self): - g = rx.PyDAG() + g = rx.PyDAG(initial_node_count=4, initial_edge_count=4) g.add_nodes_from(["A", "B", "C", "D"]) g.add_edges_from([(0, 1, "A -> B"), (1, 2, "B -> C"), (3, 0, "D -> A"), (3, 1, "D -> B")]) g.remove_node(0) diff --git a/tests/graph/test_pickle.py b/tests/graph/test_pickle.py index 44220f113..b4d8c4cf3 100644 --- a/tests/graph/test_pickle.py +++ b/tests/graph/test_pickle.py @@ -30,7 +30,7 @@ def test_noweight_graph(self): self.assertEqual({1: (1, 2, None), 3: (3, 1, None)}, dict(gprime.edge_index_map())) def test_weight_graph(self): - g = rx.PyGraph() + g = rx.PyGraph(initial_node_count=4, initial_edge_count=4) g.add_nodes_from(["A", "B", "C", "D"]) g.add_edges_from([(0, 1, "A -> B"), (1, 2, "B -> C"), (3, 0, "D -> A"), (3, 1, "D -> B")]) g.remove_node(0)