In [1]:
import datetime
import numpy as np
import os 

from keras.models import *
from keras.layers import *
from keras.optimizers import *
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
from keras import backend as keras
from keras.preprocessing.image import ImageDataGenerator

Using TensorFlow backend.


In [21]:
ultrasound_fullname = r"j:\Data\ProstateCatheter\Stacked Arrays\Training\stacked_image_array.npy"
segmentation_fullname = r"j:\Data\ProstateCatheter\Stacked Arrays\Training\stacked_segmentation_array.npy"

export_folder = r"j:\Temp"

ultrasound_data = np.load(ultrasound_fullname)
segmentation_data = np.load(segmentation_fullname)

num_ultrasound = ultrasound_data.shape[0]
num_segmentation = segmentation_data.shape[0]

print("\nFound {} ultrasound images and {} segmentations".format(num_ultrasound, num_segmentation))


Found 24 ultrasound images and 24 segmentations


In [3]:
test_ultrasound_fullname = r"j:\Data\ProstateCatheter\Stacked Arrays\Test\test_image_array.npy"
test_segmentation_fullname = r"j:\Data\ProstateCatheter\Stacked Arrays\Test\test_segmentation_array.npy"

print("Reading test ultrasound from: {}".format(test_ultrasound_fullname))
print("Reading test segmentation from : {}".format(test_segmentation_fullname))

test_ultrasound_data = np.load(test_ultrasound_fullname)
test_segmentation_data = np.load(test_segmentation_fullname)

num_test_ultrasound = test_ultrasound_data.shape[0]
num_test_segmentation = test_segmentation_data.shape[0]

print("\nFound {} test ultrasound images and {} segmentations".format(num_test_ultrasound, num_test_segmentation))

Reading test ultrasound from: j:\Data\ProstateCatheter\Stacked Arrays\Test\test_image_array.npy
Reading test segmentation from : j:\Data\ProstateCatheter\Stacked Arrays\Test\test_segmentation_array.npy

Found 6 test ultrasound images and 6 segmentations


In [4]:
# Batch Generator

import keras.utils
import scipy.ndimage

max_rotation_angle = 10
max_shift = 0.2
max_zoom = 0.2

class UltrasoundSegmentationBatchGenerator(keras.utils.Sequence):
    
    def __init__(self,
                 x_set,
                 y_set,
                 batch_size,
                 image_dimensions=(128, 128, 128),
                 shuffle=True,
                 n_channels=1,
                 n_classes=2):
        self.x = x_set
        self.y = y_set
        self.batch_size = batch_size
        self.image_dimensions = image_dimensions
        self.shuffle = shuffle
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.number_of_images = self.x.shape[0]
        self.indexes = np.arange(self.number_of_images)
        if self.shuffle == True:
            np.random.shuffle(self.indexes)
            
    def __len__(self):
        return int(np.floor(self.number_of_images / self.batch_size))
    
    def on_epoch_end(self):
        self.indexes = np.arange(self.number_of_images)
        if self.shuffle == True:
            np.random.shuffle(self.indexes)
    
    def __getitem__(self, index):
        batch_indexes = self.indexes[index*self.batch_size : (index+1)*self.batch_size]
        x = np.empty((self.batch_size, *self.image_dimensions, self.n_channels))
        y = np.empty((self.batch_size, *self.image_dimensions))
        
        for i in range(self.batch_size):
            flip_flag = np.random.randint(2)
            x[i,:,:,:,:] = np.flip(self.x[batch_indexes[i],:,:,:,:])
            y[i,:,:,:]= np.flip(self.y[batch_indexes[i],:,:,:])
            
        angle_x = np.random.randint(-max_rotation_angle, max_rotation_angle)
        # rotate x-axis
        x_rot = scipy.ndimage.interpolation.rotate(x, angle_x, (1,2), False, mode="constant", cval=0, order=0)
        y_rot = scipy.ndimage.interpolation.rotate(y, angle_x, (1,2), False, mode="constant", cval=0, order=0)
        
        angle_y = np.random.randint(-max_rotation_angle, max_rotation_angle)
        #rotate y-axis
        x_rot = scipy.ndimage.interpolation.rotate(x, angle_y, (0,2), False, mode="constant", cval=0, order=0)
        y_rot = scipy.ndimage.interpolation.rotate(y, angle_y, (0,2), False, mode="constant", cval=0, order=0)
        
        angle_z = np.random.randint(-max_rotation_angle, max_rotation_angle)
        #rotate z-axis
        x_rot = scipy.ndimage.interpolation.rotate(x, angle_z, (0,1), False, mode="constant", cval=0, order=0)
        y_rot = scipy.ndimage.interpolation.rotate(y, angle_z, (0,1), False, mode="constant", cval=0, order=0)

        
        #shift = np.random.uniform(-max_shift, max_shift)
        #x_shift = scipy.ndimage.interpolation.shift(x_rot, shift)
        #y_shift = scipy.ndimage.interpolation.shift(y_rot, shift)
        
        # zoom = np.random.uniform(-max_zoom, max_zoom)
        # x_zoom = scipy.ndimage.interpolation.zoom(x_shift, zoom)
        # y_zoom = scipy.ndimage.interpolation.zoom(y_shift, zoom)
    
        x_aug = np.clip(x_rot, 0.0, 1.0)
        y_aug = np.clip(y_rot, 0.0, 1.0)
        
        y_onehot = keras.utils.to_categorical(y_aug, self.n_classes)

        return x_aug, y_onehot
        

