# PyGreentea Network Generator 

### Load the dependencies

In [1]:
%matplotlib inline

from __future__ import print_function
import h5py
import numpy as np
from numpy import float32, int32, uint8, dtype
import sys
import matplotlib.pyplot as plt
import copy


pygt_path = '../PyGreentea'
import sys, os
sys.path.append(os.path.join(os.path.dirname(os.getcwd()), pygt_path))

import math

import PyGreentea as pygt

  from ._caffe import \
  from ._caffe import \
  from ._caffe import \


### Set the memory limits for the GPU

In [2]:
# Load the default network template
netconf = pygt.netgen.NetConf()

# We use cuDNN, so:
netconf.ignore_conv_buffer = True

# 4 GB total, ignore convolution buffer. Let's keep 0.5 GB for implementation dependent buffers.
netconf.mem_global_limit = 3.5 * 1024 * 1024 * 1024
# 4 GB convolution buffer limit
netconf.mem_buf_limit = 3.5 * 1024 * 1024 * 1024

### Set model parameters

#### 3d isotropic block-face network

In [None]:
netconf.use_batchnorm = False
netconf.dropout = 0.0
netconf.fmap_start = 24
netconf.u_netconfs[0].unet_fmap_inc_rule = lambda fmaps: int(math.ceil(fmaps * 3))
netconf.u_netconfs[0].unet_fmap_dec_rule = lambda fmaps: int(math.ceil(fmaps / 3))

In [None]:
netconf.use_batchnorm = False
netconf.dropout = 0.0

# netconf.fmap_start = 20
# netconf.unet_fmap_inc_rule = lambda fmaps: int(math.ceil(fmaps * 1))
# netconf.unet_fmap_dec_rule = lambda fmaps: int(math.ceil(fmaps / 1))
# netconf.unet_depth = 3
# # netconf.unet_downsampling_strategy = [[3,3,3],[3,3,3],[1,1,1]]

netconf.fmap_start = 24
netconf.unet_fmap_inc_rule = lambda fmaps: int(math.ceil(fmaps * 4))
netconf.unet_fmap_dec_rule = lambda fmaps: int(math.ceil(fmaps / 4))
# netconf.unet_depth = 4
# netconf.unet_downsampling_strategy = [[2,2,2],[2,2,2],[1,1,1],[1,1,1]]

#### W net

In [3]:
netconf.use_batchnorm = False
netconf.dropout = 0.0
netconf.fmap_start = 16

netconf.u_netconfs[0].unet_fmap_inc_rule = lambda fmaps: int(math.ceil(fmaps * 4))
netconf.u_netconfs[0].unet_fmap_dec_rule = lambda fmaps: int(math.ceil(fmaps / 4))
netconf.u_netconfs[0].unet_depth = 3
# Create a network with no context loss (at least for training)
netconf.u_netconfs[0].use_deconvolution_uppath = True#False

# Create a W-Net (two U-Nets concatenated)
netconf.u_netconfs += [copy.deepcopy(netconf.u_netconfs[0])]

# Run a shortcut (deep residual, additive) over the first U-Net
netconf.u_netconfs[0].bridge = True

# Run a shortcut (deep residual, additive) over the second U-Net
netconf.u_netconfs[1].bridge = True

#### 3d anisotropic serial section network

In [None]:
netconf = pygt.netgen.NetConf()
netconf.use_batchnorm = False
netconf.dropout = 0.0

# netconf.fmap_start = 20
# netconf.unet_fmap_inc_rule = lambda fmaps: int(math.ceil(fmaps * 2))
# netconf.unet_fmap_dec_rule = lambda fmaps: int(math.ceil(fmaps / 2))
# netconf.unet_depth = 4
# netconf.unet_downsampling_strategy = [[1,1,1],[1,2,2],[1,2,2],[1,2,2]]

netconf.fmap_start = 20
netconf.unet_fmap_inc_rule = lambda fmaps: int(math.ceil(fmaps * 3))
netconf.unet_fmap_dec_rule = lambda fmaps: int(math.ceil(fmaps / 3))
netconf.unet_depth = 3
netconf.unet_downsampling_strategy = [[1,2,2],[1,2,2],[1,2,2]]

# netconf.fmap_start = 20
# netconf.unet_fmap_inc_rule = lambda fmaps: int(math.ceil(fmaps * 1))
# netconf.unet_fmap_dec_rule = lambda fmaps: int(math.ceil(fmaps / 1))
# netconf.unet_depth = 4
# netconf.unet_downsampling_strategy = [[2,2,2],[2,2,2],[1,1,1],[1,1,1]]

### Explore possible network input/output shapes for the chosen settings

