In [1]:
import numpy as np 
import pandas as pd 
import skimage, os
import SimpleITK as sitk
from scipy import ndimage
import matplotlib.pyplot as plt
import os
import glob
import zarr
from sklearn.utils import shuffle
import time

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import tensorflow as tf
from keras import backend as K
K.set_image_dim_ordering('th') 

from keras.models import Sequential,load_model,Model
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Convolution2D, MaxPooling2D
from keras.layers import Input, merge, UpSampling2D
from keras.optimizers import Adam, SGD, RMSprop, Nadam
from keras.preprocessing.image import ImageDataGenerator

from keras.layers.convolutional import Convolution3D, MaxPooling3D, UpSampling3D
from keras.layers import BatchNormalization
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.layers.core import SpatialDropout3D
from keras.models import load_model

import warnings
warnings.filterwarnings('ignore')

from utils_3d import *
from paths import *

def unet_model():
    
    inputs = Input(shape=(1, max_slices, img_size, img_size))
    conv1 = Convolution3D(width, 3, 3, 3, activation = 'relu', border_mode='same')(inputs)
    conv1 = BatchNormalization(axis = 1)(conv1)
    conv1 = Convolution3D(width*2, 3, 3, 3, activation = 'relu', border_mode='same')(conv1)
    conv1 = BatchNormalization(axis = 1)(conv1)
    pool1 = MaxPooling3D(pool_size=(2, 2, 2), strides = (2, 2, 2), border_mode='same')(conv1)
    
    conv2 = Convolution3D(width*2, 3, 3, 3, activation = 'relu', border_mode='same')(pool1)
    conv2 = BatchNormalization(axis = 1)(conv2)
    conv2 = Convolution3D(width*4, 3, 3, 3, activation = 'relu', border_mode='same')(conv2)
    conv2 = BatchNormalization(axis = 1)(conv2)
    pool2 = MaxPooling3D(pool_size=(2, 2, 2), strides = (2, 2, 2), border_mode='same')(conv2)

    conv3 = Convolution3D(width*4, 3, 3, 3, activation = 'relu', border_mode='same')(pool2)
    conv3 = BatchNormalization(axis = 1)(conv3)
    conv3 = Convolution3D(width*8, 3, 3, 3, activation = 'relu', border_mode='same')(conv3)
    conv3 = BatchNormalization(axis = 1)(conv3)
    pool3 = MaxPooling3D(pool_size=(2, 2, 2), strides = (2, 2, 2), border_mode='same')(conv3)
    
    conv4 = Convolution3D(width*8, 3, 3, 3, activation = 'relu', border_mode='same')(pool3)
    conv4 = BatchNormalization(axis = 1)(conv4)
    conv4 = Convolution3D(width*8, 3, 3, 3, activation = 'relu', border_mode='same')(conv4)
    conv4 = BatchNormalization(axis = 1)(conv4)
    conv4 = Convolution3D(width*16, 3, 3, 3, activation = 'relu', border_mode='same')(conv4)
    conv4 = BatchNormalization(axis = 1)(conv4)

    up5 = merge([UpSampling3D(size=(2, 2, 2))(conv4), conv3], mode='concat', concat_axis=1)
    conv5 = SpatialDropout3D(dropout_rate)(up5)
    conv5 = Convolution3D(width*8, 3, 3, 3, activation = 'relu', border_mode='same')(conv5)
    conv5 = Convolution3D(width*8, 3, 3, 3, activation = 'relu', border_mode='same')(conv5)
    
    up6 = merge([UpSampling3D(size=(2, 2, 2))(conv5), conv2], mode='concat', concat_axis=1)
    conv6 = SpatialDropout3D(dropout_rate)(up6)
    conv6 = Convolution3D(width*4, 3, 3, 3, activation = 'relu', border_mode='same')(conv6)
    conv6 = Convolution3D(width*4, 3, 3, 3, activation = 'relu', border_mode='same')(conv6)

    up7 = merge([UpSampling3D(size=(2, 2, 2))(conv6), conv1], mode='concat', concat_axis=1)
    conv7 = SpatialDropout3D(dropout_rate)(up7)
    conv7 = Convolution3D(width*2, 3, 3, 3, activation = 'relu', border_mode='same')(conv7)
    conv7 = Convolution3D(width*2, 3, 3, 3, activation = 'relu', border_mode='same')(conv7)
    conv8 = Convolution3D(1, 1, 1, 1, activation='sigmoid')(conv7)

    model = Model(input=inputs, output=conv8)
    model.compile(optimizer=Adam(lr=1e-5), 
                  loss=dice_coef_loss, metrics=[dice_coef])

    return model


