# UNET SEGMENTATION

Arxiv Link: <a href="https://arxiv.org/abs/1505.04597">U-Net: Convolutional Networks for Biomedical Image Segmentation</a>

<ul>
<li>UNet is a fully convolutional network(FCN) that does image segmentation. Its goal is to predict each pixel's class.</li>
 
<li>UNet is built upon the FCN and modified in a way that it yields better segmentation in medical imaging.</li>
</ul>

## 1.1 Architecture

<img src="images/u-net-architecture.png"/>

<h3>UNet Architecture has 3 parts:</h3>
<ol>
    <li>The Contracting/Downsampling Path</li>
    <li>Bottleneck</li>
    <li>The Expanding/Upsampling Path</li>
</ol>

<h3>Downsampling Path: </h3> 
<ol>
    <li>It consists of two 3x3 convolutions (unpadded convolutions), each followed by a rectified linear unit (ReLU) and a 2x2 max pooling operation with stride 2 for downsampling.</li> 
    <li>At each downsampling step we double the number of feature channels.</li>
</ol>

<h3>Upsampling Path: </h3> 
<ol>
     <li> Every  step  in  the  expansive  path  consists  of  an  upsampling  of  the feature map followed by a 2x2 convolution (“up-convolution”), a concatenation with the correspondingly feature  map  from  the  downsampling  path,  and  two  3x3  convolutions,  each  followed by a ReLU.</li>
</ol>

<h3> Skip Connection: </h3>
The skip connection from the downsampling path are concatenated with feature map during upsampling path. These skip connection provide local information to global information while upsampling.

<h3> Final Layer: </h3>
At the final layer a 1x1 convolution is used to map each feature vector to the desired number of classes.

## 1.2 Advantages
<h3> Advantages: </h3>
<ol>
    <li>The UNet combines the location information from the downsampling path to finally obtain a general information combining localisation and context, which is necessary to predict a good segmentation map.</li>
    <li>No Dense layer is used, so image sizes can be used.</li>
</ol>

## 1.3 Dataset
Link: <a href="https://www.kaggle.com/c/data-science-bowl-2018">Data Science Bowl 2018</a>
Find the nuclei in divergent images to advance medical discovery

## 1.4 Code

In [1]:
## Imports
import os
import os.path
import sys
import random


import numpy as np
import cv2
import matplotlib.pyplot as plt


import keras
from keras.layers import *
from keras.models import * 
from keras.preprocessing.image import ImageDataGenerator
from data_prep import *

#Tensorboard for Visualization
from keras.callbacks import TensorBoard 
import time

"""
## Seeding 
seed = 2019
random.seed = seed
np.random.seed = seed
tf.seed = seed"""

Using TensorFlow backend.


'\n## Seeding \nseed = 2019\nrandom.seed = seed\nnp.random.seed = seed\ntf.seed = seed'

In [2]:
import keras.backend.tensorflow_backend as K

K.set_session
import tensorflow as tf
#K.tensorflow_backend._get_available_gpus()


# Sample Config File

