# U-Net Implementation

In [1]:
import tensorflow as tf
import keras.backend as K
from keras.models import Sequential, Model
from keras.layers import *
from keras.applications import *
from collections import defaultdict

  return f(*args, **kwds)
Using TensorFlow backend.


In [2]:
def down_block(x, filters, kernel_size=3, padding='same', strides=1, activation='relu'):
    'down sampling block of our UNet'
    conv = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation=activation)(x)
    conv = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation=activation)(conv)
    pool = MaxPool2D((2,2), (2,2))(conv)
    return conv, pool


def up_block(x, skip, filters, kernel_size=3, padding='same', strides=1, activation='relu'):
    'up sampling block of our UNet'
    up_sample = UpSampling2D((2,2))(x)
    concat = Concatenate()([up_sample, skip])
    conv = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation=activation)(concat)
    conv = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation=activation)(conv)
    return conv


def bottleneck(x, filters, kernel_size=3, padding='same', strides=1, activation='relu'):
    'bottle neck that sits inbetween the down sampling side and the up sampling side'
    conv = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation=activation)(x)
    conv = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation=activation)(conv)
    return conv

In [3]:
def UNet(img_shape, filters_per_block):
    
    layers = defaultdict()
    num_down_blocks = len(filters_per_block) - 1
    inputs = Input(img_shape)

    layers['p0'] = inputs

    for index, num_filters in enumerate(filters_per_block[:-1]):

        layers['c{}'.format(index+1)], layers['p{}'.format(index+1)] = down_block(layers['p{}'.format(index)], filters_per_block[index])


    bn = bottleneck(layers['p{}'.format(num_down_blocks)], filters_per_block[num_down_blocks])

    for index, num_filters in enumerate(filters_per_block[::-1][:-1]):

        if index == 0:
            layers['u1'] = up_block(bn, layers['c{}'.format(num_down_blocks)], filters_per_block[num_down_blocks-1])
        else:
            layers['u{}'.format(index+1)] = up_block(layers['u{}'.format(index)], 
                                            layers['c{}'.format(num_down_blocks-index)], 
                                            filters_per_block[num_down_blocks-index-1])

    outputs = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(layers['u{}'.format(num_down_blocks)])
    model = Model(inputs, outputs)

    return model

In [4]:
smooth = 1.

def dice_coef(y_true, y_pred):
    y_true_f = tf.keras.layers.Flatten()(y_true)
    y_pred_f = tf.keras.layers.Flatten()(y_pred)
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)


def dice_coef_loss(y_true, y_pred):
    return 1.0 - dice_coef(y_true, y_pred)

In [5]:
model = UNet(img_shape=(256, 256, 3), filters_per_block=[32, 64, 128, 256])
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 256, 256, 3)  0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 256, 256, 32) 896         input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 256, 256, 32) 9248        conv2d_1[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 128, 128, 32) 0           conv2d_2[0][0]                   
__________________________________________________________________________________________________
conv2d_3 (