Skip to content

Commit

Permalink
Return adjacent nodes in both directions in PyDiGraph.adj (#437)
Browse files Browse the repository at this point in the history
* Return adjacent nodes in both directions in `PyDiGraph.adj`
Documentation of `PyDiGraph.adj` states that the ouput dictionary
contains adjacent nodes in either direction. But previously, only
outbound neighbors were included. This has been fixed. At the same time,
this commit simplifies the source code of `PyDiGraph.adj_direction` and `PyGraph.adj`.

* cargo fmt
  • Loading branch information
georgios-ts committed Sep 6, 2021
1 parent a0e38ef commit f2f3a09
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 52 deletions.
7 changes: 7 additions & 0 deletions releasenotes/notes/bugfix-digraph-adj-b8a911fab80effec.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
fixes:
- |
Fixes the output of :meth:`~retworkx.PyDiGraph.adj` to include
neighbors that have an edge between them and the specified node,
in either direction. Previously, only outbound nodes were included.
61 changes: 19 additions & 42 deletions src/digraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1553,18 +1553,15 @@ impl PyDiGraph {
#[pyo3(text_signature = "(self, node, /)")]
pub fn adj(&mut self, node: usize) -> HashMap<usize, &PyObject> {
let index = NodeIndex::new(node);
let neighbors = self.graph.neighbors(index);
let mut out_map: HashMap<usize, &PyObject> = HashMap::new();
for neighbor in neighbors {
let mut edge = self.graph.find_edge(index, neighbor);
// If there is no edge then it must be a parent neighbor
if edge.is_none() {
edge = self.graph.find_edge(neighbor, index);
}
let edge_w = self.graph.edge_weight(edge.unwrap());
out_map.insert(neighbor.index(), edge_w.unwrap());
}
out_map
self.graph
.edges_directed(index, petgraph::Direction::Incoming)
.map(|edge| (edge.source().index(), edge.weight()))
.chain(
self.graph
.edges_directed(index, petgraph::Direction::Outgoing)
.map(|edge| (edge.target().index(), edge.weight())),
)
.collect()
}

/// Get the index and data for either the parent or children of a node.
Expand All @@ -1588,39 +1585,19 @@ impl PyDiGraph {
&mut self,
node: usize,
direction: bool,
) -> PyResult<HashMap<usize, &PyObject>> {
) -> HashMap<usize, &PyObject> {
let index = NodeIndex::new(node);
let dir = if direction {
petgraph::Direction::Incoming
if direction {
self.graph
.edges_directed(index, petgraph::Direction::Incoming)
.map(|edge| (edge.source().index(), edge.weight()))
.collect()
} else {
petgraph::Direction::Outgoing
};
let neighbors = self.graph.neighbors_directed(index, dir);
let mut out_map: HashMap<usize, &PyObject> = HashMap::new();
for neighbor in neighbors {
let edge = if direction {
match self.graph.find_edge(neighbor, index) {
Some(edge) => edge,
None => {
return Err(NoEdgeBetweenNodes::new_err(
"No edge found between nodes",
))
}
}
} else {
match self.graph.find_edge(index, neighbor) {
Some(edge) => edge,
None => {
return Err(NoEdgeBetweenNodes::new_err(
"No edge found between nodes",
))
}
}
};
let edge_w = self.graph.edge_weight(edge);
out_map.insert(neighbor.index(), edge_w.unwrap());
self.graph
.edges_directed(index, petgraph::Direction::Outgoing)
.map(|edge| (edge.target().index(), edge.weight()))
.collect()
}
Ok(out_map)
}

/// Get the neighbors (i.e. successors) of a node.
Expand Down
15 changes: 5 additions & 10 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1035,17 +1035,12 @@ impl PyGraph {
/// edge with the specified node.
/// :rtype: dict
#[pyo3(text_signature = "(self, node, /)")]
pub fn adj(&mut self, node: usize) -> PyResult<HashMap<usize, &PyObject>> {
pub fn adj(&mut self, node: usize) -> HashMap<usize, &PyObject> {
let index = NodeIndex::new(node);
let neighbors = self.graph.neighbors(index);
let mut out_map: HashMap<usize, &PyObject> = HashMap::new();

for neighbor in neighbors {
let edge = self.graph.find_edge(index, neighbor);
let edge_w = self.graph.edge_weight(edge.unwrap());
out_map.insert(neighbor.index(), edge_w.unwrap());
}
Ok(out_map)
self.graph
.edges_directed(index, petgraph::Direction::Outgoing)
.map(|edge| (edge.target().index(), edge.weight()))
.collect()
}

/// Get the neighbors of a node.
Expand Down
6 changes: 6 additions & 0 deletions tests/digraph/test_adj.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ def test_single_neighbor(self):
res = dag.adj(node_a)
self.assertEqual({node_b: {"a": 1}, node_c: {"a": 2}}, res)

def test_in_and_out_adj_neighbor(self):
dag = retworkx.PyDAG()
dag.extend_from_weighted_edge_list([(0, 1, "a"), (1, 2, "b")])
res = dag.adj(1)
self.assertEqual({0: "a", 2: "b"}, res)

def test_single_neighbor_dir(self):
dag = retworkx.PyDAG()
node_a = dag.add_node("a")
Expand Down

0 comments on commit f2f3a09

Please sign in to comment.