In [1]:
from fastai.vision import *
from include.graph.graph import *
from include.graph_transformer import *

## Vanilla CNN

In [2]:
path = untar_data(URLs.MNIST)

In [3]:
data = ImageList.from_folder(path, convert_mode="L").split_by_folder(train="training", valid="testing").label_from_folder().databunch(bs=128)

In [4]:
gr = Graph(input_shape=(1, 28, 28), output_shape=(10,))

In [5]:
next = add_conv_block(gr, gr.input, nf=16, ks=3)

In [6]:
next = add_conv_block(gr, next, nf=32, ks=3)

In [7]:
next = add_flatten_layer(gr, next)

In [8]:
next = add_linear_layer(gr, next, no=4*7*7)

In [9]:
next = add_linear_layer(gr, next, gr.output, no=10)

In [10]:
gr.visualize('cnn', './')

In [11]:
module = gr.generate_model()

In [12]:
learn = Learner(data, module, loss_func=nn.CrossEntropyLoss(), metrics=accuracy)

In [13]:
learn.summary()

Generator
Layer (type)         Output Shape         Param #    Trainable 
Conv2d               [16, 28, 28]         144        True      
______________________________________________________________________
BatchNorm2d          [16, 28, 28]         32         True      
______________________________________________________________________
ReLU                 [16, 28, 28]         0          False     
______________________________________________________________________
Conv2d               [32, 28, 28]         4,608      True      
______________________________________________________________________
BatchNorm2d          [32, 28, 28]         64         True      
______________________________________________________________________
ReLU                 [32, 28, 28]         0          False     
______________________________________________________________________
Flatten              [25088]              0          False     
____________________________________________________

## ResNet CNN

In [14]:
path_cifar = untar_data(URLs.CIFAR)

In [15]:
data_cifar = ImageList.from_folder(path_cifar).split_by_folder(train="train", valid="test").label_from_folder().databunch(bs=128)

In [16]:
gr_cifar = Graph((3, 32, 32), (10,))

In [17]:
add_res_net(gr_cifar, gr_cifar.input, gr_cifar.output, 1, [2, 2, 2, 2])

2

In [18]:
gr_cifar.visualize('res', './')

In [19]:
module = gr_cifar.generate_model()

In [20]:
learn = Learner(data_cifar, module, loss_func=nn.CrossEntropyLoss(), metrics=accuracy)

In [21]:
learn.summary()

Generator
Layer (type)         Output Shape         Param #    Trainable 
Conv2d               [64, 16, 16]         9,408      True      
______________________________________________________________________
BatchNorm2d          [64, 16, 16]         128        True      
______________________________________________________________________
ReLU                 [64, 16, 16]         0          False     
______________________________________________________________________
MaxPool2d            [64, 8, 8]           0          False     
______________________________________________________________________
Conv2d               [64, 8, 8]           36,864     True      
______________________________________________________________________
BatchNorm2d          [64, 8, 8]           128        True      
______________________________________________________________________
ReLU                 [64, 8, 8]           0          False     
____________________________________________________

## DenseNet CNN

In [22]:
path_cifar = untar_data(URLs.CIFAR)

In [23]:
data_cifar = ImageList.from_folder(path_cifar).split_by_folder(train="train", valid="test").label_from_folder().databunch(bs=128)

In [24]:
gr_cifar = Graph((3, 32, 32), (10,))

In [25]:
add_dense_net(gr_cifar, gr_cifar.input, gr_cifar.output)

2

In [26]:
gr_cifar.visualize('dense', './')

KeyboardInterrupt: 

In [27]:
module = gr_cifar.generate_model()

In [28]:
learn = Learner(data_cifar, module, loss_func=nn.CrossEntropyLoss(), metrics=accuracy)

In [29]:
learn.summary()

Generator
Layer (type)         Output Shape         Param #    Trainable 
Conv2d               [64, 16, 16]         9,408      True      
______________________________________________________________________
MaxPool2d            [64, 8, 8]           0          False     
______________________________________________________________________
BatchNorm2d          [64, 8, 8]           128        True      
______________________________________________________________________
ReLU                 [64, 8, 8]           0          False     
______________________________________________________________________
Conv2d               [128, 8, 8]          8,192      True      
______________________________________________________________________
BatchNorm2d          [128, 8, 8]          256        True      
______________________________________________________________________
ReLU                 [128, 8, 8]          0          False     
____________________________________________________