Skip to content

Commit

Permalink
Change callbacks to take ids instead of weights, rename some variable…
Browse files Browse the repository at this point in the history
…s in tests for clarity.
  • Loading branch information
ElePT committed Jun 6, 2024
1 parent 23ba36b commit faf73f8
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ features:
- |
Added a new function ``collect_bicolor_runs`` to rustworkx-core's ``dag_algo`` module.
Previously, the ``collect_bicolor_runs`` functionality for DAGs was only exposed
via the Python interface. Now Rust users can take advantage of this functionality in rustworkx-core.
via the Python interface. Now Rust users can take advantage of this functionality in ``rustworkx-core``.
105 changes: 66 additions & 39 deletions rustworkx-core/src/dag_algo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use hashbrown::HashMap;
use petgraph::algo;
use petgraph::data::DataMap;
use petgraph::visit::{
Data, EdgeRef, GraphBase, GraphProp, IntoEdgesDirected, IntoNeighborsDirected,
EdgeRef, GraphBase, GraphProp, IntoEdgesDirected, IntoEdgeReferences, IntoNeighborsDirected,
IntoNodeIdentifiers, NodeCount, Visitable,
};
use petgraph::Directed;
Expand Down Expand Up @@ -362,13 +362,12 @@ where
/// graph.add_edge(n2, n3, 1);
/// graph.add_edge(n3, n4, 0);
/// graph.add_edge(n3, n5, 1);
/// // filter_fn and color_fn must share error type
/// fn filter_fn(node: &i32) -> Result<Option<bool>, Infallible> {
/// Ok(Some(*node > 0))
/// }
/// fn color_fn(edge: &i32) -> Result<Option<usize>, Infallible> {
/// Ok(Some(*edge as usize))
/// }
/// let filter_fn = |node_id| -> Result<Option<bool>, Infallible> {
/// Ok(Some(*graph.node_weight(node_id).unwrap() > 0))
/// };
/// let color_fn = |edge_id| -> Result<Option<usize>, Infallible> {
/// Ok(Some(*graph.edge_weight(edge_id).unwrap() as usize))
/// };
/// let result = collect_bicolor_runs(&graph, filter_fn, color_fn).unwrap();
/// let expected: Vec<Vec<NodeIndex>> = vec![vec![n2, n3]];
/// assert_eq!(result, Some(expected))
Expand All @@ -379,13 +378,16 @@ pub fn collect_bicolor_runs<G, F, C, E>(
color_fn: C,
) -> Result<Option<Vec<Vec<G::NodeId>>>, E>
where
F: Fn(G::NodeId) -> Result<Option<bool>, E>,
C: Fn(G::EdgeId) -> Result<Option<usize>, E>,
F: Fn(<G as GraphBase>::NodeId) -> Result<Option<bool>, E>,
C: Fn(<G as GraphBase>::EdgeId) -> Result<Option<usize>, E>,
G: IntoNodeIdentifiers // Used in toposort
+ IntoNeighborsDirected // Used in toposort
+ IntoEdgesDirected // Used for .edges_directed
+ IntoEdgeReferences
+ Visitable // Used in toposort
+ DataMap, // Used for .node_weight
<G as GraphBase>::NodeId: Eq + Hash,
<G as GraphBase>::EdgeId: Eq + Hash,
{
let mut pending_list: Vec<Vec<G::NodeId>> = Vec::new();
let mut block_id: Vec<Option<usize>> = Vec::new();
Expand All @@ -407,14 +409,13 @@ where
}