In [5]:
# Prepare dilated output

def dilateStack(segmentation_data, iterations):
    return np.array([scipy.ndimage.binary_dilation(y, iterations=iterations) for y in segmentation_data])

width = 2
segmentation_dilated = dilateStack(segmentation_data[:, :, :, :, 0], width)

In [6]:
# Uncomment this if you don't want dilation

segmentation_dilated[:, :, :, :] = segmentation_data[:, :, :, :, 0]

In [7]:
num_classes = 2

def nvidia_unet(patch_size=128, num_classes=num_classes):
    input_ = Input((128, 128, 128, 1))
    skips = []
    output = input_
    c = num_classes
    
    for shape, filters in zip([5, 3, 3, 3, 3, 3, 3], [8, 16, 32, 32, 32, 32, 32]):
        skips.append(output)
        #print("pre_skip")
        #print(output)
        #print(shape)
        output= Conv3D(filters, (3, 3, 3), strides=2, padding="same", activation="relu")(output)
        #print("output3d")
        #print(output)
    
    # output = keras.layers.UpSampling3D(size=(1, 2, 2))(output)
    for shape, filters in zip([4, 4, 4, 4, 4, 4, 4], [32, 32, 32, 32, 16, 8, 2]):
        #print(output.shape)
        output = keras.layers.UpSampling3D()(output)
        #print("output2.0:")
        #print(output)
        skip_output = skips.pop()
        output = concatenate([output, skip_output], axis=4)

        if filters != c:
            activation = "relu"
        else:
            activation = "softmax"
        output = Conv3D(filters, (3, 3, 3), activation=activation, padding="same")(output)
        if filters != c:
            output = BatchNormalization(momentum=.9)(output)
        
        print(output)
    
    assert len(skips) == 0
    return Model([input_], [output])

model = nvidia_unet(128, num_classes)

# model.summary()

Instructions for updating:
Colocations handled automatically by placer.
Tensor("batch_normalization_1/cond/Merge:0", shape=(?, 2, 2, 2, 32), dtype=float32)
Tensor("batch_normalization_2/cond/Merge:0", shape=(?, 4, 4, 4, 32), dtype=float32)
Tensor("batch_normalization_3/cond/Merge:0", shape=(?, 8, 8, 8, 32), dtype=float32)
Tensor("batch_normalization_4/cond/Merge:0", shape=(?, 16, 16, 16, 32), dtype=float32)
Tensor("batch_normalization_5/cond/Merge:0", shape=(?, 32, 32, 32, 16), dtype=float32)
Tensor("batch_normalization_6/cond/Merge:0", shape=(?, 64, 64, 64, 8), dtype=float32)
Tensor("conv3d_14/truediv:0", shape=(?, 128, 128, 128, 2), dtype=float32)


