In [0]:
import os
import fnmatch
from shutil import copy2
import sys
import random
import warnings

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from itertools import chain
from skimage.io import imread, imshow, concatenate_images
from skimage.transform import resize
from skimage.morphology import label
from sklearn.model_selection import train_test_split

from keras.models import Model, load_model
from keras.layers import Input , Concatenate
from keras.layers.core import Lambda, RepeatVector, Reshape
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.pooling import MaxPooling2D
from keras.layers.merge import concatenate
from keras.layers import *
from keras.initializers import he_normal
from keras.regularizers import l2
from keras.layers import Dropout
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from keras import backend as K
from keras.optimizers import *
from keras.utils import to_categorical
import tensorflow as tf

from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img

In [0]:
im_width = 128
im_height = 128
im_chan = 1

In [0]:
if not os.path.exists("images"):
    os.makedirs("images")

if not os.path.exists("masks"):
    os.makedirs("masks")    

In [0]:
for file in os.listdir("train"):
    if fnmatch.fnmatch(file, '*mask.tif'):
        copy2("train/" + file, "masks")
    else:
        copy2("train/" + file, "images")

In [0]:
train_ids = next(os.walk("images"))[2]
test_ids = next(os.walk("test"))[2]

In [0]:
X_train = np.zeros((len(train_ids), im_height, im_width, im_chan), dtype=np.uint8)
Y_train = np.zeros((len(train_ids), im_height, im_width, 1), dtype=np.bool)
print('Getting and resizing train images and masks ... ')
sys.stdout.flush()
sum_whites = 0
for n, id_ in enumerate(train_ids):
    img = load_img('images/' + id_)
    x = img_to_array(img)[:,:,:]
    x = resize(x, (im_height, im_width, im_chan), mode='constant', preserve_range=True)
    X_train[n] = x
    mask = img_to_array(load_img('masks/' + id_[:-4] + '_mask.tif'))[:,:,:]
    Y_train[n] = resize(mask, (im_height, im_width, 1), mode='constant', preserve_range=True)
    sum_whites += np.sum(Y_train[n])

print('Done!')

print("nerve percentage in the dataset is {}".format(sum_whites*1.0 / (X_train.shape[0] * 128*128)))

In [0]:
X_train , X_val , Y_train , Y_val = train_test_split(X_train , Y_train , train_size = 0.9 , random_state = 2019)

In [0]:
image_generator = ImageDataGenerator(
    horizontal_flip=True,
    vertical_flip = True,
    zoom_range = 0.2,
    width_shift_range = 0.2,
    height_shift_range = 0.2
)

mask_generator = ImageDataGenerator(
    horizontal_flip=True,
    vertical_flip = True,
    zoom_range = 0.2,
    width_shift_range = 0.2,
    height_shift_range = 0.2
)

val_image_generator = ImageDataGenerator(
)
val_mask_generator = ImageDataGenerator(
)

train_img_gen = image_generator.flow(X_train , seed = 2018 , batch_size = 16)
train_mask_gen = mask_generator.flow(Y_train , seed = 2018 , batch_size = 16)

val_img_gen = val_image_generator.flow(X_val , seed = 2018 , batch_size = 16)
val_mask_gen = val_mask_generator.flow(Y_val , seed = 2018 , batch_size = 16)

In [0]:
train_gen = zip(train_img_gen , train_mask_gen)
val_gen = zip(val_img_gen , val_mask_gen)

In [0]:
def mean_iou(y_true, y_pred):
    prec = []
    for t in np.arange(0.5, 1.0, 0.05):
        y_pred_ = tf.to_int32(y_pred > t)
        score, up_opt = tf.metrics.mean_iou(y_true, y_pred_, 2)
        K.get_session().run(tf.local_variables_initializer())
        with tf.control_dependencies([up_opt]):
            score = tf.identity(score)
        prec.append(score)
    return K.mean(K.stack(prec), axis=0)

In [0]:
input_img = Input((im_height, im_width, im_chan), name='img')

inp = BatchNormalization()(input_img) 

c1 = Conv2D(4, (3, 3), activation='relu', padding='same') (inp)
c1 = Conv2D(4, (3, 3), activation='relu', padding='same') (c1)
p1 = MaxPooling2D((2, 2)) (c1)

c2 = Conv2D(8, (3, 3), activation='relu', padding='same') (p1)
c2 = Conv2D(8, (3, 3), activation='relu', padding='same') (c2)
p2 = MaxPooling2D((2, 2)) (c2)

c3 = Conv2D(16, (3, 3), activation='relu', padding='same') (p2)
c3 = Conv2D(16, (3, 3), activation='relu', padding='same') (c3)
p3 = MaxPooling2D((2, 2)) (c3)

c4 = Conv2D(32, (3, 3), activation='relu', padding='same') (p3)
c4 = Conv2D(32, (3, 3), activation='relu', padding='same') (c4)
p4 = MaxPooling2D((2, 2)) (c4)

c5 = Conv2D(64, (3, 3), activation='relu', padding='same') (p4)
c5 = Conv2D(64, (3, 3), activation='relu', padding='same') (c5)
p5 = MaxPooling2D((2, 2)) (c5)

