In [1]:
from __future__ import print_function

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from keras.models import Model
from keras.layers import Input, merge, Conv2D, Convolution2D, MaxPooling2D, BatchNormalization, UpSampling2D, Conv2DTranspose , Dropout, Permute, Reshape, Activation, concatenate
from keras.optimizers import Adam, SGD
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
from keras.utils.np_utils import to_categorical
from keras import backend as K
from keras.preprocessing.image import ImageDataGenerator
from keras.utils.np_utils import to_categorical
from keras.constraints import max_norm
from keras import metrics
from keras.models import load_model

import os
import SimpleITK as sitk

Using TensorFlow backend.


In [4]:
K.set_image_data_format('channels_last') #Tensorflow ordering data
img_rows = 512
img_cols = 512
smooth = 0.01

In [6]:
def get_unet():
    
    inputs = Input((img_rows, img_cols,1))
    conv1 = Conv2D(16, (3, 3), padding="same",kernel_constraint=max_norm(4.))(inputs)
    norm1 = BatchNormalization(axis=3)(conv1)
    act1 = Activation("relu")(norm1)
    drop1 = Dropout(0.4) (act1)
    conv1 = Conv2D(16, (3, 3), padding="same",kernel_constraint=max_norm(4.))(drop1)
    norm1 = BatchNormalization(axis=3)(conv1)
    act1 = Activation("relu")(norm1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(act1)
    
    
    conv2 = Conv2D(32, (3, 3), padding="same", kernel_constraint=max_norm(4.))(pool1)
    norm2 = BatchNormalization(axis=3)(conv2)
    act2 = Activation("relu")(norm2)
    drop2 = Dropout(0.4)(act2)
    conv2 = Conv2D(32, (3, 3), padding="same", kernel_constraint=max_norm(4.))(drop2)
    norm2 = BatchNormalization(axis=3)(conv2)
    act2 = Activation("relu")(norm2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(act2)
    conv3 = Conv2D(64, (3, 3), padding="same", kernel_constraint=max_norm(4.))(pool2)
    norm3 = BatchNormalization(axis=3)(conv3)
    act3 = Activation("relu")(norm3)
    drop3 = Dropout(0.4)(act3)
    conv3 = Conv2D(64, (3, 3), padding="same", kernel_constraint=max_norm(4.))(drop3)
    norm3 = BatchNormalization(axis=3)(conv3)
    act3 = Activation("relu")(norm3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(act3)
    conv4 = Conv2D(128, (3, 3), padding="same", kernel_constraint=max_norm(4.))(pool3)
    norm4 = BatchNormalization(axis=3)(conv4)
    act4 = Activation("relu")(norm4)
    drop4 = Dropout(0.4)(act4)
    conv4 = Conv2D(128, (3, 3), padding="same", kernel_constraint=max_norm(4.))(drop4)
    norm4 = BatchNormalization(axis=3)(conv4)
    act4 = Activation("relu")(norm4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(act4)
    conv5 = Conv2D(256, (3, 3), padding="same", kernel_constraint=max_norm(4.))(pool4)
    norm5 = BatchNormalization(axis=3)(conv5)
    act5 = Activation("relu")(norm5)
    drop5 = Dropout(0.4)(act5)
    conv5 = Conv2D(256, (3, 3), padding="same", kernel_constraint=max_norm(4.))(drop5)
    norm5 = BatchNormalization(axis=3)(conv5)
    act5 = Activation("relu")(norm5)
    
    up6 = concatenate([Conv2DTranspose(128,(2, 2), strides =(2,2),padding="same")(act5), act4], axis=3)
    conv6 = Conv2D(128, (3, 3), padding="same", kernel_constraint=max_norm(4.))(up6)
    norm6 = BatchNormalization(axis=3)(conv6)
    act6 = Activation("relu")(norm6)
    drop6 = Dropout(0.4)(act6)
    conv6 = Conv2D(128, (3, 3), padding="same", kernel_constraint=max_norm(4.))(drop6)
    norm6 = BatchNormalization(axis=3)(conv6)
    act6 = Activation("relu")(norm6)
    up7 = concatenate([Conv2DTranspose(64,(2, 2), strides =(2,2),padding="same")(act6), act3], axis=3)
    conv7 = Conv2D(64, (3, 3), padding="same", kernel_constraint=max_norm(4.))(up7)
    norm7 = BatchNormalization(axis=3)(conv7)
    act7 = Activation("relu")(norm7)
    drop7 = Dropout(0.4)(act7)
    conv7 = Conv2D(64, (3, 3), padding="same", kernel_constraint=max_norm(4.))(drop7)
    norm7 = BatchNormalization(axis=3)(conv7)
    act7 = Activation("relu")(norm7)
    up8 = concatenate([Conv2DTranspose(32,(2, 2), strides =(2,2),padding="same")(act7), act2], axis=3)
    conv8 = Conv2D(32, (3, 3), padding="same", kernel_constraint=max_norm(4.))(up8)
    norm8 = BatchNormalization(axis=3)(conv8)
    act8 = Activation("relu")(norm8)
    drop8 = Dropout(0.4)(act8)
    conv8 = Conv2D(32, (3, 3), padding="same", kernel_constraint=max_norm(4.))(drop8)
    norm8 = BatchNormalization(axis=3)(conv8)
    act8 = Activation("relu")(norm8)
    up9 = concatenate([Conv2DTranspose(16,(2, 2), strides =(2,2),padding="same")(act8), act1], axis=3)
    conv9 = Conv2D(16, (3, 3), padding="same", kernel_constraint=max_norm(4.))(up9)
    norm9 = BatchNormalization(axis=3)(conv9)
    act9 = Activation("relu")(norm9)
    drop9 = Dropout(0.4)(act9)
    conv9 = Conv2D(16, (3, 3), padding="same",kernel_constraint=max_norm(4.))(drop9)
    norm9 = BatchNormalization(axis=3)(conv9)
    act9 = Activation("relu")(norm9)
    conv10 = Conv2D(7, (1, 1), activation='linear')(act9)
    
    flat = Reshape((img_rows*img_cols,7))(conv10)
    soft = Activation("softmax")(flat)
    model = Model(inputs=[inputs], outputs=[soft])
    model.compile(optimizer=Adam(lr=1e-4 ,decay=1e-7), loss=loss_function, metrics=[dice_coef])
    #model.compile(optimizer=SGD(lr=1e-1, decay=1e-3, momentum=0.90, nesterov=True), loss=loss_function, metrics=[dice_coef])
    
    return model

In [7]:
def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

In [8]:
def dice_coef_loss(y_true, y_pred):
    return -dice_coef(y_true, y_pred)

In [9]:
def loss_function(y_true,y_pred):
    y_true_f = K.reshape(y_true,(4*512*512,7))
    y_pred_f = K.reshape(y_pred,(4*512*512,7))
    
    class_back = K.constant(1,shape=[8*512*512,1])
    class_bone = K.constant(200,shape=[8*512*512,6])
    class_weight = K.concatenate((class_back,class_bone))
    
    
    weight_map = y_true_f*class_weight
    weight_map = K.sum(weight_map, axis=1)
    
    loss_map = K.categorical_crossentropy(y_pred_f, y_true_f,from_logits=False) #change if softmax is present or not in the net
    weighted_loss = loss_map*weight_map
    
    loss=K.mean(weighted_loss)
    
    return loss

In [10]:
def loss_function(y_true,y_pred):
    y_true_f = K.reshape(y_true,(8*512*512,7))
    y_pred_f = K.reshape(y_pred,(8*512*512,7))
    
    class_back = K.constant(0.98,shape=[8*512*512,1])
    class_bone = K.constant(102.,shape=[8*512*512,1])
    class_weight = K.concatenate((class_back,class_bone))
    
    class_bone = K.constant(54.,shape=[8*512*512,1])
    class_w = K.concatenate((class_weight,class_bone))
    
    class_bone = K.constant(403.,shape=[8*512*512,1])
    class_w = K.concatenate((class_w,class_bone))
    
    class_bone = K.constant(134.,shape=[8*512*512,1])
    class_w = K.concatenate((class_w,class_bone))
    
    class_bone = K.constant(100.,shape=[8*512*512,1])
    class_w = K.concatenate((class_w,class_bone))
    
    class_bone = K.constant(54.,shape=[8*512*512,1])
    class_weight = K.concatenate((class_w,class_bone))
    
    weight_map = y_true_f*class_weight
    weight_map = K.sum(weight_map, axis=1)
    
    
    loss_map = K.categorical_crossentropy(y_pred_f, y_true_f,from_logits=False) #change if softmax is present or not in the net
    weighted_loss = loss_map*weight_map
    
    loss=K.mean(weighted_loss)
    
    return loss

In [11]:
def dice_coef(y_true, y_pred):
    
    y_true = K.reshape(y_true,(8*512*512,7))
    y_pred = K.reshape(y_pred,(8*512*512,7))
    
    y_true_back = y_true[:,0]
    y_pred_back= y_pred[:,0]
    intersection_back= K.sum(y_true_back * y_pred_back)
    back = (2. * intersection_back + smooth) / (K.sum(y_true_back) + K.sum(y_pred_back) + smooth)
    
    
    y_true_spine = y_true[:,1]
    y_pred_spine= y_pred[:,1]
    #y_pred_spine = K.greater(y_pred_spine, 0.99)
    #y_pred_spine = K.cast(y_pred_spine,"float32")
    intersection_spine= K.sum(y_true_spine * y_pred_spine)
    spine = (2. * intersection_spine + smooth) / (K.sum(y_true_spine) + K.sum(y_pred_spine) + smooth)
    
    y_true_hips = y_true[:,2]
    y_pred_hips= y_pred[:,2]
    intersection_hips= K.sum(y_true_hips * y_pred_hips)
    hips = (2. * intersection_hips + smooth) / (K.sum(y_true_hips) + K.sum(y_pred_hips) + smooth)
    
    y_true_sternum = y_true[:,3]
    y_pred_sternum= y_pred[:,3]
    intersection_sternum= K.sum(y_true_sternum * y_pred_sternum)
    sternum = (2. * intersection_sternum + smooth) / (K.sum(y_true_sternum)+ K.sum(y_pred_sternum) + smooth)
    
    y_true_ribs = y_true[:,4]
    y_pred_ribs= y_pred[:,4]
    intersection_ribs= K.sum(y_true_ribs * y_pred_ribs)
    ribs = (2. * intersection_ribs + smooth) / (K.sum(y_true_ribs) + K.sum(y_pred_ribs) + smooth)
    
    y_true_sacrum = y_true[:,5]
    y_pred_sacrum= y_pred[:,5]
    intersection_sacrum= K.sum(y_true_sacrum * y_pred_sacrum)
    sacrum = (2. * intersection_sacrum + smooth) / (K.sum(y_true_sacrum) + K.sum(y_pred_sacrum) + smooth)
    
    y_true_femur = y_true[:,6]
    y_pred_femur= y_pred[:,6]
    intersection_femur= K.sum(y_true_femur * y_pred_femur)
    femur = (2. * intersection_femur + smooth) / (K.sum(y_true_femur) + K.sum(y_pred_femur) + smooth)
    
    acc_spine = (intersection_spine / K.sum(y_true_spine))
    acc_sternum = (intersection_sternum / K.sum(y_true_sternum))
    acc_back = (intersection_back / K.sum(y_true_back))
    acc_ribs = (intersection_ribs / K.sum(y_true_ribs))
    
    
    return (back+hips+sacrum+spine+ribs+sternum+femur)/ 7.


In [12]:
def dice_coef_back(y_true, y_pred):
    
    y_true = K.reshape(y_true,(8*512*512,7))
    y_pred = K.reshape(y_pred,(8*512*512,7))
    
    y_true_back = y_true[:,0]
    y_pred_back= y_pred[:,0]
    intersection_back= K.sum(y_true_back * y_pred_back)
    back = (2. * intersection_back + smooth) / (K.sum(y_true_back) + K.sum(y_pred_back) + smooth)    
    return back

In [13]:
def dice_coef_spine(y_true, y_pred):
    
    y_true = K.reshape(y_true,(8*512*512,7))
    y_pred = K.reshape(y_pred,(8*512*512,7))    
    y_true_spine = y_true[:,1]
    y_pred_spine= y_pred[:,1]
    intersection_spine= K.sum(y_true_spine * y_pred_spine)
    spine = (2. * intersection_spine + smooth) / (K.sum(y_true_spine) + K.sum(y_pred_spine) + smooth)
    return spine

In [14]:
def dice_coef_hips(y_true, y_pred):
    
    y_true = K.reshape(y_true,(8*512*512,7))
    y_pred = K.reshape(y_pred,(8*512*512,7))    
    y_true_hips = y_true[:,2]
    y_pred_hips= y_pred[:,2]
    intersection_hips= K.sum(y_true_hips * y_pred_hips)
    hips = (2. * intersection_hips + smooth) / (K.sum(y_true_hips) + K.sum(y_pred_hips) + smooth)
    return hips

In [15]:
def dice_coef_sternum(y_true, y_pred):
    
    y_true = K.reshape(y_true,(8*512*512,7))
    y_pred = K.reshape(y_pred,(8*512*512,7))    
    y_true_sternum = y_true[:,3]
    y_pred_sternum= y_pred[:,3]
    intersection_sternum= K.sum(y_true_sternum * y_pred_sternum)
    sternum = (2. * intersection_sternum + smooth) / (K.sum(y_true_sternum)+ K.sum(y_pred_sternum) + smooth)
    return sternum

In [16]:
def dice_coef_ribs(y_true, y_pred):
    
    y_true = K.reshape(y_true,(8*512*512,7))
    y_pred = K.reshape(y_pred,(8*512*512,7))    
    y_true_ribs = y_true[:,4]
    y_pred_ribs= y_pred[:,4]
    intersection_ribs= K.sum(y_true_ribs * y_pred_ribs)
    ribs = (2. * intersection_ribs + smooth) / (K.sum(y_true_ribs) + K.sum(y_pred_ribs) + smooth)
    return ribs

In [17]:
def dice_coef_sacrum(y_true, y_pred):
    
    y_true = K.reshape(y_true,(8*512*512,7))
    y_pred = K.reshape(y_pred,(8*512*512,7))    
    y_true_sacrum = y_true[:,5]
    y_pred_sacrum= y_pred[:,5]
    intersection_sacrum= K.sum(y_true_sacrum * y_pred_sacrum)
    sacrum = (2. * intersection_sacrum + smooth) / (K.sum(y_true_sacrum) + K.sum(y_pred_sacrum) + smooth)
    return sacrum

In [18]:
def dice_coef_femur(y_true, y_pred):
    
    y_true = K.reshape(y_true,(8*512*512,7))
    y_pred = K.reshape(y_pred,(8*512*512,7))    
    y_true_femur = y_true[:,6]
    y_pred_femur= y_pred[:,6]
    intersection_femur= K.sum(y_true_femur * y_pred_femur)
    femur = (2. * intersection_femur + smooth) / (K.sum(y_true_femur) + K.sum(y_pred_femur) + smooth)
    return femur

In [19]:
def dice_coef(y_true, y_pred):
    
    y_true_r = K.reshape(y_true,(8*512*512,7))
    y_pred_r = K.reshape(y_pred,(8*512*512,7))
    
    y_true_bones = y_true_r[:,1:7]
    y_pred_bones= y_pred_r[:,1:7]
    
    y_true_f_bones = K.flatten(y_true_bones)
    y_pred_f_bones = K.flatten(y_pred_bones)
    intersection_bones= K.sum(y_true_f_bones * y_pred_f_bones)
    bones = (2. * intersection_bones + smooth) / (K.sum(y_true_f_bones) + K.sum(y_pred_f_bones) + smooth)
    
    return(bones)

In [20]:
def dice_coef_loss(y_true, y_pred):
    return -dice_coef(y_true, y_pred)

In [21]:
from keras.callbacks import Callback
class AdamLearningRateTracker(Callback):
    def on_epoch_end(self, epoch, logs={}):
        optimizer = self.model.optimizer
        lr = K.eval(optimizer.lr * (1. / (1. + optimizer.decay * optimizer.iterations)))
        print('\nLR: {:.9f}\n'.format(lr))
        
class LossHistory(Callback):
    def on_train_begin(self, logs={}):
        self.losses = []
        self.e_losses = []
        self.val_losses = []

    def on_epoch_end(self,batch,logs={}):
        self.losses.append(logs.get('loss'))
        numpy_loss_history = np.array(self.losses)
        plt.figure(figsize=(8, 4))
        plt.plot(numpy_loss_history, label="loss")
        plt.ylabel('loss')
        plt.legend()
        plt.show()
        
        self.val_losses.append(logs.get('val_loss'))
        numpy_val_loss_history = np.array(self.val_losses)
        self.e_losses.append(logs.get("loss"))
        e_loss = np.array(self.e_losses)
        
        plt.figure(figsize=(8, 4))
        plt.plot(e_loss, label="loss")
        plt.plot(numpy_val_loss_history, label="val_loss")
        plt.ylabel('loss')
        plt.legend()
        plt.show()
       
import keras        
tbCallBack = keras.callbacks.TensorBoard(log_dir='Data/tensorboard', histogram_freq=1, write_graph=False, write_images=False)    

In [22]:
def generator2(image_path, mask_path, batch_size):
    sample=0
 
    n_total= 15648
    b=np.arange(0,15653) 
    a = np.random.choice(b,n_total, replace=False) 
   
    while True:
        image = np.zeros((batch_size, 512, 512,1))
        mask = np.zeros((batch_size, 512, 512,1))
        mask_c = np.zeros((batch_size, 512*512,7))
        
        for i in range(0, batch_size):
            
            if sample == n_total:
                   sample=0
                   a = np.random.choice(b,n_total, replace=False)
        
            index=a[sample]
            image_name="O_" +str(index)+ ".mhd" #check name!
            mask_name="M_" +str(index)+ ".mhd" #check name!
     
            img= sitk.ReadImage(os.path.join(image_path, image_name),sitk.sitkInt16)
            img = sitk.GetArrayFromImage(img)
            mas= sitk.ReadImage(os.path.join(mask_path, mask_name),sitk.sitkInt16)
            mas = sitk.GetArrayFromImage(mas)

            t = np.expand_dims(img, axis=3)
            m = np.expand_dims(mas, axis=3)

            
            image[i]=t
            mask[i]=m
            sample+=1
            
            m=np.reshape(mask[i],(512*512))
            mask_c[i] = to_categorical(m,7)
            
            image[i] -= np.mean(image[i])
            image[i] /= np.std(image[i])

        yield (image, mask_c)
   

Real-Time Data Augmentation

In [23]:
def generator(image_path, mask_path, batch_size):
    sample=0
 
    n_total= 15648
    b=np.arange(0,15653) 
    a = np.random.choice(b,n_total, replace=False) 
   
    while True:
        image = np.zeros((batch_size, 512, 512,1))
        mask = np.zeros((batch_size, 512, 512,1))
        mask_c = np.zeros((batch_size, 512*512,7))
        
        for i in range(0, batch_size):
            
            if sample == n_total:
                   sample=0
                   a = np.random.choice(b,n_total, replace=False)
        
            index=a[sample]
            image_name="O_" +str(index)+ ".mhd" #check name!
            mask_name="M_" +str(index)+ ".mhd" #check name!
     
            img= sitk.ReadImage(os.path.join(image_path, image_name),sitk.sitkInt16)
            img = sitk.GetArrayFromImage(img)
            mas= sitk.ReadImage(os.path.join(mask_path, mask_name),sitk.sitkInt16)
            mas = sitk.GetArrayFromImage(mas)
     
            #mas = to_categorical(mas,2)
            t = np.expand_dims(img, axis=3)
            m = np.expand_dims(mas, axis=3)
            augmented=random_shift(t,m, 0.05, 0.05)
            augmented=random_rotation(augmented[0],augmented[1],5)
            augmented=random_zoom(augmented[0],augmented[1],(0.95,0.90))
            
            if np.random.random() < 0.5:
                f = augmented
                f =np.reshape(f,(2,512,512))
                flipped=flip_axis(f[0],f[1])
                flipped= np.expand_dims(flipped,axis=3)
                augmented=flipped
             
            image[i]=augmented[0]
            mask[i]=augmented[1]
            sample+=1
            
            m=np.reshape(mask[i],(512*512))
            mask_c[i] = to_categorical(m,7)
            
            image[i] -= np.mean(image[i])
            image[i] /= np.std(image[i])
            
            
            # VISUALIZE AUGMENTED IMAGES!
            #y=np.reshape(image[i],(512,512))
            #yy=np.reshape(mask[i],(512,512))
            
            #fig = plt.figure(figsize=(16, 12))
            #z=fig.add_subplot(1,4,1)
            #imgplot1 = plt.imshow(img, cmap="gray")
            #zz=fig.add_subplot(1,4,2)
            #imgplot2 = plt.imshow(mas, cmap="gray")
            #zzz=fig.add_subplot(1,4,3)
            #imgplot1 = plt.imshow(y, cmap="gray")
            #zzzz=fig.add_subplot(1,4,4)
            #imgplot2 = plt.imshow(yy, cmap="gray")
            
            
       
        yield (image, mask_c)
   

In [24]:
def generator_val(batch_size):
    sample=0
    path="Train/2DIm/Validation/"
    n_total=424 
    b=np.arange(0,425) 
    a = np.random.choice(b,n_total, replace=False) 
   
    while True:
        
        image = np.zeros((batch_size, 512, 512,1))
        mask = np.zeros((batch_size, 512*512,7))
        
        for i in range(0, batch_size):
            if sample == n_total:
                   sample=0
                   a = np.random.choice(b,n_total, replace=False)
        
          
            index=a[sample]
     
            image_name="O/O_" +str(index)+ ".mhd" #check name!
            mask_name="M/M_" +str(index)+ ".mhd" #check name!
     
            img= sitk.ReadImage(os.path.join(path, image_name),sitk.sitkInt16)
            img = sitk.GetArrayFromImage(img)
            mas= sitk.ReadImage(os.path.join(path, mask_name),sitk.sitkInt16)
            mas = sitk.GetArrayFromImage(mas)
     
            mas=np.reshape(mas,(512*512))
            mask[i] = to_categorical(mas,7)
            image[i] = np.expand_dims(img, axis=3)
            
            image[i] -= np.mean(image[i])
            image[i] /= np.std(image[i])
            sample+=1
        
        yield (image, mask)

In [25]:
from keras.callbacks import History 
from keras.callbacks import Callback
import matplotlib.pyplot as plt

In [28]:
def train_and_predict():
    
    steps_per_epoch=1956 # = number of sample / batch size
    val_steps=53
    
    print('-'*30)
    print('Loading and preprocessing train data...')
    print('-'*30)
    print('Creating and compiling model...')
    print('-'*30)
    model = get_unet()
    #model.load_weights('Models/Last/unet.h5')
    
    model_checkpoint = ModelCheckpoint('Models/Last/unet1.h5', monitor='loss', save_weights_only=False, save_best_only=False)
    history = History()
    history_loss = LossHistory()
    print('-'*30)
    print('Fitting model...')
    print('-'*30)

    model.fit_generator(generator2("../Orig2D","../FinalMasks2D",8), steps_per_epoch=steps_per_epoch, epochs=100, max_q_size=4, verbose=1, initial_epoch=0, callbacks=[model_checkpoint,history, history_loss, AdamLearningRateTracker()])

    #TQDMNotebookCallback()
        
    plt.figure(figsize=(8, 4))
    plt.plot(history.history["loss"], label="loss")
    #plt.plot(history.history["val_loss"], label="val_loss")
    plt.ylabel('error')
    plt.xlabel('epochs')
    plt.title('training error')
    plt.legend()
    plt.show()


In [0]:
train_and_predict()

In [0]:
# Testing the trained model

model = get_unet()
model.load_weights('../Models/unet.hdf5')

In [0]:
#Load scan converted to a numpy array

t= np.load('../Data/E1.npy')
t = np.expand_dims(t, axis=3)

In [0]:
image_t = model.predict(t, batch_size=6, verbose=1)
print(image_t.shape)

In [0]:
#Select class image to save

image_t=np.reshape(image_t,(512,512,7))
spine=image_t[:,:,2] #select class
np.save('../Test/spine.npy', spine)