In [3]:
"""import json
with open('/Users/vasudevsharma/Desktop/axondeepseg-master-2/AxonDeepSeg/models/default_SEM_model_v1/config_network.json') as djson:
    training_config = json.load(djson)
print(training_config) """
# Example of network configuration for TEM data (small network trainable on a Titan X GPU card)
training_config = {
    
# General parameters:    
  "n_classes": 3,  # Number of classes. For this application, the number of classes should be set to **3** (i.e. axon pixel, myelin pixel, or background pixel).
  "thresholds": [0, 0.2, 0.8],  # Thresholds for the 3-class classification problem. Do not modify.  
  "trainingset_patchsize": 512,  # Patch size of the training set in pixels (note that the patches have the same size in both dimensions).  
  "trainingset": "TEM_3c_512",  # Name of the training set.
  "batch_size": 8,  # Batch size, i.e. the number of training patches used in one iteration of the training. Note that a larger batch size will take more memory.

# Network architecture parameters:     
  "depth": 4,  # Depth of the network (i.e. number of blocks of the U-net).
  "convolution_per_layer": [2, 2, 2, 2],  # Number of convolution layers used at each block.
  "size_of_convolutions_per_layer": [[5, 5], [3, 3], [3, 3], [3, 3]],  # Kernel size of each convolution layer of the network.
  "features_per_convolution": [[[1, 16], [16, 16]], [[16, 32], [32, 32]], [[32, 64], [64, 64]], [[64, 128], [128, 128]]],  # Number of features of each convolution layer.
  "downsampling": "convolution",  # Type of downsampling to use in the downsampling layers of the network. Option "maxpooling" for standard max pooling layer or option "convolution" for learned convolutional downsampling.
  "dropout": 0.75,  # Dropout to use for the training. Note: In TensorFlow, the keep probability is used instead. For instance, setting this param. to 0.75 means that 75% of the neurons of the network will be kept (i.e. dropout of 25%).
     
# Learning rate parameters:    
  "learning_rate": 0.01,  # Learning rate to use in the training.  
  "learning_rate_decay_activate": True,  # Set to "True" to use a decay on the learning rate.  
  "learning_rate_decay_period": 24000,  # Period of the learning rate decay, expressed in number of images (samples) seen.
  "learning_rate_decay_type": "polynomial",  # Type of decay to use. An exponential decay will be used by default unless this param. is set to "polynomial" (to use a polynomial decay).
  "learning_rate_decay_rate": 0.99,  # Rate of the decay to use for the exponential decay. This only applies when the user does not set the decay type to "polynomial".
    
# Batch normalization parameters:     
  "batch_norm_activate": True,  # Set to "True" to use batch normalization during the training.
  "batch_norm_decay_decay_activate": True,  # Set to "True" to activate an exponential decay for the batch normalization step of the training.  
  "batch_norm_decay_starting_decay": 0.7,  # The starting decay value for the batch normalization. 
  "batch_norm_decay_ending_decay": 0.9,  # The ending decay value for the batch normalization.
  "batch_norm_decay_decay_period": 16000,  # Period of the batch normalization decay, expressed in number of images (samples) seen.
        
# Weighted cost parameters:    
  "weighted_cost-activate": True,  # Set to "True" to use weights based on the class in the cost function for the training.
  "weighted_cost-balanced_activate": True,  # Set to "True" to use weights in the cost function to correct class imbalance. 
  "weighted_cost-balanced_weights": [1.1, 1, 1.3],  # Values of the weights for the class imbalance. Typically, larger weights are assigned to classes with less pixels to add more penalty in the cost function when there is a misclassification. Order of the classes in the weights list: background, myelin, axon.
  "weighted_cost-boundaries_sigma": 2,  # Set to "True" to add weights to the boundaries (e.g. penalize more when misclassification happens in the axon-myelin interface).
  "weighted_cost-boundaries_activate": False,  # Value to control the distribution of the boundary weights (if activated). 
    
# Data augmentation parameters:
  "da-type": "all",  # Type of data augmentation procedure. Option "all" applies all selected data augmentation transformations sequentially, while option "random" only applies one of the selected transformations (randomly) to the sample(s). List of available data augmentation transformations: 'random_rotation', 'noise_addition', 'elastic', 'shifting', 'rescaling' and 'flipping'. 
  "da-0-shifting-activate": True, 
  "da-1-rescaling-activate": False,
  "da-2-random_rotation-activate": False,  
  "da-3-elastic-activate": True, 
  "da-4-flipping-activate": True, 
  "da-5-noise_addition-activate": False
}




## Different Convolutional Blocks

