Skip to content

Commit

Permalink
added test for overwriting of graphs with using keywords
Browse files Browse the repository at this point in the history
  • Loading branch information
eileen-kuehn committed Jul 31, 2018
1 parent cb39fdd commit 0d78f2a
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions graphi_unittests/types_unittests/test_adjacency_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import graphi.types.adjacency_graph
import graphi.abc
from graphi import operators
from graphi.types.decorator import undirectable

from . import _graph_interface_mixins as mixins

Expand All @@ -16,6 +17,32 @@ class TestAdjacencyGraphInterface(mixins.Mixin.GraphInitMixin, mixins.Mixin.Grap
graph_cls = graphi.types.adjacency_graph.AdjacencyGraph


@undirectable
class DistanceGraph(graphi.types.adjacency_graph.AdjacencyGraph):
def __init__(self, *source, distance=lambda *args: 1, **kwargs):
self.distance = distance
super(DistanceGraph, self).__init__(*source, **kwargs)

def __getitem__(self, item):
if isinstance(item, slice):
assert item.step is None, '%s does not support stride argument for edges' % self.__class__.__name__
node_from, node_to = item.start, item.stop
if node_from not in self._adjacency:
raise graphi.abc.EdgeError # first edge node
elif node_to not in self._adjacency:
raise graphi.abc.EdgeError # second edge node
# Since we don't know the type of nodes, we cannot test
# node_to > node_from to detect swapped pairs. Since we
# *do* store nodes in a `set`, they must support hash.
if self.undirected and hash(node_to) > hash(node_from):
node_to, node_from = node_from, node_to
return self.distance(node_from, node_to)
else:
if item not in self:
raise graphi.abc.NodeError
return {candidate: self[item:candidate] for candidate in self if candidate != item}


class TestAdjacencyGraph(unittest.TestCase):
# distance graph class to test
graph_cls = graphi.types.adjacency_graph.AdjacencyGraph
Expand Down Expand Up @@ -316,3 +343,8 @@ def test_value_view_directed(self):
self.assertTrue(value in value_view)
self.assertFalse(3 in value_view)
self.assertEqual(len(nodes) * 2, len(value_view))

def test_graph_customisation(self):
graph = DistanceGraph([1, 2], undirected=True)
self.assertEquals(1, graph[1:2])
self.assertEquals(1, graph.distance(1, 2))

0 comments on commit 0d78f2a

Please sign in to comment.