In [None]:
import tensorflow as tf
import numpy as np

from keras.models import Model
from keras.optimizers import *
from keras.layers import *

In [None]:
# The UNet Architecture

def unet(input_size=(572,572,1)):
    inputs = Input(input_size)
        
    conv1 = conv_relu(inputs, 64)
    conv1 = conv_relu(conv1, 64)
        
    pool1 = maxpool(conv1)
    conv2 = conv_relu(pool1, 128)
    conv2 = conv_relu(conv2, 128)

    pool2 = maxpool(conv2)
    conv3 = conv_relu(pool2, 256)
    conv3 = conv_relu(conv3, 256)
        
    pool3 = maxpool(conv3)
    conv4 = conv_relu(pool3, 512)
    conv4 = conv_relu(conv4, 512)

    pool4 = maxpool(conv4)
    conv5 = conv_relu(pool4, 1024)
    conv5 = conv_relu(conv5, 1024)
        
    upconv1 = upconv(conv5, 512)
    merge1 = copy(crop(conv4, 56/64), upconv1)
    conv6 = conv_relu(merge1, 512)
    conv6 = conv_relu(conv6, 512)
        
    upconv2 = upconv(conv6, 256)
    merge2 = copy(crop(conv3, 104/136), upconv2)
    conv7 = conv_relu(merge2, 256)
    conv7 = conv_relu(conv7, 256)
        
    upconv3 = upconv(conv7, 128)
    merge3 = copy(crop(conv2, 200/280), upconv3)
    conv8 = conv_relu(merge3, 128)
    conv8 = conv_relu(conv8, 128)
        
    upconv4 = upconv(conv8, 64)
    merge4 = copy(crop(conv1, 392/568), upconv4)
    conv9 = conv_relu(merge4, 64)
    conv9 = conv_relu(conv9, 64)
        
    output_map = conv_relu(conv9, 2, ksize=[1,1])
        
    model = Model(input=inputs, output=output_map)
    model.compile(optimizer = Adam(lr = 1e-4), loss = 'binary_crossentropy', metrics = ['accuracy'])
        
    return model
            
        

In [None]:
# helper functions

def conv_relu(image, num_filters, ksize=3, stride=1, pad="VALID"):
    conv = Conv2D(filters=num_filters, kernel_size=ksize, 
                            strides=stride, padding=pad, activation='relu', 
                            kernel_initializer='he_normal') (image)
    return conv


def maxpool(image, kernel=2, stride=2, pad="VALID"):
    pool = MaxPool2D(pool_size=kernel, strides=stride, padding=pad) (image)
    return pool


def upconv(image, num_filters, ksize=2, stride=1, pad="VALID"):
    deconv = Conv2DTranspose(filters=num_filters, kernel_size=ksize,
                                        strides=stride, padding=pad) (image)
    return deconv


def copy(map1, map2):
    merged_maps = tf.concat([map1, map2], -1)
    return merged_maps


def crop(image, fraction):
    cropped_image = tf.image.central_crop(image, fraction)
    return cropped_image


In [None]:
model = unet()