Skip to content

Commit

Permalink
Fix panic error at shortest paths (#1134)
Browse files Browse the repository at this point in the history
* lint suggestion UP032

* Test of dijkstra raising IndexError

* Test Bellman Ford and A*

* fix: shortest path
  • Loading branch information
JPena-code authored Mar 8, 2024
1 parent de6684f commit 058b0c7
Show file tree
Hide file tree
Showing 9 changed files with 154 additions and 12 deletions.
14 changes: 14 additions & 0 deletions releasenotes/notes/fix-panic-3962ad36788cab00.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
---
fixes:
- |
Fixed an issue with the Dijkstra path functions:
* :func:`rustworkx.dijkstra_shortest_paths`
* :func:`rustworkx.dijkstra_shortest_path_lengths`
* :func:`rustworkx.bellman_ford_shortest_path_lengths`
* :func:`rustworkx.bellman_ford_shortest_paths`
* :func:`rustworkx.astar_shortest_path`
where a `Pyo3.PanicException`were raise with no much detail at the moment
of pass in the `source` argument the index of an out of bound node.
Fixed `#1117 <https://github.com/Qiskit/rustworkx/issues/1117>`__
64 changes: 63 additions & 1 deletion src/shortest_path/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ pub fn graph_dijkstra_shortest_paths(
default_weight: f64,
) -> PyResult<PathMapping> {
let start = NodeIndex::new(source);
if !graph.graph.contains_node(start) {
return Err(PyIndexError::new_err(format!(
"Node source index \"{source}\" out of graph bound"
)));
}

let goal_index: Option<NodeIndex> = target.map(NodeIndex::new);
let mut paths: DictMap<NodeIndex, Vec<NodeIndex>> = DictMap::with_capacity(graph.node_count());

Expand Down Expand Up @@ -217,6 +223,12 @@ pub fn digraph_dijkstra_shortest_paths(
as_undirected: bool,
) -> PyResult<PathMapping> {
let start = NodeIndex::new(source);
if !graph.graph.contains_node(start) {
return Err(PyIndexError::new_err(format!(
"Node source index \"{source}\" out of graph bound"
)));
}

let goal_index: Option<NodeIndex> = target.map(NodeIndex::new);
let mut paths: DictMap<NodeIndex, Vec<NodeIndex>> = DictMap::with_capacity(graph.node_count());
let cost_fn = CostFn::try_from((weight_fn, default_weight))?;
Expand Down Expand Up @@ -371,10 +383,16 @@ pub fn graph_dijkstra_shortest_path_lengths(
edge_cost_fn: PyObject,
goal: Option<usize>,
) -> PyResult<PathLengthMapping> {
let edge_cost_callable = CostFn::from(edge_cost_fn);
let start = NodeIndex::new(node);
let edge_cost_callable = CostFn::from(edge_cost_fn);
let goal_index: Option<NodeIndex> = goal.map(NodeIndex::new);

if !graph.graph.contains_node(start) {
return Err(PyIndexError::new_err(format!(
"Node source index \"{node}\" out of graph bound"
)));
}

let res: Vec<Option<f64>> = dijkstra(
&graph.graph,
start,
Expand Down Expand Up @@ -445,6 +463,12 @@ pub fn digraph_dijkstra_shortest_path_lengths(
let start = NodeIndex::new(node);
let goal_index: Option<NodeIndex> = goal.map(NodeIndex::new);

if !graph.graph.contains_node(start) {
return Err(PyIndexError::new_err(format!(
"Node source index \"{node}\" out of graph bound"
)));
}

let res: Vec<Option<f64>> = dijkstra(
&graph.graph,
start,
Expand Down Expand Up @@ -671,6 +695,12 @@ pub fn digraph_astar_shortest_path(
let estimate_cost_callable = CostFn::from(estimate_cost_fn);
let start = NodeIndex::new(node);

if !graph.graph.contains_node(start) {
return Err(PyIndexError::new_err(format!(
"Node source index \"{node}\" out of graph bound"
)));
}

let astar_res = astar(
&graph.graph,
start,
Expand Down Expand Up @@ -729,6 +759,12 @@ pub fn graph_astar_shortest_path(
let estimate_cost_callable = CostFn::from(estimate_cost_fn);
let start = NodeIndex::new(node);

if !graph.graph.contains_node(start) {
return Err(PyIndexError::new_err(format!(
"Node source index \"{node}\" out of graph bound"
)));
}

let astar_res = astar(
&graph.graph,
start,
Expand Down Expand Up @@ -1561,6 +1597,12 @@ pub fn digraph_bellman_ford_shortest_path_lengths(

let start = NodeIndex::new(node);

if !graph.graph.contains_node(start) {
return Err(PyIndexError::new_err(format!(
"Node source index \"{node}\" out of graph bound"
)));
}

let res: Option<Vec<Option<f64>>> =
bellman_ford(&graph.graph, start, |e| edge_cost(e.id()), None)?;

Expand Down Expand Up @@ -1640,6 +1682,12 @@ pub fn graph_bellman_ford_shortest_path_lengths(

let start = NodeIndex::new(node);

if !graph.graph.contains_node(start) {
return Err(PyIndexError::new_err(format!(
"Node source index \"{node}\" out of graph bound"
)));
}

let res: Option<Vec<Option<f64>>> =
bellman_ford(&graph.graph, start, |e| edge_cost(e.id()), None)?;

Expand Down Expand Up @@ -1715,6 +1763,13 @@ pub fn graph_bellman_ford_shortest_paths(
default_weight: f64,
) -> PyResult<PathMapping> {
let start = NodeIndex::new(source);

if !graph.graph.contains_node(start) {
return Err(PyIndexError::new_err(format!(
"Node source index \"{source}\" out of graph bound"
)));
}

let mut paths: DictMap<NodeIndex, Vec<NodeIndex>> = DictMap::with_capacity(graph.node_count());

let edge_weights: Vec<Option<f64>> =
Expand Down Expand Up @@ -1801,6 +1856,13 @@ pub fn digraph_bellman_ford_shortest_paths(
}

let start = NodeIndex::new(source);

if !graph.graph.contains_node(start) {
return Err(PyIndexError::new_err(format!(
"Node source index \"{source}\" out of graph bound"
)));
}

let mut paths: DictMap<NodeIndex, Vec<NodeIndex>> = DictMap::with_capacity(graph.node_count());

let edge_weights: Vec<Option<f64>> =
Expand Down
14 changes: 14 additions & 0 deletions tests/digraph/test_astar.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,17 @@ def test_astar_with_invalid_weights(self):
edge_cost_fn=lambda _: invalid_weight,
estimate_cost_fn=lambda _: 0,
)

def test_astar_with_invalid_source_node(self):
g = rustworkx.PyDAG()
a = g.add_node("A")
b = g.add_node("B")
g.add_edge(a, b, 7)
with self.assertRaises(IndexError):
rustworkx.digraph_astar_shortest_path(
g,
len(g.node_indices()) + 1,
goal_fn=lambda goal: goal == "B",
edge_cost_fn=lambda x: float(x),
estimate_cost_fn=lambda _: 0,
)
12 changes: 12 additions & 0 deletions tests/digraph/test_bellman_ford.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,3 +442,15 @@ def test_raises_negative_cycle_all_pairs_bellman_ford_path_lenghts(self):

with self.assertRaises(rustworkx.NegativeCycle):
rustworkx.all_pairs_bellman_ford_path_lengths(graph, float)

def test_raises_index_error_bellman_ford_paths(self):
with self.assertRaises(IndexError):
rustworkx.digraph_bellman_ford_shortest_paths(
self.graph, len(self.graph.node_indices()) + 1, weight_fn=lambda x: float(x)
)

def test_raises_index_error_bellman_ford_path_lenghts(self):
with self.assertRaises(IndexError):
rustworkx.digraph_bellman_ford_shortest_path_lengths(
self.graph, len(self.graph.node_indices()) + 1, edge_cost_fn=lambda x: float(x)
)
10 changes: 10 additions & 0 deletions tests/digraph/test_dijkstra.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,13 @@ def all_pairs_dijkstra_lenghts_with_invalid_weights(self):
rustworkx.digraph_all_pairs_dijkstra_path_lengths(
graph, edge_cost_fn=lambda _: invalid_weight
)

def test_dijkstra_path_digraph_with_invalid_source(self):
with self.assertRaises(IndexError):
rustworkx.dijkstra_shortest_paths(self.graph, len(self.graph.node_indices()) + 1)

def test_dijkstra_path_digraph_lengths_with_invalid_source(self):
with self.assertRaises(IndexError):
rustworkx.dijkstra_shortest_path_lengths(
self.graph, len(self.graph.node_indices()) + 1, edge_cost_fn=lambda x: x
)
14 changes: 14 additions & 0 deletions tests/graph/test_astar.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,17 @@ def test_astar_with_invalid_weights(self):
edge_cost_fn=lambda _: invalid_weight,
estimate_cost_fn=lambda _: 0,
)

def test_astar_with_invalid_source_node(self):
g = rustworkx.PyGraph()
a = g.add_node("A")
b = g.add_node("B")
g.add_edge(a, b, 7)
with self.assertRaises(IndexError):
rustworkx.graph_astar_shortest_path(
g,
len(g.node_indices()) + 1,
goal_fn=lambda goal: goal == "B",
edge_cost_fn=lambda x: float(x),
estimate_cost_fn=lambda _: 0,
)
12 changes: 12 additions & 0 deletions tests/graph/test_bellman_ford.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,15 @@ def test_raises_negative_cycle_all_pairs_bellman_ford_path_lenghts(self):

with self.assertRaises(rustworkx.NegativeCycle):
rustworkx.all_pairs_bellman_ford_path_lengths(graph, float)

def test_raises_index_error_bellman_ford_paths(self):
with self.assertRaises(IndexError):
rustworkx.graph_bellman_ford_shortest_paths(
self.graph, len(self.graph.node_indices()) + 1, weight_fn=lambda x: float(x)
)

def test_raises_index_error_bellman_ford_path_lenghts(self):
with self.assertRaises(IndexError):
rustworkx.graph_bellman_ford_shortest_path_lengths(
self.graph, len(self.graph.node_indices()) + 1, edge_cost_fn=lambda x: float(x)
)
10 changes: 10 additions & 0 deletions tests/graph/test_dijkstra.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,16 @@ def dijkstra_with_invalid_weights(self):
as_undirected=as_undirected,
)

def test_dijkstra_path_with_invalid_source(self):
with self.assertRaises(IndexError):
rustworkx.dijkstra_shortest_paths(self.graph, len(self.graph.node_indices()) + 1)

def test_dijkstra_path_lengths_with_invalid_source(self):
with self.assertRaises(IndexError):
rustworkx.dijkstra_shortest_path_lengths(
self.graph, len(self.graph.node_indices()) + 1, edge_cost_fn=float
)

def dijkstra_lengths_with_invalid_weights(self):
graph = rustworkx.generators.path_graph(2)
for invalid_weight in [float("nan"), -1]:
Expand Down
16 changes: 5 additions & 11 deletions tests/graph/test_max_weight_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,11 @@ def get_nx_weight(edge):
if (u, v) not in nx_matches:
if (v, u) not in nx_matches:
print(
"seed {} failed. Element {} and it's "
"reverse {} not found in networkx output.\nrustworkx"
" output: {}\nnetworkx output: {}\nedge list: {}\n"
"falling back to checking for a valid solution".format(
seed,
(u, v),
(v, u),
rx_matches,
nx_matches,
list(rx_graph.weighted_edge_list()),
)
f"seed {seed} failed. Element {(u, v)} and it's "
f"reverse {(v, u)} not found in networkx output.\nrustworkx"
f" output: {rx_matches}\nnetworkx output: {nx_matches}"
f"\nedge list: {list(rx_graph.weighted_edge_list())}\n"
"falling back to checking for a valid solution"
)
not_match = True
break
Expand Down

0 comments on commit 058b0c7

Please sign in to comment.