def generate_train(start, end, seed = None):
    size_3d = 128
    size = 128
    lungs = sorted(glob.glob(src + 'lung_mask/*.npy'))[start:end]
    nods = sorted(glob.glob(src + 'nodule_mask/*.npy'))[start:end]
    while True:
        print('Shuffling data')
        lungs, nods = shuffle(lungs, nods)
        for i in range(len(lungs)):
            lung_3d = np.full((1, 1, size_3d, size, size), 7.).astype('float32')
            nodule_3d = np.zeros((1, 1, size_3d, size, size)).astype('float32')
            lung = np.load(lungs[i]).astype('float32')
            nod = np.load(nods[i]).astype('float32')
            lung = lung.swapaxes(1, 0)
            nod = nod.swapaxes(1, 0)
            num_slices = lung.shape[1]
            offset = (size_3d - num_slices)
            if offset == 0:
                lung_3d[0, :, :, :, :] = lung[:, :, :, :]
                nodule_3d[0, :, :, :, :] = nod[:, :, :, :]
            if offset > 0:
                begin_offset = int(np.round(offset/2))
                end_offset = int(offset - begin_offset)
                lung_3d[0, :, begin_offset:-end_offset, :, :] = lung[:, :, :, :]
                nodule_3d[0, :, begin_offset:-end_offset, :, :] = nod[:, :, :, :]
            if offset < 0:
                print('{} slices lost due to size restrictions'.format(offset))
                offset = -(size_3d - num_slices)
                begin_offset = int(np.round(offset/2))
                end_offset = int(offset - begin_offset)
                lung_3d[0, :, :, :, :] = lung[:, begin_offset:-end_offset, :, :]
                nodule_3d[0, :, :, :, :] = nod[:, begin_offset:-end_offset, :, :]
                del lung, nod
               
            yield(lung_3d, nodule_3d)
            
def generate_val(start, end, seed = None):
    size_3d = 128
    size = 128
    lungs = sorted(glob.glob(src + 'lung_mask/*.npy'))[start:end]
    nods = sorted(glob.glob(src + 'nodule_mask/*.npy'))[start:end]
    while True:
        for i in range(len(lungs)):
            lung_3d = np.full((1, 1, size_3d, size, size), 7.).astype('float32')
            nodule_3d = np.zeros((1, 1, size_3d, size, size)).astype('float32')
            lung = np.load(lungs[i]).astype('float32')
            nod = np.load(nods[i]).astype('float32')
            lung = lung.swapaxes(1, 0)
            nod = nod.swapaxes(1, 0)
            num_slices = lung.shape[1]
            offset = (size_3d - num_slices)
            if offset == 0:
                lung_3d[0, :, :, :, :] = lung[:, :, :, :]
                nodule_3d[0, :, :, :, :] = nod[:, :, :, :]
            if offset > 0:
                begin_offset = int(np.round(offset/2))
                end_offset = int(offset - begin_offset)
                lung_3d[0, :, begin_offset:-end_offset, :, :] = lung[:, :, :, :]
                nodule_3d[0, :, begin_offset:-end_offset, :, :] = nod[:, :, :, :]
            if offset < 0:
                print('{} slices lost due to size restrictions'.format(offset))
                offset = -(size_3d - num_slices)
                begin_offset = int(np.round(offset/2))
                end_offset = int(offset - begin_offset)
                lung_3d[0, :, :, :, :] = lung[:, begin_offset:-end_offset, :, :]
                nodule_3d[0, :, :, :, :] = nod[:, begin_offset:-end_offset, :, :]
                del lung, nod
   
            yield(lung_3d, nodule_3d)


