### Net only

In [1]:
import torch
import torch.nn as nn

In [2]:
'''
Segnet network, default VGG16 encoder - VGG16 decoder

Encoder given to front_layer
Decoder given to back_layer

upsampling default unmax-pooling
may also try deconv, upsampling (preserved)


'''

# size of the max pooling need to be optimize when can not divided by 2


class SegNet(nn.Module):
    def __init__(self, front_layer = [(64,),(64,),'M',(128,),(128,),'M',(256,),(256,),(256,),\
                                      'M', (512,),(512,),(512,),'M',(512,),(512,),(512,),'M'],
                 back_layer = [('I', 512),'UM',(512,), (512,), (512,), 'UM', (512,), (512,), (256,), 'UM',\
                              (256,), (256,),(128,), 'UM', (128,),(64,),'UM',(64,), (64,) ],
         class_num = 20,
         use_BN = True,
         upsampling = 'UM' # 'Deconv', 'USample'
         ):
        super(SegNet, self).__init__()
        
        self.class_num = class_num
        self.use_BN = use_BN
        
        self.front_process = self.make_cnn_layers(front_layer, batch_norm= self.use_BN)
        self.back_process = self.make_cnn_layers(back_layer, batch_norm = self.use_BN)
        self.last_process = self.make_cnn_layers([('I',back_layer[-1][0]), (class_num,)])
        
    def forward(self, x):
        idx_lst = []
        for layers in self.front_process:
            if isinstance(layers, nn.MaxPool2d):
                x,idx = layers(x)
                idx_lst.append(idx)
                print(x.size())
            else:
                x = layers(x)
            
            
        print('\n\nmiddle \n\n')
                
        for layers in self.back_process:
            if isinstance(layers, nn.MaxUnpool2d):
                x = layers(x, idx_lst.pop(-1))
                print(x.size())
            else:
                x = layers(x)
            
        x = self.last_process(x)
        print(x.size())
        

    def make_cnn_layers(self, cfg, batch_norm=True):

        '''
        The input should be a list, the elements should be tuples. 
        based on pytorch docs:
        Parameters:
            in_channels (int) – Number of channels in the input image
            out_channels (int) – Number of channels produced by the convolution
            kernel_size (int or tuple) – Size of the convolving kernel
            stride (int or tuple, optional) – Stride of the convolution. Default: 1
            padding (int or tuple, optional) – Zero-padding added to both sides of the input. Default: 1
            dilation (int or tuple, optional) – Spacing between kernel elements. Default: 1
            groups (int, optional) – Number of blocked connections from input channels to output channels. Default: 1
            bias (bool, optional) – If True, adds a learnable bias to the output. Default: True

        For convolution layers, the order is:
        (out_channels (int), kernel_size (int or tuple, optional, default = 3), 
        stride (int or tuple, optional), padding (int or tuple, optional), dilation (int or tuple, optional))
        if input is less than 0 (ie. -1) then will use default value.

        For maxpooling layer:
        'M', if need more argument, modify as needed.

        5/3/2018 C.

        '''
        
        layers = []
        in_channels = 3
        for v in cfg:

            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2,return_indices=True)]
            
            elif v == 'UM':
                layers += [nn.MaxUnpool2d(kernel_size=2, stride=2)]
                
            elif v == 'Deconv':
                raise Exception('Deconvolution network not written yet')
                
            elif v == 'Usample':
                raise Exception('Upsampling network not written yet')
            
            elif v[0] == 'I':
                in_channels = v[1]

            else:
                v_len = len(v)
                ker_size = 3
                stride_val = 1
                padding_val = 1
                dialtion_val = 1

                out_channels = v[0]
                if v_len >= 2:
                    if v[1] > 0:
                        ker_size = v[1]
                    if v_len >= 3:
                        if v[2] > 0:
                            stride_val = v[2]
                        if v_len >= 4:
                            if v[3] > 0:
                                padding_val = v[3]
                            if v_len >= 5:
                                if v[4] > 0:
                                    dialtion_val = v[4]


                conv2d = nn.Conv2d(in_channels, out_channels, kernel_size = ker_size, stride = stride_val,
                                   padding = padding_val, dilation = dialtion_val)
                if batch_norm:
                    layers += [conv2d, nn.BatchNorm2d(v[0]), nn.ReLU(inplace=True)]
                else:
                    layers += [conv2d, nn.ReLU(inplace=True)]
                in_channels = v[0]
        return nn.Sequential(*layers)


In [3]:
torch.cuda.empty_cache()
test_net  = SegNet().cuda()
input_test = torch.zeros((1,3,800,800)).cuda()
print(input_test.size())
out_test = test_net(input_test)

torch.Size([1, 3, 800, 800])
torch.Size([1, 64, 400, 400])
1
torch.Size([1, 128, 200, 200])
2
torch.Size([1, 256, 100, 100])
3
torch.Size([1, 512, 50, 50])
4
torch.Size([1, 512, 25, 25])
5


middle 


torch.Size([1, 512, 50, 50])
4
torch.Size([1, 512, 100, 100])
3
torch.Size([1, 256, 200, 200])
2
torch.Size([1, 128, 400, 400])
1
torch.Size([1, 64, 800, 800])
0
torch.Size([1, 20, 800, 800])