In [8]:
print("Model built with {} parameters".format(model.count_params()))

Model built with 376624 parameters


In [9]:
max_learning_rate = 0.01
min_learning_rate = 0.00001
num_epochs = 250

learning_rate_decay = (max_learning_rate - min_learning_rate) / num_epochs

model.compile(optimizer=keras.optimizers.adam(lr=max_learning_rate, decay=learning_rate_decay),
               loss= "binary_crossentropy",
               metrics=["accuracy"])

print("Learning rate decay = {}".format(learning_rate_decay))

Learning rate decay = 3.9960000000000004e-05


In [10]:
batch_size = 3

training_generator = UltrasoundSegmentationBatchGenerator(ultrasound_data, segmentation_data[:, :, :, :, 0], batch_size)
test_generator = UltrasoundSegmentationBatchGenerator(test_ultrasound_data, test_segmentation_data[:, :, :, :, 0], batch_size)

training_time_start = datetime.datetime.now()

training_log = model.fit_generator(training_generator,
                                   validation_data=test_generator,
                                   epochs=num_epochs,
                                   verbose=1)

Instructions for updating:
Use tf.cast instead.
Instructions for updating:
Deprecated in favor of operator or tf.math.divide.
Epoch 1/250
Epoch 2/250
Epoch 3/250
Epoch 4/250
Epoch 5/250
Epoch 6/250
Epoch 7/250
Epoch 8/250
Epoch 9/250
Epoch 10/250
Epoch 11/250
Epoch 12/250
Epoch 13/250
Epoch 14/250
Epoch 15/250
Epoch 16/250
Epoch 17/250
Epoch 18/250
Epoch 19/250
Epoch 20/250
Epoch 21/250
Epoch 22/250
Epoch 23/250
Epoch 24/250
Epoch 25/250
Epoch 26/250
Epoch 27/250
Epoch 28/250
Epoch 29/250
Epoch 30/250
Epoch 31/250
Epoch 32/250
Epoch 33/250
Epoch 34/250
Epoch 35/250
Epoch 36/250
Epoch 37/250
Epoch 38/250
Epoch 39/250
Epoch 40/250
Epoch 41/250
Epoch 42/250
Epoch 43/250
Epoch 44/250
Epoch 45/250
Epoch 46/250
Epoch 47/250
Epoch 48/250
Epoch 49/250
Epoch 50/250
Epoch 51/250
Epoch 52/250
Epoch 53/250
Epoch 54/250
Epoch 55/250
Epoch 56/250
Epoch 57/250
Epoch 58/250
Epoch 59/250
Epoch 60/250


Epoch 61/250
Epoch 62/250
Epoch 63/250
Epoch 64/250
Epoch 65/250
Epoch 66/250
Epoch 67/250
Epoch 68/250
Epoch 69/250
Epoch 70/250
Epoch 71/250
Epoch 72/250
Epoch 73/250
Epoch 74/250
Epoch 75/250
Epoch 76/250
Epoch 77/250
Epoch 78/250
Epoch 79/250
Epoch 80/250
Epoch 81/250
Epoch 82/250
Epoch 83/250
Epoch 84/250
Epoch 85/250
Epoch 86/250
Epoch 87/250
Epoch 88/250
Epoch 89/250
Epoch 90/250
Epoch 91/250
Epoch 92/250
Epoch 93/250
Epoch 94/250
Epoch 95/250
Epoch 96/250
Epoch 97/250
Epoch 98/250
Epoch 99/250
Epoch 100/250
Epoch 101/250
Epoch 102/250
Epoch 103/250
Epoch 104/250
Epoch 105/250
Epoch 106/250
Epoch 107/250
Epoch 108/250
Epoch 109/250
Epoch 110/250
Epoch 111/250
Epoch 112/250
Epoch 113/250
Epoch 114/250
Epoch 115/250
Epoch 116/250
Epoch 117/250
Epoch 118/250
Epoch 119/250
Epoch 120/250
Epoch 121/250
Epoch 122/250


