Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement biparition_graph_mst and bipartition_tree functions #572

Open
wants to merge 46 commits into
base: main
Choose a base branch
from

Conversation

InnovativeInventor
Copy link
Contributor

Note: this is a followup from email correspondence that @mtreinish and I had a few months ago. This is a draft PR, so feedback is welcome and it's not ready for merging (yet). It's fairly non-performant (but enough to beat Python/networkx solidly).

This function takes in a graph with population assigned to each node and draws a minimum spanning tree and finds a cut edge that splits the tree into two partitions that have total populations within some epsilon.

To explain the motivation behind this PR, this function is the main workhorse behind the ReCom algorithm detailed in this paper, which has been used in many litigation and civil rights projects to challenge racial and partisan gerrymandered maps and determine VRA compliance (see the recent cases in Alabama, North Carolina, Pennsylvania, etc.). I'm working on a rewrite of MGGG's main gerrymandering analysis software/engine (GerryChain, written in Python) and have achieved a ~15x speedup by naively rewriting the core graph operations in retworkx.

You can still see my messy debugging, but the rough idea is here.

@InnovativeInventor InnovativeInventor marked this pull request as draft March 19, 2022 04:10
Copy link
Member

@mtreinish mtreinish left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is still a WIP and there is still some work to go on it, but I took a quick look and there are some easy optimizations you can make to improve the performance I commented on inline. I'll wait till it's closer to ready to do a more detailed review.

src/tree.rs Outdated Show resolved Hide resolved
src/tree.rs Outdated Show resolved Hide resolved
src/tree.rs Outdated Show resolved Hide resolved
@coveralls
Copy link

coveralls commented May 17, 2022

Pull Request Test Coverage Report for Build 2402224991

  • 94 of 96 (97.92%) changed or added relevant lines in 2 files are covered.
  • 1 unchanged line in 1 file lost coverage.
  • Overall coverage decreased (-0.002%) to 97.159%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/tree.rs 92 94 97.87%
Files with Coverage Reduction New Missed Lines %
src/shortest_path/all_pairs_dijkstra.rs 1 98.54%
Totals Coverage Status
Change from base Build 2373810364: -0.002%
Covered Lines: 12277
Relevant Lines: 12636

💛 - Coveralls

@InnovativeInventor InnovativeInventor marked this pull request as ready for review May 18, 2022 01:49
@InnovativeInventor
Copy link
Contributor Author

InnovativeInventor commented May 18, 2022

@mtreinish I just gave this a rebase and cleaned up the code a bit -- should be ready for your review. Let me know what you think!

Copy link
Collaborator

@IvanIsCoding IvanIsCoding left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for contributing this, we're glad retworkx was useful for your research!

Some things are missing before we are able to merge the PR, an action list would be:

  • Write tests for the functions you added in tests/
  • Add a release note announcing the new method using reno new bipartition_tree
  • Add an entry of the new function on docs/source/api.rst

You can find more details of the steps above in CONTRIBUTING.md. Feel free to ping me or Matthew if you need help

@InnovativeInventor InnovativeInventor changed the title Implement biparition_tree function Implement biparition_graph function May 23, 2022
@InnovativeInventor InnovativeInventor force-pushed the feat-balanced-cut-edges branch 2 times, most recently from 11f9b9a to c777080 Compare May 24, 2022 22:02
@InnovativeInventor InnovativeInventor changed the title Implement biparition_graph function Implement biparition_graph and bipartition_tree functions May 24, 2022
@InnovativeInventor
Copy link
Contributor Author

@IvanIsCoding @mtreinish Ok, I've added tests, made a release note with reno, and added a new entry of the function in the docs. Is there anything else you'd like me to do?

Copy link
Collaborator

@IvanIsCoding IvanIsCoding left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some minor comments but overall the code is looking good

releasenotes/notes/bipartition_tree-4c1ad080b1fab9e8.yaml Outdated Show resolved Hide resolved
tests/graph/test_bipartition.py Outdated Show resolved Hide resolved
tests/graph/test_bipartition.py Outdated Show resolved Hide resolved
tests/graph/test_bipartition.py Outdated Show resolved Hide resolved
tests/graph/test_bipartition.py Outdated Show resolved Hide resolved
Copy link
Member

