From 1cdf76a977811ff20bf623fa81a3cf0cff5d408d Mon Sep 17 00:00:00 2001 From: Max Fan Date: Tue, 11 Jan 2022 23:36:37 -0500 Subject: [PATCH 01/44] Draft bipartition_tree implementation --- src/lib.rs | 2 + src/tree.rs | 146 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 148 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index 023b9d059..de703e67f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -412,6 +412,8 @@ fn retworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(max_weight_matching))?; m.add_wrapped(wrap_pyfunction!(minimum_spanning_edges))?; m.add_wrapped(wrap_pyfunction!(minimum_spanning_tree))?; + m.add_wrapped(wrap_pyfunction!(balanced_cut_edge))?; + m.add_wrapped(wrap_pyfunction!(bipartition_tree))?; m.add_wrapped(wrap_pyfunction!(graph_transitivity))?; m.add_wrapped(wrap_pyfunction!(digraph_transitivity))?; m.add_wrapped(wrap_pyfunction!(graph_core_number))?; diff --git a/src/tree.rs b/src/tree.rs index 262cb53dc..4b23b6777 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -11,11 +11,13 @@ // under the License. use std::cmp::Ordering; +use std::collections::VecDeque; use super::{graph, weight_callable}; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; +use pyo3::types::{PyFloat, PyList}; use pyo3::Python; use petgraph::prelude::*; @@ -135,3 +137,147 @@ pub fn minimum_spanning_tree( Ok(spanning_tree) } + +/// Find balanced cut edge of the minmum spanning tree of a graph using node +/// contraction. Assumes that the tree is connected and is a spanning tree. +/// +/// :param PyGraph graph: Undirected graph +/// :param weight_fn: A callable object (function, lambda, etc) which +/// will be passed the edge object and expected to return a ``float``. This +/// tells retworkx/rust how to extract a numerical weight as a ``float`` +/// for edge object. Some simple examples are:: +/// +/// minimum_spanning_tree(graph, weight_fn: lambda x: 1) +/// +/// to return a weight of 1 for all edges. Also:: +/// +/// minimum_spanning_tree(graph, weight_fn: float) +/// +/// to cast the edge object as a float as the weight. +/// :param float default_weight: If ``weight_fn`` isn't specified this optional +/// float value will be used for the weight/cost of each edge. +/// +/// :returns: A set of nodes in one half of the spanning tree +/// +#[pyfunction] +#[pyo3(text_signature = "(spanning_tree, pop, target_pop, epsilon)")] +pub fn balanced_cut_edge( + _py: Python, + spanning_tree: &graph::PyGraph, + py_pops: &PyList, + py_pop_target: &PyFloat, + py_epsilon: &PyFloat, +) -> PyResult)>> { + let epsilon = py_epsilon.value(); + let pop_target = py_pop_target.value(); + + let mut pops: Vec = vec![]; // not sure if the conversions are needed + for i in 0..py_pops.len() { + pops.push(py_pops.get_item(i).unwrap().extract::().unwrap()); + } + + let mut node_queue: VecDeque = VecDeque::::new(); + for leaf_node in spanning_tree.graph.node_indices() { + // todo: filter expr + if spanning_tree.graph.neighbors(leaf_node).count() == 1 { + node_queue.push_back(leaf_node); + } + } + + // eprintln!("leaf nodes: {}", node_queue.len()); + + // this process can be multithreaded, if the locking overhead isn't too high + // (note: locking may not even be needed given the invariants this is assumed to maintain) + let mut balanced_nodes: Vec<(usize, Vec)> = vec![]; + let mut same_partition_tracker: Vec> = + vec![vec![]; spanning_tree.graph.node_count()]; // keeps track of all all the nodes on the same side of the partition + let mut seen_nodes: Vec = + vec![false; spanning_tree.graph.node_count()]; // todo: perf test this + while node_queue.len() > 0 { + let node = node_queue.pop_front().unwrap(); + let pop = pops[node.index()]; + + // todo: factor out expensive clones + // Mark as seen; push to queue if only one unseen neighbor + let unseen_neighbors: Vec = spanning_tree + .graph + .neighbors(node) + .filter(|node| !seen_nodes[node.index()]) + .collect(); + if unseen_neighbors.len() == 1 { + // this may be false if root + let neighbor = unseen_neighbors[0]; + pops[neighbor.index()] += pop.clone(); + same_partition_tracker[node.index()].push(node.index()); + same_partition_tracker[neighbor.index()] = + same_partition_tracker[node.index()].clone(); + // eprintln!("node pushed to queue (pop = {}): {}", pops[neighbor.index()], neighbor.index()); + + if !node_queue.contains(&neighbor) { + node_queue.push_back(neighbor); + } + } + pops[node.index()] = 0.0; + + // Check if balanced + if pop >= pop_target * (1.0 - epsilon) + && pop <= pop_target * (1.0 + epsilon) + { + // slightly different + // eprintln!("balanced node found: {}", node.index()); + balanced_nodes.push(( + node.index(), + same_partition_tracker[node.index()].clone(), + )); + } + + seen_nodes[node.index()] = true; + } + + Ok(balanced_nodes) +} + +/// Bipartition graph into two contiguous, population-balanced components. +/// Assumes that graph is contiguous. +/// +/// :param PyGraph graph: Undirected graph +/// :param weight_fn: A callable object (function, lambda, etc) which +/// will be passed the edge object and expected to return a ``float``. This +/// tells retworkx/rust how to extract a numerical weight as a ``float`` +/// for edge object. Some simple examples are:: +/// +/// minimum_spanning_tree(graph, weight_fn: lambda x: 1) +/// +/// to return a weight of 1 for all edges. Also:: +/// +/// minimum_spanning_tree(graph, weight_fn: float) +/// +/// to cast the edge object as a float as the weight. +/// :param float default_weight: If ``weight_fn`` isn't specified this optional +/// float value will be used for the weight/cost of each edge. +/// +/// :returns: A set of nodes in one half of the spanning tree +/// +#[pyfunction] +#[pyo3(text_signature = "(graph, weight_fn, pop, target_pop, epsilon)")] +pub fn bipartition_tree( + py: Python, + graph: &graph::PyGraph, + weight_fn: PyObject, + py_pops: &PyList, + py_pop_target: &PyFloat, + py_epsilon: &PyFloat, +) -> PyResult)>> { + let mut balanced_nodes: Vec<(usize, Vec)> = vec![]; + + while balanced_nodes.len() == 0 { + let mst = + minimum_spanning_tree(py, graph, Some(weight_fn.clone()), 1.0) + .unwrap(); + balanced_nodes = + balanced_cut_edge(py, &mst, py_pops, py_pop_target, py_epsilon) + .unwrap(); + } + + Ok(balanced_nodes) +} From 20281b0260aa0066f5387d910692cb508fa7c72a Mon Sep 17 00:00:00 2001 From: Max Fan Date: Wed, 12 Jan 2022 16:14:52 -0500 Subject: [PATCH 02/44] Working bipartition tree impl --- src/tree.rs | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index 4b23b6777..05c7c10e6 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -20,6 +20,7 @@ use pyo3::prelude::*; use pyo3::types::{PyFloat, PyList}; use pyo3::Python; +use petgraph::algo::{connected_components, is_cyclic_undirected}; use petgraph::prelude::*; use petgraph::stable_graph::EdgeReference; use petgraph::unionfind::UnionFind; @@ -176,12 +177,15 @@ pub fn balanced_cut_edge( pops.push(py_pops.get_item(i).unwrap().extract::().unwrap()); } + let mut same_partition_tracker: Vec> = + vec![vec![]; spanning_tree.graph.node_count()]; // keeps track of all all the nodes on the same side of the partition let mut node_queue: VecDeque = VecDeque::::new(); for leaf_node in spanning_tree.graph.node_indices() { // todo: filter expr if spanning_tree.graph.neighbors(leaf_node).count() == 1 { node_queue.push_back(leaf_node); } + same_partition_tracker[leaf_node.index()].push(leaf_node.index()); } // eprintln!("leaf nodes: {}", node_queue.len()); @@ -189,12 +193,15 @@ pub fn balanced_cut_edge( // this process can be multithreaded, if the locking overhead isn't too high // (note: locking may not even be needed given the invariants this is assumed to maintain) let mut balanced_nodes: Vec<(usize, Vec)> = vec![]; - let mut same_partition_tracker: Vec> = - vec![vec![]; spanning_tree.graph.node_count()]; // keeps track of all all the nodes on the same side of the partition let mut seen_nodes: Vec = vec![false; spanning_tree.graph.node_count()]; // todo: perf test this while node_queue.len() > 0 { let node = node_queue.pop_front().unwrap(); + if seen_nodes[node.index()] { + // should not need this + // eprintln!("Invalid state! Double vision . . ."); + continue; + } let pop = pops[node.index()]; // todo: factor out expensive clones @@ -204,20 +211,26 @@ pub fn balanced_cut_edge( .neighbors(node) .filter(|node| !seen_nodes[node.index()]) .collect(); + // eprintln!("unseen_neighbors: {}", unseen_neighbors.len()); if unseen_neighbors.len() == 1 { - // this may be false if root + // this will be false if root let neighbor = unseen_neighbors[0]; pops[neighbor.index()] += pop.clone(); - same_partition_tracker[node.index()].push(node.index()); - same_partition_tracker[neighbor.index()] = + let mut current_partition_tracker = same_partition_tracker[node.index()].clone(); - // eprintln!("node pushed to queue (pop = {}): {}", pops[neighbor.index()], neighbor.index()); + same_partition_tracker[neighbor.index()] + .append(&mut current_partition_tracker); + // eprintln!("node pushed to queue (pop = {}, target = {}): {}", pops[neighbor.index()], pop_target, neighbor.index()); if !node_queue.contains(&neighbor) { node_queue.push_back(neighbor); } + } else if unseen_neighbors.len() == 0 { + break; + } else { + continue; } - pops[node.index()] = 0.0; + // pops[node.index()] = 0.0; // not needed? // Check if balanced if pop >= pop_target * (1.0 - epsilon) @@ -274,6 +287,8 @@ pub fn bipartition_tree( let mst = minimum_spanning_tree(py, graph, Some(weight_fn.clone()), 1.0) .unwrap(); + // assert_eq!(is_cyclic_undirected(&mst.graph), false); + // assert_eq!(connected_components(&mst.graph), 1); balanced_nodes = balanced_cut_edge(py, &mst, py_pops, py_pop_target, py_epsilon) .unwrap(); From f85eb52632111dad04876708187609f771e5ce16 Mon Sep 17 00:00:00 2001 From: Max Fan Date: Thu, 13 Jan 2022 17:24:43 -0500 Subject: [PATCH 03/44] Release GIL during most of balanced_edge finding code --- src/tree.rs | 129 +++++++++++++++++++++++++++------------------------- 1 file changed, 67 insertions(+), 62 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index 05c7c10e6..fd6aa6b1d 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -163,7 +163,7 @@ pub fn minimum_spanning_tree( #[pyfunction] #[pyo3(text_signature = "(spanning_tree, pop, target_pop, epsilon)")] pub fn balanced_cut_edge( - _py: Python, + py: Python, spanning_tree: &graph::PyGraph, py_pops: &PyList, py_pop_target: &PyFloat, @@ -177,75 +177,80 @@ pub fn balanced_cut_edge( pops.push(py_pops.get_item(i).unwrap().extract::().unwrap()); } - let mut same_partition_tracker: Vec> = - vec![vec![]; spanning_tree.graph.node_count()]; // keeps track of all all the nodes on the same side of the partition - let mut node_queue: VecDeque = VecDeque::::new(); - for leaf_node in spanning_tree.graph.node_indices() { - // todo: filter expr - if spanning_tree.graph.neighbors(leaf_node).count() == 1 { - node_queue.push_back(leaf_node); + let spanning_tree_graph = &spanning_tree.graph; + + let balanced_nodes = py.allow_threads( move || { + let mut same_partition_tracker: Vec> = + vec![vec![]; spanning_tree_graph.node_count()]; // keeps track of all all the nodes on the same side of the partition + let mut node_queue: VecDeque = VecDeque::::new(); + for leaf_node in spanning_tree_graph.node_indices() { + // todo: filter expr + if spanning_tree_graph.neighbors(leaf_node).count() == 1 { + node_queue.push_back(leaf_node); + } + same_partition_tracker[leaf_node.index()].push(leaf_node.index()); } - same_partition_tracker[leaf_node.index()].push(leaf_node.index()); - } - // eprintln!("leaf nodes: {}", node_queue.len()); + // eprintln!("leaf nodes: {}", node_queue.len()); - // this process can be multithreaded, if the locking overhead isn't too high - // (note: locking may not even be needed given the invariants this is assumed to maintain) - let mut balanced_nodes: Vec<(usize, Vec)> = vec![]; - let mut seen_nodes: Vec = - vec![false; spanning_tree.graph.node_count()]; // todo: perf test this - while node_queue.len() > 0 { - let node = node_queue.pop_front().unwrap(); - if seen_nodes[node.index()] { - // should not need this - // eprintln!("Invalid state! Double vision . . ."); - continue; - } - let pop = pops[node.index()]; + // this process can be multithreaded, if the locking overhead isn't too high + // (note: locking may not even be needed given the invariants this is assumed to maintain) + let mut balanced_nodes: Vec<(usize, Vec)> = vec![]; + let mut seen_nodes: Vec = + vec![false; spanning_tree_graph.node_count()]; // todo: perf test this + while node_queue.len() > 0 { + let node = node_queue.pop_front().unwrap(); + if seen_nodes[node.index()] { + // should not need this + // eprintln!("Invalid state! Double vision . . ."); + continue; + } + let pop = pops[node.index()]; - // todo: factor out expensive clones - // Mark as seen; push to queue if only one unseen neighbor - let unseen_neighbors: Vec = spanning_tree - .graph - .neighbors(node) - .filter(|node| !seen_nodes[node.index()]) - .collect(); - // eprintln!("unseen_neighbors: {}", unseen_neighbors.len()); - if unseen_neighbors.len() == 1 { - // this will be false if root - let neighbor = unseen_neighbors[0]; - pops[neighbor.index()] += pop.clone(); - let mut current_partition_tracker = - same_partition_tracker[node.index()].clone(); - same_partition_tracker[neighbor.index()] - .append(&mut current_partition_tracker); - // eprintln!("node pushed to queue (pop = {}, target = {}): {}", pops[neighbor.index()], pop_target, neighbor.index()); + // todo: factor out expensive clones + // Mark as seen; push to queue if only one unseen neighbor + let unseen_neighbors: Vec = spanning_tree + .graph + .neighbors(node) + .filter(|node| !seen_nodes[node.index()]) + .collect(); + // eprintln!("unseen_neighbors: {}", unseen_neighbors.len()); + if unseen_neighbors.len() == 1 { + // this will be false if root + let neighbor = unseen_neighbors[0]; + pops[neighbor.index()] += pop.clone(); + let mut current_partition_tracker = + same_partition_tracker[node.index()].clone(); + same_partition_tracker[neighbor.index()] + .append(&mut current_partition_tracker); + // eprintln!("node pushed to queue (pop = {}, target = {}): {}", pops[neighbor.index()], pop_target, neighbor.index()); - if !node_queue.contains(&neighbor) { - node_queue.push_back(neighbor); + if !node_queue.contains(&neighbor) { + node_queue.push_back(neighbor); + } + } else if unseen_neighbors.len() == 0 { + break; + } else { + continue; } - } else if unseen_neighbors.len() == 0 { - break; - } else { - continue; - } - // pops[node.index()] = 0.0; // not needed? + // pops[node.index()] = 0.0; // not needed? - // Check if balanced - if pop >= pop_target * (1.0 - epsilon) - && pop <= pop_target * (1.0 + epsilon) - { - // slightly different - // eprintln!("balanced node found: {}", node.index()); - balanced_nodes.push(( - node.index(), - same_partition_tracker[node.index()].clone(), - )); - } + // Check if balanced + if pop >= pop_target * (1.0 - epsilon) + && pop <= pop_target * (1.0 + epsilon) + { + // slightly different + // eprintln!("balanced node found: {}", node.index()); + balanced_nodes.push(( + node.index(), + same_partition_tracker[node.index()].clone(), + )); + } - seen_nodes[node.index()] = true; - } + seen_nodes[node.index()] = true; + } + balanced_nodes + }); Ok(balanced_nodes) } From 79b3d9414814f02697ad0ea8af03a21da57858dd Mon Sep 17 00:00:00 2001 From: Max Fan Date: Fri, 14 Jan 2022 23:04:51 -0500 Subject: [PATCH 04/44] Ensure that unused vars get gc'ed on each loop --- src/tree.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/tree.rs b/src/tree.rs index fd6aa6b1d..db4a5c8ce 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -289,6 +289,11 @@ pub fn bipartition_tree( let mut balanced_nodes: Vec<(usize, Vec)> = vec![]; while balanced_nodes.len() == 0 { + // Wee: https://pyo3.rs/v0.15.1/memory.html#gil-bound-memory + // (workaround to force objects to be gc'ed on each loop) + let pool = unsafe { py.new_pool() }; + let py = pool.python(); + let mst = minimum_spanning_tree(py, graph, Some(weight_fn.clone()), 1.0) .unwrap(); From d5d464b4a6690d0e66651daf86cc26ad156645d4 Mon Sep 17 00:00:00 2001 From: Max Fan Date: Sat, 19 Mar 2022 00:01:41 -0400 Subject: [PATCH 05/44] Lint with cargo fmt --- src/tree.rs | 30 +++++++++--------------------- 1 file changed, 9 insertions(+), 21 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index db4a5c8ce..bc47869ec 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -178,8 +178,8 @@ pub fn balanced_cut_edge( } let spanning_tree_graph = &spanning_tree.graph; - - let balanced_nodes = py.allow_threads( move || { + + let balanced_nodes = py.allow_threads(move || { let mut same_partition_tracker: Vec> = vec![vec![]; spanning_tree_graph.node_count()]; // keeps track of all all the nodes on the same side of the partition let mut node_queue: VecDeque = VecDeque::::new(); @@ -196,8 +196,7 @@ pub fn balanced_cut_edge( // this process can be multithreaded, if the locking overhead isn't too high // (note: locking may not even be needed given the invariants this is assumed to maintain) let mut balanced_nodes: Vec<(usize, Vec)> = vec![]; - let mut seen_nodes: Vec = - vec![false; spanning_tree_graph.node_count()]; // todo: perf test this + let mut seen_nodes: Vec = vec![false; spanning_tree_graph.node_count()]; // todo: perf test this while node_queue.len() > 0 { let node = node_queue.pop_front().unwrap(); if seen_nodes[node.index()] { @@ -219,10 +218,8 @@ pub fn balanced_cut_edge( // this will be false if root let neighbor = unseen_neighbors[0]; pops[neighbor.index()] += pop.clone(); - let mut current_partition_tracker = - same_partition_tracker[node.index()].clone(); - same_partition_tracker[neighbor.index()] - .append(&mut current_partition_tracker); + let mut current_partition_tracker = same_partition_tracker[node.index()].clone(); + same_partition_tracker[neighbor.index()].append(&mut current_partition_tracker); // eprintln!("node pushed to queue (pop = {}, target = {}): {}", pops[neighbor.index()], pop_target, neighbor.index()); if !node_queue.contains(&neighbor) { @@ -236,15 +233,10 @@ pub fn balanced_cut_edge( // pops[node.index()] = 0.0; // not needed? // Check if balanced - if pop >= pop_target * (1.0 - epsilon) - && pop <= pop_target * (1.0 + epsilon) - { + if pop >= pop_target * (1.0 - epsilon) && pop <= pop_target * (1.0 + epsilon) { // slightly different // eprintln!("balanced node found: {}", node.index()); - balanced_nodes.push(( - node.index(), - same_partition_tracker[node.index()].clone(), - )); + balanced_nodes.push((node.index(), same_partition_tracker[node.index()].clone())); } seen_nodes[node.index()] = true; @@ -294,14 +286,10 @@ pub fn bipartition_tree( let pool = unsafe { py.new_pool() }; let py = pool.python(); - let mst = - minimum_spanning_tree(py, graph, Some(weight_fn.clone()), 1.0) - .unwrap(); + let mst = minimum_spanning_tree(py, graph, Some(weight_fn.clone()), 1.0).unwrap(); // assert_eq!(is_cyclic_undirected(&mst.graph), false); // assert_eq!(connected_components(&mst.graph), 1); - balanced_nodes = - balanced_cut_edge(py, &mst, py_pops, py_pop_target, py_epsilon) - .unwrap(); + balanced_nodes = balanced_cut_edge(py, &mst, py_pops, py_pop_target, py_epsilon).unwrap(); } Ok(balanced_nodes) From e7409a5c942d82589a883fed63a076b6cc0349ce Mon Sep 17 00:00:00 2001 From: Max Fan Date: Mon, 16 May 2022 11:59:41 -0400 Subject: [PATCH 06/44] Update to using built-in retworkx macros Co-authored-by: Matthew Treinish --- src/tree.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index bc47869ec..69d31f463 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -274,9 +274,9 @@ pub fn bipartition_tree( py: Python, graph: &graph::PyGraph, weight_fn: PyObject, - py_pops: &PyList, - py_pop_target: &PyFloat, - py_epsilon: &PyFloat, + py_pops: Vec, + py_pop_target: f64, + py_epsilon: f64, ) -> PyResult)>> { let mut balanced_nodes: Vec<(usize, Vec)> = vec![]; From a45c147ee77f0b60a6de947798c11cdb131b400a Mon Sep 17 00:00:00 2001 From: Max Fan Date: Mon, 16 May 2022 12:38:47 -0400 Subject: [PATCH 07/44] Take advantage of built-in retworkx macros --- src/tree.rs | 137 ++++++++++++++++++++++++---------------------------- 1 file changed, 63 insertions(+), 74 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index 69d31f463..8b617e0b3 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -17,15 +17,15 @@ use super::{graph, weight_callable}; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; -use pyo3::types::{PyFloat, PyList}; use pyo3::Python; -use petgraph::algo::{connected_components, is_cyclic_undirected}; use petgraph::prelude::*; use petgraph::stable_graph::EdgeReference; use petgraph::unionfind::UnionFind; use petgraph::visit::{IntoEdgeReferences, NodeIndexable}; +use numpy::PyReadonlyArray1; + use rayon::prelude::*; use crate::iterators::WeightedEdgeList; @@ -165,84 +165,73 @@ pub fn minimum_spanning_tree( pub fn balanced_cut_edge( py: Python, spanning_tree: &graph::PyGraph, - py_pops: &PyList, - py_pop_target: &PyFloat, - py_epsilon: &PyFloat, + pops: Vec, + pop_target: f64, + epsilon: f64, ) -> PyResult)>> { - let epsilon = py_epsilon.value(); - let pop_target = py_pop_target.value(); - - let mut pops: Vec = vec![]; // not sure if the conversions are needed - for i in 0..py_pops.len() { - pops.push(py_pops.get_item(i).unwrap().extract::().unwrap()); - } - + let mut pops = pops.clone(); let spanning_tree_graph = &spanning_tree.graph; - - let balanced_nodes = py.allow_threads(move || { - let mut same_partition_tracker: Vec> = - vec![vec![]; spanning_tree_graph.node_count()]; // keeps track of all all the nodes on the same side of the partition - let mut node_queue: VecDeque = VecDeque::::new(); - for leaf_node in spanning_tree_graph.node_indices() { - // todo: filter expr - if spanning_tree_graph.neighbors(leaf_node).count() == 1 { - node_queue.push_back(leaf_node); - } - same_partition_tracker[leaf_node.index()].push(leaf_node.index()); + let mut same_partition_tracker: Vec> = + vec![vec![]; spanning_tree_graph.node_count()]; // keeps track of all all the nodes on the same side of the partition + let mut node_queue: VecDeque = VecDeque::::new(); + for leaf_node in spanning_tree_graph.node_indices() { + // todo: filter expr + if spanning_tree_graph.neighbors(leaf_node).count() == 1 { + node_queue.push_back(leaf_node); } + same_partition_tracker[leaf_node.index()].push(leaf_node.index()); + } - // eprintln!("leaf nodes: {}", node_queue.len()); - - // this process can be multithreaded, if the locking overhead isn't too high - // (note: locking may not even be needed given the invariants this is assumed to maintain) - let mut balanced_nodes: Vec<(usize, Vec)> = vec![]; - let mut seen_nodes: Vec = vec![false; spanning_tree_graph.node_count()]; // todo: perf test this - while node_queue.len() > 0 { - let node = node_queue.pop_front().unwrap(); - if seen_nodes[node.index()] { - // should not need this - // eprintln!("Invalid state! Double vision . . ."); - continue; - } - let pop = pops[node.index()]; + // eprintln!("leaf nodes: {}", node_queue.len()); - // todo: factor out expensive clones - // Mark as seen; push to queue if only one unseen neighbor - let unseen_neighbors: Vec = spanning_tree - .graph - .neighbors(node) - .filter(|node| !seen_nodes[node.index()]) - .collect(); - // eprintln!("unseen_neighbors: {}", unseen_neighbors.len()); - if unseen_neighbors.len() == 1 { - // this will be false if root - let neighbor = unseen_neighbors[0]; - pops[neighbor.index()] += pop.clone(); - let mut current_partition_tracker = same_partition_tracker[node.index()].clone(); - same_partition_tracker[neighbor.index()].append(&mut current_partition_tracker); - // eprintln!("node pushed to queue (pop = {}, target = {}): {}", pops[neighbor.index()], pop_target, neighbor.index()); + // this process can be multithreaded, if the locking overhead isn't too high + // (note: locking may not even be needed given the invariants this is assumed to maintain) + let mut balanced_nodes: Vec<(usize, Vec)> = vec![]; + let mut seen_nodes: Vec = vec![false; spanning_tree_graph.node_count()]; // todo: perf test this + while node_queue.len() > 0 { + let node = node_queue.pop_front().unwrap(); + if seen_nodes[node.index()] { + // should not need this + // eprintln!("Invalid state! Double vision . . ."); + continue; + } + let pop = pops[node.index()]; - if !node_queue.contains(&neighbor) { - node_queue.push_back(neighbor); - } - } else if unseen_neighbors.len() == 0 { - break; - } else { - continue; - } - // pops[node.index()] = 0.0; // not needed? + // todo: factor out expensive clones + // Mark as seen; push to queue if only one unseen neighbor + let unseen_neighbors: Vec = spanning_tree + .graph + .neighbors(node) + .filter(|node| !seen_nodes[node.index()]) + .collect(); + // eprintln!("unseen_neighbors: {}", unseen_neighbors.len()); + if unseen_neighbors.len() == 1 { + // this will be false if root + let neighbor = unseen_neighbors[0]; + pops[neighbor.index()] += pop.clone(); + let mut current_partition_tracker = same_partition_tracker[node.index()].clone(); + same_partition_tracker[neighbor.index()].append(&mut current_partition_tracker); + // eprintln!("node pushed to queue (pop = {}, target = {}): {}", pops[neighbor.index()], pop_target, neighbor.index()); - // Check if balanced - if pop >= pop_target * (1.0 - epsilon) && pop <= pop_target * (1.0 + epsilon) { - // slightly different - // eprintln!("balanced node found: {}", node.index()); - balanced_nodes.push((node.index(), same_partition_tracker[node.index()].clone())); + if !node_queue.contains(&neighbor) { + node_queue.push_back(neighbor); } + } else if unseen_neighbors.len() == 0 { + break; + } else { + continue; + } + // pops[node.index()] = 0.0; // not needed? - seen_nodes[node.index()] = true; + // Check if balanced + if pop >= pop_target * (1.0 - epsilon) && pop <= pop_target * (1.0 + epsilon) { + // slightly different + // eprintln!("balanced node found: {}", node.index()); + balanced_nodes.push((node.index(), same_partition_tracker[node.index()].clone())); } - balanced_nodes - }); + + seen_nodes[node.index()] = true; + } Ok(balanced_nodes) } @@ -274,9 +263,9 @@ pub fn bipartition_tree( py: Python, graph: &graph::PyGraph, weight_fn: PyObject, - py_pops: Vec, - py_pop_target: f64, - py_epsilon: f64, + pops: Vec, + pop_target: f64, + epsilon: f64, ) -> PyResult)>> { let mut balanced_nodes: Vec<(usize, Vec)> = vec![]; @@ -289,7 +278,7 @@ pub fn bipartition_tree( let mst = minimum_spanning_tree(py, graph, Some(weight_fn.clone()), 1.0).unwrap(); // assert_eq!(is_cyclic_undirected(&mst.graph), false); // assert_eq!(connected_components(&mst.graph), 1); - balanced_nodes = balanced_cut_edge(py, &mst, py_pops, py_pop_target, py_epsilon).unwrap(); + balanced_nodes = balanced_cut_edge(py, &mst, pops.clone(), pop_target, epsilon).unwrap(); } Ok(balanced_nodes) From 4d93f0ec07a8d9d17f017d2c3638a72cf30d7af1 Mon Sep 17 00:00:00 2001 From: Max Fan Date: Tue, 17 May 2022 16:23:53 -0400 Subject: [PATCH 08/44] Use mem::take to save on memory allocs --- src/tree.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index 8b617e0b3..08d12deb2 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -9,6 +9,8 @@ // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the // License for the specific language governing permissions and limitations // under the License. +// +use std::mem; use std::cmp::Ordering; use std::collections::VecDeque; @@ -209,8 +211,10 @@ pub fn balanced_cut_edge( // this will be false if root let neighbor = unseen_neighbors[0]; pops[neighbor.index()] += pop.clone(); - let mut current_partition_tracker = same_partition_tracker[node.index()].clone(); + // let mut current_partition_tracker = same_partition_tracker[node.index()].clone(); + let mut current_partition_tracker = mem::take(&mut same_partition_tracker[node.index()]); same_partition_tracker[neighbor.index()].append(&mut current_partition_tracker); + // same_partition_tracker[neighbor.index()].append(&mut current_partition_tracker); // eprintln!("node pushed to queue (pop = {}, target = {}): {}", pops[neighbor.index()], pop_target, neighbor.index()); if !node_queue.contains(&neighbor) { @@ -270,7 +274,7 @@ pub fn bipartition_tree( let mut balanced_nodes: Vec<(usize, Vec)> = vec![]; while balanced_nodes.len() == 0 { - // Wee: https://pyo3.rs/v0.15.1/memory.html#gil-bound-memory + // See: https://pyo3.rs/v0.15.1/memory.html#gil-bound-memory // (workaround to force objects to be gc'ed on each loop) let pool = unsafe { py.new_pool() }; let py = pool.python(); From f01201c6e534f7648101cb3ec600c063a1d077c8 Mon Sep 17 00:00:00 2001 From: Max Fan Date: Tue, 17 May 2022 17:07:57 -0400 Subject: [PATCH 09/44] Revert "Use mem::take to save on memory allocs" This reverts commit 46d0950b9f8d2222cc8ec44557f94ab22d62d164. --- src/tree.rs | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index 08d12deb2..8b617e0b3 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -9,8 +9,6 @@ // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the // License for the specific language governing permissions and limitations // under the License. -// -use std::mem; use std::cmp::Ordering; use std::collections::VecDeque; @@ -211,10 +209,8 @@ pub fn balanced_cut_edge( // this will be false if root let neighbor = unseen_neighbors[0]; pops[neighbor.index()] += pop.clone(); - // let mut current_partition_tracker = same_partition_tracker[node.index()].clone(); - let mut current_partition_tracker = mem::take(&mut same_partition_tracker[node.index()]); + let mut current_partition_tracker = same_partition_tracker[node.index()].clone(); same_partition_tracker[neighbor.index()].append(&mut current_partition_tracker); - // same_partition_tracker[neighbor.index()].append(&mut current_partition_tracker); // eprintln!("node pushed to queue (pop = {}, target = {}): {}", pops[neighbor.index()], pop_target, neighbor.index()); if !node_queue.contains(&neighbor) { @@ -274,7 +270,7 @@ pub fn bipartition_tree( let mut balanced_nodes: Vec<(usize, Vec)> = vec![]; while balanced_nodes.len() == 0 { - // See: https://pyo3.rs/v0.15.1/memory.html#gil-bound-memory + // Wee: https://pyo3.rs/v0.15.1/memory.html#gil-bound-memory // (workaround to force objects to be gc'ed on each loop) let pool = unsafe { py.new_pool() }; let py = pool.python(); From fdbeaafa42ba6f5032e227a58f08d499fc96f1a8 Mon Sep 17 00:00:00 2001 From: Max Fan Date: Tue, 17 May 2022 17:17:04 -0400 Subject: [PATCH 10/44] Fix unnecessary clones; address clippy warnings --- src/tree.rs | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index 8b617e0b3..7a1bb3c7e 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -24,8 +24,6 @@ use petgraph::stable_graph::EdgeReference; use petgraph::unionfind::UnionFind; use petgraph::visit::{IntoEdgeReferences, NodeIndexable}; -use numpy::PyReadonlyArray1; - use rayon::prelude::*; use crate::iterators::WeightedEdgeList; @@ -163,13 +161,13 @@ pub fn minimum_spanning_tree( #[pyfunction] #[pyo3(text_signature = "(spanning_tree, pop, target_pop, epsilon)")] pub fn balanced_cut_edge( - py: Python, + _py: Python, spanning_tree: &graph::PyGraph, pops: Vec, pop_target: f64, epsilon: f64, ) -> PyResult)>> { - let mut pops = pops.clone(); + let mut pops = pops; let spanning_tree_graph = &spanning_tree.graph; let mut same_partition_tracker: Vec> = vec![vec![]; spanning_tree_graph.node_count()]; // keeps track of all all the nodes on the same side of the partition @@ -188,7 +186,7 @@ pub fn balanced_cut_edge( // (note: locking may not even be needed given the invariants this is assumed to maintain) let mut balanced_nodes: Vec<(usize, Vec)> = vec![]; let mut seen_nodes: Vec = vec![false; spanning_tree_graph.node_count()]; // todo: perf test this - while node_queue.len() > 0 { + while !node_queue.is_empty() { let node = node_queue.pop_front().unwrap(); if seen_nodes[node.index()] { // should not need this @@ -208,7 +206,7 @@ pub fn balanced_cut_edge( if unseen_neighbors.len() == 1 { // this will be false if root let neighbor = unseen_neighbors[0]; - pops[neighbor.index()] += pop.clone(); + pops[neighbor.index()] += pop; let mut current_partition_tracker = same_partition_tracker[node.index()].clone(); same_partition_tracker[neighbor.index()].append(&mut current_partition_tracker); // eprintln!("node pushed to queue (pop = {}, target = {}): {}", pops[neighbor.index()], pop_target, neighbor.index()); @@ -216,7 +214,7 @@ pub fn balanced_cut_edge( if !node_queue.contains(&neighbor) { node_queue.push_back(neighbor); } - } else if unseen_neighbors.len() == 0 { + } else if unseen_neighbors.is_empty() { break; } else { continue; @@ -269,7 +267,7 @@ pub fn bipartition_tree( ) -> PyResult)>> { let mut balanced_nodes: Vec<(usize, Vec)> = vec![]; - while balanced_nodes.len() == 0 { + while balanced_nodes.is_empty() { // Wee: https://pyo3.rs/v0.15.1/memory.html#gil-bound-memory // (workaround to force objects to be gc'ed on each loop) let pool = unsafe { py.new_pool() }; From 517d697d14a785e13e61d4421b69091abcde7b2d Mon Sep 17 00:00:00 2001 From: Max Fan Date: Tue, 17 May 2022 17:37:22 -0400 Subject: [PATCH 11/44] Remove unnecessary dead/commented out code --- src/tree.rs | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index 7a1bb3c7e..c7311058b 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -180,19 +180,11 @@ pub fn balanced_cut_edge( same_partition_tracker[leaf_node.index()].push(leaf_node.index()); } - // eprintln!("leaf nodes: {}", node_queue.len()); - - // this process can be multithreaded, if the locking overhead isn't too high - // (note: locking may not even be needed given the invariants this is assumed to maintain) + // BFS search for balanced nodes let mut balanced_nodes: Vec<(usize, Vec)> = vec![]; let mut seen_nodes: Vec = vec![false; spanning_tree_graph.node_count()]; // todo: perf test this while !node_queue.is_empty() { let node = node_queue.pop_front().unwrap(); - if seen_nodes[node.index()] { - // should not need this - // eprintln!("Invalid state! Double vision . . ."); - continue; - } let pop = pops[node.index()]; // todo: factor out expensive clones @@ -202,14 +194,13 @@ pub fn balanced_cut_edge( .neighbors(node) .filter(|node| !seen_nodes[node.index()]) .collect(); - // eprintln!("unseen_neighbors: {}", unseen_neighbors.len()); + if unseen_neighbors.len() == 1 { // this will be false if root let neighbor = unseen_neighbors[0]; pops[neighbor.index()] += pop; let mut current_partition_tracker = same_partition_tracker[node.index()].clone(); same_partition_tracker[neighbor.index()].append(&mut current_partition_tracker); - // eprintln!("node pushed to queue (pop = {}, target = {}): {}", pops[neighbor.index()], pop_target, neighbor.index()); if !node_queue.contains(&neighbor) { node_queue.push_back(neighbor); @@ -219,12 +210,10 @@ pub fn balanced_cut_edge( } else { continue; } - // pops[node.index()] = 0.0; // not needed? // Check if balanced if pop >= pop_target * (1.0 - epsilon) && pop <= pop_target * (1.0 + epsilon) { // slightly different - // eprintln!("balanced node found: {}", node.index()); balanced_nodes.push((node.index(), same_partition_tracker[node.index()].clone())); } @@ -268,14 +257,12 @@ pub fn bipartition_tree( let mut balanced_nodes: Vec<(usize, Vec)> = vec![]; while balanced_nodes.is_empty() { - // Wee: https://pyo3.rs/v0.15.1/memory.html#gil-bound-memory + // See: https://pyo3.rs/v0.15.1/memory.html#gil-bound-memory // (workaround to force objects to be gc'ed on each loop) let pool = unsafe { py.new_pool() }; let py = pool.python(); let mst = minimum_spanning_tree(py, graph, Some(weight_fn.clone()), 1.0).unwrap(); - // assert_eq!(is_cyclic_undirected(&mst.graph), false); - // assert_eq!(connected_components(&mst.graph), 1); balanced_nodes = balanced_cut_edge(py, &mst, pops.clone(), pop_target, epsilon).unwrap(); } From 2bd4c33a5955ce29d7550d90a04c184decc93428 Mon Sep 17 00:00:00 2001 From: Max Fan Date: Tue, 17 May 2022 21:30:23 -0400 Subject: [PATCH 12/44] Reduce number of memory allocs --- src/tree.rs | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index c7311058b..d8c5ff1f9 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -10,6 +10,8 @@ // License for the specific language governing permissions and limitations // under the License. +use std::mem; + use std::cmp::Ordering; use std::collections::VecDeque; @@ -127,6 +129,18 @@ pub fn minimum_spanning_tree( let mut spanning_tree = (*graph).clone(); spanning_tree.graph.clear_edges(); + _minimum_spanning_tree(py, graph, &mut spanning_tree, weight_fn, default_weight)?; + Ok(spanning_tree) +} + +/// Helper function to allow reuse of spanning_tree object to reduce memory allocs +fn _minimum_spanning_tree<'a>( + py: Python, + graph: &graph::PyGraph, + spanning_tree: &'a mut graph::PyGraph, + weight_fn: Option, + default_weight: f64, +) -> PyResult<&'a mut graph::PyGraph> { for edge in minimum_spanning_edges(py, graph, weight_fn, default_weight)? .edges .iter() @@ -255,14 +269,11 @@ pub fn bipartition_tree( epsilon: f64, ) -> PyResult)>> { let mut balanced_nodes: Vec<(usize, Vec)> = vec![]; + let mut mst = (*graph).clone(); while balanced_nodes.is_empty() { - // See: https://pyo3.rs/v0.15.1/memory.html#gil-bound-memory - // (workaround to force objects to be gc'ed on each loop) - let pool = unsafe { py.new_pool() }; - let py = pool.python(); - - let mst = minimum_spanning_tree(py, graph, Some(weight_fn.clone()), 1.0).unwrap(); + mst.graph.clear_edges(); + _minimum_spanning_tree(py, graph, &mut mst, Some(weight_fn.clone()), 1.0)?; balanced_nodes = balanced_cut_edge(py, &mst, pops.clone(), pop_target, epsilon).unwrap(); } From 658ffa5f7dbdd95e43629a418e212a538c052a00 Mon Sep 17 00:00:00 2001 From: Max Fan Date: Tue, 17 May 2022 21:34:58 -0400 Subject: [PATCH 13/44] Remove unused import --- src/tree.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index d8c5ff1f9..1c66e076a 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -10,8 +10,6 @@ // License for the specific language governing permissions and limitations // under the License. -use std::mem; - use std::cmp::Ordering; use std::collections::VecDeque; From a3606951fe3f8c8a574503407aa56e153a0222c2 Mon Sep 17 00:00:00 2001 From: Max Fan Date: Tue, 17 May 2022 21:37:13 -0400 Subject: [PATCH 14/44] Remove no longer relevant todos --- src/tree.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index 1c66e076a..8dd6ed833 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -183,9 +183,9 @@ pub fn balanced_cut_edge( let spanning_tree_graph = &spanning_tree.graph; let mut same_partition_tracker: Vec> = vec![vec![]; spanning_tree_graph.node_count()]; // keeps track of all all the nodes on the same side of the partition + let mut node_queue: VecDeque = VecDeque::::new(); for leaf_node in spanning_tree_graph.node_indices() { - // todo: filter expr if spanning_tree_graph.neighbors(leaf_node).count() == 1 { node_queue.push_back(leaf_node); } @@ -194,12 +194,11 @@ pub fn balanced_cut_edge( // BFS search for balanced nodes let mut balanced_nodes: Vec<(usize, Vec)> = vec![]; - let mut seen_nodes: Vec = vec![false; spanning_tree_graph.node_count()]; // todo: perf test this + let mut seen_nodes: Vec = vec![false; spanning_tree_graph.node_count()]; while !node_queue.is_empty() { let node = node_queue.pop_front().unwrap(); let pop = pops[node.index()]; - // todo: factor out expensive clones // Mark as seen; push to queue if only one unseen neighbor let unseen_neighbors: Vec = spanning_tree .graph From c5d76cb1644cfc6008676e9c9cd277613b0673a4 Mon Sep 17 00:00:00 2001 From: Max Fan Date: Tue, 17 May 2022 21:45:42 -0400 Subject: [PATCH 15/44] Switch to using HashSet for seen_nodes tracker --- src/tree.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index 8dd6ed833..0bb6d6f63 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -10,6 +10,7 @@ // License for the specific language governing permissions and limitations // under the License. +use std::collections::HashSet; use std::cmp::Ordering; use std::collections::VecDeque; @@ -194,7 +195,7 @@ pub fn balanced_cut_edge( // BFS search for balanced nodes let mut balanced_nodes: Vec<(usize, Vec)> = vec![]; - let mut seen_nodes: Vec = vec![false; spanning_tree_graph.node_count()]; + let mut seen_nodes = HashSet::with_capacity(spanning_tree_graph.node_count()); while !node_queue.is_empty() { let node = node_queue.pop_front().unwrap(); let pop = pops[node.index()]; @@ -203,7 +204,7 @@ pub fn balanced_cut_edge( let unseen_neighbors: Vec = spanning_tree .graph .neighbors(node) - .filter(|node| !seen_nodes[node.index()]) + .filter(|node| !seen_nodes.contains(&node.index())) .collect(); if unseen_neighbors.len() == 1 { @@ -228,7 +229,7 @@ pub fn balanced_cut_edge( balanced_nodes.push((node.index(), same_partition_tracker[node.index()].clone())); } - seen_nodes[node.index()] = true; + seen_nodes.insert(node.index()); } Ok(balanced_nodes) From 6643195ee39d181249efce45d2bdc4f206d49ed9 Mon Sep 17 00:00:00 2001 From: Max Fan Date: Tue, 17 May 2022 21:48:58 -0400 Subject: [PATCH 16/44] Touch up comments to make more sense --- src/tree.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index 0bb6d6f63..6bbf1c392 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -183,7 +183,7 @@ pub fn balanced_cut_edge( let mut pops = pops; let spanning_tree_graph = &spanning_tree.graph; let mut same_partition_tracker: Vec> = - vec![vec![]; spanning_tree_graph.node_count()]; // keeps track of all all the nodes on the same side of the partition + vec![vec![]; spanning_tree_graph.node_count()]; // Keeps track of all all the nodes on the same side of the partition let mut node_queue: VecDeque = VecDeque::::new(); for leaf_node in spanning_tree_graph.node_indices() { @@ -208,7 +208,7 @@ pub fn balanced_cut_edge( .collect(); if unseen_neighbors.len() == 1 { - // this will be false if root + // This will be false if root let neighbor = unseen_neighbors[0]; pops[neighbor.index()] += pop; let mut current_partition_tracker = same_partition_tracker[node.index()].clone(); @@ -217,15 +217,14 @@ pub fn balanced_cut_edge( if !node_queue.contains(&neighbor) { node_queue.push_back(neighbor); } - } else if unseen_neighbors.is_empty() { + } else if unseen_neighbors.is_empty() { // root break; - } else { + } else { // Not at the leaves of the unseen subgraph continue; } // Check if balanced if pop >= pop_target * (1.0 - epsilon) && pop <= pop_target * (1.0 + epsilon) { - // slightly different balanced_nodes.push((node.index(), same_partition_tracker[node.index()].clone())); } From 90d5ba339a497313e4086ee481890b337558886a Mon Sep 17 00:00:00 2001 From: Max Fan Date: Tue, 17 May 2022 21:50:47 -0400 Subject: [PATCH 17/44] Lint with cargo fmt --- src/tree.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index 6bbf1c392..d6c45e61f 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -10,8 +10,8 @@ // License for the specific language governing permissions and limitations // under the License. -use std::collections::HashSet; use std::cmp::Ordering; +use std::collections::HashSet; use std::collections::VecDeque; use super::{graph, weight_callable}; @@ -217,9 +217,11 @@ pub fn balanced_cut_edge( if !node_queue.contains(&neighbor) { node_queue.push_back(neighbor); } - } else if unseen_neighbors.is_empty() { // root + } else if unseen_neighbors.is_empty() { + // root break; - } else { // Not at the leaves of the unseen subgraph + } else { + // Not at the leaves of the unseen subgraph continue; } From b522daae3ae58ba6274cefa97183040e15f87b0c Mon Sep 17 00:00:00 2001 From: Max Fan Date: Mon, 23 May 2022 11:57:02 -0400 Subject: [PATCH 18/44] Rename functions to biparition_tree and bipartition_graph; update docs --- src/lib.rs | 2 +- src/tree.rs | 44 +++++++++++++++++++++----------------------- 2 files changed, 22 insertions(+), 24 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index de703e67f..f0e0dae8a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -412,8 +412,8 @@ fn retworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(max_weight_matching))?; m.add_wrapped(wrap_pyfunction!(minimum_spanning_edges))?; m.add_wrapped(wrap_pyfunction!(minimum_spanning_tree))?; - m.add_wrapped(wrap_pyfunction!(balanced_cut_edge))?; m.add_wrapped(wrap_pyfunction!(bipartition_tree))?; + m.add_wrapped(wrap_pyfunction!(bipartition_graph))?; m.add_wrapped(wrap_pyfunction!(graph_transitivity))?; m.add_wrapped(wrap_pyfunction!(digraph_transitivity))?; m.add_wrapped(wrap_pyfunction!(graph_core_number))?; diff --git a/src/tree.rs b/src/tree.rs index d6c45e61f..7e9b643ef 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -150,30 +150,23 @@ fn _minimum_spanning_tree<'a>( Ok(spanning_tree) } -/// Find balanced cut edge of the minmum spanning tree of a graph using node -/// contraction. Assumes that the tree is connected and is a spanning tree. +/// Find balanced cut edge of a spanning tree using node contraction. +/// Assumes that the tree is connected and is a spanning tree. /// -/// :param PyGraph graph: Undirected graph -/// :param weight_fn: A callable object (function, lambda, etc) which -/// will be passed the edge object and expected to return a ``float``. This -/// tells retworkx/rust how to extract a numerical weight as a ``float`` -/// for edge object. Some simple examples are:: +/// :param PyGraph graph: Spanning tree. Must be fully connected +/// :param pops: The populations assigned to each node in the graph. +/// :param float pop_target: The population target to reach when partitioning the +/// graph. +/// :param float epsilon: The maximum percent deviation from the pop_target +/// allowed while still being a valid balanced cut edge. /// -/// minimum_spanning_tree(graph, weight_fn: lambda x: 1) -/// -/// to return a weight of 1 for all edges. Also:: -/// -/// minimum_spanning_tree(graph, weight_fn: float) -/// -/// to cast the edge object as a float as the weight. -/// :param float default_weight: If ``weight_fn`` isn't specified this optional -/// float value will be used for the weight/cost of each edge. -/// -/// :returns: A set of nodes in one half of the spanning tree +/// :returns: A list of tuples, with each tuple representing a distinct +/// balanced edge that can be cut. The tuple contains the root of one of the +/// two partitioned subtrees and the set of nodes making up that subtree. /// #[pyfunction] #[pyo3(text_signature = "(spanning_tree, pop, target_pop, epsilon)")] -pub fn balanced_cut_edge( +pub fn bipartition_tree( _py: Python, spanning_tree: &graph::PyGraph, pops: Vec, @@ -252,14 +245,19 @@ pub fn balanced_cut_edge( /// minimum_spanning_tree(graph, weight_fn: float) /// /// to cast the edge object as a float as the weight. -/// :param float default_weight: If ``weight_fn`` isn't specified this optional -/// float value will be used for the weight/cost of each edge. +/// :param pops: The populations assigned to each node in the graph. +/// :param float pop_target: The population target to reach when partitioning the +/// graph. +/// :param float epsilon: The maximum percent deviation from the pop_target +/// allowed while still being a valid balanced cut edge. /// -/// :returns: A set of nodes in one half of the spanning tree +/// :returns: A list of tuples, with each tuple representing a distinct +/// balanced edge that can be cut. The tuple contains the root of one of the +/// two partitioned subtrees and the set of nodes making up that subtree. /// #[pyfunction] #[pyo3(text_signature = "(graph, weight_fn, pop, target_pop, epsilon)")] -pub fn bipartition_tree( +pub fn bipartition_graph( py: Python, graph: &graph::PyGraph, weight_fn: PyObject, From 1e6290de870029798a6c6998d548ab5e094dfa04 Mon Sep 17 00:00:00 2001 From: Max Fan Date: Mon, 23 May 2022 12:06:56 -0400 Subject: [PATCH 19/44] Make docstrings more descriptive and fix renaming issue --- src/tree.rs | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index 7e9b643ef..f652eb933 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -150,8 +150,13 @@ fn _minimum_spanning_tree<'a>( Ok(spanning_tree) } -/// Find balanced cut edge of a spanning tree using node contraction. -/// Assumes that the tree is connected and is a spanning tree. +/// Bipartition tree by finding balanced cut edges of a spanning tree using +/// node contraction. Assumes that the tree is connected and is a spanning tree. +/// A balanced edge is defined as an edge that, when cut, will split the +/// population of the tree into two connected subtrees that have population near +/// the population target within some epsilon. The function returns a list of +/// all such possible cuts, represented as the set of nodes in one +/// partition/subtree. /// /// :param PyGraph graph: Spanning tree. Must be fully connected /// :param pops: The populations assigned to each node in the graph. @@ -230,7 +235,8 @@ pub fn bipartition_tree( } /// Bipartition graph into two contiguous, population-balanced components. -/// Assumes that graph is contiguous. +/// Assumes that the graph is contiguous. See :func:`~bipartition_tree` for +/// details on how balance is defined. /// /// :param PyGraph graph: Undirected graph /// :param weight_fn: A callable object (function, lambda, etc) which @@ -271,7 +277,7 @@ pub fn bipartition_graph( while balanced_nodes.is_empty() { mst.graph.clear_edges(); _minimum_spanning_tree(py, graph, &mut mst, Some(weight_fn.clone()), 1.0)?; - balanced_nodes = balanced_cut_edge(py, &mst, pops.clone(), pop_target, epsilon).unwrap(); + balanced_nodes = bipartition_tree(py, &mst, pops.clone(), pop_target, epsilon).unwrap(); } Ok(balanced_nodes) From c28335543e495ccec9b9c183774ea85f7d733096 Mon Sep 17 00:00:00 2001 From: Max Fan Date: Mon, 23 May 2022 12:08:27 -0400 Subject: [PATCH 20/44] Add release notes for bipartition_tree and bipartition_graph --- .../bipartition_graph-ccb2204bc7b6c407.yaml | 7 ++++++ .../bipartition_tree-4c1ad080b1fab9e8.yaml | 22 +++++++++++++++++++ 2 files changed, 29 insertions(+) create mode 100644 releasenotes/notes/bipartition_graph-ccb2204bc7b6c407.yaml create mode 100644 releasenotes/notes/bipartition_tree-4c1ad080b1fab9e8.yaml diff --git a/releasenotes/notes/bipartition_graph-ccb2204bc7b6c407.yaml b/releasenotes/notes/bipartition_graph-ccb2204bc7b6c407.yaml new file mode 100644 index 000000000..b7456dd27 --- /dev/null +++ b/releasenotes/notes/bipartition_graph-ccb2204bc7b6c407.yaml @@ -0,0 +1,7 @@ +--- +features: + - | + Added a new function :func:`~.bipartition_graph` that takes in a connected + graph and tries to draw a minimum spanning tree and find a balanced cut + edge to target using :func:`~.bipartition_tree`. If such a corresponding + tree and edge cannnot be found, then it retries. diff --git a/releasenotes/notes/bipartition_tree-4c1ad080b1fab9e8.yaml b/releasenotes/notes/bipartition_tree-4c1ad080b1fab9e8.yaml new file mode 100644 index 000000000..74953bec5 --- /dev/null +++ b/releasenotes/notes/bipartition_tree-4c1ad080b1fab9e8.yaml @@ -0,0 +1,22 @@ +--- +features: + - | + Added a new function :func:`~.bipartition_tree` that takes in spanning tree + and a list of populations assigned to each node in the tree and finds all + balanced edges, if they exist. A balanced edge is defined as an edge that, + when cut, will split the population of the tree into two connected subtrees + that have population near the population target within some epsilon. The + function returns a list of all such possible cuts, represented as the set + of nodes in one partition/subtree. For example:: + + balanced_node_choices = retworkx.bipartition_tree( + tree, + pops, + float(pop_target), + float(epsilon) + ) + + returns a list of tuples, with each tuple representing a distinct balanced + edge that can be cut. The tuple contains the root of one of the two + partitioned subtrees and the set of nodes making up that subtree. The other + partition can be recovered by computing the complement of the set of nodes. From cbaed096a3629529ed1abe12a7322c3646d6c995 Mon Sep 17 00:00:00 2001 From: Max Fan Date: Mon, 23 May 2022 12:10:59 -0400 Subject: [PATCH 21/44] Add bipartition_graph and bipartition_tree to API section of docs --- docs/source/api.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/api.rst b/docs/source/api.rst index 310f1b174..8f3ee5766 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -103,6 +103,7 @@ Tree retworkx.minimum_spanning_edges retworkx.minimum_spanning_tree retworkx.steiner_tree + retworkx.bipartition_tree .. _isomorphism: @@ -178,6 +179,7 @@ Other Algorithm Functions retworkx.core_number retworkx.graph_greedy_color retworkx.metric_closure + retworkx.bipartition_graph .. _generator_funcs: From 41e503abfdacbc95f3de2eaf03731e262273f102 Mon Sep 17 00:00:00 2001 From: Max Fan Date: Mon, 23 May 2022 12:21:22 -0400 Subject: [PATCH 22/44] Shorten description of weight_fn by pointing to minimum_spanning_tree func --- src/tree.rs | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index f652eb933..48e9bff8f 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -240,17 +240,8 @@ pub fn bipartition_tree( /// /// :param PyGraph graph: Undirected graph /// :param weight_fn: A callable object (function, lambda, etc) which -/// will be passed the edge object and expected to return a ``float``. This -/// tells retworkx/rust how to extract a numerical weight as a ``float`` -/// for edge object. Some simple examples are:: -/// -/// minimum_spanning_tree(graph, weight_fn: lambda x: 1) -/// -/// to return a weight of 1 for all edges. Also:: -/// -/// minimum_spanning_tree(graph, weight_fn: float) -/// -/// to cast the edge object as a float as the weight. +/// will be passed the edge object and expected to return a ``float``. See +/// :func:`~minimum_spanning_tree` for details. /// :param pops: The populations assigned to each node in the graph. /// :param float pop_target: The population target to reach when partitioning the /// graph. From da1e8a7e75db503bc6b83f2beb960c08e0e66cd5 Mon Sep 17 00:00:00 2001 From: Max Fan Date: Tue, 24 May 2022 16:34:49 -0400 Subject: [PATCH 23/44] Fix end with blank line linting issue --- src/tree.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index 48e9bff8f..d5435ece3 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -168,7 +168,6 @@ fn _minimum_spanning_tree<'a>( /// :returns: A list of tuples, with each tuple representing a distinct /// balanced edge that can be cut. The tuple contains the root of one of the /// two partitioned subtrees and the set of nodes making up that subtree. -/// #[pyfunction] #[pyo3(text_signature = "(spanning_tree, pop, target_pop, epsilon)")] pub fn bipartition_tree( @@ -251,7 +250,6 @@ pub fn bipartition_tree( /// :returns: A list of tuples, with each tuple representing a distinct /// balanced edge that can be cut. The tuple contains the root of one of the /// two partitioned subtrees and the set of nodes making up that subtree. -/// #[pyfunction] #[pyo3(text_signature = "(graph, weight_fn, pop, target_pop, epsilon)")] pub fn bipartition_graph( From 99e52571edf27506198e5cb5a7bf876f448a2e18 Mon Sep 17 00:00:00 2001 From: Max Fan Date: Tue, 24 May 2022 17:43:36 -0400 Subject: [PATCH 24/44] Add bipartition tests --- tests/graph/test_bipartition.py | 129 ++++++++++++++++++++++++++++++++ 1 file changed, 129 insertions(+) create mode 100644 tests/graph/test_bipartition.py diff --git a/tests/graph/test_bipartition.py b/tests/graph/test_bipartition.py new file mode 100644 index 000000000..7a35d8fff --- /dev/null +++ b/tests/graph/test_bipartition.py @@ -0,0 +1,129 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest +import random + +import retworkx + + +class TestBipartition(unittest.TestCase): + def setUp(self): + self.line = retworkx.PyGraph() + a = self.line.add_node(0) + b = self.line.add_node(1) + c = self.line.add_node(2) + d = self.line.add_node(3) + e = self.line.add_node(4) + f = self.line.add_node(5) + + self.line.add_edges_from( + [ + (a, b, 1), + (b, c, 1), + (c, d, 1), + (d, e, 1), + (e, f, 1), + ] + ) + + self.tree = retworkx.PyGraph() + a = self.tree.add_node(0) + b = self.tree.add_node(1) + c = self.tree.add_node(2) + d = self.tree.add_node(3) + e = self.tree.add_node(4) + f = self.tree.add_node(5) + + self.tree.add_edges_from( + [ + (a, b, 1), + (a, d, 1), + (c, d, 1), + (a, f, 1), + (d, e, 1), + ] + ) + + def test_one_balanced_edge_tree(self): + balanced_edges = retworkx.bipartition_tree( + self.tree, + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + 3.0, + 0.2, + ) + self.assertEqual(len(balanced_edges), 1) + + # Since this is already a spanning tree, bipartition_graph should + # behave identically. That is, it should be invariant to weight_fn + graph_balanced_edges = retworkx.bipartition_graph( + self.tree, + lambda x: random.random(), + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + 3.0, + 0.2, + ) + self.assertEqual(balanced_edges, graph_balanced_edges) + + def test_two_balanced_edges_tree(self): + balanced_edges = retworkx.bipartition_tree( + self.tree, + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + 3.0, + 0.5, + ) + self.assertEqual(len(balanced_edges), 1) + + graph_balanced_edges = retworkx.bipartition_graph( + self.tree, + lambda x: random.random(), + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + 3.0, + 0.5, + ) + self.assertEqual(balanced_edges, graph_balanced_edges) + + def test_three_balanced_edges_line(self): + balanced_edges = retworkx.bipartition_tree( + self.line, + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + 3.0, + 0.5, + ) + self.assertEqual(len(balanced_edges), 3) + + graph_balanced_edges = retworkx.bipartition_graph( + self.line, + lambda x: random.random(), + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + 3.0, + 0.5, + ) + self.assertEqual(balanced_edges, graph_balanced_edges) + + def test_one_balanced_edges_line(self): + balanced_edges = retworkx.bipartition_tree( + self.line, + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + 3.0, + 0.01, + ) + self.assertEqual(len(balanced_edges), 1) + + graph_balanced_edges = retworkx.bipartition_graph( + self.line, + lambda x: random.random(), + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + 3.0, + 0.01, + ) + self.assertEqual(balanced_edges, graph_balanced_edges) From db879244168a9a318b0d9c6ef4e4e3b81219be5f Mon Sep 17 00:00:00 2001 From: Max Fan Date: Tue, 24 May 2022 17:47:29 -0400 Subject: [PATCH 25/44] Fix indent issues in retworkx bipartition docstrings --- src/tree.rs | 12 ++++++------ tests/graph/test_bipartition.py | 18 +++++++++--------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index d5435ece3..c4f08da76 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -166,8 +166,8 @@ fn _minimum_spanning_tree<'a>( /// allowed while still being a valid balanced cut edge. /// /// :returns: A list of tuples, with each tuple representing a distinct -/// balanced edge that can be cut. The tuple contains the root of one of the -/// two partitioned subtrees and the set of nodes making up that subtree. +/// balanced edge that can be cut. The tuple contains the root of one of the +/// two partitioned subtrees and the set of nodes making up that subtree. #[pyfunction] #[pyo3(text_signature = "(spanning_tree, pop, target_pop, epsilon)")] pub fn bipartition_tree( @@ -242,14 +242,14 @@ pub fn bipartition_tree( /// will be passed the edge object and expected to return a ``float``. See /// :func:`~minimum_spanning_tree` for details. /// :param pops: The populations assigned to each node in the graph. -/// :param float pop_target: The population target to reach when partitioning the -/// graph. +/// :param float pop_target: The population target to reach when partitioning +/// the graph. /// :param float epsilon: The maximum percent deviation from the pop_target /// allowed while still being a valid balanced cut edge. /// /// :returns: A list of tuples, with each tuple representing a distinct -/// balanced edge that can be cut. The tuple contains the root of one of the -/// two partitioned subtrees and the set of nodes making up that subtree. +/// balanced edge that can be cut. The tuple contains the root of one of the +/// two partitioned subtrees and the set of nodes making up that subtree. #[pyfunction] #[pyo3(text_signature = "(graph, weight_fn, pop, target_pop, epsilon)")] pub fn bipartition_graph( diff --git a/tests/graph/test_bipartition.py b/tests/graph/test_bipartition.py index 7a35d8fff..16132ba4e 100644 --- a/tests/graph/test_bipartition.py +++ b/tests/graph/test_bipartition.py @@ -56,17 +56,17 @@ def setUp(self): def test_one_balanced_edge_tree(self): balanced_edges = retworkx.bipartition_tree( - self.tree, + self.tree, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 3.0, 0.2, ) self.assertEqual(len(balanced_edges), 1) - # Since this is already a spanning tree, bipartition_graph should + # Since this is already a spanning tree, bipartition_graph should # behave identically. That is, it should be invariant to weight_fn graph_balanced_edges = retworkx.bipartition_graph( - self.tree, + self.tree, lambda x: random.random(), [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 3.0, @@ -76,7 +76,7 @@ def test_one_balanced_edge_tree(self): def test_two_balanced_edges_tree(self): balanced_edges = retworkx.bipartition_tree( - self.tree, + self.tree, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 3.0, 0.5, @@ -84,7 +84,7 @@ def test_two_balanced_edges_tree(self): self.assertEqual(len(balanced_edges), 1) graph_balanced_edges = retworkx.bipartition_graph( - self.tree, + self.tree, lambda x: random.random(), [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 3.0, @@ -94,7 +94,7 @@ def test_two_balanced_edges_tree(self): def test_three_balanced_edges_line(self): balanced_edges = retworkx.bipartition_tree( - self.line, + self.line, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 3.0, 0.5, @@ -102,7 +102,7 @@ def test_three_balanced_edges_line(self): self.assertEqual(len(balanced_edges), 3) graph_balanced_edges = retworkx.bipartition_graph( - self.line, + self.line, lambda x: random.random(), [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 3.0, @@ -112,7 +112,7 @@ def test_three_balanced_edges_line(self): def test_one_balanced_edges_line(self): balanced_edges = retworkx.bipartition_tree( - self.line, + self.line, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 3.0, 0.01, @@ -120,7 +120,7 @@ def test_one_balanced_edges_line(self): self.assertEqual(len(balanced_edges), 1) graph_balanced_edges = retworkx.bipartition_graph( - self.line, + self.line, lambda x: random.random(), [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 3.0, From f6d807c38f9b6e7c328be45775537644347a69b8 Mon Sep 17 00:00:00 2001 From: Max Fan Date: Fri, 27 May 2022 12:32:27 -0400 Subject: [PATCH 26/44] Switch to using hashbrown's HashSet impl Co-authored-by: Matthew Treinish --- src/tree.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tree.rs b/src/tree.rs index c4f08da76..acf4c7c2a 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -11,7 +11,7 @@ // under the License. use std::cmp::Ordering; -use std::collections::HashSet; +use hashbrown::HashSet; use std::collections::VecDeque; use super::{graph, weight_callable}; From 46404f39645f2b06396de80e99a7f33909b1679c Mon Sep 17 00:00:00 2001 From: Max Fan Date: Fri, 27 May 2022 12:36:41 -0400 Subject: [PATCH 27/44] Make tests deterministic --- tests/graph/test_bipartition.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/graph/test_bipartition.py b/tests/graph/test_bipartition.py index 16132ba4e..d8103ec4c 100644 --- a/tests/graph/test_bipartition.py +++ b/tests/graph/test_bipartition.py @@ -11,7 +11,6 @@ # under the License. import unittest -import random import retworkx @@ -67,7 +66,7 @@ def test_one_balanced_edge_tree(self): # behave identically. That is, it should be invariant to weight_fn graph_balanced_edges = retworkx.bipartition_graph( self.tree, - lambda x: random.random(), + lambda _: 1, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 3.0, 0.2, @@ -85,7 +84,7 @@ def test_two_balanced_edges_tree(self): graph_balanced_edges = retworkx.bipartition_graph( self.tree, - lambda x: random.random(), + lambda _: 1, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 3.0, 0.5, @@ -103,7 +102,7 @@ def test_three_balanced_edges_line(self): graph_balanced_edges = retworkx.bipartition_graph( self.line, - lambda x: random.random(), + lambda _: 1, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 3.0, 0.5, @@ -121,7 +120,7 @@ def test_one_balanced_edges_line(self): graph_balanced_edges = retworkx.bipartition_graph( self.line, - lambda x: random.random(), + lambda _: 1, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 3.0, 0.01, From 1ea5cd1b31dcb367949af83e30efdef3f64b8fd5 Mon Sep 17 00:00:00 2001 From: Max Fan Date: Fri, 27 May 2022 12:37:56 -0400 Subject: [PATCH 28/44] Reorder imports as per cargo fmt --- src/tree.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tree.rs b/src/tree.rs index acf4c7c2a..4d1949f8f 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -10,8 +10,8 @@ // License for the specific language governing permissions and limitations // under the License. -use std::cmp::Ordering; use hashbrown::HashSet; +use std::cmp::Ordering; use std::collections::VecDeque; use super::{graph, weight_callable}; From 79b5e1715b900d3f1d4dfa41705bfb55e3a47ac1 Mon Sep 17 00:00:00 2001 From: Max Fan Date: Fri, 27 May 2022 12:40:55 -0400 Subject: [PATCH 29/44] Wrap in rst Python code block --- .../notes/bipartition_tree-4c1ad080b1fab9e8.yaml | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/releasenotes/notes/bipartition_tree-4c1ad080b1fab9e8.yaml b/releasenotes/notes/bipartition_tree-4c1ad080b1fab9e8.yaml index 74953bec5..0abc0b404 100644 --- a/releasenotes/notes/bipartition_tree-4c1ad080b1fab9e8.yaml +++ b/releasenotes/notes/bipartition_tree-4c1ad080b1fab9e8.yaml @@ -7,14 +7,16 @@ features: when cut, will split the population of the tree into two connected subtrees that have population near the population target within some epsilon. The function returns a list of all such possible cuts, represented as the set - of nodes in one partition/subtree. For example:: + of nodes in one partition/subtree. For example, - balanced_node_choices = retworkx.bipartition_tree( - tree, - pops, - float(pop_target), - float(epsilon) - ) + .. code-block:: python + + balanced_node_choices = retworkx.bipartition_tree( + tree, + pops, + float(pop_target), + float(epsilon) + ) returns a list of tuples, with each tuple representing a distinct balanced edge that can be cut. The tuple contains the root of one of the two From 4e00b9b91980e13300d609e5b381ec9d37ab6752 Mon Sep 17 00:00:00 2001 From: Max Fan Date: Fri, 27 May 2022 12:56:06 -0400 Subject: [PATCH 30/44] Switch to passing by value for mst --- src/tree.rs | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index 4d1949f8f..2d115e803 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -125,21 +125,20 @@ pub fn minimum_spanning_tree( weight_fn: Option, default_weight: f64, ) -> PyResult { - let mut spanning_tree = (*graph).clone(); + let mut spanning_tree: graph::PyGraph = graph.clone(); spanning_tree.graph.clear_edges(); - _minimum_spanning_tree(py, graph, &mut spanning_tree, weight_fn, default_weight)?; - Ok(spanning_tree) + _minimum_spanning_tree(py, graph, spanning_tree, weight_fn, default_weight) } /// Helper function to allow reuse of spanning_tree object to reduce memory allocs -fn _minimum_spanning_tree<'a>( +fn _minimum_spanning_tree( py: Python, graph: &graph::PyGraph, - spanning_tree: &'a mut graph::PyGraph, + mut spanning_tree: graph::PyGraph, weight_fn: Option, default_weight: f64, -) -> PyResult<&'a mut graph::PyGraph> { +) -> PyResult { for edge in minimum_spanning_edges(py, graph, weight_fn, default_weight)? .edges .iter() @@ -172,11 +171,11 @@ fn _minimum_spanning_tree<'a>( #[pyo3(text_signature = "(spanning_tree, pop, target_pop, epsilon)")] pub fn bipartition_tree( _py: Python, - spanning_tree: &graph::PyGraph, + spanning_tree: graph::PyGraph, pops: Vec, pop_target: f64, epsilon: f64, -) -> PyResult)>> { +) -> PyResult<(graph::PyGraph, Vec<(usize, Vec)>)> { let mut pops = pops; let spanning_tree_graph = &spanning_tree.graph; let mut same_partition_tracker: Vec> = @@ -230,7 +229,7 @@ pub fn bipartition_tree( seen_nodes.insert(node.index()); } - Ok(balanced_nodes) + Ok((spanning_tree, balanced_nodes)) } /// Bipartition graph into two contiguous, population-balanced components. @@ -261,12 +260,13 @@ pub fn bipartition_graph( epsilon: f64, ) -> PyResult)>> { let mut balanced_nodes: Vec<(usize, Vec)> = vec![]; - let mut mst = (*graph).clone(); + let mut mst: graph::PyGraph = graph.clone(); while balanced_nodes.is_empty() { mst.graph.clear_edges(); - _minimum_spanning_tree(py, graph, &mut mst, Some(weight_fn.clone()), 1.0)?; - balanced_nodes = bipartition_tree(py, &mst, pops.clone(), pop_target, epsilon).unwrap(); + mst = _minimum_spanning_tree(py, graph, mst, Some(weight_fn.clone()), 1.0)?; + (mst, balanced_nodes) = + bipartition_tree(py, mst, pops.clone(), pop_target, epsilon).unwrap(); } Ok(balanced_nodes) From 5ba7cc906cd3455c99261414272300b657b73681 Mon Sep 17 00:00:00 2001 From: Max Fan Date: Fri, 27 May 2022 13:10:23 -0400 Subject: [PATCH 31/44] Make test name more accurate --- tests/graph/test_bipartition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/graph/test_bipartition.py b/tests/graph/test_bipartition.py index d8103ec4c..c3a7ffc62 100644 --- a/tests/graph/test_bipartition.py +++ b/tests/graph/test_bipartition.py @@ -73,7 +73,7 @@ def test_one_balanced_edge_tree(self): ) self.assertEqual(balanced_edges, graph_balanced_edges) - def test_two_balanced_edges_tree(self): + def test_one_balanced_edge_tree_alt(self): balanced_edges = retworkx.bipartition_tree( self.tree, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], From 36857206cc14a8b64f297a833670ef7c1ecec44e Mon Sep 17 00:00:00 2001 From: Max Fan Date: Fri, 27 May 2022 13:28:50 -0400 Subject: [PATCH 32/44] Handle holes in graph node indices --- src/tree.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tree.rs b/src/tree.rs index 2d115e803..6dc148e59 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -179,7 +179,7 @@ pub fn bipartition_tree( let mut pops = pops; let spanning_tree_graph = &spanning_tree.graph; let mut same_partition_tracker: Vec> = - vec![vec![]; spanning_tree_graph.node_count()]; // Keeps track of all all the nodes on the same side of the partition + vec![vec![]; spanning_tree_graph.node_bound()]; // Keeps track of all all the nodes on the same side of the partition let mut node_queue: VecDeque = VecDeque::::new(); for leaf_node in spanning_tree_graph.node_indices() { From 8e513f7f62e3cf1b139f9e81b13df0db85729ae1 Mon Sep 17 00:00:00 2001 From: Max Fan Date: Fri, 27 May 2022 13:45:06 -0400 Subject: [PATCH 33/44] Revert "Switch to passing by value for mst" This reverts commit 8e4ffdfcc94599df560195c6d1179dd8970489ab. --- src/tree.rs | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index 6dc148e59..56cc9c0d5 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -125,20 +125,21 @@ pub fn minimum_spanning_tree( weight_fn: Option, default_weight: f64, ) -> PyResult { - let mut spanning_tree: graph::PyGraph = graph.clone(); + let mut spanning_tree = (*graph).clone(); spanning_tree.graph.clear_edges(); - _minimum_spanning_tree(py, graph, spanning_tree, weight_fn, default_weight) + _minimum_spanning_tree(py, graph, &mut spanning_tree, weight_fn, default_weight)?; + Ok(spanning_tree) } /// Helper function to allow reuse of spanning_tree object to reduce memory allocs -fn _minimum_spanning_tree( +fn _minimum_spanning_tree<'a>( py: Python, graph: &graph::PyGraph, - mut spanning_tree: graph::PyGraph, + spanning_tree: &'a mut graph::PyGraph, weight_fn: Option, default_weight: f64, -) -> PyResult { +) -> PyResult<&'a mut graph::PyGraph> { for edge in minimum_spanning_edges(py, graph, weight_fn, default_weight)? .edges .iter() @@ -171,11 +172,11 @@ fn _minimum_spanning_tree( #[pyo3(text_signature = "(spanning_tree, pop, target_pop, epsilon)")] pub fn bipartition_tree( _py: Python, - spanning_tree: graph::PyGraph, + spanning_tree: &graph::PyGraph, pops: Vec, pop_target: f64, epsilon: f64, -) -> PyResult<(graph::PyGraph, Vec<(usize, Vec)>)> { +) -> PyResult)>> { let mut pops = pops; let spanning_tree_graph = &spanning_tree.graph; let mut same_partition_tracker: Vec> = @@ -229,7 +230,7 @@ pub fn bipartition_tree( seen_nodes.insert(node.index()); } - Ok((spanning_tree, balanced_nodes)) + Ok(balanced_nodes) } /// Bipartition graph into two contiguous, population-balanced components. @@ -260,13 +261,12 @@ pub fn bipartition_graph( epsilon: f64, ) -> PyResult)>> { let mut balanced_nodes: Vec<(usize, Vec)> = vec![]; - let mut mst: graph::PyGraph = graph.clone(); + let mut mst = (*graph).clone(); while balanced_nodes.is_empty() { mst.graph.clear_edges(); - mst = _minimum_spanning_tree(py, graph, mst, Some(weight_fn.clone()), 1.0)?; - (mst, balanced_nodes) = - bipartition_tree(py, mst, pops.clone(), pop_target, epsilon).unwrap(); + _minimum_spanning_tree(py, graph, &mut mst, Some(weight_fn.clone()), 1.0)?; + balanced_nodes = bipartition_tree(py, &mst, pops.clone(), pop_target, epsilon).unwrap(); } Ok(balanced_nodes) From c1f73a2ec29504447478bbc3b5ccb59dfc0b76b0 Mon Sep 17 00:00:00 2001 From: Max Fan Date: Fri, 27 May 2022 13:51:15 -0400 Subject: [PATCH 34/44] Remove return reference in _minimum_spanning_tree helper --- src/tree.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index 56cc9c0d5..c4f76851d 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -133,13 +133,13 @@ pub fn minimum_spanning_tree( } /// Helper function to allow reuse of spanning_tree object to reduce memory allocs -fn _minimum_spanning_tree<'a>( +fn _minimum_spanning_tree( py: Python, graph: &graph::PyGraph, - spanning_tree: &'a mut graph::PyGraph, + spanning_tree: &mut graph::PyGraph, weight_fn: Option, default_weight: f64, -) -> PyResult<&'a mut graph::PyGraph> { +) -> PyResult<()> { for edge in minimum_spanning_edges(py, graph, weight_fn, default_weight)? .edges .iter() @@ -147,7 +147,7 @@ fn _minimum_spanning_tree<'a>( spanning_tree.add_edge(edge.0, edge.1, edge.2.clone_ref(py)); } - Ok(spanning_tree) + Ok(()) } /// Bipartition tree by finding balanced cut edges of a spanning tree using From f59f7c2c36fa085ce7dda2615b16586eaa69ca54 Mon Sep 17 00:00:00 2001 From: Max Fan Date: Fri, 27 May 2022 15:37:37 -0400 Subject: [PATCH 35/44] Rename bipartition_graph to bipartition_graph_mst --- docs/source/api.rst | 2 +- ...aml => bipartition_graph_mst-ccb2204bc7b6c407.yaml} | 2 +- src/lib.rs | 2 +- src/tree.rs | 6 +++--- tests/graph/test_bipartition.py | 10 +++++----- 5 files changed, 11 insertions(+), 11 deletions(-) rename releasenotes/notes/{bipartition_graph-ccb2204bc7b6c407.yaml => bipartition_graph_mst-ccb2204bc7b6c407.yaml} (73%) diff --git a/docs/source/api.rst b/docs/source/api.rst index 8f3ee5766..2075b60c5 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -179,7 +179,7 @@ Other Algorithm Functions retworkx.core_number retworkx.graph_greedy_color retworkx.metric_closure - retworkx.bipartition_graph + retworkx.bipartition_graph_mst .. _generator_funcs: diff --git a/releasenotes/notes/bipartition_graph-ccb2204bc7b6c407.yaml b/releasenotes/notes/bipartition_graph_mst-ccb2204bc7b6c407.yaml similarity index 73% rename from releasenotes/notes/bipartition_graph-ccb2204bc7b6c407.yaml rename to releasenotes/notes/bipartition_graph_mst-ccb2204bc7b6c407.yaml index b7456dd27..040c90a94 100644 --- a/releasenotes/notes/bipartition_graph-ccb2204bc7b6c407.yaml +++ b/releasenotes/notes/bipartition_graph_mst-ccb2204bc7b6c407.yaml @@ -1,7 +1,7 @@ --- features: - | - Added a new function :func:`~.bipartition_graph` that takes in a connected + Added a new function :func:`~.bipartition_graph_mst` that takes in a connected graph and tries to draw a minimum spanning tree and find a balanced cut edge to target using :func:`~.bipartition_tree`. If such a corresponding tree and edge cannnot be found, then it retries. diff --git a/src/lib.rs b/src/lib.rs index f0e0dae8a..3c9d15e48 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -413,7 +413,7 @@ fn retworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(minimum_spanning_edges))?; m.add_wrapped(wrap_pyfunction!(minimum_spanning_tree))?; m.add_wrapped(wrap_pyfunction!(bipartition_tree))?; - m.add_wrapped(wrap_pyfunction!(bipartition_graph))?; + m.add_wrapped(wrap_pyfunction!(bipartition_graph_mst))?; m.add_wrapped(wrap_pyfunction!(graph_transitivity))?; m.add_wrapped(wrap_pyfunction!(digraph_transitivity))?; m.add_wrapped(wrap_pyfunction!(graph_core_number))?; diff --git a/src/tree.rs b/src/tree.rs index c4f76851d..3488804fe 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -233,8 +233,8 @@ pub fn bipartition_tree( Ok(balanced_nodes) } -/// Bipartition graph into two contiguous, population-balanced components. -/// Assumes that the graph is contiguous. See :func:`~bipartition_tree` for +/// Bipartition graph into two contiguous, population-balanced components using +/// mst. Assumes that the graph is contiguous. See :func:`~bipartition_tree` for /// details on how balance is defined. /// /// :param PyGraph graph: Undirected graph @@ -252,7 +252,7 @@ pub fn bipartition_tree( /// two partitioned subtrees and the set of nodes making up that subtree. #[pyfunction] #[pyo3(text_signature = "(graph, weight_fn, pop, target_pop, epsilon)")] -pub fn bipartition_graph( +pub fn bipartition_graph_mst( py: Python, graph: &graph::PyGraph, weight_fn: PyObject, diff --git a/tests/graph/test_bipartition.py b/tests/graph/test_bipartition.py index c3a7ffc62..de528e150 100644 --- a/tests/graph/test_bipartition.py +++ b/tests/graph/test_bipartition.py @@ -62,9 +62,9 @@ def test_one_balanced_edge_tree(self): ) self.assertEqual(len(balanced_edges), 1) - # Since this is already a spanning tree, bipartition_graph should + # Since this is already a spanning tree, bipartition_graph_mst should # behave identically. That is, it should be invariant to weight_fn - graph_balanced_edges = retworkx.bipartition_graph( + graph_balanced_edges = retworkx.bipartition_graph_mst( self.tree, lambda _: 1, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], @@ -82,7 +82,7 @@ def test_one_balanced_edge_tree_alt(self): ) self.assertEqual(len(balanced_edges), 1) - graph_balanced_edges = retworkx.bipartition_graph( + graph_balanced_edges = retworkx.bipartition_graph_mst( self.tree, lambda _: 1, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], @@ -100,7 +100,7 @@ def test_three_balanced_edges_line(self): ) self.assertEqual(len(balanced_edges), 3) - graph_balanced_edges = retworkx.bipartition_graph( + graph_balanced_edges = retworkx.bipartition_graph_mst( self.line, lambda _: 1, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], @@ -118,7 +118,7 @@ def test_one_balanced_edges_line(self): ) self.assertEqual(len(balanced_edges), 1) - graph_balanced_edges = retworkx.bipartition_graph( + graph_balanced_edges = retworkx.bipartition_graph_mst( self.line, lambda _: 1, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], From 4088b76481c9b7dd754947b92c32f87135706180 Mon Sep 17 00:00:00 2001 From: Max Fan Date: Fri, 27 May 2022 17:21:09 -0400 Subject: [PATCH 36/44] Create _bipartition_tree internal func --- src/tree.rs | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index 3488804fe..c448c3bb2 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -156,7 +156,7 @@ fn _minimum_spanning_tree( /// population of the tree into two connected subtrees that have population near /// the population target within some epsilon. The function returns a list of /// all such possible cuts, represented as the set of nodes in one -/// partition/subtree. +/// partition/subtree. Wraps around ``_bipartition_tree``. /// /// :param PyGraph graph: Spanning tree. Must be fully connected /// :param pops: The populations assigned to each node in the graph. @@ -177,6 +177,16 @@ pub fn bipartition_tree( pop_target: f64, epsilon: f64, ) -> PyResult)>> { + Ok(_bipartition_tree(spanning_tree, pops, pop_target, epsilon)) +} + +/// Internal _bipartition_tree implementation. +fn _bipartition_tree( + spanning_tree: &graph::PyGraph, + pops: Vec, + pop_target: f64, + epsilon: f64, +) -> Vec<(usize, Vec)> { let mut pops = pops; let spanning_tree_graph = &spanning_tree.graph; let mut same_partition_tracker: Vec> = @@ -230,7 +240,7 @@ pub fn bipartition_tree( seen_nodes.insert(node.index()); } - Ok(balanced_nodes) + balanced_nodes } /// Bipartition graph into two contiguous, population-balanced components using @@ -266,7 +276,7 @@ pub fn bipartition_graph_mst( while balanced_nodes.is_empty() { mst.graph.clear_edges(); _minimum_spanning_tree(py, graph, &mut mst, Some(weight_fn.clone()), 1.0)?; - balanced_nodes = bipartition_tree(py, &mst, pops.clone(), pop_target, epsilon).unwrap(); + balanced_nodes = _bipartition_tree(&mst, pops.clone(), pop_target, epsilon); } Ok(balanced_nodes) From 663d54e589856a15992c35f691f104820a597edd Mon Sep 17 00:00:00 2001 From: Max Fan Date: Fri, 27 May 2022 18:38:12 -0400 Subject: [PATCH 37/44] Use numpy PyReadonlyArray to avoid one, unnecessary copy --- src/tree.rs | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index c448c3bb2..fc955901b 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -11,6 +11,7 @@ // under the License. use hashbrown::HashSet; +use numpy::PyReadonlyArray1; use std::cmp::Ordering; use std::collections::VecDeque; @@ -173,21 +174,29 @@ fn _minimum_spanning_tree( pub fn bipartition_tree( _py: Python, spanning_tree: &graph::PyGraph, - pops: Vec, + pops: PyReadonlyArray1, pop_target: f64, epsilon: f64, ) -> PyResult)>> { - Ok(_bipartition_tree(spanning_tree, pops, pop_target, epsilon)) + // See: https://github.com/Qiskit/retworkx/pull/572#discussion_r883643134 + // There should be no other views of pops, so this should be fine + unsafe { + Ok(_bipartition_tree( + spanning_tree, + pops.as_slice_mut()?, + pop_target, + epsilon, + )) + } } /// Internal _bipartition_tree implementation. fn _bipartition_tree( spanning_tree: &graph::PyGraph, - pops: Vec, + pops: &mut [f64], pop_target: f64, epsilon: f64, ) -> Vec<(usize, Vec)> { - let mut pops = pops; let spanning_tree_graph = &spanning_tree.graph; let mut same_partition_tracker: Vec> = vec![vec![]; spanning_tree_graph.node_bound()]; // Keeps track of all all the nodes on the same side of the partition @@ -266,17 +275,18 @@ pub fn bipartition_graph_mst( py: Python, graph: &graph::PyGraph, weight_fn: PyObject, - pops: Vec, + pops: PyReadonlyArray1, pop_target: f64, epsilon: f64, ) -> PyResult)>> { let mut balanced_nodes: Vec<(usize, Vec)> = vec![]; let mut mst = (*graph).clone(); + let pops_slice: Vec = pops.to_vec()?; while balanced_nodes.is_empty() { mst.graph.clear_edges(); _minimum_spanning_tree(py, graph, &mut mst, Some(weight_fn.clone()), 1.0)?; - balanced_nodes = _bipartition_tree(&mst, pops.clone(), pop_target, epsilon); + balanced_nodes = _bipartition_tree(&mst, &mut pops_slice.clone(), pop_target, epsilon); } Ok(balanced_nodes) From a8f2305aa0b7886702d7575369c55b431f454c2b Mon Sep 17 00:00:00 2001 From: Max Fan Date: Fri, 27 May 2022 19:04:01 -0400 Subject: [PATCH 38/44] Revert "Use numpy PyReadonlyArray to avoid one, unnecessary copy" This reverts commit 53842efad41c41dce5ba1b1418ee7dbafb410734. --- src/tree.rs | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index fc955901b..c448c3bb2 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -11,7 +11,6 @@ // under the License. use hashbrown::HashSet; -use numpy::PyReadonlyArray1; use std::cmp::Ordering; use std::collections::VecDeque; @@ -174,29 +173,21 @@ fn _minimum_spanning_tree( pub fn bipartition_tree( _py: Python, spanning_tree: &graph::PyGraph, - pops: PyReadonlyArray1, + pops: Vec, pop_target: f64, epsilon: f64, ) -> PyResult)>> { - // See: https://github.com/Qiskit/retworkx/pull/572#discussion_r883643134 - // There should be no other views of pops, so this should be fine - unsafe { - Ok(_bipartition_tree( - spanning_tree, - pops.as_slice_mut()?, - pop_target, - epsilon, - )) - } + Ok(_bipartition_tree(spanning_tree, pops, pop_target, epsilon)) } /// Internal _bipartition_tree implementation. fn _bipartition_tree( spanning_tree: &graph::PyGraph, - pops: &mut [f64], + pops: Vec, pop_target: f64, epsilon: f64, ) -> Vec<(usize, Vec)> { + let mut pops = pops; let spanning_tree_graph = &spanning_tree.graph; let mut same_partition_tracker: Vec> = vec![vec![]; spanning_tree_graph.node_bound()]; // Keeps track of all all the nodes on the same side of the partition @@ -275,18 +266,17 @@ pub fn bipartition_graph_mst( py: Python, graph: &graph::PyGraph, weight_fn: PyObject, - pops: PyReadonlyArray1, + pops: Vec, pop_target: f64, epsilon: f64, ) -> PyResult)>> { let mut balanced_nodes: Vec<(usize, Vec)> = vec![]; let mut mst = (*graph).clone(); - let pops_slice: Vec = pops.to_vec()?; while balanced_nodes.is_empty() { mst.graph.clear_edges(); _minimum_spanning_tree(py, graph, &mut mst, Some(weight_fn.clone()), 1.0)?; - balanced_nodes = _bipartition_tree(&mst, &mut pops_slice.clone(), pop_target, epsilon); + balanced_nodes = _bipartition_tree(&mst, pops.clone(), pop_target, epsilon); } Ok(balanced_nodes) From dd427006eca671493f563cd9c9ee23d6ed7587b6 Mon Sep 17 00:00:00 2001 From: Max Fan Date: Sat, 28 May 2022 11:55:23 -0400 Subject: [PATCH 39/44] Update pyo3 text_signature to reflect args Co-authored-by: georgios-ts <45130028+georgios-ts@users.noreply.github.com> --- src/tree.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tree.rs b/src/tree.rs index c448c3bb2..a81067acb 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -169,7 +169,7 @@ fn _minimum_spanning_tree( /// balanced edge that can be cut. The tuple contains the root of one of the /// two partitioned subtrees and the set of nodes making up that subtree. #[pyfunction] -#[pyo3(text_signature = "(spanning_tree, pop, target_pop, epsilon)")] +#[pyo3(text_signature = "(spanning_tree, pops, target_pop, epsilon)")] pub fn bipartition_tree( _py: Python, spanning_tree: &graph::PyGraph, From 7c131a11a8c73da615fdf7f7cc01b766a964fb20 Mon Sep 17 00:00:00 2001 From: Max Fan Date: Sat, 28 May 2022 11:57:16 -0400 Subject: [PATCH 40/44] Apply suggestions from @georgois-ts Co-authored-by: georgios-ts <45130028+georgios-ts@users.noreply.github.com> --- src/tree.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index a81067acb..f9823092a 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -171,13 +171,12 @@ fn _minimum_spanning_tree( #[pyfunction] #[pyo3(text_signature = "(spanning_tree, pops, target_pop, epsilon)")] pub fn bipartition_tree( - _py: Python, spanning_tree: &graph::PyGraph, pops: Vec, pop_target: f64, epsilon: f64, -) -> PyResult)>> { - Ok(_bipartition_tree(spanning_tree, pops, pop_target, epsilon)) +) -> Vec<(usize, Vec)> { + _bipartition_tree(spanning_tree, pops, pop_target, epsilon) } /// Internal _bipartition_tree implementation. From aa24d61da61ebe23cf7c2e8ef3503b36e6aeaef3 Mon Sep 17 00:00:00 2001 From: Max Fan Date: Sat, 28 May 2022 12:20:31 -0400 Subject: [PATCH 41/44] Update pyo3 text_signature to reflect Rust args --- src/tree.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tree.rs b/src/tree.rs index f9823092a..a8286f17a 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -260,7 +260,7 @@ fn _bipartition_tree( /// balanced edge that can be cut. The tuple contains the root of one of the /// two partitioned subtrees and the set of nodes making up that subtree. #[pyfunction] -#[pyo3(text_signature = "(graph, weight_fn, pop, target_pop, epsilon)")] +#[pyo3(text_signature = "(graph, weight_fn, pops, target_pop, epsilon)")] pub fn bipartition_graph_mst( py: Python, graph: &graph::PyGraph, From e7fdffe394e63e37e5a68f9ce6aa127c928e67bc Mon Sep 17 00:00:00 2001 From: Max Fan Date: Sat, 28 May 2022 15:16:05 -0400 Subject: [PATCH 42/44] Switch to using LinkedList for cheaper appends and remove unnecessary queue traversal --- src/tree.rs | 52 ++++++++++++++++++++++++++++++++-------------------- 1 file changed, 32 insertions(+), 20 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index a8286f17a..ce3edbcd5 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -12,7 +12,9 @@ use hashbrown::HashSet; use std::cmp::Ordering; +use std::collections::LinkedList; use std::collections::VecDeque; +use std::mem; use super::{graph, weight_callable}; @@ -188,23 +190,25 @@ fn _bipartition_tree( ) -> Vec<(usize, Vec)> { let mut pops = pops; let spanning_tree_graph = &spanning_tree.graph; - let mut same_partition_tracker: Vec> = - vec![vec![]; spanning_tree_graph.node_bound()]; // Keeps track of all all the nodes on the same side of the partition + let mut same_partition_tracker: Vec> = + vec![LinkedList::new(); spanning_tree_graph.node_bound()]; // Keeps track of all all the nodes on the same side of the partition let mut node_queue: VecDeque = VecDeque::::new(); for leaf_node in spanning_tree_graph.node_indices() { if spanning_tree_graph.neighbors(leaf_node).count() == 1 { node_queue.push_back(leaf_node); } - same_partition_tracker[leaf_node.index()].push(leaf_node.index()); + same_partition_tracker[leaf_node.index()].push_back(leaf_node.index()); } - // BFS search for balanced nodes - let mut balanced_nodes: Vec<(usize, Vec)> = vec![]; + // BFS search for balanced nodes using LinkedList since append is O(1) + let mut balanced_nodes: Vec<(usize, LinkedList)> = vec![]; let mut seen_nodes = HashSet::with_capacity(spanning_tree_graph.node_count()); while !node_queue.is_empty() { let node = node_queue.pop_front().unwrap(); - let pop = pops[node.index()]; + if seen_nodes.contains(&node.index()) { + continue; + } // Mark as seen; push to queue if only one unseen neighbor let unseen_neighbors: Vec = spanning_tree @@ -214,32 +218,40 @@ fn _bipartition_tree( .collect(); if unseen_neighbors.len() == 1 { - // This will be false if root + // At leaf, will be false at root + let pop = pops[node.index()]; + + // Update neighbor pop let neighbor = unseen_neighbors[0]; pops[neighbor.index()] += pop; - let mut current_partition_tracker = same_partition_tracker[node.index()].clone(); - same_partition_tracker[neighbor.index()].append(&mut current_partition_tracker); - if !node_queue.contains(&neighbor) { - node_queue.push_back(neighbor); + // Check if balanced; mark as seen + if pop >= pop_target * (1.0 - epsilon) && pop <= pop_target * (1.0 + epsilon) { + balanced_nodes.push((node.index(), same_partition_tracker[node.index()].clone())); } + seen_nodes.insert(node.index()); + + // Update neighbor partition tracker + let mut current_partition_tracker = + mem::take(&mut same_partition_tracker[node.index()]); + same_partition_tracker[neighbor.index()].append(&mut current_partition_tracker); + + // Queue neighbor + node_queue.push_back(neighbor); } else if unseen_neighbors.is_empty() { - // root + // Is root break; } else { - // Not at the leaves of the unseen subgraph + // Not a leaf yet continue; } - - // Check if balanced - if pop >= pop_target * (1.0 - epsilon) && pop <= pop_target * (1.0 + epsilon) { - balanced_nodes.push((node.index(), same_partition_tracker[node.index()].clone())); - } - - seen_nodes.insert(node.index()); } + // Convert LinkedList back to vec balanced_nodes + .iter() + .map(|(node, partition_nodes)| (*node, partition_nodes.iter().copied().collect())) + .collect() } /// Bipartition graph into two contiguous, population-balanced components using From 4dc3ba7caa2438376ea198bd27cb31e7c68e42c2 Mon Sep 17 00:00:00 2001 From: Max Fan Date: Sat, 28 May 2022 17:02:04 -0400 Subject: [PATCH 43/44] Remove LinkedList use --- src/tree.rs | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index ce3edbcd5..c3545386e 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -12,7 +12,6 @@ use hashbrown::HashSet; use std::cmp::Ordering; -use std::collections::LinkedList; use std::collections::VecDeque; use std::mem; @@ -190,19 +189,19 @@ fn _bipartition_tree( ) -> Vec<(usize, Vec)> { let mut pops = pops; let spanning_tree_graph = &spanning_tree.graph; - let mut same_partition_tracker: Vec> = - vec![LinkedList::new(); spanning_tree_graph.node_bound()]; // Keeps track of all all the nodes on the same side of the partition + let mut same_partition_tracker: Vec> = + vec![Vec::new(); spanning_tree_graph.node_bound()]; // Keeps track of all all the nodes on the same side of the partition let mut node_queue: VecDeque = VecDeque::::new(); for leaf_node in spanning_tree_graph.node_indices() { if spanning_tree_graph.neighbors(leaf_node).count() == 1 { node_queue.push_back(leaf_node); } - same_partition_tracker[leaf_node.index()].push_back(leaf_node.index()); + same_partition_tracker[leaf_node.index()].push(leaf_node.index()); } - // BFS search for balanced nodes using LinkedList since append is O(1) - let mut balanced_nodes: Vec<(usize, LinkedList)> = vec![]; + // BFS search for balanced nodes + let mut balanced_nodes: Vec<(usize, Vec)> = vec![]; let mut seen_nodes = HashSet::with_capacity(spanning_tree_graph.node_count()); while !node_queue.is_empty() { let node = node_queue.pop_front().unwrap(); @@ -247,11 +246,7 @@ fn _bipartition_tree( } } - // Convert LinkedList back to vec balanced_nodes - .iter() - .map(|(node, partition_nodes)| (*node, partition_nodes.iter().copied().collect())) - .collect() } /// Bipartition graph into two contiguous, population-balanced components using From 930069adef3cee72c3bce9951a67221d71daca09 Mon Sep 17 00:00:00 2001 From: Ivan Carvalho Date: Mon, 1 Aug 2022 15:43:12 -0700 Subject: [PATCH 44/44] Move test file to rustworkx tests --- .../graph/test_bipartition.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{retworkx_backwards_compat => rustworkx_tests}/graph/test_bipartition.py (100%) diff --git a/tests/retworkx_backwards_compat/graph/test_bipartition.py b/tests/rustworkx_tests/graph/test_bipartition.py similarity index 100% rename from tests/retworkx_backwards_compat/graph/test_bipartition.py rename to tests/rustworkx_tests/graph/test_bipartition.py