In [4]:
def conv_relu(x, filters , kernel_size , strides , activation = 'relu', kernel_initializer = 'glorot_normal', activate_bn = True, bn_decay = 0.999, keep_prob=1.0):
    if activate_bn == True:

            net = Conv2D(filters = filters, kernel_size = kernel_size, strides = strides, padding = 'same', activation = activation, kernel_initializer = kernel_initializer)(x)
            net = BatchNormalization(axis = 3, momentum = 1 - bn_decay)(net)
            
    else: 
            net = Conv2D(filters = filters, kernel_size = kernel_size, strides = strides, activation = activation, kernel_initializer = kernel_initializer, padding = 'same')(net)
            
    net =  Dropout(rate = 1 - keep_prob)(net)
            
    return net 

   
def downconv(x, filters , kernel_size = 5, strides = 2, activation = 'relu', kernel_initializer = 'glorot_normal', activate_bn = True, bn_decay = 0.999):
    if activate_bn == True:

            net = Conv2D(filters = filters, kernel_size = kernel_size, strides = strides, padding = 'same',activation = activation, kernel_initializer = kernel_initializer)(x)
            net = BatchNormalization(axis = 3, momentum = 1 - bn_decay)(net)

    else:

            net = Conv2D(filters = filters, kernel_size = kernel_size, strides = strides, activation = activation, kernel_initializer = kernel_initializer, padding = 'same')(net)
            
    return net

"""def upconv(x, n_out_chan, scope, 
              w_initializer=tf.contrib.layers.xavier_initializer_conv2d(),
              training_phase=True, activate_bn = True, bn_decay = 0.999):
   
    
    with tf.variable_scope(scope):
        if activate_bn == True:
            net = tf.contrib.layers.conv2d(x, num_outputs=n_out_chan, kernel_size=3, stride=1, 
                                       activation_fn=tf.nn.relu, normalizer_fn=tf.contrib.layers.batch_norm,
                                       normalizer_params={'scale':True, 'is_training':training_phase,
                                                          'decay':bn_decay, 'scope':'bn'},
                                       weights_initializer = w_initializer, scope='convolution'
                                      )
        else:
            net = tf.contrib.layers.conv2d(x, num_outputs=n_out_chan, kernel_size=3, stride=1, 
                                       activation_fn=tf.nn.relu, weights_initializer = w_initializer, scope='convolution'
                                      )
        
        tf.add_to_collection('activations',net)
        return net


"""


"def upconv(x, n_out_chan, scope, \n              w_initializer=tf.contrib.layers.xavier_initializer_conv2d(),\n              training_phase=True, activate_bn = True, bn_decay = 0.999):\n   \n    \n    with tf.variable_scope(scope):\n        if activate_bn == True:\n            net = tf.contrib.layers.conv2d(x, num_outputs=n_out_chan, kernel_size=3, stride=1, \n                                       activation_fn=tf.nn.relu, normalizer_fn=tf.contrib.layers.batch_norm,\n                                       normalizer_params={'scale':True, 'is_training':training_phase,\n                                                          'decay':bn_decay, 'scope':'bn'},\n                                       weights_initializer = w_initializer, scope='convolution'\n                                      )\n        else:\n            net = tf.contrib.layers.conv2d(x, num_outputs=n_out_chan, kernel_size=3, stride=1, \n                                       activation_fn=tf.nn.relu, weights_initiali

## UNet Model

In [5]:


# ------------------------ NETWORK STRUCTURE ------------------------ #