@mtreinish mtreinish left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this LGTM, I'll have to read the paper in more depth to check the implementation matches what's expected but I trust you did that correctly if I don't have to time to check that. :) I just have a few more mechanical questions and suggestion inline but I think this is getting close. Thanks for sticking with this and continuing to push it forward.

src/tree.rs Outdated Show resolved Hide resolved
src/tree.rs Show resolved Hide resolved
src/tree.rs Outdated Show resolved Hide resolved
src/tree.rs Outdated Show resolved Hide resolved
src/tree.rs Show resolved Hide resolved
src/tree.rs Outdated Show resolved Hide resolved
src/tree.rs Outdated Show resolved Hide resolved
@InnovativeInventor InnovativeInventor changed the title Implement biparition_graph and bipartition_tree functions Implement biparition_graph_mst and bipartition_tree functions May 27, 2022
@InnovativeInventor
Copy link
Contributor Author

@mtreinish @IvanIsCoding I think I addressed all of the feedback you two gave (thanks, by the way!). Let me know if I should re-open any of the comments or address any other concerns/feedback you have.

Copy link
Collaborator

@IvanIsCoding IvanIsCoding left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

src/tree.rs Outdated Show resolved Hide resolved
src/tree.rs Outdated Show resolved Hide resolved
src/tree.rs Outdated Show resolved Hide resolved
src/tree.rs Outdated Show resolved Hide resolved
src/tree.rs Outdated Show resolved Hide resolved

