In [1]:
from utils.imports import *

Using TensorFlow backend.


## 训练模型

In [2]:
src = PATH['model_train']
model_paths = PATH['model_paths']

In [3]:
def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + 1) / (K.sum(y_true_f) + K.sum(y_pred_f) + 1)
def dice_coef_loss(y_true, y_pred):
    return -dice_coef(y_true, y_pred)

def dice_coef_np(y_true,y_pred):
    y_true_f = y_true.flatten()
    y_pred_f = y_pred.flatten()
    intersection = np.sum(y_true_f * y_pred_f)
    return ((2. * intersection + 1) / (np.sum(y_true_f) + np.sum(y_pred_f) + 1))

def unet_model(dropout_rate,learn_rate, width):
    inputs = Input((1, 512, 512))
    conv1 = Convolution2D(width, (3, 3), padding="same", activation="relu")(inputs)
    conv1 = BatchNormalization(axis = 1)(conv1)
    conv1 = Convolution2D(width, (3, 3), padding="same", activation="relu")(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Convolution2D(width*2, (3, 3), padding="same", activation="relu")(pool1)
    conv2 = BatchNormalization(axis = 1)(conv2)
    conv2 = Convolution2D(width*2, (3, 3), padding="same", activation="relu")(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Convolution2D(width*4, (3, 3), padding="same", activation="relu")(pool2)
    conv3 = BatchNormalization(axis = 1)(conv3)
    conv3 = Convolution2D(width*4, (3, 3), padding="same", activation="relu")(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Convolution2D(width*8, (3, 3), padding="same", activation="relu")(pool3)
    conv4 = BatchNormalization(axis = 1)(conv4)
    conv4 = Convolution2D(width*8, (3, 3), padding="same", activation="relu")(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = Convolution2D(width*16, (3, 3), padding="same", activation="relu")(pool4)
    conv5 = BatchNormalization(axis = 1)(conv5)
    conv5 = Convolution2D(width*16, (3, 3), padding="same", activation="relu")(conv5)

    up6 = merge([UpSampling2D(size=(2, 2))(conv5), conv4], mode='concat', concat_axis=1)
    conv6 = SpatialDropout2D(dropout_rate)(up6)
    conv6 = Convolution2D(width*8, (3, 3), padding="same", activation="relu")(conv6)
    conv6 = Convolution2D(width*8, (3, 3), padding="same", activation="relu")(conv6)

    up7 = merge([UpSampling2D(size=(2, 2))(conv6), conv3], mode='concat', concat_axis=1)
    conv7 = SpatialDropout2D(dropout_rate)(up7)
    conv7 = Convolution2D(width*4, (3, 3), padding="same", activation="relu")(conv7)
    conv7 = Convolution2D(width*4, (3, 3), padding="same", activation="relu")(conv7)

    up8 = merge([UpSampling2D(size=(2, 2))(conv7), conv2], mode='concat', concat_axis=1)
    conv8 = SpatialDropout2D(dropout_rate)(up8)
    conv8 = Convolution2D(width*2, (3, 3), padding="same", activation="relu")(conv8)
    conv8 = Convolution2D(width*2, (3, 3), padding="same", activation="relu")(conv8)

    up9 = merge([UpSampling2D(size=(2, 2))(conv8), conv1], mode='concat', concat_axis=1)
    conv9 = SpatialDropout2D(dropout_rate)(up9)
    conv9 = Convolution2D(width, (3, 3), padding="same", activation="relu")(conv9)
    conv9 = Convolution2D(width, (3, 3), padding="same", activation="relu")(conv9)
    conv10 = Convolution2D(1, (1, 1), activation="sigmoid")(conv9)

    model = Model(input=inputs, output=conv10)
    #model.summary()
    model.compile(optimizer=Adam(lr=learn_rate), loss=dice_coef_loss, metrics=[dice_coef])
    #model.compile(optimizer=SGD(lr=0.0001, momentum=0.9, nesterov=True), loss=dice_coef_loss, metrics=[dice_coef])
    
    #plot_model(model, to_file='model1.png',show_shapes=True)
    return model


def get_unet_small():
    inputs = Input((1, 512, 512))
    conv1 = Convolution2D(32, 3, 3, activation='elu',border_mode='same')(inputs)
    conv1 = Dropout(0.2)(conv1)
    conv1 = Convolution2D(32, 3, 3, activation='elu',border_mode='same', name='conv_1')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2), name='pool_1')(conv1)
    pool1 = BatchNormalization()(pool1)

    conv2 = Convolution2D(64, 3, 3, activation='elu',border_mode='same')(pool1)
    conv2 = Dropout(0.2)(conv2)
    conv2 = Convolution2D(64, 3, 3, activation='elu',border_mode='same', name='conv_2')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2), name='pool_2')(conv2)
    pool2 = BatchNormalization()(pool2)

    conv3 = Convolution2D(128, 3, 3, activation='elu',border_mode='same')(pool2)
    conv3 = Dropout(0.2)(conv3)
    conv3 = Convolution2D(128, 3, 3, activation='elu',border_mode='same', name='conv_3')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2), name='pool_3')(conv3)
    pool3 = BatchNormalization()(pool3)

    conv4 = Convolution2D(256, 3, 3, activation='elu',border_mode='same')(pool3)
    conv4 = Dropout(0.2)(conv4)
    conv4 = Convolution2D(256, 3, 3, activation='elu',border_mode='same', name='conv_4')(conv4)
    conv4 = BatchNormalization()(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2), name='pool_4')(conv4)

    conv5 = Convolution2D(512, 3, 3, activation='elu',border_mode='same')(pool4)
    conv5 = Dropout(0.2)(conv5)
    conv5 = Convolution2D(512, 3, 3, activation='elu',border_mode='same', name='conv_5')(conv5)

    up6 = merge([UpSampling2D(size=(2, 2))(conv5), conv4], mode='concat', concat_axis=1)
    conv6 = Convolution2D(256, 3, 3, activation='elu',border_mode='same')(up6)
    conv6 = Dropout(0.2)(conv6)
    conv6 = Convolution2D(256, 3, 3, activation='elu',border_mode='same', name='conv_6')(conv6)

    up7 = merge([UpSampling2D(size=(2, 2))(conv6), conv3], mode='concat', concat_axis=1)

    conv7 = Convolution2D(128, 3, 3, activation='elu',border_mode='same')(up7)
    conv7 = Dropout(0.2)(conv7)
    conv7 = Convolution2D(128, 3, 3, activation='elu',border_mode='same', name='conv_7')(conv7)
    conv7 = BatchNormalization()(conv7)

    up8 = merge([UpSampling2D(size=(2, 2))(conv7), conv2], mode='concat', concat_axis=1)
    conv8 = Convolution2D(64, 3, 3, activation='elu',border_mode='same')(up8)
    conv8 = Dropout(0.2)(conv8)
    conv8 = Convolution2D(64, 3, 3, activation='elu',border_mode='same', name='conv_8')(conv8)
    conv8 = BatchNormalization()(conv8)

    up9 = merge([UpSampling2D(size=(2, 2))(conv8), conv1], mode='concat', concat_axis=1)
    conv9 = Convolution2D(32, 3, 3, activation='elu',border_mode='same')(up9)
    conv9 = Dropout(0.2)(conv9)
    conv9 = Convolution2D(32, 3, 3, activation='elu',border_mode='same', name='conv_9')(conv9)
    conv9 = BatchNormalization()(conv9)

    conv10 = Convolution2D(1, 1, 1, activation='sigmoid', name='sigmoid')(conv9)

    model = Model(input=inputs, output=conv10)
    model.summary()
    model.compile(optimizer=Adam(lr=1e-3, clipvalue=1., clipnorm=1.), loss=dice_coef_loss, metrics=[dice_coef])

    return model

