In [1]:
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt

from graphviz import Digraph

In [2]:
# -------------------- utils ---------------------------------

class Stack(object):
    """栈"""

    def __init__(self):
        self.__data = []

    def __repr__(self):
        return 'Stack ' + self.__data.__repr__().strip(']')

    def put(self, item):
        self.__data.append(item)

    def pop(self):
        return self.__data.pop()

    def peep(self):
        return self.__data[-1]

    @property
    def empty(self):
        return len(self.__data) == 0

class Queue(object):
    """队列"""

    def __init__(self):
        self.__data = []

    def __repr__(self):
        return 'Queue\n' + self.__data.__repr__().strip('[|]')

    def put(self, item):
        self.__data.append(item)

    def pop(self):
        return self.__data.pop(0)

    @property
    def empty(self):
        return len(self.__data) == 0



class Graph(object):
    """有向图"""
    
    def __init__(self, mtx=None, tb=None):
        self.n_vertex = None
        self.adjacent_matrix = None
        
        if mtx is not None:
            self.from_matrix(mtx)
        elif tb is not None:
            self.from_table(tb)
        else:
            pass
    
    # -------------- Construction ------------------------------------
    
    def from_matrix(self, mtx):
        """通过邻接矩阵初始化
        """
        assert isinstance(mtx, np.ndarray), "`mtx` must be np.ndarray"
        n_vertex = len(mtx)
        assert mtx.shape == (n_vertex, n_vertex), "invalid mtx.shape"
        
        # 缺失值，填充成 Inf
        mtx = mtx.astype(float)
        mtx[np.isnan(mtx)] = np.inf
        
        # 没有负权重
        assert (mtx >= 0).all(), "all item mtx must be >= 0"
        
        self.n_vertex = n_vertex
        self.adjacent_matrix = mtx
        
    def from_table(self, tb):
        """通过边集数组初始化
        """
        pass
        
    def add_vertex(self, idx):
        pass
    
    def add_arc(self, tail, head):
        pass
        
    # -------------- Traversal ------------------------------------
    
    def _iter_out_arc(self, vertex_id):
        """依次迭代顶点vertex的每个出边的迭代器
        """
        
        for j in range(self.n_vertex):
            weight = self.adjacent_matrix[vertex_id, j]
            if weight < np.inf:
                yield (vertex_id, j, weight)

    def DepthFirstTraverse(self, start=0):
        """深度优先遍历，栈
        """
        stack = Stack()
        visited = [False] * self.n_vertex
        
        stack.put(start)
        yield start
        visited[start] = True
        
        iter_list = [self._iter_out_arc(i) for i in range(self.n_vertex)]   # 迭代器列表

        while not stack.empty:
            vertex_id = stack.peep()
            iterator_ = iter_list[vertex_id]
            # print('--> 栈顶元素为 ', vertex_id)     
            try:
                _, v, _ = next(iterator_)
                if not visited[v]:
                    stack.put(v)
                    visited[v] = True
                    yield v
            except StopIteration:
                _ = stack.pop()
                # print('--> 子树遍历完毕，pop ', _)
    
    def BreadthFristTraverse(self, start=0):
        """广度优先遍历，队列
        """
        q = Queue()
        visited = [False] * self.n_vertex
        
        q.put(start)
        while not q.empty:
            vertex = q.pop()
            if not visited[vertex]:
                yield vertex
                visited[vertex] = True
            for _, v, _ in self._iter_out_arc(vertex):
                if not visited[v]:
                    q.put(v)
            # print(q)
            
    # -------------- Algorithms ------------------------------------
    
    def MinCostSpanningTree_Prim(self):
        pass
    
    def MinCostSpanningTree_Kruskal(self):
        """最小生成树，Kruskal算法
        
        基于边
        寻找（当前最小cost、且连接两个已连通分量）的 边
        """
        
    
    def ShortestPath_Dijkstra(self):
        pass
    
    def ShortestPath_Floyd(self):
        pass
    
    # -------------- plot use graphviz ------------------------------------
    
    def plot(self):
        pass
    

class AOV(Graph):
    """AOV Network
    """
    
    def TopologicalSeq(self):
        """生成一个or所有拓扑序列，顺便检查有没有环。
        
        找到入度为零的顶点，删除该顶点及其出边；循环，直到没有入度为零的点为止
        """
        pass


class AOE(Graph):
    """AOE Network
    """
    
    def CriticalPath(self):
        """关键路径
        
        - 每个顶点的 etv, ltv
            + 先求出拓扑序列
            + etv：从源点开始向后依次计算
            + ltv：从汇点开始向前依次计算
        - 每条边的 ete， lte
            + ete = etv(tail)
            + lte = ltv(head) - len<>
        - ete == lte 即为关键路径。todo:关键路径不止一条？
        """
        pass

In [5]:
x = np.array([[np.inf,      1,      1,      1, np.inf, np.inf], 
              [np.inf, np.inf,      1, np.inf,      1,      1], 
              [np.inf, np.inf, np.inf, np.inf, np.inf, np.inf], 
              [np.inf, np.inf, np.inf, np.inf, np.inf, np.inf], 
              [np.inf, np.inf, np.inf, np.inf, np.inf, np.inf], 
              [np.inf, np.inf, np.inf, np.inf, np.inf, np.inf], 
             ])

g = Graph(x)

for i in g.DepthFirstTraverse():
    print(i)
    
print('=============')

for i in g.BreadthFristTraverse():
    print(i)

0
1
2
4
5
3
0
1
2
3
4
5


In [12]:
gg = Digraph('x', format='png')

n = len(x)
for i in range(n):
    for j in range(n):
        if x[i, j] < np.inf:
            gg.edge(str(i), str(j))
            
gg.render()

'x.gv.png'

![](x.gv.png)

## TODO

also traverse edges

vertex and edges with attributes