Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement biparition_graph_mst and bipartition_tree functions #572

Open
wants to merge 46 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
1cdf76a
Draft bipartition_tree implementation
InnovativeInventor Jan 12, 2022
20281b0
Working bipartition tree impl
InnovativeInventor Jan 12, 2022
f85eb52
Release GIL during most of balanced_edge finding code
InnovativeInventor Jan 13, 2022
79b3d94
Ensure that unused vars get gc'ed on each loop
InnovativeInventor Jan 15, 2022
d5d464b
Lint with cargo fmt
InnovativeInventor Mar 19, 2022
e7409a5
Update to using built-in retworkx macros
InnovativeInventor May 16, 2022
a45c147
Take advantage of built-in retworkx macros
InnovativeInventor May 16, 2022
4d93f0e
Use mem::take to save on memory allocs
InnovativeInventor May 17, 2022
f01201c
Revert "Use mem::take to save on memory allocs"
InnovativeInventor May 17, 2022
fdbeaaf
Fix unnecessary clones; address clippy warnings
InnovativeInventor May 17, 2022
517d697
Remove unnecessary dead/commented out code
InnovativeInventor May 17, 2022
2bd4c33
Reduce number of memory allocs
InnovativeInventor May 18, 2022
658ffa5
Remove unused import
InnovativeInventor May 18, 2022
a360695
Remove no longer relevant todos
InnovativeInventor May 18, 2022
c5d76cb
Switch to using HashSet for seen_nodes tracker
InnovativeInventor May 18, 2022
6643195
Touch up comments to make more sense
InnovativeInventor May 18, 2022
90d5ba3
Lint with cargo fmt
InnovativeInventor May 18, 2022
b522daa
Rename functions to biparition_tree and bipartition_graph; update docs
InnovativeInventor May 23, 2022
1e6290d
Make docstrings more descriptive and fix renaming issue
InnovativeInventor May 23, 2022
c283355
Add release notes for bipartition_tree and bipartition_graph
InnovativeInventor May 23, 2022
cbaed09
Add bipartition_graph and bipartition_tree to API section of docs
InnovativeInventor May 23, 2022
41e503a
Shorten description of weight_fn by pointing to minimum_spanning_tree…
InnovativeInventor May 23, 2022
da1e8a7
Fix end with blank line linting issue
InnovativeInventor May 24, 2022
99e5257
Add bipartition tests
InnovativeInventor May 24, 2022
db87924
Fix indent issues in retworkx bipartition docstrings
InnovativeInventor May 24, 2022
f6d807c
Switch to using hashbrown's HashSet impl
InnovativeInventor May 27, 2022
46404f3
Make tests deterministic
InnovativeInventor May 27, 2022
1ea5cd1
Reorder imports as per cargo fmt
InnovativeInventor May 27, 2022
79b5e17
Wrap in rst Python code block
InnovativeInventor May 27, 2022
4e00b9b
Switch to passing by value for mst
InnovativeInventor May 27, 2022
5ba7cc9
Make test name more accurate
InnovativeInventor May 27, 2022
3685720
Handle holes in graph node indices
InnovativeInventor May 27, 2022
8e513f7
Revert "Switch to passing by value for mst"
InnovativeInventor May 27, 2022
c1f73a2
Remove return reference in _minimum_spanning_tree helper
InnovativeInventor May 27, 2022
f59f7c2
Rename bipartition_graph to bipartition_graph_mst
InnovativeInventor May 27, 2022
4088b76
Create _bipartition_tree internal func
InnovativeInventor May 27, 2022
663d54e
Use numpy PyReadonlyArray to avoid one, unnecessary copy
InnovativeInventor May 27, 2022
a8f2305
Revert "Use numpy PyReadonlyArray to avoid one, unnecessary copy"
InnovativeInventor May 27, 2022
dd42700
Update pyo3 text_signature to reflect args
InnovativeInventor May 28, 2022
7c131a1
Apply suggestions from @georgois-ts
InnovativeInventor May 28, 2022
aa24d61
Update pyo3 text_signature to reflect Rust args
InnovativeInventor May 28, 2022
e7fdffe
Switch to using LinkedList for cheaper appends and remove unnecessary…
InnovativeInventor May 28, 2022
4dc3ba7
Remove LinkedList use
InnovativeInventor May 28, 2022
90df5ac
Merge remote-tracking branch 'origin/main' into feat-balanced-cut-edges
IvanIsCoding Aug 1, 2022
5fd550a
Merge remote-tracking branch 'origin/main' into feat-balanced-cut-edges
IvanIsCoding Aug 1, 2022
930069a
Move test file to rustworkx tests
IvanIsCoding Aug 1, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ Tree
retworkx.minimum_spanning_edges
retworkx.minimum_spanning_tree
retworkx.steiner_tree
retworkx.bipartition_tree

