In [1]:
#export
from .edges import Edge
from ..layer import Concat, Add
from ..generator import Generator
from graphviz import Digraph

In [1]:
from include.graph.edges import Edge
# from math import ceil, sqrt
from graphviz import Digraph
from include.generator import Generator
from include.layer import Concat, Add

# Graph

In [20]:
#export
class Graph:
    def __init__(self, input_shape, output_shape):
        self.nodes={}
        self.edges={}
        self._node_index = 0
        self._edge_index = 0
        
        self.input = self.add_node(input_shape)
        self.output = self.add_node(output_shape)
        
    
    def add_node(self, shape, multi_input=False, layer=None):
        node = Node(shape, multi_input, layer)
        self._node_index = self._node_index + 1
        self.nodes[self._node_index] = node
        return self._node_index
    
    def insert_node(self, src, multi_input=False, edge=None, layer=None):
        if multi_input:
            if layer == None: raise Exception('Multi-input node requires layer.')
                
            # Create new node
            new_node = self.add_node(self.nodes[src].shape, multi_input, layer)
            
            # Redirect edges
            self.nodes[new_node].out_edge = self.nodes[src].out_edge
            for edge in self.nodes[new_node].out_edge:
                self.edges[edge].src = new_node
            self.nodes[src].out_edge = []
            
            # Add edge between new and src node
            self._add_edge(src, new_node, identical=True)
        else:
            # Add new node
            new_node = self.add_node(shape=self.nodes[src].shape)
            
            # Redirect edges
            self.nodes[new_node].out_edge = self.nodes[src].out_edge
            for edge_id in self.nodes[new_node].out_edge:
                self.edges[edge_id].src = new_node
            self.nodes[src].out_edge = []
            
            # Attach edge between src and new node.
            edge.src = src
            edge.dest = new_node
            self.add_edge(edge)
        
        return new_node
    
    def _add_edge(self, src, dest, layer=None, identical=False):
        edge = Edge(src, dest, layer, identical)
        return self.add_edge(edge)
    
    def add_edge(self, edge):
        self._edge_index = self._edge_index + 1
        self.edges[self._edge_index] = edge
        
        self.nodes[edge.src].add_out_edge(self._edge_index)
        self.nodes[edge.dest].add_in_edge(self._edge_index)
        
        return self._edge_index
      
    def visualize(self, filename, path):
        digraph = Digraph(comment="Model")
        for node in self.nodes:
            digraph.node(str(node), label=str(node) + str(self.nodes[node].shape))
        for id in self.edges:
            edge = self.edges[id]
            digraph.edge(str(edge.src), str(edge.dest), label="id" if edge.identical else str(edge.as_layer()))
        
        digraph.format='svg'
        digraph.filename=filename
        digraph.directory=path
        digraph.render(view=False)

    def _node_as_layer(self, id):
        node = self.nodes[id]
        layer = node.layer()
        inputs = []
        
        for edge_id in node.in_edge:
            inputs.append(self.edges[edge_id].src)
        
        return (layer, inputs, id)
    
    def _update_node_shape(self, id):
        node = self.nodes[id]
        if node.multi_input:
            if node.layer is Concat:
                method = lambda x, y: y if x == None else (x[0] + y[0],) + x[1:]
            else:
                method = lambda x, y: y
            shape = None
            for edge_id in node.in_edge:
                edge = self.edges[edge_id]
                shape = method(shape, self.nodes[edge.src].shape)
        else:
            edge = self.edges[node.in_edge[0]]
            in_shape = self.nodes[edge.src].shape
            shape = edge.calculate_output(in_shape)
        self.nodes[id].shape = shape
    
    def _reverse_traversal(self, id, visited):
#         print('traveling on {}'.format(id))
        ts = []
        visited[id] = True
#         print('current visited: {}'.format(visited))
        for edge_id in self.nodes[id].in_edge:
            edge = self.edges[edge_id]
            if not visited[edge.src]: ts.extend(self._reverse_traversal(edge.src, visited))
            if not edge.identical: ts.append((edge.as_layer(), edge.src, edge.dest))
        if self.nodes[id].multi_input:
            ts.append((self._node_as_layer(id)))
#         print('on {}: {}'.format(id, ts))
        return ts
    
    def generate_model(self):
        visited = {}
        for key in self.nodes.keys():
            visited[key] = False
        
        ts = self._reverse_traversal(2, visited)
#         print(ts)
        return Generator(ts)   

In [3]:
#export
class Node:
    def __init__(self, shape, multi_input=False, layer=None):
        self.shape = shape
        self.multi_input = multi_input
        self.in_edge = []
        self.out_edge = []
        self.layer = layer
        
    def add_in_edge(self, edge):
        self.in_edge.append(edge)
    
    def add_out_edge(self, edge):
        self.out_edge.append(edge)
    
    def num_output(self):
        return len(self.out_edge)
            
    def set_shape(self, shape):
        self.shape = shape

# Export

In [23]:
!python nb2py.py graph.ipynb

Converted graph.ipynb to exp/nb_graph.py


In [5]:
gr = Graph((1, 28, 28), (10,))

In [6]:
gr.visualize('asdfasf', './')

In [7]:
gr.generate_model()

Generator(
  (layers): ModuleDict()
)