<img src="fcn_model.png">

Reference 1 : https://github.com/shekkizh/FCN.tensorflow/blob/master/FCN.py
<br>Reference 2 : https://modulabs-biomedical.github.io/FCN
<br>Reference 3 : https://stackoverflow.com/questions/38160940/how-to-count-total-number-of-trainable-parameters-in-a-tensorflow-model

In [1]:
import tensorflow as tf
import numpy as np
import tensorflow.contrib.slim as slim
import tensorflow.contrib.slim.nets as nets

  from ._conv import register_converters as _register_converters


In [2]:
# image size

height = 224
width = 224
num_of_channels = 3
num_of_classes = 21

keep_prob = 0.5

In [3]:
image = tf.placeholder(tf.float32, [None, height, width, num_of_channels])

In [4]:
def fcn8_with_pretrained_vgg16(image, num_of_classes, keep_prob):
    "Extracting Layer from Pre-trained VGG16"
    print("------------------- Conv Layer -------------------")
    
    vgg16 = nets.vgg.vgg_16(image, num_classes = num_of_classes, dropout_keep_prob = keep_prob)
    
    pool3 = tf.get_default_graph().get_tensor_by_name('vgg_16/pool3/MaxPool:0')
    print(pool3.name + " / shape = " + str(pool3.shape))
    
    pool4 = tf.get_default_graph().get_tensor_by_name('vgg_16/pool4/MaxPool:0')
    print(pool4.name + " / shape = " + str(pool4.shape))
    
    pool5 = tf.get_default_graph().get_tensor_by_name('vgg_16/pool5/MaxPool:0')
    print(pool5.name + " / shape = " + str(pool5.shape))
    
    
    "Feature-level Classificaiton"
    # output shape = [None, H/32, H/32, 4096]
    conv6 = slim.conv2d(pool5, 4096, [7, 7], scope = 'conv6')
    conv6 = tf.nn.dropout(conv6, keep_prob = keep_prob)
    print(conv6.name + " / shape = " + str(conv6.shape))
    
    # output shape = [None, H/32, H/32, 4096]
    conv7 = slim.conv2d(conv6, 4096, [1, 1], scope = 'conv7')
    conv7 = tf.nn.dropout(conv7, keep_prob = keep_prob)
    print(conv7.name + " / shape = " + str(conv7.shape))
    
    # output shape = [None, H/32, H/32, num_of_classes]
    conv8 = slim.conv2d(conv7, num_of_classes, [1, 1], scope = 'conv8')
    print(conv8.name + " / shape = " + str(conv8.shape))
    
    
    "Upsampling"
    print("------------------- Upsampling -------------------")
    conv_t1 = slim.conv2d_transpose(conv8, num_outputs = pool4.get_shape()[3], kernel_size = [4, 4], stride = 2)
    fuse_1 = tf.add(conv_t1, pool4, name = 'fuse_1')
    print(fuse_1.name + " / shape = " + str(fuse_1.shape))
    
    conv_t2 = slim.conv2d_transpose(fuse_1, num_outputs = pool3.get_shape()[3], kernel_size = 4, stride = 2)
    fuse_2 = tf.add(conv_t2, pool3, name = 'fuse_2')
    print(fuse_2.name + " / shape = " + str(fuse_2.shape))
    
    conv_t3 = slim.conv2d_transpose(fuse_2, num_outputs = num_of_channels, kernel_size = 16, stride = 8)
    print(conv_t3.name + " / shape = " + str(conv_t3.shape))
    
    "Segmentation"
    print("------------------ Segmentation ------------------")
    annotation_pred = tf.argmax(conv_t3, dimension = num_of_channels, name="prediction")
    print(annotation_pred.name + " / shape = " + str(annotation_pred.shape))
    fcn8 = tf.expand_dims(annotation_pred, dim = num_of_channels)
    print(fcn8.name + " / shape = " + str(fcn8.shape))
    
    return fcn8

In [5]:
fcn8_with_pretrained_vgg16 = fcn8_with_pretrained_vgg16(image, num_of_classes, keep_prob)

------------------- Conv Layer -------------------
vgg_16/pool3/MaxPool:0 / shape = (?, 28, 28, 256)
vgg_16/pool4/MaxPool:0 / shape = (?, 14, 14, 512)
vgg_16/pool5/MaxPool:0 / shape = (?, 7, 7, 512)
dropout/mul:0 / shape = (?, 7, 7, 4096)
dropout_1/mul:0 / shape = (?, 7, 7, 4096)
conv8/Relu:0 / shape = (?, 7, 7, 21)
------------------- Upsampling -------------------
fuse_1:0 / shape = (?, 14, 14, 512)
fuse_2:0 / shape = (?, 28, 28, 256)
Conv2d_transpose_2/Relu:0 / shape = (?, 224, 224, 3)
------------------ Segmentation ------------------
prediction:0 / shape = (?, 224, 224)
ExpandDims:0 / shape = (?, 224, 224, 1)


In [6]:
total_parameters = np.sum([np.prod(var.get_shape().as_list()) for var in tf.trainable_variables()])
print("Number of Weights : " + format(total_parameters, ','))

Number of Weights : 256,445,037