Epoch 123/250
Epoch 124/250
Epoch 125/250
Epoch 126/250
Epoch 127/250
Epoch 128/250
Epoch 129/250
Epoch 130/250
Epoch 131/250
Epoch 132/250
Epoch 133/250
Epoch 134/250
Epoch 135/250
Epoch 136/250
Epoch 137/250
Epoch 138/250
Epoch 139/250
Epoch 140/250
Epoch 141/250
Epoch 142/250
Epoch 143/250
Epoch 144/250
Epoch 145/250
Epoch 146/250
Epoch 147/250
Epoch 148/250
Epoch 149/250
Epoch 150/250
Epoch 151/250
Epoch 152/250
Epoch 153/250
Epoch 154/250
Epoch 155/250
Epoch 156/250
Epoch 157/250
Epoch 158/250
Epoch 159/250
Epoch 160/250
Epoch 161/250
Epoch 162/250
Epoch 163/250
Epoch 164/250
Epoch 165/250
Epoch 166/250
Epoch 167/250
Epoch 168/250
Epoch 169/250
Epoch 170/250
Epoch 171/250
Epoch 172/250
Epoch 173/250
Epoch 174/250
Epoch 175/250
Epoch 176/250
Epoch 177/250
Epoch 178/250
Epoch 179/250
Epoch 180/250
Epoch 181/250
Epoch 182/250
Epoch 183/250
Epoch 184/250
Epoch 185/250


Epoch 186/250
Epoch 187/250
Epoch 188/250
Epoch 189/250
Epoch 190/250
Epoch 191/250
Epoch 192/250
Epoch 193/250
Epoch 194/250
Epoch 195/250
Epoch 196/250
Epoch 197/250
Epoch 198/250
Epoch 199/250
Epoch 200/250
Epoch 201/250
Epoch 202/250
Epoch 203/250
Epoch 204/250
Epoch 205/250
Epoch 206/250
Epoch 207/250
Epoch 208/250
Epoch 209/250
Epoch 210/250
Epoch 211/250
Epoch 212/250
Epoch 213/250
Epoch 214/250
Epoch 215/250
Epoch 216/250
Epoch 217/250
Epoch 218/250
Epoch 219/250
Epoch 220/250
Epoch 221/250
Epoch 222/250
Epoch 223/250
Epoch 224/250
Epoch 225/250
Epoch 226/250
Epoch 227/250
Epoch 228/250
Epoch 229/250
Epoch 230/250
Epoch 231/250
Epoch 232/250
Epoch 233/250
Epoch 234/250
Epoch 235/250
Epoch 236/250
Epoch 237/250
Epoch 238/250
Epoch 239/250
Epoch 240/250
Epoch 241/250
Epoch 242/250
Epoch 243/250
Epoch 244/250
Epoch 245/250
Epoch 246/250
Epoch 247/250


Epoch 248/250
Epoch 249/250
Epoch 250/250


In [11]:
training_time_stop = datetime.datetime.now()
print("Training started at: {}".format(training_time_start))
print("Training stopped at: {}".format(training_time_stop))
print("Total training time: {}".format(training_time_stop-training_time_start))

Training started at: 2019-07-26 00:27:53.292133
Training stopped at: 2019-07-26 02:09:31.485691
Total training time: 1:41:38.193558


In [12]:
import datetime

timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

saved_models_folder = r"j:\Data\SavedModels"
model_file_name = "model_" + timestamp + ".h5"
weights_file_path = os.path.join(saved_models_folder, model_file_name)

model.save(weights_file_path)
print("Model saved to: {}".format(weights_file_path))

Model saved to: j:\Data\SavedModels\model_2019-07-26_02-09-31.h5


