Skip to content

Commit

Permalink
Add goal and target to functions
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanIsCoding committed May 5, 2022
1 parent d5aa128 commit e7a7cfc
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 16 deletions.
26 changes: 18 additions & 8 deletions retworkx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2084,6 +2084,7 @@ def _graph_dijkstra_search(graph, source, weight_fn, visitor):
def bellman_ford_shortest_paths(
graph,
source,
target=None,
weight_fn=None,
default_weight=1.0,
as_undirected=False,
Expand All @@ -2096,6 +2097,7 @@ def bellman_ford_shortest_paths(
:param graph: The input graph to use. Can either be a
:class:`~retworkx.PyGraph` or :class:`~retworkx.PyDiGraph`
:param int source: The node index to find paths from
:param int target: An optional target to find a path to
:param weight_fn: An optional weight function for an edge. It will accept
a single argument, the edge's weight object and will return a float
which will be used to represent the weight/cost of the edge
Expand All @@ -2109,7 +2111,7 @@ def bellman_ford_shortest_paths(
and the dict values are lists of node indices making the path.
:rtype: PathMapping
:raises :class:`~retworkx.NegativeCycle`: when there is a negative cycle and the shortest
:raises: :class:`~retworkx.NegativeCycle`: when there is a negative cycle and the shortest
path is not defined
"""
raise TypeError("Invalid Input Type %s for graph" % type(graph))
Expand All @@ -2119,31 +2121,36 @@ def bellman_ford_shortest_paths(
def _digraph_bellman_ford_shortest_path(
graph,
source,
target=None,
weight_fn=None,
default_weight=1.0,
as_undirected=False,
):
return digraph_bellman_ford_shortest_paths(
graph,
source,
target=target,
weight_fn=weight_fn,
default_weight=default_weight,
as_undirected=as_undirected,
)


@bellman_ford_shortest_paths.register(PyGraph)
def _bellman_ford_shortest_path(graph, source, weight_fn=None, default_weight=1.0):
def _graph_bellman_ford_shortest_path(
graph, source, target=None, weight_fn=None, default_weight=1.0
):
return graph_bellman_ford_shortest_paths(
graph,
source,
target=target,
weight_fn=weight_fn,
default_weight=default_weight,
)


@functools.singledispatch
def bellman_ford_shortest_path_lengths(graph, node, edge_cost_fn):
def bellman_ford_shortest_path_lengths(graph, node, edge_cost_fn, goal=None):
"""Compute the lengths of the shortest paths for a graph object using
the Bellman-Ford algorithm with the SPFA heuristic.
Expand All @@ -2154,23 +2161,26 @@ def bellman_ford_shortest_path_lengths(graph, node, edge_cost_fn):
:param edge_cost_fn: A python callable that will take in 1 parameter, an
edge's data object and will return a float that represents the
cost/weight of that edge. It can be negative.
:param int goal: An optional node index to use as the end of the path.
When specified the output dictionary will only have a single entry with
the length of the shortest path to the goal node.
:returns: A read-only dictionary of the shortest paths from the provided node
where the key is the node index of the end of the path and the value is the
cost/sum of the weights of path
:rtype: PathLengthMapping
:raises :class:`~retworkx.NegativeCycle`: when there is a negative cycle and the shortest
:raises: :class:`~retworkx.NegativeCycle`: when there is a negative cycle and the shortest
path is not defined
"""
raise TypeError("Invalid Input Type %s for graph" % type(graph))


@bellman_ford_shortest_path_lengths.register(PyDiGraph)
def _digraph_bellman_ford_shortest_path_lengths(graph, node, edge_cost_fn):
return digraph_bellman_ford_shortest_path_lengths(graph, node, edge_cost_fn)
def _digraph_bellman_ford_shortest_path_lengths(graph, node, edge_cost_fn, goal=None):
return digraph_bellman_ford_shortest_path_lengths(graph, node, edge_cost_fn, goal=goal)


@bellman_ford_shortest_path_lengths.register(PyGraph)
def _graph_bellman_ford_shortest_path_lengths(graph, node, edge_cost_fn):
return graph_bellman_ford_shortest_path_lengths(graph, node, edge_cost_fn)
def _graph_bellman_ford_shortest_path_lengths(graph, node, edge_cost_fn, goal=None):
return graph_bellman_ford_shortest_path_lengths(graph, node, edge_cost_fn, goal=goal)
58 changes: 50 additions & 8 deletions src/shortest_path/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1208,13 +1208,16 @@ pub fn graph_unweighted_average_shortest_path_length(
/// :param edge_cost_fn: A python callable that will take in 1 parameter, an
/// edge's data object and will return a float that represents the
/// cost/weight of that edge. It can be negative.
/// :param int goal: An optional node index to use as the end of the path.
/// When specified the output dictionary will only have a single entry with
/// the length of the shortest path to the goal node.
///
/// :returns: A read-only dictionary of the shortest paths from the provided node where
/// the key is the node index of the end of the path and the value is the
/// cost/sum of the weights of path
/// :rtype: PathLengthMapping
///
/// :raises :class:`~retworkx.NegativeCycle`: when there is a negative cycle and the shortest
/// :raises: :class:`~retworkx.NegativeCycle`: when there is a negative cycle and the shortest
/// path is not defined.
#[pyfunction]
#[pyo3(text_signature = "(graph, node, edge_cost_fn, /)")]
Expand All @@ -1223,6 +1226,7 @@ pub fn digraph_bellman_ford_shortest_path_lengths(
graph: &digraph::PyDiGraph,
node: usize,
edge_cost_fn: PyObject,
goal: Option<usize>,
) -> PyResult<PathLengthMapping> {
let edge_weights: Vec<Option<f64>> =
edge_weights_from_callable(py, &graph.graph, &Some(edge_cost_fn), 1.0)?;
Expand All @@ -1246,6 +1250,19 @@ pub fn digraph_bellman_ford_shortest_path_lengths(

let res = res.unwrap();

if let Some(goal_usize) = goal {
return Ok(PathLengthMapping {
path_lengths: match res[goal_usize] {
Some(goal_length) => {
let mut ans = DictMap::new();
ans.insert(goal_usize, goal_length);
ans
}
None => DictMap::new(),
},
});
}

Ok(PathLengthMapping {
path_lengths: res
.into_iter()
Expand All @@ -1270,13 +1287,16 @@ pub fn digraph_bellman_ford_shortest_path_lengths(
/// :param edge_cost_fn: A python callable that will take in 1 parameter, an
/// edge's data object and will return a float that represents the
/// cost/weight of that edge. It can be negative.
/// :param int goal: An optional node index to use as the end of the path.
/// When specified the output dictionary will only have a single entry with
/// the length of the shortest path to the goal node.
///
/// :returns: A read-only dictionary of the shortest paths from the provided node where
/// the key is the node index of the end of the path and the value is the
/// cost/sum of the weights of path
/// :rtype: PathLengthMapping
///
/// :raises :class:`~retworkx.NegativeCycle`: when there is a negative cycle and the shortest
/// :raises: :class:`~retworkx.NegativeCycle`: when there is a negative cycle and the shortest
/// path is not defined.
#[pyfunction]
#[pyo3(text_signature = "(graph, node, edge_cost_fn, /)")]
Expand All @@ -1285,6 +1305,7 @@ pub fn graph_bellman_ford_shortest_path_lengths(
graph: &graph::PyGraph,
node: usize,
edge_cost_fn: PyObject,
goal: Option<usize>,
) -> PyResult<PathLengthMapping> {
let edge_weights: Vec<Option<f64>> =
edge_weights_from_callable(py, &graph.graph, &Some(edge_cost_fn), 1.0)?;
Expand All @@ -1308,6 +1329,19 @@ pub fn graph_bellman_ford_shortest_path_lengths(

let res = res.unwrap();

if let Some(goal_usize) = goal {
return Ok(PathLengthMapping {
path_lengths: match res[goal_usize] {
Some(goal_length) => {
let mut ans = DictMap::new();
ans.insert(goal_usize, goal_length);
ans
}
None => DictMap::new(),
},
});
}

Ok(PathLengthMapping {
path_lengths: res
.into_iter()
Expand All @@ -1330,6 +1364,7 @@ pub fn graph_bellman_ford_shortest_path_lengths(
///
/// :param PyGraph graph: The input graph to use
/// :param int source: The node index to find paths from
/// :param int target: An optional target to find a path to
/// :param weight_fn: An optional weight function for an edge. It will accept
/// a single argument, the edge's weight object and will return a float which
/// will be used to represent the weight/cost of the edge
Expand All @@ -1342,14 +1377,15 @@ pub fn graph_bellman_ford_shortest_path_lengths(
/// the dict values are lists of node indices making the path.
/// :rtype: PathMapping
///
/// :raises :class:`~retworkx.NegativeCycle`: when there is a negative cycle and the shortest
/// :raises: :class:`~retworkx.NegativeCycle`: when there is a negative cycle and the shortest
/// path is not defined.
#[pyfunction(default_weight = "1.0", as_undirected = "false")]
#[pyo3(text_signature = "(graph, source, /, weight_fn=None, default_weight=1.0)")]
pub fn graph_bellman_ford_shortest_paths(
py: Python,
graph: &graph::PyGraph,
source: usize,
target: Option<usize>,
weight_fn: Option<PyObject>,
default_weight: f64,
) -> PyResult<PathMapping> {
Expand Down Expand Up @@ -1379,7 +1415,7 @@ pub fn graph_bellman_ford_shortest_paths(
.iter()
.filter_map(|(k, v)| {
let k_int = k.index();
if k_int == source {
if k_int == source || target.is_some() && target.unwrap() != k_int {
None
} else {
Some((
Expand All @@ -1399,6 +1435,7 @@ pub fn graph_bellman_ford_shortest_paths(
///
/// :param PyDiGraph graph: The input graph to use
/// :param int source: The node index to find paths from
/// :param int target: An optional target to find a path to
/// :param weight_fn: An optional weight function for an edge. It will accept
/// a single argument, the edge's weight object and will return a float which
/// will be used to represent the weight/cost of the edge
Expand All @@ -1411,7 +1448,7 @@ pub fn graph_bellman_ford_shortest_paths(
/// the dict values are lists of node indices making the path.
/// :rtype: PathMapping
///
/// :raises :class:`~retworkx.NegativeCycle`: when there is a negative cycle and the shortest
/// :raises: :class:`~retworkx.NegativeCycle`: when there is a negative cycle and the shortest
/// path is not defined.
#[pyfunction(default_weight = "1.0", as_undirected = "false")]
#[pyo3(
Expand All @@ -1421,6 +1458,7 @@ pub fn digraph_bellman_ford_shortest_paths(
py: Python,
graph: &digraph::PyDiGraph,
source: usize,
target: Option<usize>,
weight_fn: Option<PyObject>,
default_weight: f64,
as_undirected: bool,
Expand All @@ -1430,6 +1468,7 @@ pub fn digraph_bellman_ford_shortest_paths(
py,
&graph.to_undirected(py, true, None)?,
source,
target,
weight_fn.map(|x| x.clone_ref(py)),
default_weight,
);
Expand Down Expand Up @@ -1461,10 +1500,13 @@ pub fn digraph_bellman_ford_shortest_paths(
.iter()
.filter_map(|(k, v)| {
let k_int = k.index();
if k_int == source {
if k_int == source || target.is_some() && target.unwrap() != k_int {
None
} else {
Some((k_int, v.iter().map(|x| x.index()).collect::<Vec<usize>>()))
Some((
k.index(),
v.iter().map(|x| x.index()).collect::<Vec<usize>>(),
))
}
})
.collect(),
Expand Down Expand Up @@ -1517,7 +1559,7 @@ pub fn negative_edge_cycle(
/// :return: A list of the nodes in an arbitrary negative cycle, if it exists
/// :rtype: NodeIndices
///
/// :raises ValueError: when there is no cycle in the graph provided
/// :raises: ValueError: when there is no cycle in the graph provided
#[pyfunction]
#[pyo3(text_signature = "(graph, edge_cost_fn, /)")]
pub fn find_negative_cycle(
Expand Down
28 changes: 28 additions & 0 deletions tests/digraph/test_bellman_ford.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,34 @@ def test_bellman_ford_with_no_goal_set(self):
expected = {1: 1.0, 2: 2.0, 3: 1.0, 4: 2.0, 5: 2.0}
self.assertEqual(expected, path)

def test_bellman_path(self):
path = retworkx.digraph_bellman_ford_shortest_paths(
self.graph, self.a, weight_fn=lambda x: float(x), target=self.e
)
expected = retworkx.digraph_dijkstra_shortest_paths(
self.graph, self.a, weight_fn=lambda x: float(x), target=self.e
)
self.assertEqual(expected, path)

def test_bellman_path_lengths(self):
path = retworkx.digraph_bellman_ford_shortest_path_lengths(
self.graph, self.a, lambda x: float(x), goal=self.e
)
expected = retworkx.digraph_dijkstra_shortest_path_lengths(
self.graph, self.a, lambda x: float(x), goal=self.e
)
self.assertEqual(expected, path)

def test_bellman_ford_length_with_no_path_and_goal(self):
g = retworkx.PyDiGraph()
a = g.add_node("A")
b = g.add_node("B")
path_lenghts = retworkx.digraph_bellman_ford_shortest_path_lengths(
g, a, edge_cost_fn=float, goal=b
)
expected = retworkx.digraph_dijkstra_shortest_path_lengths(g, a, edge_cost_fn=float, goal=b)
self.assertEqual(expected, path_lenghts)

def test_bellman_ford_with_no_path(self):
g = retworkx.PyDiGraph()
a = g.add_node("A")
Expand Down
28 changes: 28 additions & 0 deletions tests/graph/test_bellman_ford.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,34 @@ def test_bellman_ford_with_no_goal_set(self):
expected = retworkx.graph_dijkstra_shortest_path_lengths(self.graph, self.a, lambda x: 1)
self.assertEqual(expected, path)

def test_bellman_path(self):
path = retworkx.graph_bellman_ford_shortest_paths(
self.graph, self.a, weight_fn=lambda x: float(x), target=self.e
)
expected = retworkx.graph_dijkstra_shortest_paths(
self.graph, self.a, weight_fn=lambda x: float(x), target=self.e
)
self.assertEqual(expected, path)

def test_bellman_path_lengths(self):
path = retworkx.graph_bellman_ford_shortest_path_lengths(
self.graph, self.a, lambda x: float(x), goal=self.e
)
expected = retworkx.graph_dijkstra_shortest_path_lengths(
self.graph, self.a, lambda x: float(x), goal=self.e
)
self.assertEqual(expected, path)

def test_bellman_ford_length_with_no_path_and_goal(self):
g = retworkx.PyGraph()
a = g.add_node("A")
b = g.add_node("B")
path_lenghts = retworkx.graph_bellman_ford_shortest_path_lengths(
g, a, edge_cost_fn=float, goal=b
)
expected = retworkx.graph_dijkstra_shortest_path_lengths(g, a, edge_cost_fn=float, goal=b)
self.assertEqual(expected, path_lenghts)

def test_bellman_ford_length_with_no_path(self):
g = retworkx.PyGraph()
a = g.add_node("A")
Expand Down

0 comments on commit e7a7cfc

Please sign in to comment.