Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Always return a cycle in digraph_find_cycle if no node is specified and a cycle exists #1181

Merged
merged 21 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions releasenotes/notes/fix-digraph-find-cycle-141e302ff4a8fcd4.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
---
fixes:
IvanIsCoding marked this conversation as resolved.
Show resolved Hide resolved
- |
Fixed the behavior of :func:`~rustworkx.digraph_find_cycle` when
no source node was provided. Previously, the function would start looking
for a cycle at an arbitrary node which was not guaranteed to return a cycle.
Now, the function will smartly choose a source node to start the search from
such that if a cycle exists, it will be found.
other:
- |
The `rustworkx-core` function `rustworkx_core::connectivity::find_cycle` now
requires the `petgraph::visit::Visitable` trait.
112 changes: 82 additions & 30 deletions rustworkx-core/src/connectivity/find_cycle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
// under the License.

use hashbrown::{HashMap, HashSet};
use petgraph::algo;
use petgraph::visit::{
EdgeCount, GraphBase, IntoNeighborsDirected, IntoNodeIdentifiers, NodeCount,
EdgeCount, GraphBase, IntoNeighborsDirected, IntoNodeIdentifiers, NodeCount, Visitable,
};
use petgraph::Direction::Outgoing;
use std::hash::Hash;
Expand Down Expand Up @@ -57,22 +58,22 @@ where
G: GraphBase,
G: NodeCount,
G: EdgeCount,
for<'b> &'b G: GraphBase<NodeId = G::NodeId> + IntoNodeIdentifiers + IntoNeighborsDirected,
for<'b> &'b G:
GraphBase<NodeId = G::NodeId> + IntoNodeIdentifiers + IntoNeighborsDirected + Visitable,
G::NodeId: Eq + Hash,
{
// Find a cycle in the given graph and return it as a list of edges
let mut graph_nodes: HashSet<G::NodeId> = graph.node_identifiers().collect();
let mut cycle: Vec<(G::NodeId, G::NodeId)> = Vec::with_capacity(graph.edge_count());
let temp_value: G::NodeId;
// If source is not set get an arbitrary node from the set of graph
// nodes we've not "examined"
// If source is not set get a node in an arbitrary cycle if it exists,
// otherwise return that there is no cycle
let source_index = match source {
Some(source_value) => source_value,
None => {
temp_value = *graph_nodes.iter().next().unwrap();
graph_nodes.remove(&temp_value);
temp_value
}
None => match find_node_in_arbitrary_cycle(&graph) {
Some(node_in_cycle) => node_in_cycle,
None => {
return Vec::new();
}
},
};
// Stack (ie "pushdown list") of vertices already in the spanning tree
let mut stack: Vec<G::NodeId> = vec![source_index];
Expand Down Expand Up @@ -119,11 +120,47 @@ where
cycle
}

fn find_node_in_arbitrary_cycle<G>(graph: &G) -> Option<G::NodeId>
where
G: GraphBase,
G: NodeCount,
G: EdgeCount,
for<'b> &'b G:
GraphBase<NodeId = G::NodeId> + IntoNodeIdentifiers + IntoNeighborsDirected + Visitable,
G::NodeId: Eq + Hash,
{
for scc in algo::kosaraju_scc(&graph) {
if scc.len() > 1 {
return Some(scc[0]);
}
}
for node in graph.node_identifiers() {
for neighbor in graph.neighbors_directed(node, Outgoing) {
if neighbor == node {
return Some(node);
}
}
}
None
}

