### Imports

In [1]:
import os
import numpy as np

import tensorflow
from tensorflow import keras
from tensorflow.keras.callbacks import History
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Add, concatenate, Activation, BatchNormalization, Conv2D, MaxPooling2D, Conv2DTranspose, UpSampling2D
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger
import tensorflow.keras.backend as K

### Settings

In [2]:
keras.backend.set_image_data_format('channels_last')  

img_rows = 240
img_cols = 240
smooth = 1.

### Loss Function Definition

In [3]:
def dice_coeff(y_true, y_pred):
    y_true_f = keras.backend.flatten(y_true)
    y_pred_f = keras.backend.flatten(y_pred)
    intersection = keras.backend.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)


def dice_coeff_loss(y_true, y_pred):
    return -dice_coeff(y_true, y_pred)

### ResUNet Architecture Definition

In [4]:
def Conv2DBlock(n_filters, input, padding='same', strides=2):
    bnorm = BatchNormalization()(input)
    activated = Activation('relu')(bnorm)
    conv_output = Conv2D(n_filters, (3, 3), padding=padding, strides=strides)(activated)
    return conv_output

def ContractingBlock2D(n_filters, input):

    conv_a = Conv2DBlock(n_filters, input, strides=2)
    conv_b = Conv2DBlock(n_filters, conv_a, strides=1)

    shortcut = Conv2D(n_filters, (1, 1), padding='same', strides=2)(input)
    bnorm_shortcut = BatchNormalization()(shortcut)

    output = Add()([bnorm_shortcut, conv_b])
    return output

def ExpandingBlock2D(n_filters, upconv_input, concat_input):

    upsamp = UpSampling2D((2, 2))(upconv_input)
    conv_input = concatenate([upsamp, concat_input], axis=3)

    conv_a = Conv2DBlock(n_filters, conv_input, strides=1)
    conv_b = Conv2DBlock(n_filters, conv_a, strides=1)

    shortcut = Conv2D(n_filters, (1, 1), padding='same', strides=1)(conv_b)
    bnorm_shortcut = BatchNormalization()(shortcut)

    output = Add()([bnorm_shortcut, conv_b])
    return output
    
def ResUNet():
    input_tensor = Input((img_rows, img_cols, 1))

    conv_1 = Conv2D(32, (3, 3), padding='same')(input_tensor)
    block_1 = Conv2DBlock(32, conv_1, padding='same', strides=1)
    shortcut = Conv2D(32, (1, 1), padding='same', strides=1)(input_tensor)
    bnorm_shortcut = BatchNormalization()(shortcut)
    output_1 = Add()([block_1, bnorm_shortcut])

    output_2 = ContractingBlock2D(64, output_1) 
    output_3 = ContractingBlock2D(128, output_2) 
    output_4 = ContractingBlock2D(256, output_3)

    bridge_1 = Conv2DBlock(512, output_4, strides=2) 
    bridge_2 = Conv2DBlock(512, bridge_1, strides=1) 

    output_5 = ExpandingBlock2D(256, bridge_2, output_4) 
    output_6 = ExpandingBlock2D(128, output_5, output_3) 
    output_7 = ExpandingBlock2D(64, output_6, output_2) 
    output_8 = ExpandingBlock2D(32, output_7, output_1)

    output_9 = Conv2D(1, (1, 1), activation = 'sigmoid')(output_8)

    model = Model(inputs = [input_tensor], outputs = [output_9])
    return model

In [5]:
model = ResUNet()
model.compile(optimizer = Adam(learning_rate = 1e-3), loss = dice_coeff_loss, metrics = [dice_coeff])
model_ckpt = ModelCheckpoint('final_resunet_weights.h5', monitor='val_loss', save_best_only=True)
early_stop = EarlyStopping(monitor = 'val_loss', patience = 5)

### Loading Dataset + Normalization

In [6]:
imgs_train = np.load('imgs_train.npy').astype('float32') 
masks_train = np.load('masks_train.npy').astype('float32') 

imgs_train = np.reshape(imgs_train, (imgs_train.shape[0], img_rows, img_cols, 1))
masks_train = np.reshape(masks_train, (masks_train.shape[0], img_rows, img_cols, 1))

In [7]:
mean = np.mean(imgs_train)  
std = np.std(imgs_train)

In [8]:
imgs_train -= mean
imgs_train /= std

### Model Training

In [9]:
history=model.fit(imgs_train, masks_train, batch_size=128, epochs=2, verbose=1, shuffle=True,
              validation_split=0.2,
              callbacks=[model_ckpt, early_stop])

Train on 9184 samples, validate on 2297 samples
Epoch 1/2
Epoch 2/2
