From 488579e48562ebd39ca0940972a0a89b46368a7a Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Mon, 25 Jan 2021 17:18:21 -0500 Subject: [PATCH 1/4] Add universal methods In the rust generated Python API we need to have fixed class inpurts to satisfy the traits used by the pyo3 macro generated FFI functions. This results in duplicate methods like digraph_dfs_edges and graph_dfs_edges with the same implementation just differing input types. To simplify the API for users this commit adds universal functions to the python side of the retworkx package to take in any retworkx graph object and dispatch to the proper function in the rust generated api that relies on strict input types. Fixes #215 --- docs/source/api.rst | 21 +++ retworkx/__init__.py | 353 +++++++++++++++++++++++++++++++++++++++++ tests/test_dispatch.py | 71 +++++++++ 3 files changed, 445 insertions(+) create mode 100644 tests/test_dispatch.py diff --git a/docs/source/api.rst b/docs/source/api.rst index 0ac5fcdb7..a92d0b774 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -88,6 +88,27 @@ Algorithm Functions retworkx.digraph_find_cycle retworkx.digraph_union +Universal Functions +------------------- + +These functions are algorithm functions that wrap per graph object +type functions in the algorithms API but can be run with a +:class:`~retworkx.PyGraph`, :class:`~retworkx.PyDiGraph`, or +:class:`~retworkx.PyDAG` object. + +.. autosummary:: + :toctree: stubs + + retworkx.distance_matrix + retworkx.floyd_warshall_numpy + retworkx.adjacency_matrix + retworkx.all_simple_paths + retworkx.astar_shortest_path + retworkx.dijkstra_shortest_paths + retworkx.dijkstra_shortest_path_lengths + retworkx.k_shortest_path_lengths + retworkx.dfs_edges + Exceptions ---------- diff --git a/retworkx/__init__.py b/retworkx/__init__.py index 9b6fdf3c3..38a133320 100644 --- a/retworkx/__init__.py +++ b/retworkx/__init__.py @@ -7,6 +7,7 @@ # that they have been altered from the originals. import sys +import functools from .retworkx import * sys.modules['retworkx.generators'] = generators @@ -81,3 +82,355 @@ class PyDAG(PyDiGraph): :meth:`PyDAG.add_parent` will avoid this overhead. """ pass + + +@functools.singledispatch +def distance_matrix(graph, parallel_threshold=300): + """Get the distance matrix for a directed graph + + This differs from functions like floyd_warshall_numpy in that the + edge weight/data payload is not used and each edge is treated as a + distance of 1. + + This function is also multithreaded and will run in parallel if the number + of nodes in the graph is above the value of ``parallel_threshold`` (it + defaults to 300). If the function will be running in parallel the env var + ``RAYON_NUM_THREADS`` can be used to adjust how many threads will be used. + + :param graph: The graph to get the distance matrix for, can be either a + :class:`~retworkx.PyGraph` or :class:`~retworkx.PyDiGraph`. + :param int parallel_threshold: The number of nodes to calculate the + the distance matrix in parallel at. It defaults to 300, but this can + be tuned + :param bool as_undirected: If set to ``True`` the input directed graph + will be treat as if each edge was bidirectional/undirected in the + output distance matrix. + + :returns: The distance matrix + :rtype: numpy.ndarray + """ + raise TypeError("Invalid Input Type %s for graph" % type(graph)) + + +@distance_matrix.register(PyDiGraph) +def _digraph_distance_matrix(graph, parallel_threshold=300): + return digraph_distance_matrix(graph, parallel_threshold=parallel_threshold) + + +@distance_matrix.register(PyGraph) +def _graph_distance_matrix(graph, parallel_threshold=300): + return graph_distance_matrix(graph, parallel_threshold=parallel_threshold) + + +@functools.singledispatch +def adjacency_matrix(graph, parallel_threshold=300): + """Return the adjacency matrix for a graph object + + In the case where there are multiple edges between nodes the value in the + output matrix will be the sum of the edges' weights. + + :param graph: The graph used to generate the adjacency matrix from. Can + either be a :class:`~retworkx.PyGraph` or :class:`~retworkx.PyDiGraph` + :param callable weight_fn: A callable object (function, lambda, etc) which + will be passed the edge object and expected to return a ``float``. This + tells retworkx/rust how to extract a numerical weight as a ``float`` + for edge object. Some simple examples are:: + + adjacency_matrix(graph, weight_fn: lambda x: 1) + + to return a weight of 1 for all edges. Also:: + + adjacency_matrix(graph, weight_fn: lambda x: float(x)) + + to cast the edge object as a float as the weight. If this is not + specified a default value (either ``default_weight`` or 1) will be used + for all edges. + :param float default_weight: If ``weight_fn`` is not used this can be + optionally used to specify a default weight to use for all edges. + + :return: The adjacency matrix for the input dag as a numpy array + :rtype: numpy.ndarray + """ + raise TypeError("Invalid Input Type %s for graph" % type(graph)) + + +@adjacency_matrix.register(PyDiGraph) +def _digraph_adjacency_matrix(graph, parallel_threshold=300): + return digraph_adjacency_matrix(graph) + + +@adjacency_matrix.register(PyGraph) +def _graph_adjacency_matrix(graph, parallel_threshold=300): + return graph_adjacency_matrix(graph) + + +@functools.singledispatch +def all_simple_paths(graph, from_, to, min_depth=None, cutoff=None): + """Return all simple paths between 2 nodes in a PyGraph object + + A simple path is a path with no repeated nodes. + + :param graph: The graph to find the path in. Can either be a + class:`~retworkx.PyGraph` or :class:`~retworkx.PyDiGraph` + :param int from_: The node index to find the paths from + :param int to: The node index to find the paths to + :param int min_depth: The minimum depth of the path to include in the output + list of paths. By default all paths are included regardless of depth, + setting to 0 will behave like the default. + :param int cutoff: The maximum depth of path to include in the output list + of paths. By default includes all paths regardless of depth, setting to + 0 will behave like default. + + :returns: A list of lists where each inner list is a path of node indices + :rtype: list + """ + raise TypeError("Invalid Input Type %s for graph" % type(graph)) + + +@all_simple_paths.register(PyDiGraph) +def _digraph_all_simple_paths(graph, from_, to, min_depth=None, cutoff=None): + return digraph_all_simple_paths(graph, from_, to, min_depth=min_depth, + cutoff=cutoff) + + +@all_simple_paths.register(PyGraph) +def _digraph_all_simple_paths(graph, from_, to, min_depth=None, cutoff=None): + return graph_all_simple_paths(graph, from_, to, min_depth=min_depth, + cutoff=cutoff) + + +@functools.singledispatch +def floyd_warshall_numpy(graph, weight_fn=None, default_weight=1.0): + """Return the adjacency matrix for a graph object + + In the case where there are multiple edges between nodes the value in the + output matrix will be the sum of the edges' weights. + + :param graph: The graph used to generate the adjacency matrix from. Can + either be a :class:`~retworkx.PyGraph` or :class:`~retworkx.PyDiGraph` + :param callable weight_fn: A callable object (function, lambda, etc) which + will be passed the edge object and expected to return a ``float``. This + tells retworkx/rust how to extract a numerical weight as a ``float`` + for edge object. Some simple examples are:: + + adjacency_matrix(graph, weight_fn: lambda x: 1) + + to return a weight of 1 for all edges. Also:: + + adjacency_matrix(graph, weight_fn: lambda x: float(x)) + + to cast the edge object as a float as the weight. If this is not + specified a default value (either ``default_weight`` or 1) will be used + for all edges. + :param float default_weight: If ``weight_fn`` is not used this can be + optionally used to specify a default weight to use for all edges. + + :return: The adjacency matrix for the input dag as a numpy array + :rtype: numpy.ndarray + """ + raise TypeError("Invalid Input Type %s for graph" % type(graph)) + + +@floyd_warshall_numpy.register(PyDiGraph) +def _digraph_floyd_warshall_numpy(graph, weight_fn=None, default_weight=1.0): + return digraph_adjacency_matrix(graph, weight_fn=weight_fn, + default_weight=default_weight) + + +@floyd_warshall_numpy.register(PyGraph) +def _graph_floyd_warshall_numpy(graph, weight_fn=None, default_weight=1.0): + return graph_adjacency_matrix(graph, weight_fn=weight_fn, + default_weight=default_weight) + + +@functools.singledispatch +def astar_shortest_path(graph, node, goal_fn, edge_cost_fn, estimate_cost_fn): + """Compute the A* shortest path for a PyGraph + + :param graph: The input graph to use. Can + either be a :class:`~retworkx.PyGraph` or :class:`~retworkx.PyDiGraph` + :param int node: The node index to compute the path from + :param goal_fn: A python callable that will take in 1 parameter, a node's + data object and will return a boolean which will be True if it is the + finish node. + :param edge_cost_fn: A python callable that will take in 1 parameter, an + edge's data object and will return a float that represents the cost + of that edge. It must be non-negative. + :param estimate_cost_fn: A python callable that will take in 1 parameter, a + node's data object and will return a float which represents the + estimated cost for the next node. The return must be non-negative. For + the algorithm to find the actual shortest path, it should be + admissible, meaning that it should never overestimate the actual cost + to get to the nearest goal node. + + :returns: The computed shortest path between node and finish as a list + of node indices. + :rtype: NodeIndices + """ + raise TypeError("Invalid Input Type %s for graph" % type(graph)) + + +@astar_shortest_path.register(PyDiGraph) +def _digraph_astar_shortest_path(graph, node, goal_fn, edge_cost_fn, + estimate_cost_fn): + return digraph_astar_shortest_path(graph, node, goal_fn, edge_cost_fn, + estimate_cost_fn) + + +@astar_shortest_path.register(PyGraph) +def _graph_astar_shortest_path(graph, node, goal_fn, edge_cost_fn, + estimate_cost_fn): + return graph_astar_shortest_path(graph, node, goal_fn, edge_cost_fn, + estimate_cost_fn) + + +@functools.singledispatch +def dijkstra_shortest_paths(graph, source, target=None, weight_fn=None, + default_weight=1.0, as_undirected=False): + """Find the shortest path from a node + + This function will generate the shortest path from a source node using + Dijkstra's algorithm. + + :param graph: The input graph to use. Can either be a + :class:`~retworkx.PyGraph` or :class:`~retworkx.PyDiGraph` + :param int source: The node index to find paths from + :param int target: An optional target to find a path to + :param weight_fn: An optional weight function for an edge. It will accept + a single argument, the edge's weight object and will return a float which + will be used to represent the weight/cost of the edge + :param float default_weight: If ``weight_fn`` isn't specified this optional + float value will be used for the weight/cost of each edge. + :param bool as_undirected: If set to true the graph will be treated as + undirected for finding the shortest path. This only works with a + :class:`~retworkx.PyDiGraph` input for ``graph`` + + :return: Dictionary of paths. The keys are destination node indices and + the dict values are lists of node indices making the path. + :rtype: dict + """ + raise TypeError("Invalid Input Type %s for graph" % type(graph)) + + +@dijkstra_shortest_paths.register(PyDiGraph) +def _digraph_dijkstra_shortest_path(graph, source, target=None, weight_fn=None, + default_weight=1.0, as_undirected=False): + return digraph_dijkstra_shortest_paths(graph, source, target=target, + weight_fn=weight_fn, + default_weight=default_weight, + as_undirected=as_undirected) + + +@dijkstra_shortest_paths.register(PyGraph) +def _graph_dijkstra_shortest_path(graph, source, target=None, weight_fn=None, + default_weight=1.0, as_undirected=False): + if as_undirected: + raise TypeError("The ``as_undirected`` flag kwarg only works with a " + "PyDiGraph input") + return graph_dijkstra_shortest_paths(graph, source, target=target, + weight_fn=weight_fn, + default_weight=default_weight) + + +@functools.singledispatch +def dijkstra_shortest_path_lengths(graph, node, edge_cost_fn, goal=None): + """Compute the lengths of the shortest paths for a PyGraph object using + Dijkstra's algorithm. + + :param graph: The input graph to use. Can either be a + :class:`~retworkx.PyGraph` or :class:`~retworkx.PyDiGraph` + :param int node: The node index to use as the source for finding the + shortest paths from + :param edge_cost_fn: A python callable that will take in 1 parameter, an + edge's data object and will return a float that represents the + cost/weight of that edge. It must be non-negative + :param int goal: An optional node index to use as the end of the path. + When specified the traversal will stop when the goal is reached and + the output dictionary will only have a single entry with the length + of the shortest path to the goal node. + + :returns: A dictionary of the shortest paths from the provided node where + the key is the node index of the end of the path and the value is the + cost/sum of the weights of path + :rtype: dict + """ + raise TypeError("Invalid Input Type %s for graph" % type(graph)) + + +@dijkstra_shortest_path_lengths.register(PyDiGraph) +def _digraph_dijkstra_shortest_path_lengths(graph, node, edge_cost_fn, + goal=None): + return digraph_dijkstra_shortest_path_lengths(graph, node, edge_cost_fn, + goal=goal) + +@dijkstra_shortest_path_lengths.register(PyGraph) +def _digraph_dijkstra_shortest_path_lengths(graph, node, edge_cost_fn, + goal=None): + return graph_dijkstra_shortest_path_lengths(graph, node, edge_cost_fn, + goal=goal) + + +@functools.singledispatch +def k_shortest_path_lengths(graph, start, k, edge_cost, goal=None): + """Compute the length of the kth shortest path + + Computes the lengths of the kth shortest path from ``start`` to every + reachable node. + + Computes in :math:`O(k * (|E| + |V|*log(|V|)))` time (average). + + :param graph: The graph to find the shortest paths in. Can either be a + :class:`~retworkx.PyGraph` or :class:`~retworkx.PyDiGraph` + :param int start: The node index to find the shortest paths from + :param int k: The kth shortest path to find the lengths of + :param edge_cost: A python callable that will receive an edge payload and + return a float for the cost of that eedge + :param int goal: An optional goal node index, if specified the output + dictionary + + :returns: A dict of lengths where the key is the destination node index and + the value is the length of the path. + :rtype: dict + """ + raise TypeError("Invalid Input Type %s for graph" % type(graph)) + + +@k_shortest_path_lengths.register(PyDiGraph) +def _digraph_k_shortest_path_lengths(graph, start, k, edge_cost, goal=None): + return digraph_k_shortest_path_lengths(graph, start, k, edge_cost, + goal=goal) + + +@k_shortest_path_lengths.register(PyGraph) +def _graph_k_shortest_path_lengths(graph, start, k, edge_cost, goal=None): + return graph_k_shortest_path_lengths(graph, start, k, edge_cost, + goal=goal) + + +@functools.singledispatch +def dfs_edges(graph, source): + """Get edge list in depth first order + + :param PyGraph graph: The graph to get the DFS edge list from + :param int source: An optional node index to use as the starting node + for the depth-first search. The edge list will only return edges in + the components reachable from this index. If this is not specified + then a source will be chosen arbitrarly and repeated until all + components of the graph are searched. + + :returns: A list of edges as a tuple of the form ``(source, target)`` in + depth-first order + :rtype: EdgeList + raise TypeError("Invalid Input Type %s for graph" % type(graph)) + """ + raise TypeError("Invalid Input Type %s for graph" % type(graph)) + + +@dfs_edges.register(PyDiGraph) +def _digraph_dfs_edges(graph, source): + return digraph_dfs_edges(graph, source) + + +@dfs_edges.register(PyGraph) +def _graph_dfs_edges(graph, source): + return graph_dfs_edges(graph, source) diff --git a/tests/test_dispatch.py b/tests/test_dispatch.py new file mode 100644 index 000000000..cfa21d8cd --- /dev/null +++ b/tests/test_dispatch.py @@ -0,0 +1,71 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest +import retworkx + +import numpy + + +class TestDispatchPyGraph(unittest.TestCase): + + class_type = "PyGraph" + + def setUp(self): + super().setUp() + if self.class_type == "PyGraph": + self.graph = retworkx.undirected_gnp_random_graph(10, .5, seed=42) + else: + self.graph = retworkx.directed_gnp_random_graph(10, .5, seed=42) + + def test_distance_matrix(self): + res = retworkx.distance_matrix(self.graph) + self.assertIsInstance(res, numpy.ndarray) + + def test_adjacency_matrix(self): + res = retworkx.adjacency_matrix(self.graph) + self.assertIsInstance(res, numpy.ndarray) + + def test_all_simple_paths(self): + res = retworkx.all_simple_paths(self.graph, 0, 1) + self.assertIsInstance(res, list) + + def test_floyd_warshall_numpy(self): + res = retworkx.floyd_warshall_numpy(self.graph) + self.assertIsInstance(res, numpy.ndarray) + + def test_astar_shortest_path(self): + res = retworkx.astar_shortest_path(self.graph, 0, lambda _: True, + lambda _: 1, lambda _: 1) + self.assertIsInstance(list(res), list) + + def test_dijkstra_shortest_paths(self): + res = retworkx.dijkstra_shortest_paths(self.graph, 0) + self.assertIsInstance(res, dict) + + def test_dijkstra_shortest_path_lengths(self): + res = retworkx.dijkstra_shortest_path_lengths(self.graph, 0, + lambda _: 1) + self.assertIsInstance(res, dict) + + def test_k_shortest_path_lengths(self): + res = retworkx.k_shortest_path_lengths(self.graph, 0, 2, lambda _: 1) + self.assertIsInstance(res, dict) + + def test_dfs_edges(self): + res = retworkx.dfs_edges(self.graph, 0) + self.assertIsInstance(list(res), list) + + +class TestDispatchPyDiGraph(TestDispatchPyGraph): + + class_type = "PyDiGraph" From 5587b4959298c8f5c8cfb8826683d4951ae60a66 Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Wed, 27 Jan 2021 11:45:10 -0500 Subject: [PATCH 2/4] Add release notes --- docs/source/api.rst | 5 ++++- ...add-universal-functions-1e54351f1f7afa4b.yaml | 16 ++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) create mode 100644 releasenotes/notes/add-universal-functions-1e54351f1f7afa4b.yaml diff --git a/docs/source/api.rst b/docs/source/api.rst index a92d0b774..6f46d5bd0 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -45,6 +45,9 @@ Random Circuit Functions Algorithm Functions ------------------- +Specific Graph Type Methods +''''''''''''''''''''''''''' + .. autosummary:: :toctree: stubs @@ -89,7 +92,7 @@ Algorithm Functions retworkx.digraph_union Universal Functions -------------------- +''''''''''''''''''' These functions are algorithm functions that wrap per graph object type functions in the algorithms API but can be run with a diff --git a/releasenotes/notes/add-universal-functions-1e54351f1f7afa4b.yaml b/releasenotes/notes/add-universal-functions-1e54351f1f7afa4b.yaml new file mode 100644 index 000000000..7b5cd7daa --- /dev/null +++ b/releasenotes/notes/add-universal-functions-1e54351f1f7afa4b.yaml @@ -0,0 +1,16 @@ +--- +features: + - | + New universal functions that can take in a :class:`~retworkx.PyGraph` or + :class:`~retworkx.PyDiGraph` instead of being class specific have been to + the retworkx API. These new functions are: + + * :func:`retworkx.distance_matrix` + * :func:`retworkx.floyd_warshall_numpy` + * :func:`retworkx.adjacency_matrix` + * :func:`retworkx.all_simple_paths` + * :func:`retworkx.astar_shortest_path` + * :func:`retworkx.dijkstra_shortest_paths` + * :func:`retworkx.dijkstra_shortest_path_lengths` + * :func:`retworkx.k_shortest_path_lengths` + * :func:`retworkx.dfs_edges` From b03198056165afced74201b2f8045729fc96e283 Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Wed, 27 Jan 2021 12:04:03 -0500 Subject: [PATCH 3/4] Cleanup docs for new functions --- retworkx/__init__.py | 53 ++++++++++++++++++++++---------------------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/retworkx/__init__.py b/retworkx/__init__.py index 38a133320..df02c7323 100644 --- a/retworkx/__init__.py +++ b/retworkx/__init__.py @@ -86,10 +86,10 @@ class PyDAG(PyDiGraph): @functools.singledispatch def distance_matrix(graph, parallel_threshold=300): - """Get the distance matrix for a directed graph + """Get the distance matrix for a graph - This differs from functions like floyd_warshall_numpy in that the - edge weight/data payload is not used and each edge is treated as a + This differs from functions like :func:`~retworkx.floyd_warshall_numpy` in + that the edge weight/data payload is not used and each edge is treated as a distance of 1. This function is also multithreaded and will run in parallel if the number @@ -114,7 +114,8 @@ def distance_matrix(graph, parallel_threshold=300): @distance_matrix.register(PyDiGraph) def _digraph_distance_matrix(graph, parallel_threshold=300): - return digraph_distance_matrix(graph, parallel_threshold=parallel_threshold) + return digraph_distance_matrix(graph, + parallel_threshold=parallel_threshold) @distance_matrix.register(PyGraph) @@ -123,7 +124,7 @@ def _graph_distance_matrix(graph, parallel_threshold=300): @functools.singledispatch -def adjacency_matrix(graph, parallel_threshold=300): +def adjacency_matrix(graph, weight_fn=None, default_weight=1.0): """Return the adjacency matrix for a graph object In the case where there are multiple edges between nodes the value in the @@ -155,13 +156,15 @@ def adjacency_matrix(graph, parallel_threshold=300): @adjacency_matrix.register(PyDiGraph) -def _digraph_adjacency_matrix(graph, parallel_threshold=300): - return digraph_adjacency_matrix(graph) +def _digraph_adjacency_matrix(graph, weight_fn=None, default_weight=1.0): + return digraph_adjacency_matrix(graph, weight_fn=weight_fn, + default_weight=default_weight) @adjacency_matrix.register(PyGraph) -def _graph_adjacency_matrix(graph, parallel_threshold=300): - return graph_adjacency_matrix(graph) +def _graph_adjacency_matrix(graph, weight_fn=None, default_weight=1.0): + return graph_adjacency_matrix(graph, weight_fn=weight_fn, + default_weight=default_weight) @functools.singledispatch @@ -174,9 +177,9 @@ def all_simple_paths(graph, from_, to, min_depth=None, cutoff=None): class:`~retworkx.PyGraph` or :class:`~retworkx.PyDiGraph` :param int from_: The node index to find the paths from :param int to: The node index to find the paths to - :param int min_depth: The minimum depth of the path to include in the output - list of paths. By default all paths are included regardless of depth, - setting to 0 will behave like the default. + :param int min_depth: The minimum depth of the path to include in the + output list of paths. By default all paths are included regardless of + depth, setting to 0 will behave like the default. :param int cutoff: The maximum depth of path to include in the output list of paths. By default includes all paths regardless of depth, setting to 0 will behave like default. @@ -194,9 +197,9 @@ def _digraph_all_simple_paths(graph, from_, to, min_depth=None, cutoff=None): @all_simple_paths.register(PyGraph) -def _digraph_all_simple_paths(graph, from_, to, min_depth=None, cutoff=None): +def _graph_all_simple_paths(graph, from_, to, min_depth=None, cutoff=None): return graph_all_simple_paths(graph, from_, to, min_depth=min_depth, - cutoff=cutoff) + cutoff=cutoff) @functools.singledispatch @@ -245,7 +248,7 @@ def _graph_floyd_warshall_numpy(graph, weight_fn=None, default_weight=1.0): @functools.singledispatch def astar_shortest_path(graph, node, goal_fn, edge_cost_fn, estimate_cost_fn): - """Compute the A* shortest path for a PyGraph + """Compute the A* shortest path for a graph :param graph: The input graph to use. Can either be a :class:`~retworkx.PyGraph` or :class:`~retworkx.PyDiGraph` @@ -297,8 +300,8 @@ def dijkstra_shortest_paths(graph, source, target=None, weight_fn=None, :param int source: The node index to find paths from :param int target: An optional target to find a path to :param weight_fn: An optional weight function for an edge. It will accept - a single argument, the edge's weight object and will return a float which - will be used to represent the weight/cost of the edge + a single argument, the edge's weight object and will return a float + which will be used to represent the weight/cost of the edge :param float default_weight: If ``weight_fn`` isn't specified this optional float value will be used for the weight/cost of each edge. :param bool as_undirected: If set to true the graph will be treated as @@ -323,10 +326,7 @@ def _digraph_dijkstra_shortest_path(graph, source, target=None, weight_fn=None, @dijkstra_shortest_paths.register(PyGraph) def _graph_dijkstra_shortest_path(graph, source, target=None, weight_fn=None, - default_weight=1.0, as_undirected=False): - if as_undirected: - raise TypeError("The ``as_undirected`` flag kwarg only works with a " - "PyDiGraph input") + default_weight=1.0): return graph_dijkstra_shortest_paths(graph, source, target=target, weight_fn=weight_fn, default_weight=default_weight) @@ -334,7 +334,7 @@ def _graph_dijkstra_shortest_path(graph, source, target=None, weight_fn=None, @functools.singledispatch def dijkstra_shortest_path_lengths(graph, node, edge_cost_fn, goal=None): - """Compute the lengths of the shortest paths for a PyGraph object using + """Compute the lengths of the shortest paths for a graph object using Dijkstra's algorithm. :param graph: The input graph to use. Can either be a @@ -361,13 +361,14 @@ def dijkstra_shortest_path_lengths(graph, node, edge_cost_fn, goal=None): def _digraph_dijkstra_shortest_path_lengths(graph, node, edge_cost_fn, goal=None): return digraph_dijkstra_shortest_path_lengths(graph, node, edge_cost_fn, - goal=goal) + goal=goal) + @dijkstra_shortest_path_lengths.register(PyGraph) -def _digraph_dijkstra_shortest_path_lengths(graph, node, edge_cost_fn, - goal=None): +def _graph_dijkstra_shortest_path_lengths(graph, node, edge_cost_fn, + goal=None): return graph_dijkstra_shortest_path_lengths(graph, node, edge_cost_fn, - goal=goal) + goal=goal) @functools.singledispatch From fa90541e841f0c9957ec2e6783fb0d77c1529da2 Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Mon, 1 Feb 2021 11:50:14 -0500 Subject: [PATCH 4/4] Add as_undirected flag to distance_matrix() --- retworkx/__init__.py | 6 ++++-- tests/test_dispatch.py | 8 ++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/retworkx/__init__.py b/retworkx/__init__.py index df02c7323..45acea4d3 100644 --- a/retworkx/__init__.py +++ b/retworkx/__init__.py @@ -113,9 +113,11 @@ def distance_matrix(graph, parallel_threshold=300): @distance_matrix.register(PyDiGraph) -def _digraph_distance_matrix(graph, parallel_threshold=300): +def _digraph_distance_matrix(graph, parallel_threshold=300, + as_undirected=False): return digraph_distance_matrix(graph, - parallel_threshold=parallel_threshold) + parallel_threshold=parallel_threshold, + as_undirected=as_undirected) @distance_matrix.register(PyGraph) diff --git a/tests/test_dispatch.py b/tests/test_dispatch.py index cfa21d8cd..c0340d905 100644 --- a/tests/test_dispatch.py +++ b/tests/test_dispatch.py @@ -31,6 +31,14 @@ def test_distance_matrix(self): res = retworkx.distance_matrix(self.graph) self.assertIsInstance(res, numpy.ndarray) + def test_distance_matrix_as_undirected(self): + if self.class_type == "PyGraph": + with self.assertRaises(TypeError): + retworkx.distance_matrix(self.graph, as_undirected=True) + else: + res = retworkx.distance_matrix(self.graph, as_undirected=True) + self.assertIsInstance(res, numpy.ndarray) + def test_adjacency_matrix(self): res = retworkx.adjacency_matrix(self.graph) self.assertIsInstance(res, numpy.ndarray)