In [1]:
import tensorflow as tf
import numpy as np
import os
from tensorflow.keras.preprocessing import image_dataset_from_directory
# pragma warning(disable:4996)
from tqdm import tqdm
import cv2
from PIL import Image



The following code defines some loss functions, using manually annotated masks. One is the dice coefficient, then it turns it into loss by subtracting the result from 1. Then these losses are added together into the Binary-Cross-Entropy-Dice Loss.

In [2]:
def dice_coeff(y_true, y_pred):
    smooth = 1.
    # Flatten
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    score = (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)
    return score


def dice_loss(y_true, y_pred):
    loss = 1 - dice_coeff(y_true, y_pred)
    return loss


def bce_dice_loss(y_true, y_pred):
    loss = tf.keras.losses.binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)
    return loss

In [3]:
def cv2_imread(file_path):
    cv_img = cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), -1)
    return cv_img

In [4]:
def check_gpu():
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            # Currently, memory growth needs to be the same across GPUs
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
                print(gpu)
            logical_gpus = tf.config.experimental.list_logical_devices('GPU')
            print(logical_gpus)
        except RuntimeError as e:
            # Memory growth must be set before GPUs have been initialized
            print(e)

The following is the function that trains the model. It takes images from the files specified and applies scaling, rotation, shifting and flipping. Padding is set to reflect. The training model is a version of U-net, which takes in 1024x1024 images to process them. 

In [5]:
def Train_Model(ini_data_path, model_export, IMG_WIDTH=1024, IMG_HEIGHT=1024,
                IMG_CHANNELS=3, BATCH_SIZE=8, patience = 100):
    if IMG_CHANNELS == 3:
        using_rgb = True
    else:
        using_rgb = False

    data_gen_args = dict(
        featurewise_center=False,
        featurewise_std_normalization=False,
        rotation_range=90,
        width_shift_range=0.2,
        height_shift_range=0.2,
        rescale=1. / 255,
        horizontal_flip=True,
        vertical_flip=True,
        fill_mode='reflect'
    )

    test_gen_args = dict(
        featurewise_center=False,
        featurewise_std_normalization=False,
        rescale=1. / 255,
    )

    image_datagen = tf.keras.preprocessing.image.ImageDataGenerator(**data_gen_args)
    mask_datagen = tf.keras.preprocessing.image.ImageDataGenerator(**data_gen_args)

    image_test_datagen = tf.keras.preprocessing.image.ImageDataGenerator(**test_gen_args)
    mask_test_datagen = tf.keras.preprocessing.image.ImageDataGenerator(**test_gen_args)

    seed = 1

    image_generator = image_datagen.flow_from_directory(
        ini_data_path + 'Train_set/Train_data',
        class_mode=None,
        seed=seed,
        target_size=(1024, 1024),
        color_mode=('grayscale', 'rgb')[using_rgb],
        batch_size=1,
        shuffle=True
    )

    mask_generator = mask_datagen.flow_from_directory(
        'data/Train_set/Train_masks',
        class_mode=None,
        seed=seed,
        target_size=(1024, 1024),
        color_mode='grayscale',
        batch_size=1,
        shuffle=True
    )

    image_test_generator = image_test_datagen.flow_from_directory(
        ini_data_path + 'Test_set/Test_data',
        class_mode=None,
        shuffle=False,
        target_size=(1024, 1024),
        color_mode=('grayscale', 'rgb')[using_rgb],
        batch_size=1
    )
    mask_test_generator = mask_test_datagen.flow_from_directory(
        ini_data_path + 'Test_set/Test_masks',
        class_mode=None,
        shuffle=False,
        target_size=(1024, 1024),
        color_mode='grayscale',
        batch_size=1
    )

    train_generator = zip(image_generator, mask_generator)
    test_generator = zip(image_test_generator, mask_test_generator)

    input1 = tf.keras.layers.Input((IMG_WIDTH, IMG_HEIGHT, IMG_CHANNELS))

    s1 = tf.keras.layers.Lambda(lambda x: x / 255)(input1)

    # Contraction Layer
    c1 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(s1)
    c1 = tf.keras.layers.Dropout(.1)(c1)
    c1 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c1)
    p1 = tf.keras.layers.MaxPooling2D((2, 2))(c1)

    c2 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p1)
    c2 = tf.keras.layers.Dropout(.1)(c2)
    c2 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c2)
    p2 = tf.keras.layers.MaxPooling2D((2, 2))(c2)

    c3 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p2)
    c3 = tf.keras.layers.Dropout(.2)(c3)
    c3 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c3)
    p3 = tf.keras.layers.MaxPooling2D((2, 2))(c3)

    c4 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p3)
    c4 = tf.keras.layers.Dropout(.2)(c4)
    c4 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c4)
    p4 = tf.keras.layers.MaxPooling2D((2, 2))(c4)

    c5 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p4)
    c5 = tf.keras.layers.Dropout(.3)(c5)
    c5 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c5)

    # Expansion layer
    u6 = tf.keras.layers.Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c5)
    u6 = tf.keras.layers.concatenate([u6, c4])
    c6 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u6)
    c6 = tf.keras.layers.Dropout(.2)(c6)
    c6 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c6)

    u7 = tf.keras.layers.Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c6)
    u7 = tf.keras.layers.concatenate([u7, c3], axis=3)
    c7 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u7)
    c7 = tf.keras.layers.Dropout(.2)(c7)
    c7 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c7)

    u8 = tf.keras.layers.Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(c7)
    u8 = tf.keras.layers.concatenate([u8, c2], axis=3)
    c8 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u8)
    c8 = tf.keras.layers.Dropout(.1)(c8)
    c8 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c8)

    u9 = tf.keras.layers.Conv2DTranspose(16, (2, 2), strides=(2, 2), padding='same')(c8)
    u9 = tf.keras.layers.concatenate([u9, c1], axis=3)
    c9 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u9)
    c9 = tf.keras.layers.Dropout(.1)(c9)
    c9 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c9)

    outputs = tf.keras.layers.Conv2D(1, (1, 1), activation='sigmoid')(c9)

    model = tf.keras.Model(inputs=[input1], outputs=[outputs])
    model.compile(optimizer='adam', loss=[bce_dice_loss], metrics=[dice_loss])
    model.summary()

    ####################################################################################################################

    checkpointer = tf.keras.callbacks.ModelCheckpoint('latest_model.h5', verbose=1, save_best_only=True)
    callbacks = [tf.keras.callbacks.EarlyStopping(patience=patience),
                 tf.keras.callbacks.TensorBoard(log_dir='logs', histogram_freq=1)]

    results = model.fit(train_generator, validation_data=test_generator, validation_steps=1, steps_per_epoch=65//BATCH_SIZE,
                        epochs=1500, callbacks=callbacks, batch_size=BATCH_SIZE, validation_batch_size=1)
    model.save(model_export + '.h5', include_optimizer=False)
    print('Done! Model can be found in ' + model_export)
    return True