In [4]:
# We test memory usage for training
mode = pygt.netgen.caffe_pb2.TRAIN
# The minimum we're interested in
# shape_min = [100,100,100]
shape_min = [50,100,100]
# And maximum
shape_max = [300,300,300]
# We want Z and Y to be independent, but X == Y
constraints = [None, None, lambda x: x[1]]
# Compute (can be quite intensive)
inshape, outshape, fmaps = pygt.netgen.compute_valid_io_shapes(netconf,mode,shape_min,shape_max,constraints=constraints)

-- Invalid: [50] => []
-- Invalid: [51] => []
++++ Valid: [52] => [52]
-- Invalid: [53] => []
-- Invalid: [54] => []
-- Invalid: [55] => []
-- Invalid: [56] => []
-- Invalid: [57] => []
-- Invalid: [58] => []
-- Invalid: [59] => []
++++ Valid: [60] => [60]
-- Invalid: [61] => []
-- Invalid: [62] => []
-- Invalid: [63] => []
-- Invalid: [64] => []
-- Invalid: [65] => []
-- Invalid: [66] => []
-- Invalid: [67] => []
++++ Valid: [68] => [68]
-- Invalid: [69] => []
-- Invalid: [70] => []
-- Invalid: [71] => []
-- Invalid: [72] => []
-- Invalid: [73] => []
-- Invalid: [74] => []
-- Invalid: [75] => []
++++ Valid: [76] => [76]
-- Invalid: [77] => []
-- Invalid: [78] => []
-- Invalid: [79] => []
-- Invalid: [80] => []
-- Invalid: [81] => []
-- Invalid: [82] => []
-- Invalid: [83] => []
++++ Valid: [84] => [84]
-- Invalid: [85] => []
-- Invalid: [86] => []
-- Invalid: [87] => []
-- Invalid: [88] => []
-- Invalid: [89] => []
-- Invalid: [90] => []
-- Invalid: [91] => []
++++ Valid: [92] => [92]

In [5]:
i=0
for o in zip(inshape,outshape,fmaps):
    print('i=' + str(i))
    print(o)
    i+=1

i=0
([52, 100, 100], [52, 100, 100], 24)
i=1
([60, 100, 100], [60, 100, 100], 21)
i=2
([68, 100, 100], [68, 100, 100], 19)
i=3
([76, 100, 100], [76, 100, 100], 17)
i=4
([84, 100, 100], [84, 100, 100], 16)
i=5
([92, 100, 100], [92, 100, 100], 14)
i=6
([100, 100, 100], [100, 100, 100], 13)
i=7
([108, 100, 100], [108, 100, 100], 12)
i=8
([116, 100, 100], [116, 100, 100], 11)
i=9
([124, 100, 100], [124, 100, 100], 11)
i=10
([132, 100, 100], [132, 100, 100], 10)
i=11
([140, 100, 100], [140, 100, 100], 9)
i=12
([148, 100, 100], [148, 100, 100], 9)
i=13
([156, 100, 100], [156, 100, 100], 8)
i=14
([164, 100, 100], [164, 100, 100], 8)
i=15
([172, 100, 100], [172, 100, 100], 8)
i=16
([180, 100, 100], [180, 100, 100], 7)
i=17
([188, 100, 100], [188, 100, 100], 7)
i=18
([196, 100, 100], [196, 100, 100], 7)
i=19
([204, 100, 100], [204, 100, 100], 6)
i=20
([212, 100, 100], [212, 100, 100], 6)
i=21
([220, 100, 100], [220, 100, 100], 6)
i=22
([228, 100, 100], [228, 100, 100], 6)
i=23
([236, 100, 100],

### Visualization

In [None]:
plt.figure()
# Combined output size versus feature map count
plt.scatter([x[0]*x[1]*x[2] for x in outshape], fmaps, alpha = 0.5)
plt.ylabel('Feature maps')
plt.xlabel('Combined output size')
plt.show()

### Pick parameters, actually generate and store the network

In [None]:
i=604
netconf.input_shape = inshape[i]
netconf.output_shape = outshape[i]
netconf.fmap_start = fmaps[i]

print ('netconf.input_shape = %s' % netconf.input_shape)
print ('netconf.output_shape = %s' % netconf.output_shape)
print ('netconf.fmap_start = %s' % netconf.fmap_start)

netconf.loss_function = "euclid"
train_net_conf_euclid, test_net_conf = pygt.netgen.create_nets(netconf)
netconf.loss_function = "malis"
train_net_conf_malis, test_net_conf = pygt.netgen.create_nets(netconf)

with open('net_train_euclid.prototxt', 'w') as f:
    print(train_net_conf_euclid, file=f)
with open('net_train_malis.prototxt', 'w') as f:
    print(train_net_conf_malis, file=f)
with open('net_test.prototxt', 'w') as f:
    print(test_net_conf, file=f)

In [None]:
netconf.