def unet_fit(name, start_t, end_t, start_v, end_v, check_name = None):
    
    t = time.time()
    callbacks = [EarlyStopping(monitor='val_loss', patience = 15, 
                                   verbose = 1),
    ModelCheckpoint('/Volumes/solo/ali/Data/model/{}.h5'.format(name), 
                        monitor='val_loss', 
                        verbose = 0, save_best_only = True)]
    
    if check_name is not None:
        check_model = '/Volumes/solo/ali/Data/model/{}.h5'.format(check_name)
        model = load_model(check_model, 
                           custom_objects={'dice_coef_loss': dice_coef_loss, 'dice_coef': dice_coef})
    else:
        model = unet_model()

    model.fit_generator(generate_train(start_t, end_t), nb_epoch = 150, verbose = 1, 
                        validation_data = generate_val(start_v, end_v), 
                        callbacks = callbacks,
                        samples_per_epoch = 551, nb_val_samples = 50)
        
    return

Using TensorFlow backend.


In [2]:
src = mask_train
max_slices = 128
img_size = 128
dropout_rate = 0.5
width = 8

img_rows = img_size
img_cols = img_size


#unet_fit('3DUNet_genfulldata_patients_merged', 0, 551, 551, 601)
unet_fit('3DUNet_genfulldata_patients_merged_cont', 0, 551, 551, 601, check_name = None)

Epoch 1/150
Shuffling data
-43 slices lost due to size restrictions
Epoch 2/150
 65/551 [==>...........................] - ETA: 673s - loss: -0.0175 - dice_coef: 0.0175-15 slices lost due to size restrictions
105/551 [====>.........................] - ETA: 618s - loss: -0.0203 - dice_coef: 0.0203-33 slices lost due to size restrictions
110/551 [====>.........................] - ETA: 611s - loss: -0.0195 - dice_coef: 0.0195-79 slices lost due to size restrictions
-43 slices lost due to size restrictions
Epoch 3/150
-43 slices lost due to size restrictions
Epoch 4/150
120/551 [=====>........................] - ETA: 607s - loss: -0.1082 - dice_coef: 0.1082-73 slices lost due to size restrictions
-43 slices lost due to size restrictions
-38 slices lost due to size restrictions
Epoch 5/150
 24/551 [>.............................] - ETA: 738s - loss: -0.2464 - dice_coef: 0.2464-33 slices lost due to size restrictions
-38 slices lost due to size restrictions
-43 slices lost due to size restri

-43 slices lost due to size restrictions
-38 slices lost due to size restrictions
Epoch 10/150
 69/551 [==>...........................] - ETA: 698s - loss: -0.1349 - dice_coef: 0.1349-79 slices lost due to size restrictions
-38 slices lost due to size restrictions
-43 slices lost due to size restrictions
Epoch 11/150
 23/551 [>.............................] - ETA: 769s - loss: -0.3150 - dice_coef: 0.3150-15 slices lost due to size restrictions
 89/551 [===>..........................] - ETA: 673s - loss: -0.2307 - dice_coef: 0.2307-33 slices lost due to size restrictions
126/551 [=====>........................] - ETA: 624s - loss: -0.2203 - dice_coef: 0.2203-73 slices lost due to size restrictions
-43 slices lost due to size restrictions
Epoch 12/150
 99/551 [====>.........................] - ETA: 665s - loss: -0.2482 - dice_coef: 0.2482-79 slices lost due to size restrictions
-43 slices lost due to size restrictions
Epoch 13/150
109/551 [====>.........................] - ETA: 647s - lo