.. _isomorphism:

Expand Down Expand Up @@ -178,6 +179,7 @@ Other Algorithm Functions
retworkx.core_number
retworkx.graph_greedy_color
retworkx.metric_closure
retworkx.bipartition_graph_mst

.. _generator_funcs:

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
features:
- |
Added a new function :func:`~.bipartition_graph_mst` that takes in a connected
graph and tries to draw a minimum spanning tree and find a balanced cut
edge to target using :func:`~.bipartition_tree`. If such a corresponding
tree and edge cannnot be found, then it retries.
24 changes: 24 additions & 0 deletions releasenotes/notes/bipartition_tree-4c1ad080b1fab9e8.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
---
features:
- |
Added a new function :func:`~.bipartition_tree` that takes in spanning tree
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Added a new function :func:`~.bipartition_tree` that takes in spanning tree
Added a new function :func:`~.bipartition_tree` that takes in a tree

and a list of populations assigned to each node in the tree and finds all
balanced edges, if they exist. A balanced edge is defined as an edge that,
when cut, will split the population of the tree into two connected subtrees
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
when cut, will split the population of the tree into two connected subtrees
when cut, will split the tree into two connected subtrees

that have population near the population target within some epsilon. The
function returns a list of all such possible cuts, represented as the set
of nodes in one partition/subtree. For example,

.. code-block:: python

balanced_node_choices = retworkx.bipartition_tree(
tree,
pops,
float(pop_target),
float(epsilon)
)

returns a list of tuples, with each tuple representing a distinct balanced
edge that can be cut. The tuple contains the root of one of the two
partitioned subtrees and the set of nodes making up that subtree. The other
partition can be recovered by computing the complement of the set of nodes.
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,8 @@ fn retworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(max_weight_matching))?;
m.add_wrapped(wrap_pyfunction!(minimum_spanning_edges))?;
m.add_wrapped(wrap_pyfunction!(minimum_spanning_tree))?;
m.add_wrapped(wrap_pyfunction!(bipartition_tree))?;
m.add_wrapped(wrap_pyfunction!(bipartition_graph_mst))?;
m.add_wrapped(wrap_pyfunction!(graph_transitivity))?;
m.add_wrapped(wrap_pyfunction!(digraph_transitivity))?;
m.add_wrapped(wrap_pyfunction!(graph_core_number))?;
Expand Down
154 changes: 153 additions & 1 deletion src/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
// License for the specific language governing permissions and limitations
// under the License.

use hashbrown::HashSet;
use std::cmp::Ordering;
use std::collections::VecDeque;
use std::mem;

use super::{graph, weight_callable};

Expand Down Expand Up @@ -126,12 +129,161 @@ pub fn minimum_spanning_tree(
let mut spanning_tree = (*graph).clone();
spanning_tree.graph.clear_edges();

_minimum_spanning_tree(py, graph, &mut spanning_tree, weight_fn, default_weight)?;
Ok(spanning_tree)
InnovativeInventor marked this conversation as resolved.
Show resolved Hide resolved
}

/// Helper function to allow reuse of spanning_tree object to reduce memory allocs
fn _minimum_spanning_tree(
py: Python,
graph: &graph::PyGraph,
spanning_tree: &mut graph::PyGraph,
weight_fn: Option<PyObject>,
default_weight: f64,
) -> PyResult<()> {
for edge in minimum_spanning_edges(py, graph, weight_fn, default_weight)?
.edges
.iter()
{
spanning_tree.add_edge(edge.0, edge.1, edge.2.clone_ref(py));
}

Ok(spanning_tree)
Ok(())
}

/// Bipartition tree by finding balanced cut edges of a spanning tree using
/// node contraction. Assumes that the tree is connected and is a spanning tree.
Comment on lines +154 to +155
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// Bipartition tree by finding balanced cut edges of a spanning tree using
/// node contraction. Assumes that the tree is connected and is a spanning tree.
/// Find all balanced cut edges of a tree.

