diff --git a/releasenotes/notes/lexicographical-topo-sort-core-e85fba409d612600.yaml b/releasenotes/notes/lexicographical-topo-sort-core-e85fba409d612600.yaml new file mode 100644 index 000000000..7aa6d5f6a --- /dev/null +++ b/releasenotes/notes/lexicographical-topo-sort-core-e85fba409d612600.yaml @@ -0,0 +1,7 @@ +--- +features: + - | + Added a new function ``lexicographical_topological_sort`` to the + ``rustworkx_core::dag_algo`` module. That is a gneric Rust implementation + for the core rust library that provides the + :func:`.lexicographical_topological_sort` function to Rust users. diff --git a/rustworkx-core/src/dag_algo.rs b/rustworkx-core/src/dag_algo.rs index b29389a73..fb629069c 100644 --- a/rustworkx-core/src/dag_algo.rs +++ b/rustworkx-core/src/dag_algo.rs @@ -10,19 +10,208 @@ // License for the specific language governing permissions and limitations // under the License. -use std::cmp::Eq; +use std::cmp::{Eq, Ordering}; +use std::collections::BinaryHeap; use std::hash::Hash; use hashbrown::HashMap; use petgraph::algo; use petgraph::visit::{ - EdgeRef, GraphBase, GraphProp, IntoEdgesDirected, IntoNodeIdentifiers, Visitable, + EdgeRef, GraphBase, GraphProp, IntoEdgesDirected, IntoNeighborsDirected, IntoNodeIdentifiers, + NodeCount, Visitable, }; use petgraph::Directed; use num_traits::{Num, Zero}; +/// Return a pair of [`petgraph::Direction`] values corresponding to the "forwards" and "backwards" +/// direction of graph traversal, based on whether the graph is being traved forwards (following +/// the edges) or backward (reversing along edges). The order of returns is (forwards, backwards). +#[inline(always)] +pub fn traversal_directions(reverse: bool) -> (petgraph::Direction, petgraph::Direction) { + if reverse { + (petgraph::Direction::Outgoing, petgraph::Direction::Incoming) + } else { + (petgraph::Direction::Incoming, petgraph::Direction::Outgoing) + } +} + +/// Get the lexicographical topological sorted nodes from the provided DAG +/// +/// This function returns a list of nodes data in a graph lexicographically +/// topologically sorted using the provided key function. A topological sort +/// is a linear ordering of vertices such that for every directed edge from +/// node :math:`u` to node :math:`v`, :math:`u` comes before :math:`v` +/// in the ordering. If ``reverse`` is set to ``False``, the edges are treated +/// as if they pointed in the opposite direction. +/// +/// This function differs from :func:`~rustworkx.topological_sort` because +/// when there are ties between nodes in the sort order this function will +/// use the string returned by the ``key`` argument to determine the output +/// order used. The ``reverse`` argument does not affect the ordering of keys +/// from this function, only the edges of the graph. +/// +/// # Arguments: +/// +/// * `dag`: The DAG to get the topological sorted nodes from +/// * `key`: A function that gets passed a single argument, the node id from +/// `dag` and is expected to return a `String` which will be used for +/// resolving ties in the sorting order. +/// * `reverse`: If `false`, perform a regular topological ordering. If `true`, +/// return the lexicographical topological order that would have been found +/// if all the edges in the graph were reversed. This does not affect the +/// comparisons from the `key`. +/// * `initial`: If given, the initial node indices to start the topological +/// ordering from. If not given, the topological ordering will certainly contain every node in +/// the graph. If given, only the `initial` nodes and nodes that are dominated by the +/// `initial` set will be in the ordering. Notably, any node that has a natural in degree of +/// zero will not be in the output ordering if `initial` is given and the zero-in-degree node +/// is not in it. It is not supported to give an `initial` set where the nodes have even +/// a partial topological order between themselves and `None` will be returned in this case +/// +/// # Returns +/// +/// * `None` if the graph contains a cycle or `initial` is invalid +/// * `Some(Vec)` representing the topological ordering of nodes. +/// * `Err(E)` if there is an error computing the key for any node +/// +/// # Example +/// +/// ```rust +/// use std::convert::Infallible; +/// +/// use rustworkx_core::dag_algo::lexicographical_topological_sort; +/// use rustworkx_core::petgraph::stable_graph::{StableDiGraph, NodeIndex}; +/// +/// let mut graph: StableDiGraph = StableDiGraph::new(); +/// let mut nodes: Vec = Vec::new(); +/// for weight in 0..9 { +/// nodes.push(graph.add_node(weight)); +/// } +/// let edges = [ +/// (nodes[0], nodes[1]), +/// (nodes[0], nodes[2]), +/// (nodes[1], nodes[3]), +/// (nodes[2], nodes[4]), +/// (nodes[3], nodes[4]), +/// (nodes[4], nodes[5]), +/// (nodes[5], nodes[6]), +/// (nodes[4], nodes[7]), +/// (nodes[6], nodes[8]), +/// (nodes[7], nodes[8]), +/// ]; +/// for (source, target) in edges { +/// graph.add_edge(source, target, ()); +/// } +/// let sort_fn = |index: NodeIndex| -> Result { Ok(graph[index].to_string()) }; +/// let initial = [nodes[6], nodes[7]]; +/// let result = lexicographical_topological_sort(&graph, sort_fn, true, Some(&initial)); +/// let expected = vec![ +/// nodes[6], +/// nodes[5], +/// nodes[7], +/// nodes[4], +/// nodes[2], +/// nodes[3], +/// nodes[1], +/// nodes[0] +/// ]; +/// assert_eq!(result, Ok(Some(expected))); +/// +/// ``` +pub fn lexicographical_topological_sort( + dag: G, + mut key: F, + reverse: bool, + initial: Option<&[G::NodeId]>, +) -> Result>, E> +where + G: GraphProp + + IntoNodeIdentifiers + + IntoNeighborsDirected + + IntoEdgesDirected + + NodeCount, + ::NodeId: Hash + Eq + Ord, + F: FnMut(G::NodeId) -> Result, +{ + // HashMap of node_index indegree + let node_count = dag.node_count(); + let (in_dir, out_dir) = traversal_directions(reverse); + + #[derive(Clone, Eq, PartialEq)] + struct State { + key: String, + node: N, + } + + impl Ord for State { + fn cmp(&self, other: &State) -> Ordering { + // Notice that the we flip the ordering on costs. + // In case of a tie we compare positions - this step is necessary + // to make implementations of `PartialEq` and `Ord` consistent. + other + .key + .cmp(&self.key) + .then_with(|| other.node.cmp(&self.node)) + } + } + + // `PartialOrd` needs to be implemented as well. + impl PartialOrd for State { + fn partial_cmp(&self, other: &State) -> Option { + Some(self.cmp(other)) + } + } + + let mut in_degree_map: HashMap = HashMap::with_capacity(node_count); + if let Some(initial) = initial { + // In this case, we don't iterate through all the nodes in the graph, and most nodes aren't + // in `in_degree_map`; we'll fill in the relevant edge counts lazily. + for node in initial { + in_degree_map.insert(*node, 0); + } + } else { + for node in dag.node_identifiers() { + in_degree_map.insert(node, dag.edges_directed(node, in_dir).count()); + } + } + + let mut zero_indegree = BinaryHeap::with_capacity(node_count); + for (node, degree) in in_degree_map.iter() { + if *degree == 0 { + let map_key: String = key(*node)?; + zero_indegree.push(State { + key: map_key, + node: *node, + }); + } + } + let mut out_list: Vec = Vec::with_capacity(node_count); + while let Some(State { node, .. }) = zero_indegree.pop() { + let neighbors = dag.neighbors_directed(node, out_dir); + for child in neighbors { + let child_degree = in_degree_map + .entry(child) + .or_insert_with(|| dag.edges_directed(child, in_dir).count()); + if *child_degree == 0 { + return Ok(None); + } else if *child_degree == 1 { + let map_key: String = key(child)?; + zero_indegree.push(State { + key: map_key, + node: child, + }); + in_degree_map.remove(&child); + } else { + *child_degree -= 1; + } + } + out_list.push(node) + } + Ok(Some(out_list)) +} + // Type aliases for readability type NodeId = ::NodeId; type LongestPathResult = Result>, T)>, E>; @@ -231,3 +420,182 @@ mod test_longest_path { assert_eq!(result, Err("Error: edge weight is 2")); } } + +// pub fn lexicographical_topological_sort( +// dag: G, +// mut key: F, +// reverse: bool, +// initial: Option<&[G::NodeId]>, +// ) -> Result>, E> + +#[cfg(test)] +mod test_lexicographical_topological_sort { + use super::*; + use petgraph::graph::{DiGraph, NodeIndex}; + use petgraph::stable_graph::StableDiGraph; + use std::convert::Infallible; + + #[test] + fn test_empty_graph() { + let graph: DiGraph<(), ()> = DiGraph::new(); + let sort_fn = |_: NodeIndex| -> Result { Ok("a".to_string()) }; + let result = lexicographical_topological_sort(&graph, sort_fn, false, None); + assert_eq!(result, Ok(Some(vec![]))); + } + + #[test] + fn test_empty_stable_graph() { + let graph: StableDiGraph<(), ()> = StableDiGraph::new(); + let sort_fn = |_: NodeIndex| -> Result { Ok("a".to_string()) }; + let result = lexicographical_topological_sort(&graph, sort_fn, false, None); + assert_eq!(result, Ok(Some(vec![]))); + } + + #[test] + fn test_simple_layer() { + let mut graph: DiGraph = DiGraph::new(); + let mut nodes: Vec = Vec::new(); + nodes.push(graph.add_node("a".to_string())); + for i in 0..5 { + nodes.push(graph.add_node(i.to_string())); + } + nodes.push(graph.add_node("A parent".to_string())); + for (source, target) in [(0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (6, 3)] { + graph.add_edge(nodes[source], nodes[target], ()); + } + let sort_fn = |index: NodeIndex| -> Result { Ok(graph[index].clone()) }; + let result = lexicographical_topological_sort(&graph, sort_fn, false, None); + assert_eq!( + result, + Ok(Some(vec![ + NodeIndex::new(6), + NodeIndex::new(0), + NodeIndex::new(1), + NodeIndex::new(2), + NodeIndex::new(3), + NodeIndex::new(4), + NodeIndex::new(5) + ])) + ) + } + + #[test] + fn test_simple_layer_stable() { + let mut graph: StableDiGraph = StableDiGraph::new(); + let mut nodes: Vec = Vec::new(); + nodes.push(graph.add_node("a".to_string())); + for i in 0..5 { + nodes.push(graph.add_node(i.to_string())); + } + nodes.push(graph.add_node("A parent".to_string())); + for (source, target) in [(0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (6, 3)] { + graph.add_edge(nodes[source], nodes[target], ()); + } + let sort_fn = |index: NodeIndex| -> Result { Ok(graph[index].clone()) }; + let result = lexicographical_topological_sort(&graph, sort_fn, false, None); + assert_eq!( + result, + Ok(Some(vec![ + NodeIndex::new(6), + NodeIndex::new(0), + NodeIndex::new(1), + NodeIndex::new(2), + NodeIndex::new(3), + NodeIndex::new(4), + NodeIndex::new(5) + ])) + ) + } + + #[test] + fn test_reverse_graph() { + let mut graph: DiGraph = DiGraph::new(); + let mut nodes: Vec = Vec::new(); + for weight in ["a", "b", "c", "d", "e", "f"] { + nodes.push(graph.add_node(weight.to_string())); + } + let edges = [ + (nodes[0], nodes[1]), + (nodes[0], nodes[2]), + (nodes[1], nodes[3]), + (nodes[2], nodes[3]), + (nodes[1], nodes[4]), + (nodes[2], nodes[5]), + ]; + + for (source, target) in edges { + graph.add_edge(source, target, ()); + } + let sort_fn = |index: NodeIndex| -> Result { Ok(graph[index].clone()) }; + let result = lexicographical_topological_sort(&graph, sort_fn, true, None); + graph.reverse(); + let sort_fn = |index: NodeIndex| -> Result { Ok(graph[index].clone()) }; + let expected = lexicographical_topological_sort(&graph, sort_fn, false, None); + assert_eq!(result, expected,) + } + + #[test] + fn test_reverse_graph_stable() { + let mut graph: StableDiGraph = StableDiGraph::new(); + let mut nodes: Vec = Vec::new(); + for weight in ["a", "b", "c", "d", "e", "f"] { + nodes.push(graph.add_node(weight.to_string())); + } + let edges = [ + (nodes[0], nodes[1]), + (nodes[0], nodes[2]), + (nodes[1], nodes[3]), + (nodes[2], nodes[3]), + (nodes[1], nodes[4]), + (nodes[2], nodes[5]), + ]; + + for (source, target) in edges { + graph.add_edge(source, target, ()); + } + let sort_fn = |index: NodeIndex| -> Result { Ok(graph[index].clone()) }; + let result = lexicographical_topological_sort(&graph, sort_fn, true, None); + graph.reverse(); + let sort_fn = |index: NodeIndex| -> Result { Ok(graph[index].clone()) }; + let expected = lexicographical_topological_sort(&graph, sort_fn, false, None); + assert_eq!(result, expected,) + } + + #[test] + fn test_initial() { + let mut graph: StableDiGraph = StableDiGraph::new(); + let mut nodes: Vec = Vec::new(); + for weight in 0..9 { + nodes.push(graph.add_node(weight)); + } + let edges = [ + (nodes[0], nodes[1]), + (nodes[0], nodes[2]), + (nodes[1], nodes[3]), + (nodes[2], nodes[4]), + (nodes[3], nodes[4]), + (nodes[4], nodes[5]), + (nodes[5], nodes[6]), + (nodes[4], nodes[7]), + (nodes[6], nodes[8]), + (nodes[7], nodes[8]), + ]; + for (source, target) in edges { + graph.add_edge(source, target, ()); + } + let sort_fn = + |index: NodeIndex| -> Result { Ok(graph[index].to_string()) }; + let initial = [nodes[6], nodes[7]]; + let result = lexicographical_topological_sort(&graph, sort_fn, false, Some(&initial)); + assert_eq!(result, Ok(Some(vec![nodes[6], nodes[7], nodes[8]]))); + let initial = [nodes[0]]; + let result = lexicographical_topological_sort(&graph, sort_fn, false, Some(&initial)); + assert_eq!( + result, + lexicographical_topological_sort(&graph, sort_fn, false, None) + ); + let initial = [nodes[7]]; + let result = lexicographical_topological_sort(&graph, sort_fn, false, Some(&initial)); + assert_eq!(result, Ok(Some(vec![nodes[7]]))); + } +} diff --git a/src/dag_algo/mod.rs b/src/dag_algo/mod.rs index 206fa9b45..f3138384c 100644 --- a/src/dag_algo/mod.rs +++ b/src/dag_algo/mod.rs @@ -14,12 +14,11 @@ use super::DictMap; use hashbrown::{HashMap, HashSet}; use indexmap::IndexSet; use rustworkx_core::dictmap::InitWithHasher; -use std::cmp::Ordering; -use std::collections::BinaryHeap; use super::iterators::NodeIndices; use crate::{digraph, DAGHasCycle, InvalidNode, StablePyGraph}; +use rustworkx_core::dag_algo::lexicographical_topological_sort as core_lexico_topo_sort; use rustworkx_core::dag_algo::longest_path as core_longest_path; use rustworkx_core::traversal::dfs_edges; @@ -433,100 +432,35 @@ pub fn lexicographical_topological_sort( reverse: bool, initial: Option<&Bound>, ) -> PyResult { - let key_callable = |a: &PyObject| -> PyResult { - let res = key.call1(py, (a,))?; - Ok(res.to_object(py)) + let key_callable = |a: NodeIndex| -> PyResult { + let weight = &dag.graph[a]; + let res: String = key.call1(py, (weight,))?.extract(py)?; + Ok(res) }; - // HashMap of node_index indegree - let node_count = dag.node_count(); - let (in_dir, out_dir) = traversal_directions(reverse); - - #[derive(Clone, Eq, PartialEq)] - struct State { - key: String, - node: NodeIndex, - } - - impl Ord for State { - fn cmp(&self, other: &State) -> Ordering { - // Notice that the we flip the ordering on costs. - // In case of a tie we compare positions - this step is necessary - // to make implementations of `PartialEq` and `Ord` consistent. - other - .key - .cmp(&self.key) - .then_with(|| other.node.index().cmp(&self.node.index())) - } - } - - // `PartialOrd` needs to be implemented as well. - impl PartialOrd for State { - fn partial_cmp(&self, other: &State) -> Option { - Some(self.cmp(other)) - } - } - - let mut in_degree_map: HashMap = HashMap::with_capacity(node_count); - if let Some(initial) = initial { - // In this case, we don't iterate through all the nodes in the graph, and most nodes aren't - // in `in_degree_map`; we'll fill in the relevant edge counts lazily. - for maybe_index in initial.iter()? { - let node = NodeIndex::new(maybe_index?.extract::()?); - if dag.graph.contains_node(node) { - // It's not necessarily actually zero, but we treat it as if it is. If the node is - // reachable from another we visit during the iteration, then there was a defined - // topological order between the `initial` set, and we'll throw an error. - in_degree_map.insert(node, 0); - } else { - return Err(PyValueError::new_err(format!( - "node index {} is not in this graph", - node.index() - ))); + let initial: Option> = match initial { + Some(initial) => { + let mut initial_vec: Vec = Vec::new(); + for maybe_index in initial.iter()? { + let node = NodeIndex::new(maybe_index?.extract::()?); + initial_vec.push(node); } + Some(initial_vec) } - } else { - for node in dag.graph.node_indices() { - in_degree_map.insert(node, dag.graph.edges_directed(node, in_dir).count()); - } - } - - let mut zero_indegree = BinaryHeap::with_capacity(node_count); - for (node, degree) in in_degree_map.iter() { - if *degree == 0 { - let map_key_raw = key_callable(&dag.graph[*node])?; - let map_key: String = map_key_raw.extract(py)?; - zero_indegree.push(State { - key: map_key, - node: *node, - }); - } - } - let mut out_list: Vec<&PyObject> = Vec::with_capacity(node_count); - while let Some(State { node, .. }) = zero_indegree.pop() { - let neighbors = dag.graph.neighbors_directed(node, out_dir); - for child in neighbors { - let child_degree = in_degree_map - .entry(child) - .or_insert_with(|| dag.graph.edges_directed(child, in_dir).count()); - if *child_degree == 0 { - return Err(PyValueError::new_err( - "at least one initial node is reachable from another", - )); - } else if *child_degree == 1 { - let map_key_raw = key_callable(&dag.graph[child])?; - let map_key: String = map_key_raw.extract(py)?; - zero_indegree.push(State { - key: map_key, - node: child, - }); - in_degree_map.remove(&child); - } else { - *child_degree -= 1; - } - } - out_list.push(&dag.graph[node]) + None => None, + }; + let out_list = core_lexico_topo_sort(&dag.graph, key_callable, reverse, initial.as_deref())?; + match out_list { + Some(out_list) => Ok(PyList::new_bound( + py, + out_list + .into_iter() + .map(|node| dag.graph[node].clone_ref(py)), + ) + .into()), + None => Err(PyValueError::new_err( + "at least one initial node is reachable from another", + )), } - Ok(PyList::new_bound(py, out_list).into()) } /// Return the topological generations of a DAG