def uconv_net(training_config,bn_updated_decay = None, verbose = True):
    """
    Create the U-net.
    Input :
        x : TF object to define, ensemble des patchs des images :graph input
        config : dict : described in the header.
        image_size : int : The image size

    Output :
        The U-net.
    """
    
    # Load the variables
    image_size = training_config["trainingset_patchsize"]
    n_classes = training_config["n_classes"]
    depth = training_config["depth"]
    dropout = training_config["dropout"]
    number_of_convolutions_per_layer = training_config["convolution_per_layer"]
    size_of_convolutions_per_layer = training_config["size_of_convolutions_per_layer"]
    features_per_convolution = training_config["features_per_convolution"]
    downsampling = training_config["downsampling"]
    activate_bn = training_config["batch_norm_activate"]
    if bn_updated_decay is None:
        bn_decay = training_config["batch_norm_decay_starting_decay"]
    else:
        bn_decay = bn_updated_decay

    # Input picture shape is [batch_size, height, width, number_channels_in] (number_channels_in = 1 for the input layer)
    
    data_temp_size = [image_size]
    relu_results = []

    ####################################################################
    ######################### CONTRACTION PHASE ########################
    ####################################################################
    


   # print(data_temp)

    #X = Input((image_size*image_size, 3))
    X = Input((image_size, image_size, 3))
 
    net = X
    data_temp = X
   # print(net.shape)
    #print(depth, number_of_convolutions_per_layer)
    for i in range(depth):

        for conv_number in range(number_of_convolutions_per_layer[i]):
            
            if verbose:
                #print(('Layer: ', i, ' Conv: ', conv_number, 'Features: ', features_per_convolution[i][conv_number]))
                #print(('Size:', size_of_convolutions_per_layer[i][conv_number]))
                
                net = conv_relu(net, filters = features_per_convolution[i][conv_number][1], kernel_size = size_of_convolutions_per_layer[i][conv_number], strides = 1 , activation = 'relu', kernel_initializer = 'glorot_normal', activate_bn = True, bn_decay = 0.999, keep_prob=1.0)
               
        relu_results.append(net) # We keep them for the upconvolutions
    

        if downsampling == 'convolution':
              
            net = downconv(net, filters = features_per_convolution[i][conv_number][1], kernel_size = 5, strides = 2, activation = 'relu', kernel_initializer = 'glorot_normal', activate_bn = True, bn_decay = 0.999)
      
        else: 

            net = MaxPooling2D((2, 2), padding = 'valid', strides = 2,  name ='downmp-d'+str(i))(net)

        data_temp_size.append(data_temp_size[-1] // 2)
        data_temp = net
             


    ####################################################################
    ########################## EXPANSION PHASE #########################
    ####################################################################
    
    for i in range(depth):        
        # Upsampling
        net = UpSampling2D(( 2,  2))(net)
     

        # Convolution
        net = conv_relu(net, filters = features_per_convolution[depth - i - 1][-1][1], kernel_size = 2, strides = 1 , activation = 'relu', kernel_initializer = 'glorot_normal', activate_bn = True, bn_decay = 0.999, keep_prob=1.0)
        
        data_temp_size.append(data_temp_size[-1] * 2)

        # concatenation (see U-net article)
        net = Concatenate(axis = 3)([relu_results[depth-i-1],net])
      

        # Classic convolutions
        for conv_number in range(number_of_convolutions_per_layer[depth - i - 1]):
            
            net = conv_relu(net, filters = features_per_convolution[depth - i - 1][conv_number][1], kernel_size = size_of_convolutions_per_layer[depth - i - 1][conv_number], strides = 1 , activation = 'relu', kernel_initializer = 'glorot_normal', activate_bn = True, bn_decay = 0.999, keep_prob=1.0)
            
            
        data_temp = net

   
    net = Conv2D(filters = n_classes, kernel_size = 1, strides = 1, name = 'finalconv', padding = 'same',   activation = "softmax")(net)

    model = Model(inputs = X, outputs = net)

    

    return model

        
    
  



# METRICS

In [6]:
def dice_coef(y_true, y_pred, smooth=1e-3):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return K.mean((2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth))

def dice_myelin(y_true, y_pred, smooth=1e-3):
    y_true_f = K.flatten(y_true[..., 1])
    y_pred_f = K.flatten(y_pred[..., 1])
    intersection = K.sum(y_true_f * y_pred_f)
    return K.mean((2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth))

def dice_axon(y_true, y_pred, smooth=1e-3):
    y_true_f = K.flatten(y_true[..., 2])
    y_pred_f = K.flatten(y_pred[..., 2])
    intersection = K.sum(y_true_f * y_pred_f)
    return K.mean((2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth))





# Tensorboard for Visualization

In [7]:
Name = "TEM_sample_dataset-{}".format(time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()))
tensorboard = TensorBoard(log_dir = 'models/{}'.format(Name))

