Skip to content

Commit

Permalink
Fix floyd_warshall_numpy dispatcher (#302)
Browse files Browse the repository at this point in the history
Fix bug found during #291 benchmarking.

floyd_warshall_numpy() was dispatching to the wrong function and calling adjacency_matrix()
instead of graph_floyd_warshall_numpy() and digraph_floyd_warshall_numpy().

* Fix wrong dispatch

* Fix docstrings

* Fix typo

* Improve test for Floyd-Warshall dispatch

* Add bug fix notes

* Move release notes from wrong folder
  • Loading branch information
IvanIsCoding committed Apr 6, 2021
1 parent bb6b487 commit 77ebc40
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 12 deletions.
6 changes: 6 additions & 0 deletions releasenotes/notes/floyd-warshall-fix-d8ec2131dfaeab82.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
fixes:
- |
Fix bug in :func:`~retworkx.floyd_warshall_numpy` in which the dispatcher mistakenly called
:func:`~retworkx.adjacency_matrix` instead of :func:`~retworkx.graph_floyd_warshall_numpy`
and :func:`~retworkx.digraph_floyd_warshall_numpy`.
28 changes: 16 additions & 12 deletions retworkx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,46 +207,50 @@ def _graph_all_simple_paths(graph, from_, to, min_depth=None, cutoff=None):

@functools.singledispatch
def floyd_warshall_numpy(graph, weight_fn=None, default_weight=1.0):
"""Return the adjacency matrix for a graph object
"""Find all-pairs shortest path lengths using Floyd's algorithm
In the case where there are multiple edges between nodes the value in the
output matrix will be the sum of the edges' weights.
Floyd's algorithm is used for finding shortest paths in dense graphs
or graphs with negative weights (where Dijkstra's algorithm fails).
:param graph: The graph used to generate the adjacency matrix from. Can
:param graph: The graph to run Floyd's algorithm on. Can
either be a :class:`~retworkx.PyGraph` or :class:`~retworkx.PyDiGraph`
:param callable weight_fn: A callable object (function, lambda, etc) which
will be passed the edge object and expected to return a ``float``. This
tells retworkx/rust how to extract a numerical weight as a ``float``
for edge object. Some simple examples are::
adjacency_matrix(graph, weight_fn: lambda x: 1)
floyd_warshall_numpy(graph, weight_fn: lambda x: 1)
to return a weight of 1 for all edges. Also::
adjacency_matrix(graph, weight_fn: lambda x: float(x))
floyd_warshall_numpy(graph, weight_fn: lambda x: float(x))
to cast the edge object as a float as the weight. If this is not
specified a default value (either ``default_weight`` or 1) will be used
for all edges.
:param float default_weight: If ``weight_fn`` is not used this can be
optionally used to specify a default weight to use for all edges.
:return: The adjacency matrix for the input dag as a numpy array
:rtype: numpy.ndarray
:returns: A matrix of shortest path distances between nodes. If there is no
path between two nodes then the corresponding matrix entry will be
``np.inf``.
:rtype: numpy.ndarray
"""
raise TypeError("Invalid Input Type %s for graph" % type(graph))


@floyd_warshall_numpy.register(PyDiGraph)
def _digraph_floyd_warshall_numpy(graph, weight_fn=None, default_weight=1.0):
return digraph_adjacency_matrix(graph, weight_fn=weight_fn,
default_weight=default_weight)
return digraph_floyd_warshall_numpy(
graph, weight_fn=weight_fn, default_weight=default_weight
)


@floyd_warshall_numpy.register(PyGraph)
def _graph_floyd_warshall_numpy(graph, weight_fn=None, default_weight=1.0):
return graph_adjacency_matrix(graph, weight_fn=weight_fn,
default_weight=default_weight)
return graph_floyd_warshall_numpy(
graph, weight_fn=weight_fn, default_weight=default_weight
)


@functools.singledispatch
Expand Down
7 changes: 7 additions & 0 deletions tests/test_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ def test_floyd_warshall_numpy(self):
res = retworkx.floyd_warshall_numpy(self.graph)
self.assertIsInstance(res, numpy.ndarray)

if self.class_type == "PyGraph":
expected_res = retworkx.graph_floyd_warshall_numpy(self.graph)
else:
expected_res = retworkx.digraph_floyd_warshall_numpy(self.graph)

self.assertTrue(numpy.array_equal(expected_res, res))

def test_astar_shortest_path(self):
res = retworkx.astar_shortest_path(self.graph, 0, lambda _: True,
lambda _: 1, lambda _: 1)
Expand Down

0 comments on commit 77ebc40

Please sign in to comment.