c6 = Conv2D(128, (1, 1), activation='relu', padding='same') (p5)
c6 = Dropout(0.2)(c6)
c6 = Conv2D(128, (1, 1), activation='relu', padding='same') (c6)

u7 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same') (c6)
u7 = concatenate([u7, c5])
c7 = Conv2D(64, (3, 3), activation='relu', padding='same') (u7)
c7 = Dropout(0.2)(c7)
c7 = Conv2D(64, (3, 3), activation='relu', padding='same') (c7)

u8 = Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same') (c7)
u8 = concatenate([u8, c4])
c8 = Conv2D(32, (3, 3), activation='relu', padding='same') (u8)
c8 = Dropout(0.2)(c8)
c8 = Conv2D(32, (3, 3), activation='relu', padding='same') (c8)

u9 = Conv2DTranspose(16, (2, 2), strides=(2, 2), padding='same') (c8)
u9 = concatenate([u9, c3])
c9 = Conv2D(16, (3, 3), activation='relu', padding='same') (u9)
c9 = Dropout(0.2)(c9)
c9 = Conv2D(16, (3, 3), activation='relu', padding='same') (c9)

u10 = Conv2DTranspose(8, (2, 2), strides=(2, 2), padding='same') (c9)
u10 = concatenate([u10, c2])
c10 = Conv2D(8, (3, 3), activation='relu', padding='same') (u10)
c10 = Dropout(0.2)(c10)
c10 = Conv2D(8, (3, 3), activation='relu', padding='same') (c10)

u11 = Conv2DTranspose(4, (2, 2), strides=(2, 2), padding='same') (c10)
u11 = concatenate([u11, c1], axis = 3)
c11 = Conv2D(4, (3, 3), activation='relu', padding='same') (u11)
c11 = Dropout(0.2)(c11)
c11 = Conv2D(4, (3, 3), activation='relu', padding='same') (c11)

outputs = Conv2D(1, (1, 1), activation='sigmoid') (c11)

model = Model(inputs=[input_img], outputs=[outputs])

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=[mean_iou]) 

model.summary()

In [0]:
early_stopper = EarlyStopping(monitor= 'val_mean_iou', patience=10, mode='max')
lr_reducer = ReduceLROnPlateau(monitor= 'val_loss', factor=0.5, patience=5,  min_lr=1e-6, mode='max')
checkpointer = ModelCheckpoint('unet_model.h5', monitor='val_mean_iou', verbose=2, save_best_only=True , mode = 'max')

results = model.fit_generator(train_gen , steps_per_epoch = 2525 , epochs = 50,
                              validation_data = val_gen , validation_steps = 25 ,
                              callbacks=[checkpointer , lr_reducer, early_stopper] , verbose = 1)

In [0]:
X_test = np.zeros((len(test_ids), im_height, im_width, im_chan), dtype=np.uint8)
sizes_test = []
print('Getting and resizing test images ... ')
sys.stdout.flush()
for n, id_ in enumerate(test_ids):
    path = path_test
    img = load_img('test/' + id_)
    x = img_to_array(img)[:,:,:]
    sizes_test.append([x.shape[0], x.shape[1]])
    x = resize(x, (im_height, im_width, im_chan), mode='constant', preserve_range=True)
    X_test[n] = x

print('Done!')

In [0]:
model.load_weights('unet_model.h5')

In [0]:
model.evaluate(X_valid, y_valid, verbose=1)

In [0]:
preds_train = model.predict(X_train, verbose=1)
preds_val = model.predict(X_valid, verbose=1)
preds_test = model.predict(X_test, verbose=1)

preds_train_t = (preds_train > 0.5).astype(np.uint8)
preds_val_t = (preds_val > 0.5).astype(np.uint8)
preds_test_t = (preds_test > 0.5).astype(np.uint8)

In [0]:
preds_test_upsampled = []
for i in tnrange(len(preds_test)):
    preds_test_upsampled.append(resize(np.squeeze(preds_test[i]), 
                                       (sizes_test[i][0], sizes_test[i][1]), 
                                       mode='constant', preserve_range=True))

In [0]:
def plot_sample(X, y, preds):
    ix = random.randint(0, len(X))

    has_mask = y[ix].max() > 0

    fig, ax = plt.subplots(1, 4, figsize=(20, 10))
    ax[0].imshow(X[ix, ..., 0], cmap='seismic')
    if has_mask:
        ax[0].contour(y[ix].squeeze(), colors='k', levels=[0.5])
    ax[0].set_title('Seismic')

    ax[1].imshow(X[ix, ..., 1], cmap='seismic')
    if has_mask:
        ax[1].contour(y[ix].squeeze(), colors='k', levels=[0.5])
    ax[1].set_title('Seismic cumsum')

    ax[2].imshow(y[ix].squeeze())
    ax[2].set_title('Salt')

    ax[3].imshow(preds[ix].squeeze(), vmin=0, vmax=1)
    if has_mask:
        ax[3].contour(y[ix].squeeze(), colors='k', levels=[0.5])
    ax[3].set_title('Salt Pred');

In [0]:
plot_sample(X_train, y_train, preds_train)

In [0]:
plot_sample(X_valid, y_valid, preds_val)