/// A balanced edge is defined as an edge that, when cut, will split the
/// population of the tree into two connected subtrees that have population near
/// the population target within some epsilon. The function returns a list of
/// all such possible cuts, represented as the set of nodes in one
/// partition/subtree. Wraps around ``_bipartition_tree``.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// partition/subtree. Wraps around ``_bipartition_tree``.
/// partition/subtree.

///
/// :param PyGraph graph: Spanning tree. Must be fully connected
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// :param PyGraph graph: Spanning tree. Must be fully connected
/// :param PyGraph tree: The input tree.

/// :param pops: The populations assigned to each node in the graph.
/// :param float pop_target: The population target to reach when partitioning the
/// graph.
/// :param float epsilon: The maximum percent deviation from the pop_target
/// allowed while still being a valid balanced cut edge.
///
/// :returns: A list of tuples, with each tuple representing a distinct
/// balanced edge that can be cut. The tuple contains the root of one of the
/// two partitioned subtrees and the set of nodes making up that subtree.
#[pyfunction]
#[pyo3(text_signature = "(spanning_tree, pops, target_pop, epsilon)")]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#[pyo3(text_signature = "(spanning_tree, pops, target_pop, epsilon)")]
#[pyo3(text_signature = "(tree, pops, pop_target, epsilon)")]

pub fn bipartition_tree(
spanning_tree: &graph::PyGraph,
pops: Vec<f64>,
InnovativeInventor marked this conversation as resolved.
Show resolved Hide resolved
pop_target: f64,
epsilon: f64,
) -> Vec<(usize, Vec<usize>)> {
_bipartition_tree(spanning_tree, pops, pop_target, epsilon)
}
Comment on lines +174 to +181
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pub fn bipartition_tree(
spanning_tree: &graph::PyGraph,
pops: Vec<f64>,
pop_target: f64,
epsilon: f64,
) -> Vec<(usize, Vec<usize>)> {
_bipartition_tree(spanning_tree, pops, pop_target, epsilon)
}
pub fn bipartition_tree(
tree: &graph::PyGraph,
pops: Vec<f64>,
pop_target: f64,
epsilon: f64,
) -> Vec<(usize, Vec<usize>)> {
_bipartition_tree(tree, pops, pop_target, epsilon)
}


/// Internal _bipartition_tree implementation.
fn _bipartition_tree(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't really need to put the code in an internal implementation. We can just define the public bipartition_tree.

spanning_tree: &graph::PyGraph,
pops: Vec<f64>,
pop_target: f64,
epsilon: f64,
) -> Vec<(usize, Vec<usize>)> {
let mut pops = pops;
let spanning_tree_graph = &spanning_tree.graph;
let mut same_partition_tracker: Vec<Vec<usize>> =
vec![Vec::new(); spanning_tree_graph.node_bound()]; // Keeps track of all all the nodes on the same side of the partition
Comment on lines +192 to +193
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A union-find structure is usually used for holding disjoint partition of a set. https://en.wikipedia.org/wiki/Disjoint-set_data_structure. This has the benefit of fast unions but we'll pay O(|V|) every time we need to find all nodes on the same side of the partition. If the tree has few balanced edges union-find might perform faster so it's worth doing some benchmarks. Note that petgraph crate provides an implementation https://docs.rs/petgraph/latest/petgraph/unionfind/struct.UnionFind.html


let mut node_queue: VecDeque<NodeIndex> = VecDeque::<NodeIndex>::new();
for leaf_node in spanning_tree_graph.node_indices() {
if spanning_tree_graph.neighbors(leaf_node).count() == 1 {
node_queue.push_back(leaf_node);
}
same_partition_tracker[leaf_node.index()].push(leaf_node.index());
}

// BFS search for balanced nodes
let mut balanced_nodes: Vec<(usize, Vec<usize>)> = vec![];
let mut seen_nodes = HashSet::with_capacity(spanning_tree_graph.node_count());
while !node_queue.is_empty() {
let node = node_queue.pop_front().unwrap();
if seen_nodes.contains(&node.index()) {
continue;
}

// Mark as seen; push to queue if only one unseen neighbor
let unseen_neighbors: Vec<NodeIndex> = spanning_tree
.graph
.neighbors(node)
.filter(|node| !seen_nodes.contains(&node.index()))
.collect();

if unseen_neighbors.len() == 1 {
// At leaf, will be false at root
let pop = pops[node.index()];

// Update neighbor pop
let neighbor = unseen_neighbors[0];
pops[neighbor.index()] += pop;

// Check if balanced; mark as seen
if pop >= pop_target * (1.0 - epsilon) && pop <= pop_target * (1.0 + epsilon) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading the docs where a balanced edge is defined and the linked paper I'd guess that pop_target should be equal to the total sum of the population divided by 2. Is there any reason why we allow users to define a different value of pop_target? I'm asking since depending of the value of pop_target this check might fail for the other part of the partition.

balanced_nodes.push((node.index(), same_partition_tracker[node.index()].clone()));
}
seen_nodes.insert(node.index());

// Update neighbor partition tracker
let mut current_partition_tracker =
mem::take(&mut same_partition_tracker[node.index()]);
same_partition_tracker[neighbor.index()].append(&mut current_partition_tracker);

// Queue neighbor
node_queue.push_back(neighbor);
} else if unseen_neighbors.is_empty() {
// Is root
break;
} else {
// Not a leaf yet
continue;
}
Comment on lines +243 to +246
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not really needed since we'll continue in the iteration anyway.

Suggested change
} else {
// Not a leaf yet
continue;
}
} // else node is not a leaf yet

}