#[cfg(test)]
mod tests {
use crate::connectivity::find_cycle;
use petgraph::prelude::*;

// Utility to assert cycles in the response
macro_rules! assert_cycle {
($g: expr, $cycle: expr) => {{
for i in 0..$cycle.len() {
let (s, t) = $cycle[i];
assert!($g.contains_edge(s, t));
let (next_s, _) = $cycle[(i + 1) % $cycle.len()];
assert_eq!(t, next_s);
}
}};
}

#[test]
fn test_find_cycle_source() {
let edge_list = vec![
Expand All @@ -141,20 +178,13 @@ mod tests {
(8, 9),
];
let graph = DiGraph::<i32, i32>::from_edges(edge_list);
let mut res: Vec<(usize, usize)> = find_cycle(&graph, Some(NodeIndex::new(0)))
.iter()
.map(|(s, t)| (s.index(), t.index()))
.collect();
assert_eq!(res, [(0, 1), (1, 2), (2, 3), (3, 0)]);
res = find_cycle(&graph, Some(NodeIndex::new(1)))
.iter()
.map(|(s, t)| (s.index(), t.index()))
.collect();
assert_eq!(res, [(1, 2), (2, 3), (3, 0), (0, 1)]);
res = find_cycle(&graph, Some(NodeIndex::new(5)))
.iter()
.map(|(s, t)| (s.index(), t.index()))
.collect();
for i in [0, 1, 2, 3].iter() {
let idx = NodeIndex::new(*i);
let res = find_cycle(&graph, Some(idx));
assert_cycle!(graph, res);
assert_eq!(res[0].0, idx);
}
let res = find_cycle(&graph, Some(NodeIndex::new(5)));
assert_eq!(res, []);
}

Expand All @@ -176,10 +206,32 @@ mod tests {
];
let mut graph = DiGraph::<i32, i32>::from_edges(edge_list);
graph.add_edge(NodeIndex::new(1), NodeIndex::new(1), 0);
let res: Vec<(usize, usize)> = find_cycle(&graph, Some(NodeIndex::new(0)))
.iter()
.map(|(s, t)| (s.index(), t.index()))
.collect();
assert_eq!(res, [(1, 1)]);
let res = find_cycle(&graph, Some(NodeIndex::new(0)));
assert_eq!(res[0].0, NodeIndex::new(1));
assert_cycle!(graph, res);
}

#[test]
fn test_self_loop_no_source() {
let edge_list = vec![(0, 1), (1, 2), (2, 3), (2, 2)];
let graph = DiGraph::<i32, i32>::from_edges(edge_list);
let res = find_cycle(&graph, None);
assert_cycle!(graph, res);
}

#[test]
fn test_cycle_no_source() {
let edge_list = vec![(0, 1), (1, 2), (2, 3), (3, 4), (4, 2)];
let graph = DiGraph::<i32, i32>::from_edges(edge_list);
let res = find_cycle(&graph, None);
assert_cycle!(graph, res);
}

#[test]
fn test_no_cycle_no_source() {
let edge_list = vec![(0, 1), (1, 2), (2, 3)];
let graph = DiGraph::<i32, i32>::from_edges(edge_list);
let res = find_cycle(&graph, None);
assert_eq!(res, []);
}
}
45 changes: 39 additions & 6 deletions tests/digraph/test_find_cycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import unittest

import rustworkx
import rustworkx.generators


class TestFindCycle(unittest.TestCase):
Expand All @@ -36,30 +37,38 @@ def setUp(self):
]
)

def assertCycle(self, first_node, graph, res):
self.assertEqual(first_node, res[0][0])
for i in range(len(res)):
s, t = res[i]
self.assertTrue(graph.has_edge(s, t))
next_s, _ = res[(i + 1) % len(res)]
self.assertEqual(t, next_s)

def test_find_cycle(self):
graph = rustworkx.PyDiGraph()
graph.add_nodes_from(list(range(6)))
graph.add_edges_from_no_data(
[(0, 1), (0, 3), (0, 5), (1, 2), (2, 3), (3, 4), (4, 5), (4, 0)]
)
res = rustworkx.digraph_find_cycle(graph, 0)
self.assertEqual([(0, 1), (1, 2), (2, 3), (3, 4), (4, 0)], res)
self.assertCycle(0, graph, res)

def test_find_cycle_multiple_roots_same_cycles(self):
res = rustworkx.digraph_find_cycle(self.graph, 0)
self.assertEqual(res, [(0, 1), (1, 2), (2, 3), (3, 0)])
self.assertCycle(0, self.graph, res)
res = rustworkx.digraph_find_cycle(self.graph, 1)
self.assertEqual(res, [(1, 2), (2, 3), (3, 0), (0, 1)])
self.assertCycle(1, self.graph, res)
res = rustworkx.digraph_find_cycle(self.graph, 5)
self.assertEqual(res, [])

def test_find_cycle_disconnected_graphs(self):
self.graph.add_nodes_from(["A", "B", "C"])
self.graph.add_edges_from_no_data([(10, 11), (12, 10), (11, 12)])
res = rustworkx.digraph_find_cycle(self.graph, 0)
self.assertEqual(res, [(0, 1), (1, 2), (2, 3), (3, 0)])
self.assertCycle(0, self.graph, res)
res = rustworkx.digraph_find_cycle(self.graph, 10)
self.assertEqual(res, [(10, 11), (11, 12), (12, 10)])
self.assertCycle(10, self.graph, res)

def test_invalid_types(self):
graph = rustworkx.PyGraph()
Expand All @@ -69,4 +78,28 @@ def test_invalid_types(self):
def test_self_loop(self):
self.graph.add_edge(1, 1, None)
res = rustworkx.digraph_find_cycle(self.graph, 0)
self.assertEqual([(1, 1)], res)
self.assertCycle(1, self.graph, res)

def test_no_cycle_no_source(self):
g = rustworkx.generators.directed_grid_graph(10, 10)
res = rustworkx.digraph_find_cycle(g)
self.assertEqual(res, [])

def test_cycle_no_source(self):
g = rustworkx.generators.directed_path_graph(1000)
a = g.add_node(1000)
b = g.node_indices()[-2]
g.add_edge(b, a, None)
g.add_edge(a, b, None)
res = rustworkx.digraph_find_cycle(g)
self.assertEqual(len(res), 2)
self.assertTrue(res[0] == res[1][::-1])

def test_cycle_self_loop(self):
g = rustworkx.generators.directed_path_graph(1000)
a = g.add_node(1000)
b = g.node_indices()[-1]
g.add_edge(b, a, None)
g.add_edge(a, a, None)
res = rustworkx.digraph_find_cycle(g)
self.assertEqual(res, [(a, a)])
Loading