In [97]:
#export
import torch
from .graph.graph import Graph
from .graph_transformer import *
from .function import ceil_power_of_two, next_channel

ModuleNotFoundError: No module named '__main__.graph'; '__main__' is not a package

In [28]:
import torch
from include.graph.graph import Graph
from include.graph_transformer import *
from include.function import ceil_power_of_two, next_channel
from fastai.vision import *

In [80]:
#export
class AnchorGraph(Graph):
    def __init__(self, input_shape, output_shape):
        super(AnchorGraph, self).__init__(input_shape, output_shape)
        category = output_shape[0]
        
        self.anchor = {}
        self.convs = []

        next = add_conv_block(self, self.input, nf=64, ks=3, st=2)
        self.convs.append(next)

        next = self.add_anchor(next, layer=Add)
        next = add_adaptive_pooling_layer(self, next, target=(1, 1), method='avg')
        next = add_flatten_layer(self, next)
        next = add_linear_layer(self, next, no=category * 64)
        next = add_dropout_layer(self, next, p=0.5)
        next = add_linear_layer(self, next, no=category)
        self.output = add_dropout_layer(self, next, self.output, p=0.5)
        

    def _reconstruct(self, target, expanded=[], caller=None):
        is_concat = self.nodes[target].layer is Concat
        nf = -1 if is_concat else self.nodes[target].shape[0]
        out_shape = (nf,) + self.nodes[target].shape[1:]

        # Update backwards.
        for edge_id in self.nodes[target].in_edge:
            edge = self.edges[edge_id]

            in_shape = self.nodes[edge.src].shape

            if edge.verify_output(in_shape, out_shape):
                continue
            else:
                next_shape, next_expanded = edge.updated_dest(in_shape, out_shape, expanded)
                self.nodes[edge.src].shape = next_shape
                self._reconstruct(edge.src, next_expanded, target)

        # Update self
        if is_concat:
            cat_channels = 0
            offset = 0
            new_expanded = []
            for edge_id in self.nodes[target].in_edge:
                edge = self.edges[edge_id]
                num_channels = self.nodes[edge.src].shape[0]
                if offset != 0:
                    start = cat_channels
                    end = cat_channels + num_channels
                    new_expanded.extend(
                        (i-offset, i) for i in range(start, end)
                    )
                if caller == edge.src:
                    offset = len(expanded)
                    new_expanded.extend(
                        (-1 if o<0 else o+cat_channels, c+cat_channels) for o, c in expanded
                    )
                cat_channels += num_channels

            self.nodes[target].shape = (cat_channels,) + self.nodes[target].shape[1:]
            expanded = new_expanded

        # Update forwards.
        in_shape = self.nodes[target].shape
        for edge_id in self.nodes[target].out_edge:
            edge = self.edges[edge_id]

            is_concat = self.nodes[edge.dest].layer is Concat
            nf = -1 if is_concat else self.nodes[edge.dest].shape[0]
            out_shape = (nf,) + self.nodes[edge.dest].shape[1:]
            if edge.verify_output(in_shape, out_shape) and not is_concat:
                continue
            else:
                next_shape, next_expanded = edge.updated_src(in_shape, out_shape, expanded)
                self.nodes[edge.dest].shape = next_shape
                self._reconstruct(edge.dest, next_expanded, target)

    def add_anchor(self, src, layer):
        anchor = self.insert_node(src, multi_input=True, layer=layer)
        anchor_node = self.nodes[anchor]
        self.anchor[anchor_node.rank] = anchor
        return anchor

    def add_connection(self, src, dest, layer_features=[]):
        in_shape = self.nodes[src].shape
        out_shape = self.nodes[dest].shape
        ni, nf = in_shape[0], out_shape[0]

        nh = ni
        next = src

        nl = len(layer_features)
        # Generate convolutional layers
        for id, layer in enumerate(layer_features):
            nh = layer if layer > 0 else nh
            next = add_conv_block(self, next, nf=nh, ks=3, st=1,
                                  zero_bn = True if id==nl-1 else False)
            self.convs.append(next)

        # Match the output size
        if in_shape[1:] != out_shape[1:]:
            shapes = zip(in_shape[1:], out_shape[1:])
            kernel = tuple(ceil_power_of_two(i/o) for i, o in shapes)

            next = add_pooling_layer(self, next, ks=kernel, method='avg')

        # Match the output channel
        if self.nodes[dest].layer is Concat:
            nf = self.nodes[next].shape[0]

            id_edge = IdenticalEdge(next, dest)
            self.add_edge(id_edge)
            expanded = tuple((-1, x) for x in range(nf))
            self._reconstruct(dest, expanded=expanded, caller=next)
        elif self.nodes[dest].layer is Add:
            if nh != nf:
                next = add_conv_layer(self, next,
                                      nf=nf, ks=1, bias=True)
            id_edge = IdenticalEdge(next, dest)
            self.add_edge(id_edge)
        return dest

    def deeper_net(self, layer):
        last_anchor = sorted(self.anchor.items())[-1][1]
        in_shape = self.nodes[last_anchor].shape
        ni = in_shape[0]
        
        next = last_anchor
        if ni > 512: 
            edge = AvgPoolingEdge(0, 0, kernel_size=2)
            next_shape = edge.calculate_output(in_shape)
            next = self.insert_node(next, shape=next_shape, edge=edge)
        
        self.add_anchor(next, layer=layer)
        edge = ConvEdge(0, 0, in_channels=ni, out_channels=ni, kernel_size=3, bias=False)
        edge.set_identical()
        next = self.insert_node(next, edge=edge)
        self.convs.append(next)
        edge = BatchNormEdge(0, 0, ni)
        edge.set_identical()
        next = self.insert_node(next, edge=edge)
        next = self.insert_node(next, edge=ReluEdge(0, 0))
        
        return next

    def wider_net(self, node):
        ni = self.nodes[node].shape[0]
        next_ni = next_channel(ni)
        self.expand_channel(node, next_ni)

    def expand_channel(self, node, channel):
        shape = self.nodes[node].shape
        nf_prev = shape[0]
        nf = channel
        if nf_prev > nf: raise Exception('Trying expand to lower channels')
        if nf_prev == nf: return
        empty = nf - nf_prev

        rand = torch.randint(nf_prev, (empty,))
        expanded = [(i.item(), nf_prev + n) for n, i in enumerate(rand)]

        self.nodes[node].shape = (channel, ) + shape[1:]
        self._reconstruct(node, expanded)
        return
    
    def save(self, dir_name):
        graph_data={
            'anchor': self.anchor,
            'convs': self.convs
        }
        super().save(dir_name, graph_data=graph_data)
    
    def load(self, dir_name):
        graph_data, weight_data, key_to_id = super().load(dir_name)
        anchors = graph_data['anchor']
        convs = graph_data['convs']
            
        self.anchor = dict([(k, key_to_id[i]) for k, i in anchors.items()])
        self.convs = [key_to_id[k] for k in convs]
        return graph_data, weight_data, key_to_id