In [13]:
trained_model = load_model(r"j:\Data\SavedModels\model_2019-07-26_00-23-49.h5")
# print(weights_file_path)

In [14]:
y_pred = trained_model.predict(test_ultrasound_data)

In [15]:
%matplotlib notebook

In [16]:
# Display training loss and accuracy curves over epochs

import matplotlib.pyplot as plt

plt.plot(training_log.history['loss'], 'bo--')
plt.plot(training_log.history['val_loss'], 'ro-')
plt.ylabel('Loss')
plt.xlabel('Epochs (n)')
plt.legend(['Training loss', 'Validation loss'])
plt.show()

plt.plot(training_log.history['acc'], 'bo--')
plt.plot(training_log.history['val_acc'], 'ro-')
plt.ylabel('Accuracy')
plt.xlabel('Epochs (n)')
plt.legend(['Training accuracy', 'Validation accuracy'])
plt.show()

<IPython.core.display.Javascript object>

In [17]:
# Multi-slice view code extracted and adapted from: https://www.datacamp.com/community/tutorials/matplotlib-3d-volumetric-data
import matplotlib.pyplot as plt

def multi_slice_viewer(volume):
    remove_keymap_conflicts({'j', 'k'})
    fig, ax = plt.subplots()
    ax.volume = volume
    ax.index = volume.shape[0] // 2
    ax.imshow(volume[ax.index])
    fig.canvas.mpl_connect('key_press_event', process_key)

def process_key(event):
    fig = event.canvas.figure
    ax = fig.axes[0]
    if event.key == 'j':
        previous_slice(ax)
    elif event.key == 'k':
        next_slice(ax)
    fig.canvas.draw()

def previous_slice(ax):
    volume = ax.volume
    ax.index = (ax.index - 1) % volume.shape[0]  # wrap around using %
    ax.images[0].set_array(volume[ax.index])

def next_slice(ax):
    volume = ax.volume
    ax.index = (ax.index + 1) % volume.shape[0]
    ax.images[0].set_array(volume[ax.index])

def remove_keymap_conflicts(new_keys_set):
    for prop in plt.rcParams:
        if prop.startswith('keymap.'):
            keys = plt.rcParams[prop]
            remove_list = set(keys) & new_keys_set
            for key in remove_list:
                keys.remove(key)

In [18]:
# ultrasound
ultrasound_img = test_ultrasound_data[0]
print(ultrasound_img.shape)
multi_slice_viewer(ultrasound_img[:, :, :, 0])

(128, 128, 128, 1)


<IPython.core.display.Javascript object>

In [19]:
# Segmentation
segmentation_img = test_segmentation_data[0]
print(segmentation_img.shape)
multi_slice_viewer(segmentation_img[:, :, :, 0])

(128, 128, 128, 1)


<IPython.core.display.Javascript object>

In [20]:
# Prediction
predicted_img = y_pred[0]
print(predicted_img.shape)
multi_slice_viewer(predicted_img[:, :, :, 1])

(128, 128, 128, 2)


<IPython.core.display.Javascript object>

In [23]:
# Exporting volumes

export_index = 0

ultrasound_fullname = os.path.join(export_folder, "ultraosund_" + str(export_index))
segmentation_fullname = os.path.join(export_folder, "segmentation_" + str(export_index))
prediction_fullname = os.path.join(export_folder, "prediction_" + str(export_index))

np.save(ultrasound_fullname, ultrasound_img[:, :, :, 0])
np.save(segmentation_fullname, segmentation_img[:, :, :, 0])
np.save(prediction_fullname, predicted_img[:, :, :, 1])

print("Sample ultrasound saved to:   {}".format(ultrasound_fullname))
print("Sample segmentation saved to: {}".format(segmentation_fullname))
print("Sample prediction saved to    {}".format(prediction_fullname))

Sample ultrasound saved to:   j:\Temp\ultraosund_0
Sample segmentation saved to: j:\Temp\segmentation_0
Sample prediction saved to    j:\Temp\prediction_0


In [27]:
predicted_img.dtype

dtype('float32')