def unet_fit(name, check_name = None):
    data_gen_args = dict(rotation_range=25.,   
                     width_shift_range=0.3,  
                     height_shift_range=0.3,   
                     horizontal_flip=True,   
                     vertical_flip=True,  
                     )
    from keras.preprocessing.image import ImageDataGenerator
    image_datagen = ImageDataGenerator(**data_gen_args)
    mask_datagen = ImageDataGenerator(**data_gen_args)

    # Provide the same seed and keyword arguments to the fit and flow methods
    seed = 1
    
    image_generator = image_datagen.flow_from_directory(
        src,
        class_mode=None,
        classes=['lung'],
        seed=seed,
        target_size=(512,512),
        color_mode="grayscale",
        batch_size=1)

    mask_generator = mask_datagen.flow_from_directory(
        src,
        class_mode=None,
        classes=['nodule'],
        seed=seed,
        target_size=(512,512),
        color_mode="grayscale",
        batch_size=1) 
    # combine generators into one which yields image and masks
    train_generator = itertools.izip(image_generator, mask_generator)   
    t = time.time()
    callbacks = [EarlyStopping(monitor='val_loss', patience = 20, 
                                   verbose = 1),
    ModelCheckpoint(model_paths + '{}.h5'.format(name), 
                        monitor='val_loss', 
                        verbose = 0, save_best_only = True)]
    
    if check_name is not None:
        check_model = model_paths + '{}.h5'.format(check_name)
        model = load_model(check_model, 
                           custom_objects={'dice_coef_loss': dice_coef_loss, 'dice_coef': dice_coef})
    else:
        #model = unet_model(dropout_rate = 0.2, learn_rate = 1e-5, width = 32)
        model = get_unet_small() 
    model.fit_generator(
        train_generator,
        epochs=100,
        verbose =1, 
        callbacks = callbacks,
        steps_per_epoch=20,
        validation_data = train_generator,
        nb_val_samples = 8)
    return

In [4]:
unet_fit('final_fenge_170622_test')
#unet_fit('final_fenge_170622_test','final_fenge')

Found 3711 images belonging to 1 classes.
Found 3711 images belonging to 1 classes.
____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
input_1 (InputLayer)             (None, 1, 512, 512)   0                                            
____________________________________________________________________________________________________
conv2d_1 (Conv2D)                (None, 32, 512, 512)  320         input_1[0][0]                    
____________________________________________________________________________________________________
dropout_1 (Dropout)              (None, 32, 512, 512)  0           conv2d_1[0][0]                   
____________________________________________________________________________________________________
conv_1 (Conv2D)                  (None, 32, 512, 512)  9248        dropout_1[0][0]                  
_______

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100

KeyboardInterrupt: 