In [100]:
!python nb2py.py anchored_graph.ipynb

Converted anchored_graph.ipynb to exp/nb_anchored_graph.py


In [99]:
gr = AnchorGraph((3, 32, 32), (10,))

In [82]:
gr.deeper_net(Concat)

16

In [83]:
gr.visualize('deeper', 'transform/')

In [84]:
gr.add_connection(gr.anchor[1], gr.anchor[2], layer_features=[64, 128, 256])

target13
target8
target9
target10


13

In [85]:
gr.visualize('skip', 'transform/')

In [86]:
gr.convs

[6, 14, 20, 24, 28]

In [87]:
gr.wider_net(28)

target28
target27
target26
target25
target24
target13
target8
target9
target10


In [88]:
gr.wider_net(28)
gr.wider_net(28)

target28
target27
target26
target25
target24
target13
target8
target9
target10


In [89]:
gr.visualize('wider', 'transform/')

In [90]:
gr.deeper_net(Concat)

33

In [91]:
gr.visualize('deeper_again', './transform')

In [92]:
gr.save('temp')

In [93]:
gr2 = AnchorGraph((3, 32, 32), (10,))
_ = gr2.load('temp')

In [95]:
gr2.convs

[6, 14, 20, 24, 28, 31]

In [79]:
path_cifar = untar_data(URLs.CIFAR)
data_cifar = ImageList.from_folder(path_cifar).split_by_folder(train="train", valid="test").label_from_folder().databunch(bs=128)

In [18]:
learn = Learner(data_cifar, gr2.generate_model(), loss_func=nn.CrossEntropyLoss(), metrics=accuracy)

In [19]:
learn.fit(5)

epoch,train_loss,valid_loss,accuracy,time
0,1.582388,1.486299,0.4775,00:10
1,1.340784,1.526956,0.4765,00:11
2,1.249703,1.439815,0.507,00:11
3,1.150209,1.164656,0.6058,00:11
4,1.086445,1.029783,0.6441,00:11


In [40]:
a = dict([(1, 'a'), (3, 'b')])

In [46]:
sorted(a)

[1, 3]