Skip to content

Commit

Permalink
Implement mapping protocol methods for graph classes (#119)
Browse files Browse the repository at this point in the history
This commit adds implementations for the __getitem__, __setitem__, and
__delitem__ mapping protocol methods to the PyGraph and PyDiGraph
classes. This enables accessing nodes from the graph by using standard
dict/mapping access patterns (ie graph[2]) with the node index. This
makes it easier to work with methods and functions that return node
indexes without having to call helper methods.
  • Loading branch information
mtreinish committed Aug 25, 2020
1 parent 4f86ed9 commit d321588
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 0 deletions.
25 changes: 25 additions & 0 deletions src/digraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,31 @@ impl PyMappingProtocol for PyDiGraph {
fn __len__(&self) -> PyResult<usize> {
Ok(self.graph.node_count())
}
fn __getitem__(&'p self, idx: usize) -> PyResult<&'p PyObject> {
match self.graph.node_weight(NodeIndex::new(idx as usize)) {
Some(data) => Ok(data),
None => Err(IndexError::py_err("No node found for index")),
}
}

fn __setitem__(&'p mut self, idx: usize, value: PyObject) -> PyResult<()> {
let data = match self
.graph
.node_weight_mut(NodeIndex::new(idx as usize))
{
Some(node_data) => node_data,
None => return Err(IndexError::py_err("No node found for index")),
};
*data = value;
Ok(())
}

fn __delitem__(&'p mut self, idx: usize) -> PyResult<()> {
match self.graph.remove_node(NodeIndex::new(idx as usize)) {
Some(_) => Ok(()),
None => Err(IndexError::py_err("No node found for index")),
}
}
}

fn is_cycle_check_required(
Expand Down
25 changes: 25 additions & 0 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -712,4 +712,29 @@ impl PyMappingProtocol for PyGraph {
fn __len__(&self) -> PyResult<usize> {
Ok(self.graph.node_count())
}
fn __getitem__(&'p self, idx: usize) -> PyResult<&'p PyObject> {
match self.graph.node_weight(NodeIndex::new(idx as usize)) {
Some(data) => Ok(data),
None => Err(IndexError::py_err("No node found for index")),
}
}

fn __setitem__(&'p mut self, idx: usize, value: PyObject) -> PyResult<()> {
let data = match self
.graph
.node_weight_mut(NodeIndex::new(idx as usize))
{
Some(node_data) => node_data,
None => return Err(IndexError::py_err("No node found for index")),
};
*data = value;
Ok(())
}

fn __delitem__(&'p mut self, idx: usize) -> PyResult<()> {
match self.graph.remove_node(NodeIndex::new(idx as usize)) {
Some(_) => Ok(()),
None => Err(IndexError::py_err("No node found for index")),
}
}
}
54 changes: 54 additions & 0 deletions tests/graph/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,57 @@ def test_add_node_from_empty(self):
graph = retworkx.PyGraph()
res = graph.add_nodes_from([])
self.assertEqual(len(res), 0)

def test_get_node_data_getitem(self):
graph = retworkx.PyGraph()
node_a = graph.add_node('a')
node_b = graph.add_node('b')
graph.add_edge(node_a, node_b, "Edgy")
self.assertEqual('b', graph[node_b])

def test_get_node_data_getitem_bad_index(self):
graph = retworkx.PyGraph()
node_a = graph.add_node('a')
node_b = graph.add_node('b')
graph.add_edge(node_a, node_b, "Edgy")
with self.assertRaises(IndexError):
graph[42]

def test_set_node_data_setitem(self):
graph = retworkx.PyGraph()
node_a = graph.add_node('a')
node_b = graph.add_node('b')
graph.add_edge(node_a, node_b, "Edgy")
graph[node_b] = 'Oh so cool'
self.assertEqual('Oh so cool', graph[node_b])

def test_set_node_data_setitem_bad_index(self):
graph = retworkx.PyGraph()
node_a = graph.add_node('a')
node_b = graph.add_node('b')
graph.add_edge(node_a, node_b, "Edgy")
with self.assertRaises(IndexError):
graph[42] = 'Oh so cool'

def test_remove_node_delitem(self):
graph = retworkx.PyGraph()
node_a = graph.add_node('a')
node_b = graph.add_node('b')
graph.add_edge(node_a, node_b, "Edgy")
node_c = graph.add_node('c')
graph.add_edge(node_b, node_c, "Edgy_mk2")
del graph[node_b]
res = graph.nodes()
self.assertEqual(['a', 'c'], res)
self.assertEqual([0, 2], graph.node_indexes())

def test_remove_node_delitem_invalid_index(self):
graph = retworkx.PyGraph()
graph.add_node('a')
graph.add_node('b')
graph.add_node('c')
with self.assertRaises(IndexError):
del graph[76]
res = graph.nodes()
self.assertEqual(['a', 'b', 'c'], res)
self.assertEqual([0, 1, 2], graph.node_indexes())
48 changes: 48 additions & 0 deletions tests/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,51 @@ def test_add_node_from_empty(self):
dag = retworkx.PyDAG()
res = dag.add_nodes_from([])
self.assertEqual(len(res), 0)

def test_get_node_data_getitem(self):
dag = retworkx.PyDAG()
node_a = dag.add_node('a')
node_b = dag.add_child(node_a, 'b', "Edgy")
self.assertEqual('b', dag[node_b])

def test_get_node_data_getitem_bad_index(self):
dag = retworkx.PyDAG()
node_a = dag.add_node('a')
dag.add_child(node_a, 'b', "Edgy")
with self.assertRaises(IndexError):
dag[42]

def test_set_node_data_setitem(self):
dag = retworkx.PyDAG()
node_a = dag.add_node('a')
node_b = dag.add_child(node_a, 'b', "Edgy")
dag[node_b] = 'Oh so cool'
self.assertEqual('Oh so cool', dag[node_b])

def test_set_node_data_setitem_bad_index(self):
dag = retworkx.PyDAG()
node_a = dag.add_node('a')
dag.add_child(node_a, 'b', "Edgy")
with self.assertRaises(IndexError):
dag[42] = 'Oh so cool'

def test_remove_node_delitem(self):
dag = retworkx.PyDAG()
node_a = dag.add_node('a')
node_b = dag.add_child(node_a, 'b', "Edgy")
dag.add_child(node_b, 'c', "Edgy_mk2")
del dag[node_b]
res = dag.nodes()
self.assertEqual(['a', 'c'], res)
self.assertEqual([0, 2], dag.node_indexes())

def test_remove_node_delitem_invalid_index(self):
graph = retworkx.PyDAG()
graph.add_node('a')
graph.add_node('b')
graph.add_node('c')
with self.assertRaises(IndexError):
del graph[76]
res = graph.nodes()
self.assertEqual(['a', 'b', 'c'], res)
self.assertEqual([0, 1, 2], graph.node_indexes())

0 comments on commit d321588

Please sign in to comment.