balanced_nodes
}

/// Bipartition graph into two contiguous, population-balanced components using
/// mst. Assumes that the graph is contiguous. See :func:`~bipartition_tree` for
InnovativeInventor marked this conversation as resolved.
Show resolved Hide resolved
/// details on how balance is defined.
///
/// :param PyGraph graph: Undirected graph
/// :param weight_fn: A callable object (function, lambda, etc) which
/// will be passed the edge object and expected to return a ``float``. See
/// :func:`~minimum_spanning_tree` for details.
Comment on lines +257 to +259
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be me bing picky, but I'd rename to something along the lines of weight_sample_fn and describe it as you described to me in the previous comment. It would make it clearer to the users and maintainers that the intended is non deterministic.

I'm saying so because in general, we expect weight_fn to always return the same result given the same edge. Moreover, wherever possible we also try to cache the calls to weight_fn to avoid calling Python from Rust more than necessary.

/// :param pops: The populations assigned to each node in the graph.
/// :param float pop_target: The population target to reach when partitioning
/// the graph.
/// :param float epsilon: The maximum percent deviation from the pop_target
/// allowed while still being a valid balanced cut edge.
///
/// :returns: A list of tuples, with each tuple representing a distinct
/// balanced edge that can be cut. The tuple contains the root of one of the
/// two partitioned subtrees and the set of nodes making up that subtree.
#[pyfunction]
#[pyo3(text_signature = "(graph, weight_fn, pops, target_pop, epsilon)")]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#[pyo3(text_signature = "(graph, weight_fn, pops, target_pop, epsilon)")]
#[pyo3(text_signature = "(graph, weight_fn, pops, pop_target, epsilon)")]

