From 6690475c5f027406887df0c2e93bd7551dea5c2a Mon Sep 17 00:00:00 2001 From: Julio Quintero Date: Mon, 4 Mar 2024 13:15:51 -0500 Subject: [PATCH] Fix type hint for search functions --- .../fix-type-search-6abecbc064e1dbf0.yaml | 7 + rustworkx-core/src/lib.rs | 2 + rustworkx-core/src/steiner_tree.rs | 572 ++++++++++++++++++ rustworkx/__init__.pyi | 10 +- rustworkx/rustworkx.pyi | 12 +- src/steiner_tree.rs | 288 ++------- tests/graph/test_steiner_tree.py | 2 +- 7 files changed, 660 insertions(+), 233 deletions(-) create mode 100644 releasenotes/notes/fix-type-search-6abecbc064e1dbf0.yaml create mode 100644 rustworkx-core/src/steiner_tree.rs diff --git a/releasenotes/notes/fix-type-search-6abecbc064e1dbf0.yaml b/releasenotes/notes/fix-type-search-6abecbc064e1dbf0.yaml new file mode 100644 index 000000000..9341bfdb1 --- /dev/null +++ b/releasenotes/notes/fix-type-search-6abecbc064e1dbf0.yaml @@ -0,0 +1,7 @@ +--- +fixes: + - | + Fixed the bug type hint for the :func:`~rustworkx.bfs_search`, + :func:`~rustworkx.dfs_search` and :func:`~rustworkx.dijkstra_search`. + Refer to `#1130 `__ for + more information. diff --git a/rustworkx-core/src/lib.rs b/rustworkx-core/src/lib.rs index a432d891a..e5d38eb58 100644 --- a/rustworkx-core/src/lib.rs +++ b/rustworkx-core/src/lib.rs @@ -92,6 +92,8 @@ mod min_scored; pub mod token_swapper; pub mod utils; +pub mod steiner_tree; + // re-export petgraph so there is a consistent version available to users and // then only need to require rustworkx-core in their dependencies pub use petgraph; diff --git a/rustworkx-core/src/steiner_tree.rs b/rustworkx-core/src/steiner_tree.rs new file mode 100644 index 000000000..bca9e40a6 --- /dev/null +++ b/rustworkx-core/src/steiner_tree.rs @@ -0,0 +1,572 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +use std::cmp::{Eq, Ordering}; +use std::convert::Infallible; +use std::hash::Hash; + +use hashbrown::{HashMap, HashSet}; +use rayon::prelude::*; + +use petgraph::stable_graph::{EdgeIndex, NodeIndex, StableGraph}; +use petgraph::unionfind::UnionFind; +use petgraph::visit::{ + EdgeCount, EdgeIndexable, EdgeRef, GraphProp, IntoEdgeReferences, IntoEdges, + IntoNodeIdentifiers, IntoNodeReferences, NodeCount, NodeIndexable, NodeRef, Visitable, +}; +use petgraph::Undirected; + +use crate::dictmap::*; +use crate::shortest_path::dijkstra; +use crate::utils::pairwise; + +type AllPairsDijkstraReturn = HashMap>, DictMap)>; + +fn all_pairs_dijkstra_shortest_paths( + graph: G, + mut weight_fn: F, +) -> Result +where + G: NodeIndexable + + IntoNodeIdentifiers + + EdgeCount + + NodeCount + + EdgeIndexable + + Visitable + + Sync + + IntoEdges, + G::NodeId: Eq + Hash + Send, + G::EdgeId: Eq + Hash + Send, + F: FnMut(G::EdgeRef) -> Result, +{ + if graph.node_count() == 0 { + return Ok(HashMap::new()); + } else if graph.edge_count() == 0 { + return Ok(graph + .node_identifiers() + .map(|x| { + ( + NodeIndexable::to_index(&graph, x), + (DictMap::new(), DictMap::new()), + ) + }) + .collect()); + } + let mut edge_weights: Vec> = vec![None; graph.edge_bound()]; + for edge in graph.edge_references() { + let index = EdgeIndexable::to_index(&graph, edge.id()); + edge_weights[index] = Some(weight_fn(edge)?); + } + let edge_cost = |e: G::EdgeRef| -> Result { + Ok(edge_weights[EdgeIndexable::to_index(&graph, e.id())].unwrap()) + }; + + let node_indices: Vec = graph + .node_identifiers() + .map(|n| NodeIndexable::to_index(&graph, n)) + .collect(); + Ok(node_indices + .into_par_iter() + .map(|x| { + let mut paths: DictMap> = + DictMap::with_capacity(graph.node_count()); + let distances: DictMap = dijkstra( + graph, + NodeIndexable::from_index(&graph, x), + None, + edge_cost, + Some(&mut paths), + ) + .unwrap(); + ( + x, + ( + paths + .into_iter() + .map(|(k, v)| { + ( + NodeIndexable::to_index(&graph, k), + v.into_iter() + .map(|n| NodeIndexable::to_index(&graph, n)) + .collect(), + ) + }) + .collect(), + distances + .into_iter() + .map(|(k, v)| (NodeIndexable::to_index(&graph, k), v)) + .collect(), + ), + ) + }) + .collect()) +} + +struct MetricClosureEdge { + source: usize, + target: usize, + distance: f64, + path: Vec, +} + +/// Return the metric closure of a graph +/// +/// The metric closure of a graph is the complete graph in which each edge is +/// weighted by the shortest path distance between the nodes in the graph. +/// +/// Arguments: +/// `graph`: The input graph to compute the metric closure for +/// `weight_fn`: A callable weight function that will be passed an edge reference +/// for each edge in the graph and it is expected to return a `Result` +/// which if it doesn't error represents the weight of that edge. +/// `default_weight`: A blind callable that returns a default weight to use for +/// edges added to the output +/// +/// Returns a `StableGraph` with the input graph node ids for node weights and edge weights with a +/// tuple of the numeric weight (found via `weight_fn`) and the path. The output will be `None` +/// if `graph` is disconnected. +/// +/// # Example +/// ```rust +/// use std::convert::Infallible; +/// +/// use rustworkx_core::petgraph::Graph; +/// use rustworkx_core::petgraph::Undirected; +/// use rustworkx_core::petgraph::graph::EdgeReference; +/// use rustworkx_core::petgraph::visit::{IntoEdgeReferences, EdgeRef}; +/// +/// use rustworkx_core::steiner_tree::metric_closure; +/// +/// let input_graph = Graph::<(), u8, Undirected>::from_edges(&[ +/// (0, 1, 10), +/// (1, 2, 10), +/// (2, 3, 10), +/// (3, 4, 10), +/// (4, 5, 10), +/// (1, 6, 1), +/// (6, 4, 1), +/// ]); +/// +/// let weight_fn = |e: EdgeReference| -> Result { +/// Ok(*e.weight() as f64) +/// }; +/// +/// let closure = metric_closure(&input_graph, weight_fn).unwrap().unwrap(); +/// let mut output_edge_list: Vec<(usize, usize, (f64, Vec))> = closure.edge_references().map(|edge| (edge.source().index(), edge.target().index(), edge.weight().clone())).collect(); +/// let mut expected_edges: Vec<(usize, usize, (f64, Vec))> = vec![ +/// (0, 1, (10.0, vec![0, 1])), +/// (0, 2, (20.0, vec![0, 1, 2])), +/// (0, 3, (22.0, vec![0, 1, 6, 4, 3])), +/// (0, 4, (12.0, vec![0, 1, 6, 4])), +/// (0, 5, (22.0, vec![0, 1, 6, 4, 5])), +/// (0, 6, (11.0, vec![0, 1, 6])), +/// (1, 2, (10.0, vec![1, 2])), +/// (1, 3, (12.0, vec![1, 6, 4, 3])), +/// (1, 4, (2.0, vec![1, 6, 4])), +/// (1, 5, (12.0, vec![1, 6, 4, 5])), +/// (1, 6, (1.0, vec![1, 6])), +/// (2, 3, (10.0, vec![2, 3])), +/// (2, 4, (12.0, vec![2, 1, 6, 4])), +/// (2, 5, (22.0, vec![2, 1, 6, 4, 5])), +/// (2, 6, (11.0, vec![2, 1, 6])), +/// (3, 4, (10.0, vec![3, 4])), +/// (3, 5, (20.0, vec![3, 4, 5])), +/// (3, 6, (11.0, vec![3, 4, 6])), +/// (4, 5, (10.0, vec![4, 5])), +/// (4, 6, (1.0, vec![4, 6])), +/// (5, 6, (11.0, vec![5, 4, 6])), +/// ]; +/// output_edge_list.sort_by_key(|x| [x.0, x.1]); +/// expected_edges.sort_by_key(|x| [x.0, x.1]); +/// assert_eq!(output_edge_list, expected_edges); +/// +/// ``` +#[allow(clippy::type_complexity)] +pub fn metric_closure( + graph: G, + weight_fn: F, +) -> Result), Undirected>>, E> +where + G: NodeIndexable + + EdgeIndexable + + Sync + + EdgeCount + + NodeCount + + Visitable + + IntoNodeReferences + + IntoEdges + + Visitable + + GraphProp, + G::NodeId: Eq + Hash + NodeRef + Send, + G::EdgeId: Eq + Hash + Send, + G::NodeWeight: Clone, + F: FnMut(G::EdgeRef) -> Result, +{ + let mut out_graph: StableGraph), Undirected> = + StableGraph::with_capacity(graph.node_count(), graph.edge_count()); + let node_map: HashMap = graph + .node_references() + .map(|node| { + ( + NodeIndexable::to_index(&graph, node.id()), + out_graph.add_node(node.id()), + ) + }) + .collect(); + let edges = metric_closure_edges(graph, weight_fn)?; + if edges.is_none() { + return Ok(None); + } + for edge in edges.unwrap() { + out_graph.add_edge( + node_map[&edge.source], + node_map[&edge.target], + (edge.distance, edge.path), + ); + } + Ok(Some(out_graph)) +} + +fn metric_closure_edges( + graph: G, + weight_fn: F, +) -> Result>, E> +where + G: NodeIndexable + + Sync + + Visitable + + IntoNodeReferences + + IntoEdges + + Visitable + + NodeIndexable + + NodeCount + + EdgeCount + + EdgeIndexable, + G::NodeId: Eq + Hash + Send, + G::EdgeId: Eq + Hash + Send, + F: FnMut(G::EdgeRef) -> Result, +{ + let node_count = graph.node_count(); + if node_count == 0 { + return Ok(Some(Vec::new())); + } + let mut out_vec = Vec::with_capacity(node_count * (node_count - 1) / 2); + let paths = all_pairs_dijkstra_shortest_paths(graph, weight_fn)?; + let mut nodes: HashSet = graph + .node_identifiers() + .map(|x| NodeIndexable::to_index(&graph, x)) + .collect(); + let first_node = graph + .node_identifiers() + .map(|x| NodeIndexable::to_index(&graph, x)) + .next() + .unwrap(); + let path_keys: HashSet = paths[&first_node].0.keys().copied().collect(); + // first_node will always be missing from path_keys so if the difference + // is > 1 with nodes that means there is another node in the graph that + // first_node doesn't have a path to. + if nodes.difference(&path_keys).count() > 1 { + return Ok(None); + } + // Iterate over node indices for a deterministic order + for node in graph + .node_identifiers() + .map(|x| NodeIndexable::to_index(&graph, x)) + { + let path_map = &paths[&node].0; + nodes.remove(&node); + let distance = &paths[&node].1; + for v in &nodes { + out_vec.push(MetricClosureEdge { + source: node, + target: *v, + distance: distance[v], + path: path_map[v].clone(), + }); + } + } + Ok(Some(out_vec)) +} + +/// Computes the shortest path between all pairs `(s, t)` of the given `terminal_nodes` +/// *provided* that: +/// - there is an edge `(u, v)` in the graph and path pass through this edge. +/// - node `s` is the closest node to `u` among all `terminal_nodes` +/// - node `t` is the closest node to `v` among all `terminal_nodes` +/// and wraps the result inside a `MetricClosureEdge` +/// +/// For example, if all vertices are terminals, it returns the original edges of the graph. +fn fast_metric_edges( + in_graph: G, + terminal_nodes: &[G::NodeId], + mut weight_fn: F, +) -> Result, E> +where + G: IntoEdges + + NodeIndexable + + EdgeIndexable + + Sync + + EdgeCount + + Visitable + + IntoNodeReferences + + NodeCount, + G::NodeId: Eq + Hash + Send, + G::EdgeId: Eq + Hash + Send, + F: FnMut(G::EdgeRef) -> Result, +{ + let mut graph: StableGraph<(), (), Undirected> = StableGraph::with_capacity( + in_graph.node_count() + 1, + in_graph.edge_count() + terminal_nodes.len(), + ); + let node_map: HashMap = in_graph + .node_references() + .map(|n| (n.id(), graph.add_node(()))) + .collect(); + let reverse_node_map: HashMap = + node_map.iter().map(|(k, v)| (*v, *k)).collect(); + let edge_map: HashMap = in_graph + .edge_references() + .map(|e| { + ( + graph.add_edge(node_map[&e.source()], node_map[&e.target()], ()), + e, + ) + }) + .collect(); + + // temporarily add a ``dummy`` node, connect it with + // all the terminal nodes and find all the shortest paths + // starting from ``dummy`` node. + let dummy = graph.add_node(()); + for node in terminal_nodes { + graph.add_edge(dummy, node_map[node], ()); + } + + let mut paths = DictMap::with_capacity(graph.node_count()); + + let mut wrapped_weight_fn = + |e: <&StableGraph<(), ()> as IntoEdgeReferences>::EdgeRef| -> Result { + if let Some(edge_ref) = edge_map.get(&e.id()) { + weight_fn(*edge_ref) + } else { + Ok(0.0) + } + }; + + let mut distance: DictMap = dijkstra( + &graph, + dummy, + None, + &mut wrapped_weight_fn, + Some(&mut paths), + )?; + paths.swap_remove(&dummy); + distance.swap_remove(&dummy); + + // ``partition[u]`` holds the terminal node closest to node ``u``. + let mut partition: Vec = vec![std::usize::MAX; graph.node_bound()]; + for (u, path) in paths.iter() { + let u = NodeIndexable::to_index(&in_graph, reverse_node_map[u]); + partition[u] = NodeIndexable::to_index(&in_graph, reverse_node_map[&path[1]]); + } + + let mut out_edges: Vec = Vec::with_capacity(graph.edge_count()); + + for edge in graph.edge_references() { + let source = edge.source(); + let target = edge.target(); + // assert that ``source`` is reachable from a terminal node. + if distance.contains_key(&source) { + let weight = distance[&source] + wrapped_weight_fn(edge)? + distance[&target]; + let mut path: Vec = paths[&source] + .iter() + .skip(1) + .map(|x| NodeIndexable::to_index(&in_graph, reverse_node_map[x])) + .collect(); + path.append( + &mut paths[&target] + .iter() + .skip(1) + .rev() + .map(|x| NodeIndexable::to_index(&in_graph, reverse_node_map[x])) + .collect(), + ); + + let source = NodeIndexable::to_index(&in_graph, reverse_node_map[&source]); + let target = NodeIndexable::to_index(&in_graph, reverse_node_map[&target]); + + let mut source = partition[source]; + let mut target = partition[target]; + + match source.cmp(&target) { + Ordering::Equal => continue, + Ordering::Greater => std::mem::swap(&mut source, &mut target), + _ => {} + } + + out_edges.push(MetricClosureEdge { + source, + target, + distance: weight, + path, + }); + } + } + + // if parallel edges, keep the edge with minimum distance. + out_edges.par_sort_unstable_by(|a, b| { + let weight_a = (a.source, a.target, a.distance); + let weight_b = (b.source, b.target, b.distance); + weight_a.partial_cmp(&weight_b).unwrap_or(Ordering::Less) + }); + + out_edges.dedup_by(|edge_a, edge_b| { + edge_a.source == edge_b.source && edge_a.target == edge_b.target + }); + + Ok(out_edges) +} + +pub struct SteinerTreeResult { + pub used_node_indices: HashSet, + pub used_edge_endpoints: HashSet<(usize, usize)>, +} + +/// Return an approximation to the minimum Steiner tree of a graph. +/// +/// The minimum tree of ``graph`` with regard to a set of ``terminal_nodes`` +/// is a tree within ``graph`` that spans those nodes and has a minimum size +/// (measured as the sum of edge weights) amoung all such trees. +/// +/// The minimum steiner tree can be approximated by computing the minimum +/// spanning tree of the subgraph of the metric closure of ``graph`` induced +/// by the terminal nodes, where the metric closure of ``graph`` is the +/// complete graph in which each edge is weighted by the shortest path distance +/// between nodes in ``graph``. +/// +/// This algorithm [1]_ produces a tree whose weight is within a +/// :math:`(2 - (2 / t))` factor of the weight of the optimal Steiner tree +/// where :math:`t` is the number of terminal nodes. The algorithm implemented +/// here is due to [2]_ . It avoids computing all pairs shortest paths but rather +/// reduces the problem to a single source shortest path and a minimum spanning tree +/// problem. +/// +/// Arguments: +/// `graph`: The input graph to compute the steiner tree of +/// `terminal_nodes`: The terminal nodes of the steiner tree +/// `weight_fn`: A callable weight function that will be passed an edge reference +/// for each edge in the graph and it is expected to return a `Result` +/// which if it doesn't error represents the weight of that edge. +/// +/// Returns a custom struct that contains a set of nodes and edges and `None` +/// if the graph is disconnected relative to the terminal nodes. +/// +/// # Example +/// +/// ```rust +/// use std::convert::Infallible; +/// +/// use rustworkx_core::petgraph::Graph; +/// use rustworkx_core::petgraph::graph::NodeIndex; +/// use rustworkx_core::petgraph::Undirected; +/// use rustworkx_core::petgraph::graph::EdgeReference; +/// use rustworkx_core::petgraph::visit::{IntoEdgeReferences, EdgeRef}; +/// +/// use rustworkx_core::steiner_tree::steiner_tree; +/// +/// let input_graph = Graph::<(), u8, Undirected>::from_edges(&[ +/// (0, 1, 10), +/// (1, 2, 10), +/// (2, 3, 10), +/// (3, 4, 10), +/// (4, 5, 10), +/// (1, 6, 1), +/// (6, 4, 1), +/// ]); +/// +/// let weight_fn = |e: EdgeReference| -> Result { +/// Ok(*e.weight() as f64) +/// }; +/// let terminal_nodes = vec![ +/// NodeIndex::new(0), +/// NodeIndex::new(1), +/// NodeIndex::new(2), +/// NodeIndex::new(3), +/// NodeIndex::new(4), +/// NodeIndex::new(5), +/// ]; +/// +/// let tree = steiner_tree(&input_graph, &terminal_nodes, weight_fn).unwrap().unwrap(); +/// ``` +/// +/// .. [1] Kou, Markowsky & Berman, +/// "A fast algorithm for Steiner trees" +/// Acta Informatica 15, 141–145 (1981). +/// https://link.springer.com/article/10.1007/BF00288961 +/// .. [2] Kurt Mehlhorn, +/// "A faster approximation algorithm for the Steiner problem in graphs" +/// https://doi.org/10.1016/0020-0190(88)90066-X +pub fn steiner_tree( + graph: G, + terminal_nodes: &[G::NodeId], + weight_fn: F, +) -> Result, E> +where + G: IntoEdges + + NodeIndexable + + Sync + + EdgeCount + + IntoNodeReferences + + EdgeIndexable + + Visitable + + NodeCount, + G::NodeId: Eq + Hash + Send, + G::EdgeId: Eq + Hash + Send, + F: FnMut(G::EdgeRef) -> Result, +{ + let node_bound = graph.node_bound(); + let mut edge_list = fast_metric_edges(graph, terminal_nodes, weight_fn)?; + let mut subgraphs = UnionFind::::new(node_bound); + edge_list.par_sort_unstable_by(|a, b| { + let weight_a = (a.distance, a.source, a.target); + let weight_b = (b.distance, b.source, b.target); + weight_a.partial_cmp(&weight_b).unwrap_or(Ordering::Less) + }); + let mut mst_edges: Vec = Vec::new(); + for float_edge_pair in edge_list { + let u = float_edge_pair.source; + let v = float_edge_pair.target; + if subgraphs.union(u, v) { + mst_edges.push(float_edge_pair); + } + } + // assert that the terminal nodes are connected. + if !terminal_nodes.is_empty() && mst_edges.len() != terminal_nodes.len() - 1 { + return Ok(None); + } + // Generate the output graph from the MST + let out_edge_list: Vec<[usize; 2]> = mst_edges + .into_iter() + .flat_map(|edge| pairwise(edge.path)) + .filter_map(|x| x.0.map(|a| [a, x.1])) + .collect(); + let out_edges: HashSet<(usize, usize)> = out_edge_list.iter().map(|x| (x[0], x[1])).collect(); + let out_nodes: HashSet = out_edge_list + .iter() + .flat_map(|x| x.iter()) + .copied() + .collect(); + Ok(Some(SteinerTreeResult { + used_node_indices: out_nodes, + used_edge_endpoints: out_edges, + })) +} diff --git a/rustworkx/__init__.pyi b/rustworkx/__init__.pyi index 21da9d6f0..6507cc8a0 100644 --- a/rustworkx/__init__.pyi +++ b/rustworkx/__init__.pyi @@ -9,9 +9,9 @@ # This file contains only type annotations for PyO3 functions and classes # For implementation details, see __init__.py and src/lib.rs -import numpy as np +from typing import Generic, TypeVar, Any, Callable, Iterator, overload, Sequence -from typing import Generic, TypeVar, Any, Callable, Iterator, overload +import numpy as np # Re-Exports of rust native functions in rustworkx.rustworkx # To workaround limitations in mypy around re-exporting objects from the inner @@ -551,17 +551,17 @@ def cartesian_product( ) -> tuple[PyDiGraph, ProductNodeMap]: ... def bfs_search( graph: PyGraph | PyDiGraph, - source: int, + source: Sequence[int] | None, visitor: _BFSVisitor, ) -> None: ... def dfs_search( graph: PyGraph | PyDiGraph, - source: int, + source: Sequence[int] | None, visitor: _DFSVisitor, ) -> None: ... def dijkstra_search( graph: PyGraph | PyDiGraph, - source: int, + source: Sequence[int] | None, weight_fn: Callable[[Any], float], visitor: _DijkstraVisitor, ) -> None: ... diff --git a/rustworkx/rustworkx.pyi b/rustworkx/rustworkx.pyi index 6139ff369..765085e90 100644 --- a/rustworkx/rustworkx.pyi +++ b/rustworkx/rustworkx.pyi @@ -894,33 +894,33 @@ _DijkstraVisitor = TypeVar("_DijkstraVisitor", bound=DijkstraVisitor) def digraph_bfs_search( graph: PyDiGraph, - source: int | None = ..., + source: Sequence[int] | None = ..., visitor: _BFSVisitor | None = ..., ) -> None: ... def graph_bfs_search( graph: PyGraph, - source: int | None = ..., + source: Sequence[int] | None = ..., visitor: _BFSVisitor | None = ..., ) -> None: ... def digraph_dfs_search( graph: PyDiGraph, - source: int | None = ..., + source: Sequence[int] | None = ..., visitor: _DFSVisitor | None = ..., ) -> None: ... def graph_dfs_search( graph: PyGraph, - source: int | None = ..., + source: Sequence[int] | None = ..., visitor: _DFSVisitor | None = ..., ) -> None: ... def digraph_dijkstra_search( graph: PyDiGraph, - source: int | None = ..., + source: Sequence[int] | None = ..., weight_fn: Callable[[Any], float] | None = ..., visitor: _DijkstraVisitor | None = ..., ) -> None: ... def graph_dijkstra_search( graph: PyGraph, - source: int | None = ..., + source: Sequence[int] | None = ..., weight_fn: Callable[[Any], float] | None = ..., visitor: _DijkstraVisitor | None = ..., ) -> None: ... diff --git a/src/steiner_tree.rs b/src/steiner_tree.rs index c8df908da..57819b992 100644 --- a/src/steiner_tree.rs +++ b/src/steiner_tree.rs @@ -12,7 +12,7 @@ use std::cmp::Ordering; -use hashbrown::{HashMap, HashSet}; +use hashbrown::HashMap; use rayon::prelude::*; use pyo3::exceptions::PyValueError; @@ -20,23 +20,12 @@ use pyo3::prelude::*; use pyo3::Python; use petgraph::stable_graph::{EdgeIndex, EdgeReference, NodeIndex}; -use petgraph::unionfind::UnionFind; -use petgraph::visit::{EdgeRef, IntoEdgeReferences, NodeIndexable}; +use petgraph::visit::{EdgeRef, IntoEdgeReferences}; -use crate::graph; -use crate::is_valid_weight; -use crate::shortest_path::all_pairs_dijkstra::all_pairs_dijkstra_shortest_paths; +use crate::{graph, is_valid_weight}; -use rustworkx_core::dictmap::*; -use rustworkx_core::shortest_path::dijkstra; -use rustworkx_core::utils::pairwise; - -struct MetricClosureEdge { - source: usize, - target: usize, - distance: f64, - path: Vec, -} +use rustworkx_core::steiner_tree::metric_closure as core_metric_closure; +use rustworkx_core::steiner_tree::steiner_tree as core_steiner_tree; /// Return the metric closure of a graph /// @@ -59,165 +48,31 @@ pub fn metric_closure( graph: &graph::PyGraph, weight_fn: PyObject, ) -> PyResult { - let mut out_graph = graph.clone(); - out_graph.graph.clear_edges(); - let edges = _metric_closure_edges(py, graph, weight_fn)?; - for edge in edges { - out_graph.graph.add_edge( - NodeIndex::new(edge.source), - NodeIndex::new(edge.target), - (edge.distance, edge.path).to_object(py), - ); - } - Ok(out_graph) -} - -fn _metric_closure_edges( - py: Python, - graph: &graph::PyGraph, - weight_fn: PyObject, -) -> PyResult> { - let node_count = graph.graph.node_count(); - if node_count == 0 { - return Ok(Vec::new()); - } - let mut out_vec = Vec::with_capacity(node_count * (node_count - 1) / 2); - let mut distances = HashMap::with_capacity(graph.graph.node_count()); - let paths = - all_pairs_dijkstra_shortest_paths(py, &graph.graph, weight_fn, Some(&mut distances))?.paths; - let mut nodes: HashSet = graph.graph.node_indices().map(|x| x.index()).collect(); - let first_node = graph - .graph - .node_indices() - .map(|x| x.index()) - .next() - .unwrap(); - let path_keys: HashSet = paths[&first_node].paths.keys().copied().collect(); - // first_node will always be missing from path_keys so if the difference - // is > 1 with nodes that means there is another node in the graph that - // first_node doesn't have a path to. - if nodes.difference(&path_keys).count() > 1 { - return Err(PyValueError::new_err( - "The input graph must be a connected graph. The metric closure is \ - not defined for a graph with unconnected nodes", - )); - } - // Iterate over node indices for a deterministic order - for node in graph.graph.node_indices().map(|x| x.index()) { - let path_map = &paths[&node].paths; - nodes.remove(&node); - let distance = &distances[&node]; - for v in &nodes { - let v_index = NodeIndex::new(*v); - out_vec.push(MetricClosureEdge { - source: node, - target: *v, - distance: distance[&v_index], - path: path_map[v].clone(), - }); - } - } - Ok(out_vec) -} - -/// Computes the shortest path between all pairs `(s, t)` of the given `terminal_nodes` -/// *provided* that: -/// - there is an edge `(u, v)` in the graph and path pass through this edge. -/// - node `s` is the closest node to `u` among all `terminal_nodes` -/// - node `t` is the closest node to `v` among all `terminal_nodes` -/// and wraps the result inside a `MetricClosureEdge` -/// -/// For example, if all vertices are terminals, it returns the original edges of the graph. -fn fast_metric_edges( - py: Python, - graph: &mut graph::PyGraph, - terminal_nodes: &[usize], - weight_fn: &PyObject, -) -> PyResult> { - // temporarily add a ``dummy`` node, connect it with - // all the terminal nodes and find all the shortest paths - // starting from ``dummy`` node. - let dummy = graph.graph.add_node(py.None()); - for node in terminal_nodes { - graph - .graph - .add_edge(dummy, NodeIndex::new(*node), py.None()); - } - - let cost_fn = |edge: EdgeReference<'_, PyObject>| -> PyResult { - if edge.source() != dummy && edge.target() != dummy { - let weight: f64 = weight_fn.call1(py, (edge.weight(),))?.extract(py)?; - is_valid_weight(weight) - } else { - Ok(0.0) - } + let callable = |e: EdgeReference| -> PyResult { + let data = e.weight(); + let raw = weight_fn.call1(py, (data,))?; + let weight = raw.extract(py)?; + is_valid_weight(weight) }; - - let mut paths = DictMap::with_capacity(graph.graph.node_count()); - let mut distance: DictMap = - dijkstra(&graph.graph, dummy, None, cost_fn, Some(&mut paths))?; - paths.swap_remove(&dummy); - distance.swap_remove(&dummy); - graph.graph.remove_node(dummy); - - // ``partition[u]`` holds the terminal node closest to node ``u``. - let mut partition: Vec = vec![std::usize::MAX; graph.graph.node_bound()]; - for (u, path) in paths.iter() { - let u = u.index(); - partition[u] = path[1].index(); - } - - let mut out_edges: Vec = Vec::with_capacity(graph.graph.edge_count()); - - for edge in graph.graph.edge_references() { - let source = edge.source(); - let target = edge.target(); - // assert that ``source`` is reachable from a terminal node. - if distance.contains_key(&source) { - let weight = distance[&source] + cost_fn(edge)? + distance[&target]; - let mut path: Vec = paths[&source].iter().skip(1).map(|x| x.index()).collect(); - path.append( - &mut paths[&target] - .iter() - .skip(1) - .rev() - .map(|x| x.index()) - .collect(), + if let Some(result_graph) = core_metric_closure(&graph.graph, callable)? { + let mut out_graph = graph.clone(); + out_graph.graph.clear_edges(); + for edge in result_graph.edge_indices() { + let (source, target) = result_graph.edge_endpoints(edge).unwrap(); + let weight = result_graph.edge_weight(edge).unwrap(); + out_graph.graph.add_edge( + *result_graph.node_weight(source).unwrap(), + *result_graph.node_weight(target).unwrap(), + weight.to_object(py), ); - - let source = source.index(); - let target = target.index(); - - let mut source = partition[source]; - let mut target = partition[target]; - - match source.cmp(&target) { - Ordering::Equal => continue, - Ordering::Greater => std::mem::swap(&mut source, &mut target), - _ => {} - } - - out_edges.push(MetricClosureEdge { - source, - target, - distance: weight, - path, - }); } + Ok(out_graph) + } else { + Err(PyValueError::new_err( + "The input graph must be a connected graph. The metric closure is \ + not defined for a graph with unconnected nodes", + )) } - - // if parallel edges, keep the edge with minimum distance. - out_edges.par_sort_unstable_by(|a, b| { - let weight_a = (a.source, a.target, a.distance); - let weight_b = (b.source, b.target, b.distance); - weight_a.partial_cmp(&weight_b).unwrap_or(Ordering::Less) - }); - - out_edges.dedup_by(|edge_a, edge_b| { - edge_a.source == edge_b.source && edge_a.target == edge_b.target - }); - - Ok(out_edges) } /// Return an approximation to the minimum Steiner tree of a graph. @@ -267,60 +122,51 @@ pub fn steiner_tree( terminal_nodes: Vec, weight_fn: PyObject, ) -> PyResult { - let mut edge_list = fast_metric_edges(py, graph, &terminal_nodes, &weight_fn)?; - let mut subgraphs = UnionFind::::new(graph.graph.node_bound()); - edge_list.par_sort_unstable_by(|a, b| { - let weight_a = (a.distance, a.source, a.target); - let weight_b = (b.distance, b.source, b.target); - weight_a.partial_cmp(&weight_b).unwrap_or(Ordering::Less) - }); - let mut mst_edges: Vec = Vec::new(); - for float_edge_pair in edge_list { - let u = float_edge_pair.source; - let v = float_edge_pair.target; - if subgraphs.union(u, v) { - mst_edges.push(float_edge_pair); + let callable = |e: EdgeReference| -> PyResult { + let data = e.weight(); + let raw = weight_fn.call1(py, (data,))?; + raw.extract(py) + }; + let mut terminal_n: Vec = Vec::with_capacity(terminal_nodes.len()); + for n in &terminal_nodes { + let index = NodeIndex::new(*n); + if graph.graph.node_weight(index).is_none() { + return Err(PyValueError::new_err(format!( + "Provided terminal node index {} is not present in graph", + n + ))); } + terminal_n.push(index); } - // assert that the terminal nodes are connected. - if !terminal_nodes.is_empty() && mst_edges.len() != terminal_nodes.len() - 1 { - return Err(PyValueError::new_err( + let result = core_steiner_tree(&graph.graph, &terminal_n, callable)?; + if let Some(result) = result { + let mut out_graph = graph.clone(); + for node in graph + .graph + .node_indices() + .filter(|node| !result.used_node_indices.contains(&node.index())) + { + out_graph.graph.remove_node(node); + } + for edge in graph.graph.edge_references().filter(|edge| { + let source = edge.source().index(); + let target = edge.target().index(); + !result.used_edge_endpoints.contains(&(source, target)) + && !result.used_edge_endpoints.contains(&(target, source)) + }) { + out_graph.graph.remove_edge(edge.id()); + } + deduplicate_edges(py, &mut out_graph, &weight_fn)?; + if out_graph.graph.node_count() != graph.graph.node_count() { + out_graph.node_removed = true; + } + Ok(out_graph) + } else { + Err(PyValueError::new_err( "The terminal nodes in the input graph must belong to the same connected component. \ - The steiner tree is not defined for a graph with unconnected terminal nodes", - )); - } - // Generate the output graph from the MST - let out_edge_list: Vec<[usize; 2]> = mst_edges - .into_iter() - .flat_map(|edge| pairwise(edge.path)) - .filter_map(|x| x.0.map(|a| [a, x.1])) - .collect(); - let out_edges: HashSet<(usize, usize)> = out_edge_list.iter().map(|x| (x[0], x[1])).collect(); - let mut out_graph = graph.clone(); - let out_nodes: HashSet = out_edge_list - .iter() - .flat_map(|x| x.iter()) - .copied() - .map(NodeIndex::new) - .collect(); - for node in graph - .graph - .node_indices() - .filter(|node| !out_nodes.contains(node)) - { - out_graph.graph.remove_node(node); - out_graph.node_removed = true; - } - for edge in graph.graph.edge_references().filter(|edge| { - let source = edge.source().index(); - let target = edge.target().index(); - !out_edges.contains(&(source, target)) && !out_edges.contains(&(target, source)) - }) { - out_graph.graph.remove_edge(edge.id()); + The steiner tree is not defined for a graph with unconnected terminal nodes", + )) } - // Deduplicate potential duplicate edges - deduplicate_edges(py, &mut out_graph, &weight_fn)?; - Ok(out_graph) } fn deduplicate_edges( diff --git a/tests/graph/test_steiner_tree.py b/tests/graph/test_steiner_tree.py index 74ea5de82..0d144c0b4 100644 --- a/tests/graph/test_steiner_tree.py +++ b/tests/graph/test_steiner_tree.py @@ -151,7 +151,7 @@ def test_steiner_graph_multigraph(self): def test_not_connected_steiner_tree(self): self.graph.add_node(None) with self.assertRaises(ValueError): - rustworkx.steiner_tree(self.graph, [1, 2, 8], weight_fn=float) + rustworkx.steiner_tree(self.graph, [1, 2, 0], weight_fn=float) def test_steiner_tree_empty_graph(self): graph = rustworkx.PyGraph()