Skip to content

Commit

Permalink
Add node labels to DiGraph
Browse files Browse the repository at this point in the history
  • Loading branch information
oyamad committed Mar 15, 2016
1 parent 859729f commit 152505c
Showing 1 changed file with 64 additions and 8 deletions.
72 changes: 64 additions & 8 deletions quantecon/graph_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ class DiGraph(object):
weighted : bool, optional(default=False)
Whether to treat `adj_matrix` as a weighted adjacency matrix.
node_labels : array_like(ndim=1, default=None)
Array_like of length n containing the label associated with each
node. If None, the labels default to integers 0 through n-1.
Attributes
----------
csgraph : scipy.sparse.csr_matrix
Expand Down Expand Up @@ -70,7 +74,7 @@ class DiGraph(object):
"""

def __init__(self, adj_matrix, weighted=False):
def __init__(self, adj_matrix, weighted=False, node_labels=None):
if weighted:
dtype = None
else:
Expand All @@ -83,6 +87,10 @@ def __init__(self, adj_matrix, weighted=False):

self.n = n # Number of nodes

self._node_labels = None
if node_labels is not None:
self.node_labels = node_labels

self._num_scc = None
self._scc_proj = None
self._sink_scc_labels = None
Expand All @@ -95,6 +103,19 @@ def __repr__(self):
def __str__(self):
return "Directed Graph:\n - n(number of nodes): {n}".format(n=self.n)

@property
def node_labels(self):
return self._node_labels

@node_labels.setter
def node_labels(self, values):
if len(values) != self.n:
raise ValueError('node_labels must be of length n')
self._node_labels = np.asarray(values)

def label_nodes(self, list_of_components):
return [self.node_labels[c] for c in list_of_components]

def _find_scc(self):
"""
Set ``self._num_scc`` and ``self._scc_proj``
Expand Down Expand Up @@ -169,22 +190,42 @@ def sink_scc_labels(self):
def num_sink_strongly_connected_components(self):
return len(self.sink_scc_labels)

@property
def strongly_connected_components(self):
# strongly_connected_components
def _get_strongly_connected_components(self):
if self.is_strongly_connected:
return [np.arange(self.n)]
else:
return [np.where(self.scc_proj == k)[0]
for k in range(self.num_strongly_connected_components)]

def get_strongly_connected_components(self, return_labels=True):
if return_labels:
return self.label_nodes(self._get_strongly_connected_components())
return self._get_strongly_connected_components()

@property
def sink_strongly_connected_components(self):
def strongly_connected_components(self):
return self.get_strongly_connected_components()

# sink_strongly_connected_components
def _get_sink_strongly_connected_components(self):
if self.is_strongly_connected:
return [np.arange(self.n)]
else:
return [np.where(self.scc_proj == k)[0]
for k in self.sink_scc_labels.tolist()]

def get_sink_strongly_connected_components(self, return_labels=True):
if return_labels:
return self.label_nodes(
self._get_sink_strongly_connected_components()
)
return self._get_sink_strongly_connected_components()

@property
def sink_strongly_connected_components(self):
return self.get_sink_strongly_connected_components()

def _compute_period(self):
"""
Set ``self._period`` and ``self._cyclic_components_proj``.
Expand Down Expand Up @@ -255,14 +296,23 @@ def period(self):
def is_aperiodic(self):
return (self.period == 1)

@property
def cyclic_components(self):
# cyclic_components
def _get_cyclic_components(self):
if self.is_aperiodic:
return [np.arange(self.n)]
else:
return [np.where(self._cyclic_components_proj == k)[0]
for k in range(self.period)]

def get_cyclic_components(self, return_labels=True):
if return_labels:
return self.label_nodes(self._get_cyclic_components())
return self._get_cyclic_components()

@property
def cyclic_components(self):
return self.get_cyclic_components()

def subgraph(self, nodes):
"""
Return the subgraph consisting of the given nodes and edges
Expand All @@ -271,7 +321,7 @@ def subgraph(self, nodes):
Parameters
----------
nodes : array_like(int, ndim=1)
Array of nodes.
Array of node indices.
Returns
-------
Expand All @@ -282,7 +332,13 @@ def subgraph(self, nodes):
adj_matrix = self.csgraph[nodes, :][:, nodes]

weighted = True # To copy the dtype
return DiGraph(adj_matrix, weighted=weighted)

if self.node_labels is not None:
node_labels = self.node_labels[nodes]
else:
node_labels = None

return DiGraph(adj_matrix, weighted=weighted, node_labels=node_labels)


def _csr_matrix_indices(S):
Expand Down

0 comments on commit 152505c

Please sign in to comment.