-
Notifications
You must be signed in to change notification settings - Fork 141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Transitive Reduction Function #923
Changes from 33 commits
b88741d
a7de5ad
ca01739
06119be
1a0f219
a5301ef
8ed4220
7f8b655
363dcb4
9e13637
496e1e8
ece5717
193e55a
fa60e89
284b4d0
f3ba571
a9df47d
0be3566
ef783a5
07fdde9
131759d
d8763f7
f2cfc58
ae0515e
30c4687
80c1c99
d8e35eb
de61f7a
c06559a
b28f1ac
506baaf
f89f547
9a78924
b961f4f
2ac4498
7a4ea26
250876f
6aa24dc
64dc66b
2dc72e5
32b9abf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
--- | ||
features: | ||
- | | ||
Added a new function, :func:`~.transitive_reduction` which returns the transtive reduction | ||
of a given :class:`~rustworkx.PyDiGraph`. This graph must be a Directed Acyclic Graph (DAG). | ||
|
||
Ref: https://en.wikipedia.org/wiki/Transitive_reduction | ||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -17,7 +17,9 @@ use std::cmp::Ordering; | |||||||
use std::collections::BinaryHeap; | ||||||||
|
||||||||
use super::iterators::NodeIndices; | ||||||||
use crate::{digraph, DAGHasCycle, InvalidNode}; | ||||||||
use crate::{digraph, DAGHasCycle, InvalidNode, StablePyGraph}; | ||||||||
|
||||||||
use rustworkx_core::traversal::dfs_edges; | ||||||||
|
||||||||
use pyo3::exceptions::PyValueError; | ||||||||
use pyo3::prelude::*; | ||||||||
|
@@ -637,3 +639,74 @@ pub fn collect_bicolor_runs( | |||||||
|
||||||||
Ok(block_list) | ||||||||
} | ||||||||
|
||||||||
/// Returns the transitive reduction of a directed acyclic graph | ||||||||
/// | ||||||||
/// The transitive reduction of G = (V,E) is a graph G- = (V,E-) | ||||||||
/// such that for all v,w in V there is an edge (v,w) in E- if and only if (v,w) is in E | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
/// and there is no path from v to w in G with length greater than 1. | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
/// | ||||||||
/// :param PyDiGraph graph: A directed acyclic graph | ||||||||
/// | ||||||||
/// :returns: a directed acyclic graph representing the transitive reduction | ||||||||
/// :rtype: PyDiGraph | ||||||||
/// | ||||||||
/// :raises PyValueError: if graph is not a DAG | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
|
||||||||
#[pyfunction] | ||||||||
#[pyo3(text_signature = "(graph, /)")] | ||||||||
pub fn transitive_reduction( | ||||||||
graph: &digraph::PyDiGraph, | ||||||||
py: Python, | ||||||||
) -> PyResult<digraph::PyDiGraph> { | ||||||||
let g = &graph.graph; | ||||||||
if !is_directed_acyclic_graph(graph) { | ||||||||
return Err(PyValueError::new_err( | ||||||||
"Directed Acyclic Graph required for transitive_reduction", | ||||||||
)); | ||||||||
} | ||||||||
let mut tr = StablePyGraph::<Directed>::new(); | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
let mut descendants = HashMap::new(); | ||||||||
danielleodigie marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
let mut check_count = HashMap::new(); | ||||||||
danielleodigie marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
for (i, node) in g.node_indices().enumerate() { | ||||||||
tr.add_node(graph.get_node_data(i).unwrap().clone_ref(py)); | ||||||||
check_count.insert(node, graph.in_degree(i)); | ||||||||
} | ||||||||
|
||||||||
for u in g.node_indices() { | ||||||||
let mut u_nbrs: HashSet<NodeIndex> = g.neighbors(u).collect(); | ||||||||
danielleodigie marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
for v in g.neighbors(u) { | ||||||||
if u_nbrs.contains(&v) { | ||||||||
if !descendants.contains_key(&v) { | ||||||||
let dfs = dfs_edges(&g, Some(v)); | ||||||||
descendants.insert(v, dfs); | ||||||||
} | ||||||||
for desc in &descendants[&v] { | ||||||||
u_nbrs.remove(&NodeIndex::new(desc.1)); | ||||||||
} | ||||||||
} | ||||||||
*check_count.get_mut(&v).unwrap() -= 1; | ||||||||
if check_count[&v] == 0 { | ||||||||
descendants.remove(&v); | ||||||||
} | ||||||||
danielleodigie marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
} | ||||||||
for v in u_nbrs { | ||||||||
tr.add_edge( | ||||||||
u, | ||||||||
v, | ||||||||
graph | ||||||||
.get_edge_data(u.index(), v.index()) | ||||||||
.unwrap() | ||||||||
.clone_ref(py), | ||||||||
); | ||||||||
} | ||||||||
} | ||||||||
Ok(digraph::PyDiGraph { | ||||||||
graph: tr, | ||||||||
node_removed: false, | ||||||||
danielleodigie marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
multigraph: graph.multigraph, | ||||||||
attrs: py.None(), | ||||||||
cycle_state: algo::DfsSpace::default(), | ||||||||
check_cycle: graph.check_cycle, | ||||||||
}) | ||||||||
} |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You might want to add more tests, just to make sure all bases are covered. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Licensed under the Apache License, Version 2.0 (the "License"); you may | ||
# not use this file except in compliance with the License. You may obtain | ||
# a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||
# License for the specific language governing permissions and limitations | ||
# under the License. | ||
|
||
import unittest | ||
|
||
import rustworkx | ||
|
||
|
||
class TestTransitiveReduction(unittest.TestCase): | ||
def test_tr(self): | ||
graph = rustworkx.PyDiGraph() | ||
a = graph.add_node("a") | ||
b = graph.add_node("b") | ||
c = graph.add_node("c") | ||
d = graph.add_node("d") | ||
e = graph.add_node("e") | ||
|
||
graph.add_edges_from( | ||
[(a, b, 1), (a, d, 1), (a, c, 1), (a, e, 1), (b, d, 1), (c, d, 1), (c, e, 1), (d, e, 1)] | ||
) | ||
|
||
tr = rustworkx.transitive_reduction(graph) | ||
self.assertCountEqual(list(tr.edge_list()), [(0, 2), (0, 1), (1, 3), (2, 3), (3, 4)]) | ||
mtreinish marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
graph2 = rustworkx.PyDiGraph() | ||
a = graph2.add_node("a") | ||
b = graph2.add_node("b") | ||
c = graph2.add_node("c") | ||
|
||
graph2.add_edges_from( | ||
[ | ||
(a, b, 1), | ||
(b, c, 1), | ||
(a, c, 1), | ||
] | ||
) | ||
|
||
tr2 = rustworkx.transitive_reduction(graph2) | ||
self.assertCountEqual(list(tr2.edge_list()), [(0, 1), (1, 2)]) | ||
|
||
def test_tr_error(self): | ||
digraph = rustworkx.generators.directed_cycle_graph(1000) | ||
with self.assertRaises(ValueError): | ||
rustworkx.transitive_reduction(digraph) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.