for node in nodes {
if let Some(is_match) = filter_fn(graph.node_weight(node).expect("Invalid NodeId"))? {
if let Some(is_match) = filter_fn(node)? {
let raw_edges = graph.edges_directed(node, petgraph::Direction::Outgoing);

// Remove all edges that yield errors from color_fn
let colors = raw_edges
.map(|edge| {
let edge_weight = edge.weight();
color_fn(edge_weight)
color_fn(edge.id())
})
.collect::<Result<Vec<Option<usize>>, _>>()?;

Expand Down Expand Up @@ -760,27 +761,25 @@ mod test_lexicographical_topological_sort {
mod test_collect_bicolor_runs {

use super::*;
use petgraph::graph::{DiGraph, NodeIndex};
use petgraph::graph::{DiGraph, NodeIndex, EdgeIndex};
use std::error::Error;

fn test_filter_fn(node: &i32) -> Result<Option<bool>, Box<dyn Error>> {
Ok(Some(*node > 1))
}

fn test_color_fn(edge: &i32) -> Result<Option<usize>, Box<dyn Error>> {
Ok(Some(*edge as usize))
}

#[test]
fn test_cycle() {
let mut graph = DiGraph::new();
let n0 = graph.add_node(2);
let n1 = graph.add_node(2);
let n2 = graph.add_node(2);
let n0 = graph.add_node(0);
let n1 = graph.add_node(0);
let n2 = graph.add_node(0);
graph.add_edge(n0, n1, 1);
graph.add_edge(n1, n2, 1);
graph.add_edge(n2, n0, 1);

let test_filter_fn = |_node_id: NodeIndex| -> Result<Option<bool>, Box<dyn Error>> {
Ok(Some(true))
};
let test_color_fn = |_edge_id: EdgeIndex| -> Result<Option<usize>, Box<dyn Error>>{
Ok(Some(1))
};
let result = match collect_bicolor_runs(&graph, test_filter_fn, test_color_fn) {
Ok(Some(_value)) => "Not None",
Ok(None) => "None",
Expand All @@ -794,14 +793,18 @@ mod test_collect_bicolor_runs {
let mut graph = DiGraph::new();
graph.add_node(0);

fn fail_function(node: &i32) -> Result<Option<bool>, Box<dyn Error>> {
if *node > 0 {
let fail_function = |node_id: NodeIndex| -> Result<Option<bool>, Box<dyn Error>> {
let node_weight: &i32 = graph.node_weight(node_id).expect("Invalid NodeId");
if *node_weight > 0 {
Ok(Some(true))
} else {
Err(Box::from("Failed!"))
}
}

};
let test_color_fn = |edge_id: EdgeIndex| -> Result<Option<usize>, Box<dyn Error>>{
let edge_weight: &i32 = graph.edge_weight(edge_id).expect("Invalid Edge");
Ok(Some(*edge_weight as usize))
};
let result = match collect_bicolor_runs(&graph, fail_function, test_color_fn) {
Ok(Some(_value)) => "Not None",
Ok(None) => "None",
Expand All @@ -813,6 +816,14 @@ mod test_collect_bicolor_runs {
#[test]
fn test_empty() {
let graph = DiGraph::new();
let test_filter_fn = |node_id: NodeIndex| -> Result<Option<bool>, Box<dyn Error>> {
let node_weight: &i32 = graph.node_weight(node_id).expect("Invalid NodeId");
Ok(Some(*node_weight > 1))
};
let test_color_fn = |edge_id: EdgeIndex| -> Result<Option<usize>, Box<dyn Error>>{
let edge_weight: &i32 = graph.edge_weight(edge_id).expect("Invalid Edge");
Ok(Some(*edge_weight as usize))
};
let result = collect_bicolor_runs(&graph, test_filter_fn, test_color_fn).unwrap();
let expected: Vec<Vec<NodeIndex>> = vec![];
assert_eq!(result, Some(expected))
Expand Down Expand Up @@ -850,8 +861,6 @@ mod test_collect_bicolor_runs {
Expected: [[cx, cz]]
*/
let mut graph = DiGraph::new();
// The node weight will correspond to the type of node
// All edges have the same weight in this example
let n0 = graph.add_node(0); //q0
let n1 = graph.add_node(1); //q1
let n2 = graph.add_node(2); //cx
Expand All @@ -865,6 +874,16 @@ mod test_collect_bicolor_runs {
graph.add_edge(n3, n4, 0); //cz -> q0_1
graph.add_edge(n3, n5, 1); //cz -> q1_1

// Filter out q0, q1, q0_1 and q1_1
let test_filter_fn = |node_id: NodeIndex| -> Result<Option<bool>, Box<dyn Error>> {
let node_weight: &i32 = graph.node_weight(node_id).expect("Invalid NodeId");
Ok(Some(*node_weight > 0))
};
// The edge color will match its weight
let test_color_fn = |edge_id: EdgeIndex| -> Result<Option<usize>, Box<dyn Error>>{
let edge_weight: &i32 = graph.edge_weight(edge_id).expect("Invalid Edge");
Ok(Some(*edge_weight as usize))
};
let result = collect_bicolor_runs(&graph, test_filter_fn, test_color_fn).unwrap();
let expected: Vec<Vec<NodeIndex>> = vec![vec![n2, n3]]; //[[cx, cz]]
assert_eq!(result, Some(expected))
Expand Down Expand Up @@ -922,16 +941,14 @@ mod test_collect_bicolor_runs {
Expected: [[h, cx, cz, y]]
*/
let mut graph = DiGraph::new();
// The node weight will correspond to the type of node
// All edges have the same weight in this example
let n0 = graph.add_node(0); //q0
let n1 = graph.add_node(1); //q1
let n2 = graph.add_node(2); //h
let n3 = graph.add_node(3); //cx
let n4 = graph.add_node(4); //cz
let n5 = graph.add_node(5); //y
let n1 = graph.add_node(0); //q1
let n2 = graph.add_node(1); //h
let n3 = graph.add_node(1); //cx
let n4 = graph.add_node(1); //cz
let n5 = graph.add_node(1); //y
let n6 = graph.add_node(0); //q0_1
let n7 = graph.add_node(1); //q1_1
let n7 = graph.add_node(0); //q1_1
graph.add_edge(n0, n2, 0); //q0 -> h
graph.add_edge(n2, n3, 0); //h -> cx
graph.add_edge(n1, n3, 1); //q1 -> cx
Expand All @@ -941,6 +958,16 @@ mod test_collect_bicolor_runs {
graph.add_edge(n4, n5, 1); //cz -> y
graph.add_edge(n5, n7, 1); //y -> q1_1

// Filter out q0, q1, q0_1 and q1_1
let test_filter_fn = |node_id: NodeIndex| -> Result<Option<bool>, Box<dyn Error>> {
let node_weight: &i32 = graph.node_weight(node_id).expect("Invalid NodeId");
Ok(Some(*node_weight > 0))
};
// The edge color will match its weight
let test_color_fn = |edge_id: EdgeIndex| -> Result<Option<usize>, Box<dyn Error>>{
let edge_weight: &i32 = graph.edge_weight(edge_id).expect("Invalid Edge");
Ok(Some(*edge_weight as usize))
};
let result = collect_bicolor_runs(&graph, test_filter_fn, test_color_fn).unwrap();
let expected: Vec<Vec<NodeIndex>> = vec![vec![n2, n3, n4, n5]]; //[[h, cx, cz, y]]
assert_eq!(result, Some(expected))
Expand Down
10 changes: 6 additions & 4 deletions src/dag_algo/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -636,13 +636,15 @@ pub fn collect_bicolor_runs(
) -> PyResult<Vec<Vec<PyObject>>> {
let dag = &graph.graph;

let filter_fn_wrapper = |node: &PyObject| -> PyResult<Option<bool>> {
let res = filter_fn.call1(py, (node,))?;
let filter_fn_wrapper = |node_index| -> Result<Option<bool>, PyErr> {
let node_weight = dag.node_weight(node_index).expect("Invalid NodeId");
let res = filter_fn.call1(py, (node_weight,))?;
res.extract(py)
};

let color_fn_wrapper = |edge: &PyObject| -> PyResult<Option<usize>> {
let res = color_fn.call1(py, (edge,))?;
let color_fn_wrapper = |edge_index| -> Result<Option<usize>, PyErr> {
let edge_weight = dag.edge_weight(edge_index).expect("Invalid EdgeId");
let res = color_fn.call1(py, (edge_weight,))?;
res.extract(py)
};

Expand Down
24 changes: 12 additions & 12 deletions tests/digraph/test_collect_bicolor_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ def filter_function(node):
else:
return None

def color_function(node):
if "q" in node:
return int(node[1:])
def color_function(edge):
if "q" in edge:
return int(edge[1:])
else:
return None

Expand Down Expand Up @@ -187,9 +187,9 @@ def filter_function(node):
else:
return None

def color_function(node):
if "q" in node:
return int(node[1:])
def color_function(edge):
if "q" in edge:
return int(edge[1:])
else:
return None

Expand Down Expand Up @@ -264,9 +264,9 @@ def filter_function(node):
else:
return None

def color_function(node):
if "q" in node:
return int(node[1:])
def color_function(edge):
if "q" in edge:
return int(edge[1:])
else:
return None

Expand Down Expand Up @@ -338,9 +338,9 @@ def filter_function(node):
else:
return None

def color_function(node):
if "q" in node:
return int(node[1:])
def color_function(edge):
if "q" in edge:
return int(edge[1:])
else:
return None

Expand Down

0 comments on commit faf73f8

Please sign in to comment.