diff --git a/releasenotes/notes/fix-panic-3962ad36788cab00.yaml b/releasenotes/notes/fix-panic-3962ad36788cab00.yaml new file mode 100644 index 000000000..74156a42a --- /dev/null +++ b/releasenotes/notes/fix-panic-3962ad36788cab00.yaml @@ -0,0 +1,14 @@ +--- +fixes: + - | + Fixed an issue with the Dijkstra path functions: + + * :func:`rustworkx.dijkstra_shortest_paths` + * :func:`rustworkx.dijkstra_shortest_path_lengths` + * :func:`rustworkx.bellman_ford_shortest_path_lengths` + * :func:`rustworkx.bellman_ford_shortest_paths` + * :func:`rustworkx.astar_shortest_path` + + where a `Pyo3.PanicException`were raise with no much detail at the moment + of pass in the `source` argument the index of an out of bound node. + Fixed `#1117 `__ diff --git a/src/shortest_path/mod.rs b/src/shortest_path/mod.rs index 959b1cbdd..26429f1ff 100644 --- a/src/shortest_path/mod.rs +++ b/src/shortest_path/mod.rs @@ -78,6 +78,12 @@ pub fn graph_dijkstra_shortest_paths( default_weight: f64, ) -> PyResult { let start = NodeIndex::new(source); + if !graph.graph.contains_node(start) { + return Err(PyIndexError::new_err(format!( + "Node source index \"{source}\" out of graph bound" + ))); + } + let goal_index: Option = target.map(NodeIndex::new); let mut paths: DictMap> = DictMap::with_capacity(graph.node_count()); @@ -217,6 +223,12 @@ pub fn digraph_dijkstra_shortest_paths( as_undirected: bool, ) -> PyResult { let start = NodeIndex::new(source); + if !graph.graph.contains_node(start) { + return Err(PyIndexError::new_err(format!( + "Node source index \"{source}\" out of graph bound" + ))); + } + let goal_index: Option = target.map(NodeIndex::new); let mut paths: DictMap> = DictMap::with_capacity(graph.node_count()); let cost_fn = CostFn::try_from((weight_fn, default_weight))?; @@ -371,10 +383,16 @@ pub fn graph_dijkstra_shortest_path_lengths( edge_cost_fn: PyObject, goal: Option, ) -> PyResult { - let edge_cost_callable = CostFn::from(edge_cost_fn); let start = NodeIndex::new(node); + let edge_cost_callable = CostFn::from(edge_cost_fn); let goal_index: Option = goal.map(NodeIndex::new); + if !graph.graph.contains_node(start) { + return Err(PyIndexError::new_err(format!( + "Node source index \"{node}\" out of graph bound" + ))); + } + let res: Vec> = dijkstra( &graph.graph, start, @@ -445,6 +463,12 @@ pub fn digraph_dijkstra_shortest_path_lengths( let start = NodeIndex::new(node); let goal_index: Option = goal.map(NodeIndex::new); + if !graph.graph.contains_node(start) { + return Err(PyIndexError::new_err(format!( + "Node source index \"{node}\" out of graph bound" + ))); + } + let res: Vec> = dijkstra( &graph.graph, start, @@ -671,6 +695,12 @@ pub fn digraph_astar_shortest_path( let estimate_cost_callable = CostFn::from(estimate_cost_fn); let start = NodeIndex::new(node); + if !graph.graph.contains_node(start) { + return Err(PyIndexError::new_err(format!( + "Node source index \"{node}\" out of graph bound" + ))); + } + let astar_res = astar( &graph.graph, start, @@ -729,6 +759,12 @@ pub fn graph_astar_shortest_path( let estimate_cost_callable = CostFn::from(estimate_cost_fn); let start = NodeIndex::new(node); + if !graph.graph.contains_node(start) { + return Err(PyIndexError::new_err(format!( + "Node source index \"{node}\" out of graph bound" + ))); + } + let astar_res = astar( &graph.graph, start, @@ -1561,6 +1597,12 @@ pub fn digraph_bellman_ford_shortest_path_lengths( let start = NodeIndex::new(node); + if !graph.graph.contains_node(start) { + return Err(PyIndexError::new_err(format!( + "Node source index \"{node}\" out of graph bound" + ))); + } + let res: Option>> = bellman_ford(&graph.graph, start, |e| edge_cost(e.id()), None)?; @@ -1640,6 +1682,12 @@ pub fn graph_bellman_ford_shortest_path_lengths( let start = NodeIndex::new(node); + if !graph.graph.contains_node(start) { + return Err(PyIndexError::new_err(format!( + "Node source index \"{node}\" out of graph bound" + ))); + } + let res: Option>> = bellman_ford(&graph.graph, start, |e| edge_cost(e.id()), None)?; @@ -1715,6 +1763,13 @@ pub fn graph_bellman_ford_shortest_paths( default_weight: f64, ) -> PyResult { let start = NodeIndex::new(source); + + if !graph.graph.contains_node(start) { + return Err(PyIndexError::new_err(format!( + "Node source index \"{source}\" out of graph bound" + ))); + } + let mut paths: DictMap> = DictMap::with_capacity(graph.node_count()); let edge_weights: Vec> = @@ -1801,6 +1856,13 @@ pub fn digraph_bellman_ford_shortest_paths( } let start = NodeIndex::new(source); + + if !graph.graph.contains_node(start) { + return Err(PyIndexError::new_err(format!( + "Node source index \"{source}\" out of graph bound" + ))); + } + let mut paths: DictMap> = DictMap::with_capacity(graph.node_count()); let edge_weights: Vec> = diff --git a/tests/digraph/test_astar.py b/tests/digraph/test_astar.py index 5a23f74d1..b3ee80997 100644 --- a/tests/digraph/test_astar.py +++ b/tests/digraph/test_astar.py @@ -112,3 +112,17 @@ def test_astar_with_invalid_weights(self): edge_cost_fn=lambda _: invalid_weight, estimate_cost_fn=lambda _: 0, ) + + def test_astar_with_invalid_source_node(self): + g = rustworkx.PyDAG() + a = g.add_node("A") + b = g.add_node("B") + g.add_edge(a, b, 7) + with self.assertRaises(IndexError): + rustworkx.digraph_astar_shortest_path( + g, + len(g.node_indices()) + 1, + goal_fn=lambda goal: goal == "B", + edge_cost_fn=lambda x: float(x), + estimate_cost_fn=lambda _: 0, + ) diff --git a/tests/digraph/test_bellman_ford.py b/tests/digraph/test_bellman_ford.py index 7bab70a56..a502a6956 100644 --- a/tests/digraph/test_bellman_ford.py +++ b/tests/digraph/test_bellman_ford.py @@ -442,3 +442,15 @@ def test_raises_negative_cycle_all_pairs_bellman_ford_path_lenghts(self): with self.assertRaises(rustworkx.NegativeCycle): rustworkx.all_pairs_bellman_ford_path_lengths(graph, float) + + def test_raises_index_error_bellman_ford_paths(self): + with self.assertRaises(IndexError): + rustworkx.digraph_bellman_ford_shortest_paths( + self.graph, len(self.graph.node_indices()) + 1, weight_fn=lambda x: float(x) + ) + + def test_raises_index_error_bellman_ford_path_lenghts(self): + with self.assertRaises(IndexError): + rustworkx.digraph_bellman_ford_shortest_path_lengths( + self.graph, len(self.graph.node_indices()) + 1, edge_cost_fn=lambda x: float(x) + ) diff --git a/tests/digraph/test_dijkstra.py b/tests/digraph/test_dijkstra.py index 7ec146589..614503cbe 100644 --- a/tests/digraph/test_dijkstra.py +++ b/tests/digraph/test_dijkstra.py @@ -326,3 +326,13 @@ def all_pairs_dijkstra_lenghts_with_invalid_weights(self): rustworkx.digraph_all_pairs_dijkstra_path_lengths( graph, edge_cost_fn=lambda _: invalid_weight ) + + def test_dijkstra_path_digraph_with_invalid_source(self): + with self.assertRaises(IndexError): + rustworkx.dijkstra_shortest_paths(self.graph, len(self.graph.node_indices()) + 1) + + def test_dijkstra_path_digraph_lengths_with_invalid_source(self): + with self.assertRaises(IndexError): + rustworkx.dijkstra_shortest_path_lengths( + self.graph, len(self.graph.node_indices()) + 1, edge_cost_fn=lambda x: x + ) diff --git a/tests/graph/test_astar.py b/tests/graph/test_astar.py index 2728c39e5..62bdf9b29 100644 --- a/tests/graph/test_astar.py +++ b/tests/graph/test_astar.py @@ -112,3 +112,17 @@ def test_astar_with_invalid_weights(self): edge_cost_fn=lambda _: invalid_weight, estimate_cost_fn=lambda _: 0, ) + + def test_astar_with_invalid_source_node(self): + g = rustworkx.PyGraph() + a = g.add_node("A") + b = g.add_node("B") + g.add_edge(a, b, 7) + with self.assertRaises(IndexError): + rustworkx.graph_astar_shortest_path( + g, + len(g.node_indices()) + 1, + goal_fn=lambda goal: goal == "B", + edge_cost_fn=lambda x: float(x), + estimate_cost_fn=lambda _: 0, + ) diff --git a/tests/graph/test_bellman_ford.py b/tests/graph/test_bellman_ford.py index 12d87a2b8..d0dd3f106 100644 --- a/tests/graph/test_bellman_ford.py +++ b/tests/graph/test_bellman_ford.py @@ -306,3 +306,15 @@ def test_raises_negative_cycle_all_pairs_bellman_ford_path_lenghts(self): with self.assertRaises(rustworkx.NegativeCycle): rustworkx.all_pairs_bellman_ford_path_lengths(graph, float) + + def test_raises_index_error_bellman_ford_paths(self): + with self.assertRaises(IndexError): + rustworkx.graph_bellman_ford_shortest_paths( + self.graph, len(self.graph.node_indices()) + 1, weight_fn=lambda x: float(x) + ) + + def test_raises_index_error_bellman_ford_path_lenghts(self): + with self.assertRaises(IndexError): + rustworkx.graph_bellman_ford_shortest_path_lengths( + self.graph, len(self.graph.node_indices()) + 1, edge_cost_fn=lambda x: float(x) + ) diff --git a/tests/graph/test_dijkstra.py b/tests/graph/test_dijkstra.py index 3405036fe..745502761 100644 --- a/tests/graph/test_dijkstra.py +++ b/tests/graph/test_dijkstra.py @@ -225,6 +225,16 @@ def dijkstra_with_invalid_weights(self): as_undirected=as_undirected, ) + def test_dijkstra_path_with_invalid_source(self): + with self.assertRaises(IndexError): + rustworkx.dijkstra_shortest_paths(self.graph, len(self.graph.node_indices()) + 1) + + def test_dijkstra_path_lengths_with_invalid_source(self): + with self.assertRaises(IndexError): + rustworkx.dijkstra_shortest_path_lengths( + self.graph, len(self.graph.node_indices()) + 1, edge_cost_fn=float + ) + def dijkstra_lengths_with_invalid_weights(self): graph = rustworkx.generators.path_graph(2) for invalid_weight in [float("nan"), -1]: diff --git a/tests/graph/test_max_weight_matching.py b/tests/graph/test_max_weight_matching.py index 429036649..f95860cf1 100644 --- a/tests/graph/test_max_weight_matching.py +++ b/tests/graph/test_max_weight_matching.py @@ -53,17 +53,11 @@ def get_nx_weight(edge): if (u, v) not in nx_matches: if (v, u) not in nx_matches: print( - "seed {} failed. Element {} and it's " - "reverse {} not found in networkx output.\nrustworkx" - " output: {}\nnetworkx output: {}\nedge list: {}\n" - "falling back to checking for a valid solution".format( - seed, - (u, v), - (v, u), - rx_matches, - nx_matches, - list(rx_graph.weighted_edge_list()), - ) + f"seed {seed} failed. Element {(u, v)} and it's " + f"reverse {(v, u)} not found in networkx output.\nrustworkx" + f" output: {rx_matches}\nnetworkx output: {nx_matches}" + f"\nedge list: {list(rx_graph.weighted_edge_list())}\n" + "falling back to checking for a valid solution" ) not_match = True break