while balanced_nodes.is_empty() {
mst.graph.clear_edges();
_minimum_spanning_tree(py, graph, &mut mst, Some(weight_fn.clone()), 1.0)?;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get why we need to recalculate mst every time in the while loop (or even why different loops we'll give different results) since the values of all arguments stays the same.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We recalculate mst every time in the while loop because bipartition_graph_mst is a sampling algorithm and reusing mst when _bipartition_tree fails would change the distribution that we're sampling from.

Copy link
Contributor Author

@InnovativeInventor InnovativeInventor May 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be clear: weight_fn is intended to sample and be non-deterministic (e.g. uniformly sampling between the range [0, 1]).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need to rename weight_fn to something else because on every other function weight_fn takes an edge and returns a weight. Which is not the case here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what you mean by that. This should be the same weight_fn that minimum_spanning_tree takes (that is, it takes in an edge and return a random weight). E.g. weight_fn can be: lambda _: random.random() (i.e. uniformly picks between [0,1]). However, there are other variants of the ReCom Markov Chain sampling algorithm that require certain edges to be prioritized/deprioritized like so: lambda x: random.random() if edge_crosses_county(x) else random.random() + 0.5. This is used to create redistricting plan sampling methods that respect county or city boundaries.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I totally agree with Ivan. This should be at least documented and given a different name since weight_fn is used in other places in retworkx with a different meaning and might be a source of confusion for users.

/// two partitioned subtrees and the set of nodes making up that subtree.
#[pyfunction]
#[pyo3(text_signature = "(graph, weight_fn, pop, target_pop, epsilon)")]
pub fn bipartition_graph_mst(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO this function can be omitted from our API since it's just calling minimum_spanning_tree and bipartion_tree but we already provide these functions as part of our API.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disagree. This allows reuse of the spanning tree object if there are no balanced edges detected and reduces the memory allocs (i.e. the amount of times the graph is cloned to create a new spanning tree object).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a fair point but I think we can avoid compromising performance and still omit this function with a different design here. We can implement a new pyclass SpanningTreeSampler that generates random spanning trees and reuse internally the same memory for stroring the MST and avoid unnecessary allocations. Users can then call:

sampler = retworkx.SpanningTreeSampler(graph)
balanced_nodes = []
while not balanced_nodes:
    tree = sampler.sample()
    balanced_nodes = retworkx.bipartition_tree(tree, ..)

to replicate the output of bipartition_graph_mst (and it'll be marginally slower).

My main motivation for the above design is two-fold:

  1. Sampling a spanning tree is an interesting problem on its own with different algorithms in the literature (which we can implement at a later point) and more users will benefit from it.
  2. The output of bipartition_graph_mst feels a bit artificial to me, e.g cutting an output edge will not necessarily cut the input graph into two connected components but it'll only cut the tree that we randomly drew.

src/tree.rs Show resolved Hide resolved
@InnovativeInventor
Copy link
Contributor Author

FYI, I just rebased this off main since this PR appeared to be out of date from main.

src/tree.rs Outdated Show resolved Hide resolved
src/tree.rs Outdated Show resolved Hide resolved
Comment on lines +257 to +259
/// :param weight_fn: A callable object (function, lambda, etc) which
/// will be passed the edge object and expected to return a ``float``. See
/// :func:`~minimum_spanning_tree` for details.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be me bing picky, but I'd rename to something along the lines of weight_sample_fn and describe it as you described to me in the previous comment. It would make it clearer to the users and maintainers that the intended is non deterministic.

I'm saying so because in general, we expect weight_fn to always return the same result given the same edge. Moreover, wherever possible we also try to cache the calls to weight_fn to avoid calling Python from Rust more than necessary.

---
features:
- |
Added a new function :func:`~.bipartition_tree` that takes in spanning tree
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Added a new function :func:`~.bipartition_tree` that takes in spanning tree
Added a new function :func:`~.bipartition_tree` that takes in a tree

Added a new function :func:`~.bipartition_tree` that takes in spanning tree
and a list of populations assigned to each node in the tree and finds all
balanced edges, if they exist. A balanced edge is defined as an edge that,
when cut, will split the population of the tree into two connected subtrees
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
when cut, will split the population of the tree into two connected subtrees
when cut, will split the tree into two connected subtrees

Comment on lines +154 to +155
/// Bipartition tree by finding balanced cut edges of a spanning tree using
/// node contraction. Assumes that the tree is connected and is a spanning tree.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// Bipartition tree by finding balanced cut edges of a spanning tree using
/// node contraction. Assumes that the tree is connected and is a spanning tree.
/// Find all balanced cut edges of a tree.

/// population of the tree into two connected subtrees that have population near
/// the population target within some epsilon. The function returns a list of
/// all such possible cuts, represented as the set of nodes in one
/// partition/subtree. Wraps around ``_bipartition_tree``.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// partition/subtree. Wraps around ``_bipartition_tree``.
/// partition/subtree.

/// all such possible cuts, represented as the set of nodes in one
/// partition/subtree. Wraps around ``_bipartition_tree``.
///
/// :param PyGraph graph: Spanning tree. Must be fully connected
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// :param PyGraph graph: Spanning tree. Must be fully connected
/// :param PyGraph tree: The input tree.

Comment on lines +243 to +246
} else {
// Not a leaf yet
continue;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not really needed since we'll continue in the iteration anyway.

Suggested change
} else {
// Not a leaf yet
continue;
}
} // else node is not a leaf yet

/// balanced edge that can be cut. The tuple contains the root of one of the
/// two partitioned subtrees and the set of nodes making up that subtree.
#[pyfunction]
#[pyo3(text_signature = "(graph, weight_fn, pops, target_pop, epsilon)")]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#[pyo3(text_signature = "(graph, weight_fn, pops, target_pop, epsilon)")]
#[pyo3(text_signature = "(graph, weight_fn, pops, pop_target, epsilon)")]

}

/// Internal _bipartition_tree implementation.
fn _bipartition_tree(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't really need to put the code in an internal implementation. We can just define the public bipartition_tree.

pops[neighbor.index()] += pop;

// Check if balanced; mark as seen
if pop >= pop_target * (1.0 - epsilon) && pop <= pop_target * (1.0 + epsilon) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading the docs where a balanced edge is defined and the linked paper I'd guess that pop_target should be equal to the total sum of the population divided by 2. Is there any reason why we allow users to define a different value of pop_target? I'm asking since depending of the value of pop_target this check might fail for the other part of the partition.


while balanced_nodes.is_empty() {
mst.graph.clear_edges();
_minimum_spanning_tree(py, graph, &mut mst, Some(weight_fn.clone()), 1.0)?;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I totally agree with Ivan. This should be at least documented and given a different name since weight_fn is used in other places in retworkx with a different meaning and might be a source of confusion for users.

pjrule added a commit to mggg/gerrychain.rs that referenced this pull request Apr 29, 2023
See this PR: Qiskit/rustworkx#572
We thank the rustworkx reviewers for their suggestions.

Co-authored-by: Max Fan <root@max.fan>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants