-
Notifications
You must be signed in to change notification settings - Fork 153
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement biparition_graph_mst and bipartition_tree functions #572
base: main
Are you sure you want to change the base?
Implement biparition_graph_mst and bipartition_tree functions #572
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know this is still a WIP and there is still some work to go on it, but I took a quick look and there are some easy optimizations you can make to improve the performance I commented on inline. I'll wait till it's closer to ready to do a more detailed review.
9a821f6
to
46d0950
Compare
Pull Request Test Coverage Report for Build 2402224991
💛 - Coveralls |
1d70fab
to
a33a77c
Compare
@mtreinish I just gave this a rebase and cleaned up the code a bit -- should be ready for your review. Let me know what you think! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for contributing this, we're glad retworkx was useful for your research!
Some things are missing before we are able to merge the PR, an action list would be:
- Write tests for the functions you added in
tests/
- Add a release note announcing the new method using
reno new bipartition_tree
- Add an entry of the new function on
docs/source/api.rst
You can find more details of the steps above in CONTRIBUTING.md. Feel free to ping me or Matthew if you need help
27010c9
to
f2fc8b4
Compare
11f9b9a
to
c777080
Compare
@IvanIsCoding @mtreinish Ok, I've added tests, made a release note with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some minor comments but overall the code is looking good
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall this LGTM, I'll have to read the paper in more depth to check the implementation matches what's expected but I trust you did that correctly if I don't have to time to check that. :) I just have a few more mechanical questions and suggestion inline but I think this is getting close. Thanks for sticking with this and continuing to push it forward.
877ba74
to
889dff5
Compare
4cbcf84
to
a025a07
Compare
@mtreinish @IvanIsCoding I think I addressed all of the feedback you two gave (thanks, by the way!). Let me know if I should re-open any of the comments or address any other concerns/feedback you have. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
||
while balanced_nodes.is_empty() { | ||
mst.graph.clear_edges(); | ||
_minimum_spanning_tree(py, graph, &mut mst, Some(weight_fn.clone()), 1.0)?; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't get why we need to recalculate mst
every time in the while loop (or even why different loops we'll give different results) since the values of all arguments stays the same.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We recalculate mst
every time in the while loop because bipartition_graph_mst
is a sampling algorithm and reusing mst
when _bipartition_tree
fails would change the distribution that we're sampling from.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be clear: weight_fn
is intended to sample and be non-deterministic (e.g. uniformly sampling between the range [0, 1]
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might need to rename weight_fn to something else because on every other function weight_fn takes an edge and returns a weight. Which is not the case here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure what you mean by that. This should be the same weight_fn
that minimum_spanning_tree
takes (that is, it takes in an edge and return a random weight). E.g. weight_fn
can be: lambda _: random.random()
(i.e. uniformly picks between [0,1]
). However, there are other variants of the ReCom Markov Chain sampling algorithm that require certain edges to be prioritized/deprioritized like so: lambda x: random.random() if edge_crosses_county(x) else random.random() + 0.5
. This is used to create redistricting plan sampling methods that respect county or city boundaries.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I totally agree with Ivan. This should be at least documented and given a different name since weight_fn
is used in other places in retworkx with a different meaning and might be a source of confusion for users.
/// two partitioned subtrees and the set of nodes making up that subtree. | ||
#[pyfunction] | ||
#[pyo3(text_signature = "(graph, weight_fn, pop, target_pop, epsilon)")] | ||
pub fn bipartition_graph_mst( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO this function can be omitted from our API since it's just calling minimum_spanning_tree
and bipartion_tree
but we already provide these functions as part of our API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I disagree. This allows reuse of the spanning tree object if there are no balanced edges detected and reduces the memory allocs (i.e. the amount of times the graph is cloned to create a new spanning tree object).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a fair point but I think we can avoid compromising performance and still omit this function with a different design here. We can implement a new pyclass SpanningTreeSampler
that generates random spanning trees and reuse internally the same memory for stroring the MST and avoid unnecessary allocations. Users can then call:
sampler = retworkx.SpanningTreeSampler(graph)
balanced_nodes = []
while not balanced_nodes:
tree = sampler.sample()
balanced_nodes = retworkx.bipartition_tree(tree, ..)
to replicate the output of bipartition_graph_mst
(and it'll be marginally slower).
My main motivation for the above design is two-fold:
- Sampling a spanning tree is an interesting problem on its own with different algorithms in the literature (which we can implement at a later point) and more users will benefit from it.
- The output of
bipartition_graph_mst
feels a bit artificial to me, e.g cutting an output edge will not necessarily cut the input graph into two connected components but it'll only cut the tree that we randomly drew.
6b811ca
to
1edfa62
Compare
1edfa62
to
e7fdffe
Compare
FYI, I just rebased this off |
/// :param weight_fn: A callable object (function, lambda, etc) which | ||
/// will be passed the edge object and expected to return a ``float``. See | ||
/// :func:`~minimum_spanning_tree` for details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be me bing picky, but I'd rename to something along the lines of weight_sample_fn
and describe it as you described to me in the previous comment. It would make it clearer to the users and maintainers that the intended is non deterministic.
I'm saying so because in general, we expect weight_fn
to always return the same result given the same edge. Moreover, wherever possible we also try to cache the calls to weight_fn
to avoid calling Python from Rust more than necessary.
--- | ||
features: | ||
- | | ||
Added a new function :func:`~.bipartition_tree` that takes in spanning tree |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a new function :func:`~.bipartition_tree` that takes in spanning tree | |
Added a new function :func:`~.bipartition_tree` that takes in a tree |
Added a new function :func:`~.bipartition_tree` that takes in spanning tree | ||
and a list of populations assigned to each node in the tree and finds all | ||
balanced edges, if they exist. A balanced edge is defined as an edge that, | ||
when cut, will split the population of the tree into two connected subtrees |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when cut, will split the population of the tree into two connected subtrees | |
when cut, will split the tree into two connected subtrees |
/// Bipartition tree by finding balanced cut edges of a spanning tree using | ||
/// node contraction. Assumes that the tree is connected and is a spanning tree. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// Bipartition tree by finding balanced cut edges of a spanning tree using | |
/// node contraction. Assumes that the tree is connected and is a spanning tree. | |
/// Find all balanced cut edges of a tree. |
/// population of the tree into two connected subtrees that have population near | ||
/// the population target within some epsilon. The function returns a list of | ||
/// all such possible cuts, represented as the set of nodes in one | ||
/// partition/subtree. Wraps around ``_bipartition_tree``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// partition/subtree. Wraps around ``_bipartition_tree``. | |
/// partition/subtree. |
/// all such possible cuts, represented as the set of nodes in one | ||
/// partition/subtree. Wraps around ``_bipartition_tree``. | ||
/// | ||
/// :param PyGraph graph: Spanning tree. Must be fully connected |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// :param PyGraph graph: Spanning tree. Must be fully connected | |
/// :param PyGraph tree: The input tree. |
} else { | ||
// Not a leaf yet | ||
continue; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not really needed since we'll continue in the iteration anyway.
} else { | |
// Not a leaf yet | |
continue; | |
} | |
} // else node is not a leaf yet |
/// balanced edge that can be cut. The tuple contains the root of one of the | ||
/// two partitioned subtrees and the set of nodes making up that subtree. | ||
#[pyfunction] | ||
#[pyo3(text_signature = "(graph, weight_fn, pops, target_pop, epsilon)")] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#[pyo3(text_signature = "(graph, weight_fn, pops, target_pop, epsilon)")] | |
#[pyo3(text_signature = "(graph, weight_fn, pops, pop_target, epsilon)")] |
} | ||
|
||
/// Internal _bipartition_tree implementation. | ||
fn _bipartition_tree( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't really need to put the code in an internal implementation. We can just define the public bipartition_tree
.
pops[neighbor.index()] += pop; | ||
|
||
// Check if balanced; mark as seen | ||
if pop >= pop_target * (1.0 - epsilon) && pop <= pop_target * (1.0 + epsilon) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reading the docs where a balanced edge is defined and the linked paper I'd guess that pop_target
should be equal to the total sum of the population divided by 2. Is there any reason why we allow users to define a different value of pop_target
? I'm asking since depending of the value of pop_target
this check might fail for the other part of the partition.
|
||
while balanced_nodes.is_empty() { | ||
mst.graph.clear_edges(); | ||
_minimum_spanning_tree(py, graph, &mut mst, Some(weight_fn.clone()), 1.0)?; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I totally agree with Ivan. This should be at least documented and given a different name since weight_fn
is used in other places in retworkx with a different meaning and might be a source of confusion for users.
See this PR: Qiskit/rustworkx#572 We thank the rustworkx reviewers for their suggestions. Co-authored-by: Max Fan <root@max.fan>
Note: this is a followup from email correspondence that @mtreinish and I had a few months ago. This is a draft PR, so feedback is welcome and it's not ready for merging (yet). It's fairly non-performant (but enough to beat Python/networkx solidly).
This function takes in a graph with population assigned to each node and draws a minimum spanning tree and finds a cut edge that splits the tree into two partitions that have total populations within some epsilon.
To explain the motivation behind this PR, this function is the main workhorse behind the ReCom algorithm detailed in this paper, which has been used in many litigation and civil rights projects to challenge racial and partisan gerrymandered maps and determine VRA compliance (see the recent cases in Alabama, North Carolina, Pennsylvania, etc.). I'm working on a rewrite of MGGG's main gerrymandering analysis software/engine (GerryChain, written in Python) and have achieved a ~15x speedup by naively rewriting the core graph operations in retworkx.
You can still see my messy debugging, but the rough idea is here.