In [None]:
#export
import torch
from .graph.graph import Graph
from .graph_transformer import *
from .function import ceil_power_of_two

In [1]:
from fastai.vision import *
from include.graph.graph import Graph
from include.graph_transformer import *
from include.function import ceil_power_of_two

# Anchored-Graph
Add/Concat on anchors
## Required functions
    Add Anchor
    Add threads
    Verifies output and update
## Limitations
    Keep kernel size 1 or 3 or 5
    Keep shape to ( x // 2**k)
    

In [2]:
#export
class AnchorGraph(Graph):
    def __init__(self, input_shape, output_shape):
        super(AnchorGraph, self).__init__(input_shape, output_shape)
        self.anchor = []
        self.convs = []
        
        next = add_conv_block(self, self.input, nf=64, ks=3, st=2)
        self.convs.extend(next)
        
        next = self.add_anchor(next, layer=Add)
        next = add_flatten_layer(self, next)
        self.output = add_linear_layer(self, next, self.output)
    
    def _reconstruct(self, target, expanded=[], caller=None):
        print('constructing on {}'.format(target))
        is_concat = self.nodes[target].layer is Concat
        nf = -1 if is_concat else self.nodes[target].shape[0] 
        out_shape = (nf,) + self.nodes[target].shape[1:]

        # Update backwards.
        for edge_id in self.nodes[target].in_edge:
            edge = self.edges[edge_id]
            
            in_shape = self.nodes[edge.src].shape

            print(in_shape, out_shape, edge.verify_output(in_shape, out_shape))
            if edge.verify_output(in_shape, out_shape):
                continue
            else:
                next_shape, next_expanded = edge.updated_dest(in_shape, out_shape, expanded)
                self.nodes[edge.src].shape = next_shape
                self._reconstruct(edge.src, next_expanded, target)
                
        # Update self
        if is_concat:
            offset = 0
            for edge_id in self.nodes[target].in_edge:
                edge = self.edges[edge_id]
                if caller == edge.src:
                    expanded = [
                        (-1 if o < 0 else o + offset, c + offset) for (o, c) in expanded
                    ]
                offset = offset + self.nodes[edge.src].shape[0]
            self.nodes[target].shape = (offset,) + self.nodes[target].shape[1:]

        # Update forwards.
        in_shape = self.nodes[target].shape
        for edge_id in self.nodes[target].out_edge:
            edge = self.edges[edge_id]
            
            is_concat = self.nodes[edge.dest].layer is Concat
            nf = -1 if is_concat else self.nodes[edge.dest].shape[0]
            out_shape = (nf,) + self.nodes[edge.dest].shape[1:]
            
            if edge.verify_output(in_shape, out_shape):
                continue
            else:
                next_shape, next_expanded = edge.updated_src(in_shape, out_shape, expanded)
                self.nodes[edge.dest].shape = next_shape
                self._reconstruct(edge.dest, next_expanded, target)
        
    def add_anchor(self, src, layer):
        anchor = self.insert_node(src, multi_input=True, layer=layer)
        self.anchor.append(anchor)
        return anchor
    
    def add_connection(self, src, dest, layer_features=[]):
        in_shape = self.nodes[src].shape
        out_shape = self.nodes[dest].shape
        ni, nf = in_shape[0], out_shape[0]
        
        nh = ni
        next = src
        
        nl = len(layer_features)
        # Generate convolutional layers
        for id, layer in enumerate(layer_features):
            nh = layer if layer > 0 else nh
            next = add_conv_block(self, next, nf=nh, ks=3, st=1,
                                  zero_bn = True if id==nl-1 else False)
        
        # Match the output size
        if in_shape[1:] != out_shape[1:]:
            shapes = zip(in_shape[1:], out_shape[1:])
            kernel = tuple(ceil_power_of_two(i/o) for i, o in shapes)
            
            next = add_pooling_layer(self, next, ks=kernel, method='avg')
        
        # Match the output channel
        if self.nodes[dest].layer is Concat:
            print(self.nodes[dest].layer)
            nf = self.nodes[next].shape[0]
            
            id_edge = IdenticalEdge(next, dest)
            self.add_edge(id_edge)
            expanded = tuple((-1, x) for x in range(nf))
            self._reconstruct(dest, expanded=expanded, caller=next)
        else:
            if ni != nf:
                next = add_conv_layer(self, next,
                                      nf=nf, ks=1, bias=True, identical=True)
            id_edge = IdenticalEdge(next, dest)
            self.add_edge(id_edge)
        return dest
    
    def deeper_net(self, layer):
        last_anchor = self.anchor[-1]
        self.add_anchor(self.anchor[-1], layer=layer)
        edge = ConvEdge(0, 0, in_channels=64, out_channels=64, kernel_size=3, bias=False)
        edge.set_identical()
        next = self.insert_node(last_anchor, edge=edge)
        edge = BatchNormEdge(0, 0, 64)
        edge.set_identical()
        next = gr.insert_node(next, edge=edge)
        next = gr.insert_node(next, edge=ReluEdge(0, 0))
        return next
    
    def expand_channel(node, channel):
        self.nodes[node].shape = shape
        nf_prev = shape[0]
        nf = channel
        if nf_prev >= nf: raise Exception('Trying expand to lower channels')
        empty = nf - nf_prev
        
        rand = torch.randint(nf_prev, (empty,))
        expanded = [(i.item(), nf_prev + n) for n, i in enumerate(rand)]
        
        self.nodes[node].shape = (channel, ) + shape[1:]
        self._reconstruct(node, expanded)
        
        return

In [16]:
x = torch.rand(1, 3, 32, 32)

In [17]:
gr = AnchorGraph((3, 32, 32), (10,))

In [18]:
gr.visualize('basic', './transform')

In [19]:
gr.deeper_net(Add)

11

In [20]:
gr.visualize('deeper', './transform')

In [31]:
gr.add_connection(gr.anchor[0], gr.anchor[1], layer_features=[64])

8

In [32]:
gr.generate_model()(x)

tensor([[ 0.7845,  0.6499,  0.4445,  1.1837, -0.2896, -0.9197, -0.7860,  0.7082,
         -0.3143,  0.4005]], grad_fn=<AddmmBackward>)

In [33]:
gr.visualize('add_connection', './transform')

# Test

In [41]:
path_cifar = untar_data(URLs.CIFAR)

In [42]:
data_cifar = ImageList.from_folder(path_cifar).split_by_folder(train="train", valid="test").label_from_folder().databunch(bs=128)

In [45]:
learn = Learner(data_cifar, gr.generate_model(), loss_func=nn.CrossEntropyLoss(), metrics=accuracy)
learn.summary()

Generator
Layer (type)         Output Shape         Param #    Trainable 
Conv2d               [64, 16, 16]         1,728      True      
______________________________________________________________________
BatchNorm2d          [64, 16, 16]         128        True      
______________________________________________________________________
ReLU                 [64, 16, 16]         0          False     
______________________________________________________________________
Add                  [64, 16, 16]         0          False     
______________________________________________________________________
Conv2d               [64, 16, 16]         36,864     True      
______________________________________________________________________
BatchNorm2d          [64, 16, 16]         128        True      
______________________________________________________________________
ReLU                 [64, 16, 16]         0          False     
____________________________________________________

In [46]:
learn.fit(5)

epoch,train_loss,valid_loss,accuracy,time
0,0.519591,1.863322,0.5278,00:07
1,0.423097,1.27215,0.6294,00:07
2,0.360294,1.472648,0.6226,00:07
3,0.297328,1.38793,0.6489,00:08
4,0.250599,1.424109,0.6518,00:07