pub fn bipartition_graph_mst(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO this function can be omitted from our API since it's just calling minimum_spanning_tree and bipartion_tree but we already provide these functions as part of our API.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disagree. This allows reuse of the spanning tree object if there are no balanced edges detected and reduces the memory allocs (i.e. the amount of times the graph is cloned to create a new spanning tree object).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a fair point but I think we can avoid compromising performance and still omit this function with a different design here. We can implement a new pyclass SpanningTreeSampler that generates random spanning trees and reuse internally the same memory for stroring the MST and avoid unnecessary allocations. Users can then call:

sampler = retworkx.SpanningTreeSampler(graph)
balanced_nodes = []
while not balanced_nodes:
    tree = sampler.sample()
    balanced_nodes = retworkx.bipartition_tree(tree, ..)

to replicate the output of bipartition_graph_mst (and it'll be marginally slower).

My main motivation for the above design is two-fold:

  1. Sampling a spanning tree is an interesting problem on its own with different algorithms in the literature (which we can implement at a later point) and more users will benefit from it.
  2. The output of bipartition_graph_mst feels a bit artificial to me, e.g cutting an output edge will not necessarily cut the input graph into two connected components but it'll only cut the tree that we randomly drew.

py: Python,
graph: &graph::PyGraph,
weight_fn: PyObject,
pops: Vec<f64>,
pop_target: f64,
epsilon: f64,
) -> PyResult<Vec<(usize, Vec<usize>)>> {
let mut balanced_nodes: Vec<(usize, Vec<usize>)> = vec![];
let mut mst = (*graph).clone();

while balanced_nodes.is_empty() {
mst.graph.clear_edges();
_minimum_spanning_tree(py, graph, &mut mst, Some(weight_fn.clone()), 1.0)?;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get why we need to recalculate mst every time in the while loop (or even why different loops we'll give different results) since the values of all arguments stays the same.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We recalculate mst every time in the while loop because bipartition_graph_mst is a sampling algorithm and reusing mst when _bipartition_tree fails would change the distribution that we're sampling from.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be clear: weight_fn is intended to sample and be non-deterministic (e.g. uniformly sampling between the range [0, 1]).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need to rename weight_fn to something else because on every other function weight_fn takes an edge and returns a weight. Which is not the case here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what you mean by that. This should be the same weight_fn that minimum_spanning_tree takes (that is, it takes in an edge and return a random weight). E.g. weight_fn can be: lambda _: random.random() (i.e. uniformly picks between [0,1]). However, there are other variants of the ReCom Markov Chain sampling algorithm that require certain edges to be prioritized/deprioritized like so: lambda x: random.random() if edge_crosses_county(x) else random.random() + 0.5. This is used to create redistricting plan sampling methods that respect county or city boundaries.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I totally agree with Ivan. This should be at least documented and given a different name since weight_fn is used in other places in retworkx with a different meaning and might be a source of confusion for users.

balanced_nodes = _bipartition_tree(&mst, pops.clone(), pop_target, epsilon);
}

Ok(balanced_nodes)
}
128 changes: 128 additions & 0 deletions tests/graph/test_bipartition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

import unittest

import retworkx


class TestBipartition(unittest.TestCase):
def setUp(self):
self.line = retworkx.PyGraph()
a = self.line.add_node(0)
b = self.line.add_node(1)
c = self.line.add_node(2)
d = self.line.add_node(3)
e = self.line.add_node(4)
f = self.line.add_node(5)

self.line.add_edges_from(
[
(a, b, 1),
(b, c, 1),
(c, d, 1),
(d, e, 1),
(e, f, 1),
]
)

self.tree = retworkx.PyGraph()
a = self.tree.add_node(0)
b = self.tree.add_node(1)
c = self.tree.add_node(2)
d = self.tree.add_node(3)
e = self.tree.add_node(4)
f = self.tree.add_node(5)

self.tree.add_edges_from(
[
(a, b, 1),
(a, d, 1),
(c, d, 1),
(a, f, 1),
(d, e, 1),
]
)

def test_one_balanced_edge_tree(self):
balanced_edges = retworkx.bipartition_tree(
self.tree,
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
3.0,
0.2,
)
self.assertEqual(len(balanced_edges), 1)

# Since this is already a spanning tree, bipartition_graph_mst should
# behave identically. That is, it should be invariant to weight_fn
graph_balanced_edges = retworkx.bipartition_graph_mst(
self.tree,
lambda _: 1,
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
3.0,
0.2,
)
self.assertEqual(balanced_edges, graph_balanced_edges)

def test_one_balanced_edge_tree_alt(self):
balanced_edges = retworkx.bipartition_tree(
self.tree,
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
3.0,
0.5,
)
self.assertEqual(len(balanced_edges), 1)

graph_balanced_edges = retworkx.bipartition_graph_mst(
self.tree,
lambda _: 1,
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
3.0,
0.5,
)
self.assertEqual(balanced_edges, graph_balanced_edges)

def test_three_balanced_edges_line(self):
balanced_edges = retworkx.bipartition_tree(
self.line,
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
3.0,
0.5,
)
self.assertEqual(len(balanced_edges), 3)

graph_balanced_edges = retworkx.bipartition_graph_mst(
self.line,
lambda _: 1,
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
3.0,
0.5,
)
self.assertEqual(balanced_edges, graph_balanced_edges)

def test_one_balanced_edges_line(self):
balanced_edges = retworkx.bipartition_tree(
self.line,
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
3.0,
0.01,
)
self.assertEqual(len(balanced_edges), 1)

graph_balanced_edges = retworkx.bipartition_graph_mst(
self.line,
lambda _: 1,
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
3.0,
0.01,
)
self.assertEqual(balanced_edges, graph_balanced_edges)