In [1]:
import torch
from torchsummary import summary
from exp.nb_layer import *
from math import ceil

# Graph

In [23]:
#export
class Graph:
    def __init__(self, input_shape, output_shape):
        self.nodes={}
        self.edges={}
        self._node_index = 0
        self._edge_index = 0
        
        self.add_node(Node(input_shape))
        self.add_node(Node(output_shape))
        
    
    def add_node(self, node):
        self._node_index = self._node_index + 1
        self.nodes[self._node_index] = node
        return self._node_index
    
    def add_edge(self, src, dest, edge):
        self._edge_index = self._edge_index + 1
        self.edges[self._edge_index] = edge
        
        self.nodes[src].add_out_edge(self._edge_index)
        self.nodes[dest].add_in_edge(self._edge_index)
        
        return self._edge_index
    
    def add_conv_edge(self, src, dest, ks=3):
        input_shape = self.nodes[src].shape
        output_shape = self.nodes[dest].shape
        
        ni, nf = input_shape[0], output_shape[0]
        pd = ks//2
        st = ceil(input_shape[1] / output_shape[1])
        
        layers = [
            torch.nn.Conv2d(ni, nf, ks, st, pd),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(nf)
        ]
        
        edge = Edge(src, dest, torch.nn.Sequential(*layers))
        return self.add_edge(src, dest, edge)
    
    def add_linear_edge(self, src, dest):
        ni = self.nodes[src].shape[0]
        no = self.nodes[dest].shape[0]
        
        edge = Edge(src, dest, torch.nn.Linear(ni, no))
        return self.add_edge(src, dest, edge)
    
    def add_flatten_edge(self, src, dest):
        edge = Edge(src, dest, Flatten())
        return self.add_edge(src, dest, edge)
    
    def add_identical_edge(self, src, dest):
        edge = Edge(src, dest, None, )
        return
    
    def _node_as_layer(self, id):
        node = self.nodes[id]
        layer = node.layer
        inputs = []
        
        for edge_id in node.in_edge:
            inputs.append(self.edges[edge_id].src)
        
        return (layer, inputs, id)
            
            
    
    def _reverse_traversal(self, id, visited):
        ts = []
        visited[id] = True
        for edge_id in self.nodes[id].in_edge:
            edge = self.edges[edge_id]
            if visited[edge.src]: ts.extend(edge_id)
            else: ts.extend(self._reverse_traversal(edge.src, visited))
            
            if not edge.identical: ts.append((edge.as_layer(), edge.src, edge.dest))
        if self.nodes[id].multi_input:
            ts.append((self._node_as_layer(id)))
        return ts
    
    
    def generate_model(self):
        visited = {}
        for key in self.nodes.keys():
            visited[key] = False
        
        ts = self._reverse_traversal(2, visited)
        return Generator(ts)
    
#         it = 1
#         module = []
#         while(it != 2):
#             print('current', it)
#             node = self.nodes[it]
#             print(node.out_edge)
#             if(len(node.out_edge) == 1):
#                 edge = self.edges[node.out_edge[0]]
#                 module.append(edge.as_layer())
#                 it = edge.dest
#             else:
#                 (submodule, it) = self._generate_submodel(node)
        
#         return torch.nn.Sequential(*module)


        

In [24]:
#export
class Node:
    def __init__(self, shape, multi_input=False, layer=None):
        self.shape = shape
        self.multi_input = multi_input
        self.in_edge = []
        self.out_edge = []
        self.layer = layer
        
    def set_index(self, index):
        self.index = index

    def add_in_edge(self, edge):
        self.in_edge.append(edge)
    
    def add_out_edge(self, edge):
        self.out_edge.append(edge)
    
    def num_output(self):
        return len(self.out_edge)
            
    def set_shape(self, shape):
        self.shape = shape

In [38]:
#export
class Edge:
    def __init__(self, src, dest, layer, identical=False):
        self.src = src
        self.dest = dest
        self.identical = identical
        if not identical:
            self.layer = layer
    
    def as_layer(self):
        return self.layer
    
    def forward(self, x):
        return self.layer(x)

# Additional functions

In [None]:
def add_convolution_array(gr, ni, nf, ks, pd, str)

# Export

In [36]:
!python ../nb2py.py graph.ipynb

Converted graph.ipynb to exp/nb_graph.py


# Test
vanilla CNN

In [26]:
gr = Graph(input_shape=(1, 28, 28), output_shape=(10,))

In [27]:
acts = [
    gr.add_node(Node((16, 14, 14))),
    gr.add_node(Node((32, 7, 7))),
    gr.add_node(Node((32*7*7,))),
    gr.add_node(Node((8*7*7,)))
]

In [28]:
gr.add_conv_edge(1, acts[0])

1

In [29]:
gr.add_conv_edge(acts[0], acts[1])

2

In [30]:
gr.add_flatten_edge(acts[1], acts[2])

3

In [31]:
gr.add_linear_edge(acts[2], acts[3])

4

In [32]:
gr.add_linear_edge(acts[3], 2)

5

In [33]:
module = gr.generate_model()

In [34]:
list(module.children())

[ModuleDict(
   (0): Sequential(
     (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
     (1): ReLU()
     (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   )
   (1): Sequential(
     (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
     (1): ReLU()
     (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   )
   (2): Flatten()
   (3): Linear(in_features=1568, out_features=392, bias=True)
   (4): Linear(in_features=392, out_features=10, bias=True)
 )]

In [35]:
from fastai.vision import *

In [15]:
path = untar_data(URLs.MNIST)

In [16]:
data = ImageList.from_folder(path, convert_mode="L").split_by_folder(train="training", valid="testing").label_from_folder().databunch(bs=128)

In [17]:
learn = Learner(data, module, loss_func=nn.CrossEntropyLoss(), metrics=accuracy)

In [18]:
learn.summary()

Generator
Layer (type)         Output Shape         Param #    Trainable 
Conv2d               [16, 14, 14]         160        True      
______________________________________________________________________
ReLU                 [16, 14, 14]         0          False     
______________________________________________________________________
BatchNorm2d          [16, 14, 14]         32         True      
______________________________________________________________________
Conv2d               [32, 7, 7]           4,640      True      
______________________________________________________________________
ReLU                 [32, 7, 7]           0          False     
______________________________________________________________________
BatchNorm2d          [32, 7, 7]           64         True      
______________________________________________________________________
Flatten              [1568]               0          False     
____________________________________________________

In [19]:
learn.fit_one_cycle(5)

epoch,train_loss,valid_loss,accuracy,time
0,0.130171,0.101749,0.9686,00:05
1,0.088903,0.075537,0.9771,00:05
2,0.059901,0.059065,0.9814,00:05
3,0.033939,0.03907,0.988,00:05
4,0.013286,0.036617,0.9891,00:05


In [21]:
path_cifar = untar_data(URLs.CIFAR)

In [22]:
data_cifar = ImageList.from_folder(path_cifar).split_by_folder(train="train", valid="test").label_from_folder().databunch(bs=128)