In [8]:
Name

'TEM_sample_dataset-2019-06-10 19:37:22'

# Model Summary

In [9]:
 model = uconv_net(training_config,  bn_updated_decay = None, verbose = True)

Instructions for updating:
Colocations handled automatically by placer.


In [10]:
adam = keras.optimizers.Adam(lr=training_config['learning_rate'], beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0 )

In [11]:
model.compile(optimizer='adam', loss="categorical_crossentropy", metrics=["accuracy", dice_axon, dice_myelin, dice_coef])
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 512, 512, 3)  0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 512, 512, 16) 1216        input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 512, 512, 16) 64          conv2d_1[0][0]                   
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, 512, 512, 16) 0           batch_normalization_1[0][0]      
__________________________________________________________________________________________________
conv2d_2 (

## Training the model

In [12]:
from data_prep import *
from data_augmentation import *

In [13]:
train_path = '/home/vasha_local/axondeepseg/TEM_striatum/data/Train/'
test_path = '/home/vasha_local/axondeepseg/TEM_striatum/data/Validation/'

In [14]:
n_classes = training_config["n_classes"]
thresholds = training_config["thresholds"]
patch_size = training_config["trainingset_patchsize"]

batch_size = training_config["batch_size"]

# Data Augmentation

In [15]:
data_aug = True # Boolean Value to indicate whether you want to use Data Augmentation or not


In [16]:
if(not data_aug):
    data_gen_args = dict() 
else: 
    data_gen_args = dict(horizontal_flip = flipping()[1],
                        vertical_flip = flipping()[0], 
                        rotation_range = random_rotation()[0], 
                        width_shift_range = shifting(patch_size, n_classes)[1],
                        height_shift_range = shifting(patch_size, n_classes) [0]
                        )

#Data Augmentation Dictionary for Validation 
data_gen_args_valid = dict()


flipping up-down
flipping left-right
('height shift: ', 0.158203125, ', width shift: ', 0.083984375)
('height shift: ', 0.15234375, ', width shift: ', 0.048828125)


In [17]:
#Sanity Check of augmented data dictionary
data_gen_args

{'horizontal_flip': True,
 'vertical_flip': True,
 'rotation_range': 43.73224759795209,
 'width_shift_range': 0.083984375,
 'height_shift_range': 0.15234375}

In [18]:

 
train_gen = Generator(batch_size,'/home/vasha_local/axondeepseg/TEM_striatum/data/Train/','images','masks',data_gen_args,save_to_dir = None)
valid_gen = Generator(batch_size, '/home/vasha_local/axondeepseg/TEM_striatum/data/Validation/','images','masks', data_gen_args_valid,save_to_dir = None)                       

In [19]:
train_steps =150

In [None]:
model.fit_generator(train_gen, validation_data = valid_gen, steps_per_epoch=train_steps, validation_steps = 2,
                    epochs= 1000,  callbacks = [tensorboard])
 

Instructions for updating:
Use tf.cast instead.
Instructions for updating:
Deprecated in favor of operator or tf.math.divide.
Epoch 1/1000
Found 16 images belonging to 1 classes.
Found 36 images belonging to 1 classes.
Found 16 images belonging to 1 classes.
Found 36 images belonging to 1 classes.
Epoch 2/1000
Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 6/1000
Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 10/1000
Epoch 11/1000
Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 16/1000
Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000
Epoch 22/1000
Epoch 23/1000
Epoch 24/1000
Epoch 25/1000
Epoch 26/1000
Epoch 27/1000
Epoch 28/1000


Epoch 29/1000
Epoch 30/1000
Epoch 31/1000
Epoch 32/1000
Epoch 33/1000
Epoch 34/1000
Epoch 35/1000
Epoch 36/1000
Epoch 37/1000
Epoch 38/1000
Epoch 39/1000
Epoch 40/1000
Epoch 41/1000
Epoch 42/1000
Epoch 43/1000
Epoch 44/1000
Epoch 45/1000
Epoch 46/1000
Epoch 47/1000
Epoch 48/1000
Epoch 49/1000
Epoch 50/1000
Epoch 51/1000
Epoch 52/1000
Epoch 53/1000
Epoch 54/1000
Epoch 55/1000
Epoch 56/1000
Epoch 57/1000


Epoch 58/1000
Epoch 59/1000
Epoch 60/1000
Epoch 61/1000
Epoch 62/1000
Epoch 63/1000
Epoch 64/1000
Epoch 65/1000
Epoch 66/1000
Epoch 67/1000
Epoch 68/1000
Epoch 69/1000
Epoch 70/1000
Epoch 71/1000
Epoch 72/1000
Epoch 73/1000
Epoch 74/1000
Epoch 75/1000
Epoch 76/1000
Epoch 77/1000
Epoch 78/1000
Epoch 79/1000
Epoch 80/1000
Epoch 81/1000
Epoch 82/1000
Epoch 83/1000
Epoch 84/1000
Epoch 85/1000
Epoch 86/1000
Epoch 87/1000


Epoch 88/1000
Epoch 89/1000
Epoch 90/1000
Epoch 91/1000
Epoch 92/1000
Epoch 93/1000
Epoch 94/1000
Epoch 95/1000
Epoch 96/1000
Epoch 97/1000
Epoch 98/1000
Epoch 99/1000
Epoch 100/1000
Epoch 101/1000
Epoch 102/1000
Epoch 103/1000
Epoch 104/1000
Epoch 105/1000

## Testing the model on Validation Set

In [None]:
## Save the Weights
model.save_weights("TEM_sample_Model.h5")




In [None]:
#Plotting Ground Truth
plt.imshow(y[7,:,:,2], cmap = "gray")

In [None]:
plt.imshow(result[7,:,:,2], cmap = "gray") # Predicted Image

In [None]:
q = tf.Variable(0)

In [None]:
q

## ToDo - Testing the model

In [None]:
''''# Modify the lines below to use your image
path_img = Path("./TEM_striatum/data/Testing")
file_img = "image_819.png"'''

In [29]:
'''# In case you want to test the segmentation with a pre-trained model created using this notebook,
# uncomment the line below.
path_model = Path("./TEM_striatum/model/TEM_3c_512_2018-11-10_21-32-36/")

# reset the tensorflow graph for new testing
tf.reset_default_graph()
prediction = axon_segmentation(path_img, file_img, path_model, config_network, acquired_resolution=0.01, resampled_resolutions=0.01, verbosity_level=3)

SyntaxError: EOF while scanning triple-quoted string literal (<ipython-input-29-ab792477a621>, line 7)

In [30]:
'''file_img_seg = 'AxonDeepSeg.png'  # axon+myelin segmentation

img_seg = imageio.imread(path_img / file_img_seg)
img = imageio.imread(path_img / file_img)
# Note: The arguments of the two function calls above use the pathlib syntax for path concatenation.

fig, axes = plt.subplots(1,2, figsize=(13,10))
ax1, ax2 = axes[0], axes[1]
ax1.set_title('Original image')
ax1.imshow(img, cmap='gray')
ax2.set_title('Prediction with the trained model')
ax2.imshow(img_seg,cmap='gray')
plt.show()

SyntaxError: EOF while scanning triple-quoted string literal (<ipython-input-30-da19f92788a1>, line 13)