In [6]:
def Use_Model(model_path, data_path, img_strs, export_path='data/Nuclei_masks/', Zlevel=1):
    model = tf.keras.models.load_model(model_path + '.h5', compile=False)
    test_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
        rescale=1. / 255
    )
    file_list = os.listdir(data_path + 'Validate_set/Input')
    validation_generator = test_datagen.flow_from_directory(data_path + 'Validate_set/',
                                                            target_size=(1024, 1024),
                                                            batch_size=4,
                                                            shuffle=False,
                                                            color_mode='rgb')

    output = model.predict(validation_generator)
    for i, pic in enumerate(output):
        cv2.imwrite(export_path + str(Zlevel) + '/' + '{}_{}.png'.format(str(Zlevel), img_strs[i]), pic * 255)
    return True

In [7]:
if __name__ == '__main__':
    # train_strs = ['1_2_3', '2_1_3', '2_2_3', '3_2_3', '3_1_3', '1_4_3', '2_4_3', '3_4_3', '4_4_3', '4_3_3', '4_2_3',
    #               '3_3_3']
    # val_strs = ['1_3_3', '2_3_3']
    ini_data_path = 'C:/Users/night_3ns60sk/OneDrive/Documenten/TU_algemeen/GPU_BEP_PRACTICE/data/'
    Zlevel = 1
    Train_Model(ini_data_path, 'Models/latest_model', IMG_CHANNELS=3, BATCH_SIZE=4, patience=150)
    #
#     img_strs = data_augments.gen_input_from_img_coords(ini_data_path, (1, 1, 4, 4), Z=Zlevel, use_predicted_data=False, only_EM=False)
    #
#     Use_Model('Models/latest_model', ini_data_path, img_strs, Zlevel=Zlevel)

    # particle_analysis.ShowResults('data/Nuclei_masks/' + str(Zlevel) + '/', ini_data_path, img_strs, Zlevel=Zlevel,
    #                               upscaleTo=0, threshold_masks=True)

#     for img in img_strs:
#         mask_img = cv2_imread(ini_data_path + 'Nuclei_masks/' + str(Zlevel) + '/{}_'.format(Zlevel) + img + '.png') / 255
#         EM_img = cv2_imread(ini_data_path + 'EM/' + str(Zlevel) + '/' + img + '.png')
#         masked_img = np.dstack((EM_img, EM_img, EM_img*(1-mask_img)))
#         # cv2.imshow('{}_{}'.format(Zlevel, img), masked_img/255)
#         # cv2.waitKey()
#         # cv2.destroyAllWindows()
#         cv2.imwrite(ini_data_path + 'Nuclei_masks/' + str(Zlevel) + '/mask_multiply/' + img + '.png', masked_img)

Found 65 images belonging to 1 classes.
Found 65 images belonging to 1 classes.
Found 2 images belonging to 1 classes.
Found 2 images belonging to 1 classes.
Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 1024, 1024,  0                                            
__________________________________________________________________________________________________
lambda (Lambda)                 (None, 1024, 1024, 3 0           input_1[0][0]                    
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 1024, 1024, 1 448         lambda[0][0]                     
__________________________________________________________________________________________________
dropout (Dropout)           

Epoch 1/1500
Instructions for updating:
use `tf.profiler.experimental.stop` instead.
Epoch 2/1500
Epoch 3/1500
Epoch 4/1500
Epoch 5/1500
Epoch 6/1500
Epoch 7/1500
Epoch 8/1500
Epoch 9/1500
Epoch 10/1500
Epoch 11/1500
Epoch 12/1500
Epoch 13/1500
Epoch 14/1500
Epoch 15/1500
Epoch 16/1500
Epoch 17/1500

KeyboardInterrupt: 