-43 slices lost due to size restrictions
Epoch 19/150
 95/551 [====>.........................] - ETA: 628s - loss: -0.2798 - dice_coef: 0.2798-73 slices lost due to size restrictions
-43 slices lost due to size restrictions
-38 slices lost due to size restrictions
Epoch 20/150
 69/551 [==>...........................] - ETA: 665s - loss: -0.2393 - dice_coef: 0.2393-33 slices lost due to size restrictions
 70/551 [==>...........................] - ETA: 664s - loss: -0.2359 - dice_coef: 0.2359-79 slices lost due to size restrictions
-38 slices lost due to size restrictions
-43 slices lost due to size restrictions
Epoch 21/150
 78/551 [===>..........................] - ETA: 657s - loss: -0.3537 - dice_coef: 0.3537-33 slices lost due to size restrictions
-43 slices lost due to size restrictions
Epoch 22/150
 75/551 [===>..........................] - ETA: 657s - loss: -0.3662 - dice_coef: 0.3662-15 slices lost due to size restrictions
-43 slices lost due to size restrictions
Epoch 23/150
 34

Epoch 27/150
116/551 [=====>........................] - ETA: 600s - loss: -0.4653 - dice_coef: 0.4653-15 slices lost due to size restrictions
119/551 [=====>........................] - ETA: 596s - loss: -0.4536 - dice_coef: 0.4536-33 slices lost due to size restrictions
-43 slices lost due to size restrictions
Epoch 28/150
 16/551 [..............................] - ETA: 738s - loss: -0.2854 - dice_coef: 0.2854-73 slices lost due to size restrictions
-43 slices lost due to size restrictions
Epoch 29/150
-43 slices lost due to size restrictions
-38 slices lost due to size restrictions
Epoch 30/150
102/551 [====>.........................] - ETA: 619s - loss: -0.3990 - dice_coef: 0.3990-79 slices lost due to size restrictions
-38 slices lost due to size restrictions
-43 slices lost due to size restrictions
Epoch 31/150
-43 slices lost due to size restrictions
Epoch 32/150
-43 slices lost due to size restrictions
Epoch 33/150
107/551 [====>.........................] - ETA: 612s - loss: -0.4

-38 slices lost due to size restrictions
-43 slices lost due to size restrictions
Epoch 36/150
-43 slices lost due to size restrictions
Epoch 37/150
 90/551 [===>..........................] - ETA: 635s - loss: -0.3182 - dice_coef: 0.3182-15 slices lost due to size restrictions
 95/551 [====>.........................] - ETA: 628s - loss: -0.3336 - dice_coef: 0.3336-79 slices lost due to size restrictions
-43 slices lost due to size restrictions
Epoch 38/150
 71/551 [==>...........................] - ETA: 661s - loss: -0.6239 - dice_coef: 0.6239-79 slices lost due to size restrictions
-43 slices lost due to size restrictions
Epoch 39/150
-43 slices lost due to size restrictions
-38 slices lost due to size restrictions
Epoch 40/150
-33 slices lost due to size restrictions
124/551 [=====>........................] - ETA: 588s - loss: -0.4743 - dice_coef: 0.4743-73 slices lost due to size restrictions
-38 slices lost due to size restrictions
-43 slices lost due to size restrictions
Epoch 41/

-43 slices lost due to size restrictions
-38 slices lost due to size restrictions
Epoch 45/150
-38 slices lost due to size restrictions
-43 slices lost due to size restrictions
Epoch 46/150
 77/551 [===>..........................] - ETA: 653s - loss: -0.6445 - dice_coef: 0.6445-79 slices lost due to size restrictions
-43 slices lost due to size restrictions
Epoch 47/150
 40/551 [=>............................] - ETA: 704s - loss: -0.6224 - dice_coef: 0.6224-79 slices lost due to size restrictions
128/551 [=====>........................] - ETA: 583s - loss: -0.6809 - dice_coef: 0.6809-73 slices lost due to size restrictions
-43 slices lost due to size restrictions
Epoch 00046: early stopping
