diff --git a/docs/source/api.rst b/docs/source/api.rst index ecc6f1237..0a0aa2df8 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -106,6 +106,7 @@ Tree rustworkx.minimum_spanning_edges rustworkx.minimum_spanning_tree rustworkx.steiner_tree + rustworkx.bipartition_tree .. _isomorphism: @@ -181,6 +182,7 @@ Other Algorithm Functions rustworkx.core_number rustworkx.graph_greedy_color rustworkx.metric_closure + rustworkx.bipartition_graph_mst .. _generator_funcs: diff --git a/releasenotes/notes/bipartition_graph_mst-ccb2204bc7b6c407.yaml b/releasenotes/notes/bipartition_graph_mst-ccb2204bc7b6c407.yaml new file mode 100644 index 000000000..040c90a94 --- /dev/null +++ b/releasenotes/notes/bipartition_graph_mst-ccb2204bc7b6c407.yaml @@ -0,0 +1,7 @@ +--- +features: + - | + Added a new function :func:`~.bipartition_graph_mst` that takes in a connected + graph and tries to draw a minimum spanning tree and find a balanced cut + edge to target using :func:`~.bipartition_tree`. If such a corresponding + tree and edge cannnot be found, then it retries. diff --git a/releasenotes/notes/bipartition_tree-4c1ad080b1fab9e8.yaml b/releasenotes/notes/bipartition_tree-4c1ad080b1fab9e8.yaml new file mode 100644 index 000000000..0abc0b404 --- /dev/null +++ b/releasenotes/notes/bipartition_tree-4c1ad080b1fab9e8.yaml @@ -0,0 +1,24 @@ +--- +features: + - | + Added a new function :func:`~.bipartition_tree` that takes in spanning tree + and a list of populations assigned to each node in the tree and finds all + balanced edges, if they exist. A balanced edge is defined as an edge that, + when cut, will split the population of the tree into two connected subtrees + that have population near the population target within some epsilon. The + function returns a list of all such possible cuts, represented as the set + of nodes in one partition/subtree. For example, + + .. code-block:: python + + balanced_node_choices = retworkx.bipartition_tree( + tree, + pops, + float(pop_target), + float(epsilon) + ) + + returns a list of tuples, with each tuple representing a distinct balanced + edge that can be cut. The tuple contains the root of one of the two + partitioned subtrees and the set of nodes making up that subtree. The other + partition can be recovered by computing the complement of the set of nodes. diff --git a/src/lib.rs b/src/lib.rs index d34450e01..9bbfc84cf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -427,6 +427,8 @@ fn rustworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(max_weight_matching))?; m.add_wrapped(wrap_pyfunction!(minimum_spanning_edges))?; m.add_wrapped(wrap_pyfunction!(minimum_spanning_tree))?; + m.add_wrapped(wrap_pyfunction!(bipartition_tree))?; + m.add_wrapped(wrap_pyfunction!(bipartition_graph_mst))?; m.add_wrapped(wrap_pyfunction!(graph_transitivity))?; m.add_wrapped(wrap_pyfunction!(digraph_transitivity))?; m.add_wrapped(wrap_pyfunction!(graph_core_number))?; diff --git a/src/tree.rs b/src/tree.rs index edcd3f95f..672885012 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -10,7 +10,10 @@ // License for the specific language governing permissions and limitations // under the License. +use hashbrown::HashSet; use std::cmp::Ordering; +use std::collections::VecDeque; +use std::mem; use super::{graph, weight_callable}; @@ -126,6 +129,18 @@ pub fn minimum_spanning_tree( let mut spanning_tree = (*graph).clone(); spanning_tree.graph.clear_edges(); + _minimum_spanning_tree(py, graph, &mut spanning_tree, weight_fn, default_weight)?; + Ok(spanning_tree) +} + +/// Helper function to allow reuse of spanning_tree object to reduce memory allocs +fn _minimum_spanning_tree( + py: Python, + graph: &graph::PyGraph, + spanning_tree: &mut graph::PyGraph, + weight_fn: Option, + default_weight: f64, +) -> PyResult<()> { for edge in minimum_spanning_edges(py, graph, weight_fn, default_weight)? .edges .iter() @@ -133,5 +148,142 @@ pub fn minimum_spanning_tree( spanning_tree.add_edge(edge.0, edge.1, edge.2.clone_ref(py)); } - Ok(spanning_tree) + Ok(()) +} + +/// Bipartition tree by finding balanced cut edges of a spanning tree using +/// node contraction. Assumes that the tree is connected and is a spanning tree. +/// A balanced edge is defined as an edge that, when cut, will split the +/// population of the tree into two connected subtrees that have population near +/// the population target within some epsilon. The function returns a list of +/// all such possible cuts, represented as the set of nodes in one +/// partition/subtree. Wraps around ``_bipartition_tree``. +/// +/// :param PyGraph graph: Spanning tree. Must be fully connected +/// :param pops: The populations assigned to each node in the graph. +/// :param float pop_target: The population target to reach when partitioning the +/// graph. +/// :param float epsilon: The maximum percent deviation from the pop_target +/// allowed while still being a valid balanced cut edge. +/// +/// :returns: A list of tuples, with each tuple representing a distinct +/// balanced edge that can be cut. The tuple contains the root of one of the +/// two partitioned subtrees and the set of nodes making up that subtree. +#[pyfunction] +#[pyo3(text_signature = "(spanning_tree, pops, target_pop, epsilon)")] +pub fn bipartition_tree( + spanning_tree: &graph::PyGraph, + pops: Vec, + pop_target: f64, + epsilon: f64, +) -> Vec<(usize, Vec)> { + _bipartition_tree(spanning_tree, pops, pop_target, epsilon) +} + +/// Internal _bipartition_tree implementation. +fn _bipartition_tree( + spanning_tree: &graph::PyGraph, + pops: Vec, + pop_target: f64, + epsilon: f64, +) -> Vec<(usize, Vec)> { + let mut pops = pops; + let spanning_tree_graph = &spanning_tree.graph; + let mut same_partition_tracker: Vec> = + vec![Vec::new(); spanning_tree_graph.node_bound()]; // Keeps track of all all the nodes on the same side of the partition + + let mut node_queue: VecDeque = VecDeque::::new(); + for leaf_node in spanning_tree_graph.node_indices() { + if spanning_tree_graph.neighbors(leaf_node).count() == 1 { + node_queue.push_back(leaf_node); + } + same_partition_tracker[leaf_node.index()].push(leaf_node.index()); + } + + // BFS search for balanced nodes + let mut balanced_nodes: Vec<(usize, Vec)> = vec![]; + let mut seen_nodes = HashSet::with_capacity(spanning_tree_graph.node_count()); + while !node_queue.is_empty() { + let node = node_queue.pop_front().unwrap(); + if seen_nodes.contains(&node.index()) { + continue; + } + + // Mark as seen; push to queue if only one unseen neighbor + let unseen_neighbors: Vec = spanning_tree + .graph + .neighbors(node) + .filter(|node| !seen_nodes.contains(&node.index())) + .collect(); + + if unseen_neighbors.len() == 1 { + // At leaf, will be false at root + let pop = pops[node.index()]; + + // Update neighbor pop + let neighbor = unseen_neighbors[0]; + pops[neighbor.index()] += pop; + + // Check if balanced; mark as seen + if pop >= pop_target * (1.0 - epsilon) && pop <= pop_target * (1.0 + epsilon) { + balanced_nodes.push((node.index(), same_partition_tracker[node.index()].clone())); + } + seen_nodes.insert(node.index()); + + // Update neighbor partition tracker + let mut current_partition_tracker = + mem::take(&mut same_partition_tracker[node.index()]); + same_partition_tracker[neighbor.index()].append(&mut current_partition_tracker); + + // Queue neighbor + node_queue.push_back(neighbor); + } else if unseen_neighbors.is_empty() { + // Is root + break; + } else { + // Not a leaf yet + continue; + } + } + + balanced_nodes +} + +/// Bipartition graph into two contiguous, population-balanced components using +/// mst. Assumes that the graph is contiguous. See :func:`~bipartition_tree` for +/// details on how balance is defined. +/// +/// :param PyGraph graph: Undirected graph +/// :param weight_fn: A callable object (function, lambda, etc) which +/// will be passed the edge object and expected to return a ``float``. See +/// :func:`~minimum_spanning_tree` for details. +/// :param pops: The populations assigned to each node in the graph. +/// :param float pop_target: The population target to reach when partitioning +/// the graph. +/// :param float epsilon: The maximum percent deviation from the pop_target +/// allowed while still being a valid balanced cut edge. +/// +/// :returns: A list of tuples, with each tuple representing a distinct +/// balanced edge that can be cut. The tuple contains the root of one of the +/// two partitioned subtrees and the set of nodes making up that subtree. +#[pyfunction] +#[pyo3(text_signature = "(graph, weight_fn, pops, target_pop, epsilon)")] +pub fn bipartition_graph_mst( + py: Python, + graph: &graph::PyGraph, + weight_fn: PyObject, + pops: Vec, + pop_target: f64, + epsilon: f64, +) -> PyResult)>> { + let mut balanced_nodes: Vec<(usize, Vec)> = vec![]; + let mut mst = (*graph).clone(); + + while balanced_nodes.is_empty() { + mst.graph.clear_edges(); + _minimum_spanning_tree(py, graph, &mut mst, Some(weight_fn.clone()), 1.0)?; + balanced_nodes = _bipartition_tree(&mst, pops.clone(), pop_target, epsilon); + } + + Ok(balanced_nodes) } diff --git a/tests/rustworkx_tests/graph/test_bipartition.py b/tests/rustworkx_tests/graph/test_bipartition.py new file mode 100644 index 000000000..b43804d3c --- /dev/null +++ b/tests/rustworkx_tests/graph/test_bipartition.py @@ -0,0 +1,128 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest + +import rustworkx + + +class TestBipartition(unittest.TestCase): + def setUp(self): + self.line = rustworkx.PyGraph() + a = self.line.add_node(0) + b = self.line.add_node(1) + c = self.line.add_node(2) + d = self.line.add_node(3) + e = self.line.add_node(4) + f = self.line.add_node(5) + + self.line.add_edges_from( + [ + (a, b, 1), + (b, c, 1), + (c, d, 1), + (d, e, 1), + (e, f, 1), + ] + ) + + self.tree = rustworkx.PyGraph() + a = self.tree.add_node(0) + b = self.tree.add_node(1) + c = self.tree.add_node(2) + d = self.tree.add_node(3) + e = self.tree.add_node(4) + f = self.tree.add_node(5) + + self.tree.add_edges_from( + [ + (a, b, 1), + (a, d, 1), + (c, d, 1), + (a, f, 1), + (d, e, 1), + ] + ) + + def test_one_balanced_edge_tree(self): + balanced_edges = rustworkx.bipartition_tree( + self.tree, + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + 3.0, + 0.2, + ) + self.assertEqual(len(balanced_edges), 1) + + # Since this is already a spanning tree, bipartition_graph_mst should + # behave identically. That is, it should be invariant to weight_fn + graph_balanced_edges = rustworkx.bipartition_graph_mst( + self.tree, + lambda _: 1, + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + 3.0, + 0.2, + ) + self.assertEqual(balanced_edges, graph_balanced_edges) + + def test_one_balanced_edge_tree_alt(self): + balanced_edges = rustworkx.bipartition_tree( + self.tree, + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + 3.0, + 0.5, + ) + self.assertEqual(len(balanced_edges), 1) + + graph_balanced_edges = rustworkx.bipartition_graph_mst( + self.tree, + lambda _: 1, + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + 3.0, + 0.5, + ) + self.assertEqual(balanced_edges, graph_balanced_edges) + + def test_three_balanced_edges_line(self): + balanced_edges = rustworkx.bipartition_tree( + self.line, + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + 3.0, + 0.5, + ) + self.assertEqual(len(balanced_edges), 3) + + graph_balanced_edges = rustworkx.bipartition_graph_mst( + self.line, + lambda _: 1, + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + 3.0, + 0.5, + ) + self.assertEqual(balanced_edges, graph_balanced_edges) + + def test_one_balanced_edges_line(self): + balanced_edges = rustworkx.bipartition_tree( + self.line, + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + 3.0, + 0.01, + ) + self.assertEqual(len(balanced_edges), 1) + + graph_balanced_edges = rustworkx.bipartition_graph_mst( + self.line, + lambda _: 1, + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + 3.0, + 0.01, + ) + self.assertEqual(balanced_edges, graph_balanced_edges)