In [2]:
import os
# still use cpu
# os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
# use gpu, start id from 0
# os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1'

import tensorflow as tf
import tensorflow.keras.backend as K
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
cpus = tf.config.experimental.list_physical_devices(device_type='CPU')
print("gpus: " + str(gpus))
print("cpus: " + str(cpus))

gpus: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU')]
cpus: [PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')]


In [3]:
import random
import numpy as np 
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import accuracy_score
import shutil

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import load_model
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

In [4]:
import glob
train_list = glob.glob("/input/TrainSeg/Yes/" + "*_img.npy")
val_list = glob.glob("./ValSeg/Yes/" + "*_img.npy")

In [129]:
import scipy
from scipy import ndimage
from PIL import ImageEnhance

class DataAugmenter(object):
    
    def __init__(self, featurewise_center=False, samplewise_center=False,
                 featurewise_std_normalization=False, samplewise_std_normalization=False,
                 zca_whitening=False, zca_epsilon=1e-6,
                 rotation_range=0, width_shift_range=0.,
                 height_shift_range=0., brightness_range=None,
                 shear_range=0., zoom_range=0.,
                 channel_shift_range=0., fill_mode='nearest',
                 cval=0., horizontal_flip=False,
                 vertical_flip=False, rescale=None,
                 preprocessing_function=None, data_format='channels_last',
                 validation_split=0.0, interpolation_order=1,
                 dtype='float32'):
        
        self.featurewise_center = featurewise_center
        self.samplewise_center = samplewise_center
        self.featurewise_std_normalization = featurewise_std_normalization
        self.samplewise_std_normalization = samplewise_std_normalization
        self.zca_whitening = zca_whitening
        self.zca_epsilon = zca_epsilon
        self.rotation_range = rotation_range
        self.width_shift_range = width_shift_range
        self.height_shift_range = height_shift_range
        self.shear_range = shear_range
        self.brightness_range = brightness_range
        self.zoom_range = zoom_range
        self.channel_shift_range = channel_shift_range
        self.fill_mode = fill_mode
        self.cval = cval
        self.horizontal_flip = horizontal_flip
        self.vertical_flip = vertical_flip
        self.rescale = rescale
        self.preprocessing_function = preprocessing_function
        self.dtype = dtype
        self.interpolation_order = interpolation_order
        
        if isinstance(zoom_range, (float, int)):
            self.zoom_range = [1 - zoom_range, 1 + zoom_range]
        elif (len(zoom_range) == 2 and
              all(isinstance(val, (float, int)) for val in zoom_range)):
            self.zoom_range = [zoom_range[0], zoom_range[1]]
        else:
            raise ValueError('`zoom_range` should be a float or '
                             'a tuple or list of two floats. '
                             'Received: %s' % (zoom_range,))
    
    def random_transform(self, x, seed=None):
        params = self.get_random_transform(x.shape, seed)
        return self.apply_transform(x, params)
    
    def get_random_transform(self, img_shape, seed=None):
        img_row_axis = 0
        img_col_axis = 1

        if seed is not None:
            np.random.seed(seed)

        if self.rotation_range:
            theta = np.random.uniform(
                -self.rotation_range,
                self.rotation_range)
        else:
            theta = 0

        if self.height_shift_range:
            try:  # 1-D array-like or int
                tx = np.random.choice(self.height_shift_range)
                tx *= np.random.choice([-1, 1])
            except ValueError:  # floating point
                tx = np.random.uniform(-self.height_shift_range,
                                       self.height_shift_range)
            if np.max(self.height_shift_range) < 1:
                tx *= img_shape[img_row_axis]
        else:
            tx = 0

        if self.width_shift_range:
            try:  # 1-D array-like or int
                ty = np.random.choice(self.width_shift_range)
                ty *= np.random.choice([-1, 1])
            except ValueError:  # floating point
                ty = np.random.uniform(-self.width_shift_range,
                                       self.width_shift_range)
            if np.max(self.width_shift_range) < 1:
                ty *= img_shape[img_col_axis]
        else:
            ty = 0

        if self.shear_range:
            shear = np.random.uniform(
                -self.shear_range,
                self.shear_range)
        else:
            shear = 0

        if self.zoom_range[0] == 1 and self.zoom_range[1] == 1:
            zx, zy = 1, 1
        else:
            zx, zy = np.random.uniform(
                self.zoom_range[0],
                self.zoom_range[1],
                2)

        flip_horizontal = (np.random.random() < 0.5) * self.horizontal_flip
        flip_vertical = (np.random.random() < 0.5) * self.vertical_flip

        channel_shift_intensity = None
        if self.channel_shift_range != 0:
            channel_shift_intensity = np.random.uniform(-self.channel_shift_range,
                                                        self.channel_shift_range)

        brightness = None
        if self.brightness_range is not None:
            brightness = np.random.uniform(self.brightness_range[0],
                                           self.brightness_range[1])

        transform_parameters = {'theta': theta,
                                'tx': tx,
                                'ty': ty,
                                'shear': shear,
                                'zx': zx,
                                'zy': zy,
                                'flip_horizontal': flip_horizontal,
                                'flip_vertical': flip_vertical,
                                'channel_shift_intensity': channel_shift_intensity,
                                'brightness': brightness}

        return transform_parameters
    
    def apply_affine_transform(self, x, theta=0, tx=0, ty=0, shear=0, zx=1, zy=1,
                           row_axis=0, col_axis=1, channel_axis=2,
                           fill_mode='nearest', cval=0., order=1):
        if scipy is None:
            raise ImportError('Image transformations require SciPy. '
                              'Install SciPy.')
        transform_matrix = None
        if theta != 0:
            theta = np.deg2rad(theta)
            rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0],
                                        [np.sin(theta), np.cos(theta), 0],
                                        [0, 0, 1]])
            transform_matrix = rotation_matrix

        if tx != 0 or ty != 0:
            shift_matrix = np.array([[1, 0, tx],
                                     [0, 1, ty],
                                     [0, 0, 1]])
            if transform_matrix is None:
                transform_matrix = shift_matrix
            else:
                transform_matrix = np.dot(transform_matrix, shift_matrix)

        if shear != 0:
            shear = np.deg2rad(shear)
            shear_matrix = np.array([[1, -np.sin(shear), 0],
                                     [0, np.cos(shear), 0],
                                     [0, 0, 1]])
            if transform_matrix is None:
                transform_matrix = shear_matrix
            else:
                transform_matrix = np.dot(transform_matrix, shear_matrix)

        if zx != 1 or zy != 1:
            zoom_matrix = np.array([[zx, 0, 0],
                                    [0, zy, 0],
                                    [0, 0, 1]])
            if transform_matrix is None:
                transform_matrix = zoom_matrix
            else:
                transform_matrix = np.dot(transform_matrix, zoom_matrix)

        if transform_matrix is not None:
            h, w = x.shape[row_axis], x.shape[col_axis]
            transform_matrix = self.transform_matrix_offset_center(
                transform_matrix, h, w)
            x = np.rollaxis(x, channel_axis, 0)
            final_affine_matrix = transform_matrix[:2, :2]
            final_offset = transform_matrix[:2, 2]

            channel_images = [ndimage.interpolation.affine_transform(
                x_channel,
                final_affine_matrix,
                final_offset,
                order=order,
                mode=fill_mode,
                cval=cval) for x_channel in x]
            x = np.stack(channel_images, axis=0)
            x = np.rollaxis(x, 0, channel_axis + 1)
        return x
    
    def random_brightness(self, x, brightness_range):
        if len(brightness_range) != 2:
            raise ValueError(
                '`brightness_range should be tuple or list of two floats. '
                'Received: %s' % (brightness_range,))

        u = np.random.uniform(brightness_range[0], brightness_range[1])
        return apply_brightness_shift(x, u)
    
    def apply_brightness_shift(self, x, brightness):
        if ImageEnhance is None:
            raise ImportError('Using brightness shifts requires PIL. '
                              'Install PIL or Pillow.')
        x = array_to_img(x)
        x = imgenhancer_Brightness = ImageEnhance.Brightness(x)
        x = imgenhancer_Brightness.enhance(brightness)
        x = img_to_array(x)
        return x
    
    def transform_matrix_offset_center(self, matrix, x, y):
        o_x = float(x) / 2 + 0.5
        o_y = float(y) / 2 + 0.5
        offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]])
        reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]])
        transform_matrix = np.dot(np.dot(offset_matrix, matrix), reset_matrix)
        return transform_matrix

    def apply_transform(self, x, transform_parameters):
        # x is a single image, so it doesn't have image number at index 0
        img_row_axis = 0
        img_col_axis = 1
        img_channel_axis = 2

        x = self.apply_affine_transform(x, transform_parameters.get('theta', 0),
                                   transform_parameters.get('tx', 0),
                                   transform_parameters.get('ty', 0),
                                   transform_parameters.get('shear', 0),
                                   transform_parameters.get('zx', 1),
                                   transform_parameters.get('zy', 1),
                                   row_axis=img_row_axis,
                                   col_axis=img_col_axis,
                                   channel_axis=img_channel_axis,
                                   fill_mode=self.fill_mode,
                                   cval=self.cval,
                                   order=self.interpolation_order)

        if transform_parameters.get('channel_shift_intensity') is not None:
            x = apply_channel_shift(x,
                                    transform_parameters['channel_shift_intensity'],
                                    img_channel_axis)

        if transform_parameters.get('flip_horizontal', False):
            x = flip_axis(x, img_col_axis)

        if transform_parameters.get('flip_vertical', False):
            x = flip_axis(x, img_row_axis)

        if transform_parameters.get('brightness') is not None:
            x = self.apply_brightness_shift(x, transform_parameters['brightness'])

        return x

In [130]:
import time

class DataGenerator(tf.keras.utils.Sequence):
    def __init__(self, list_IDs, batch_size=4, dim=(240,240), n_channels=3,
                 n_classes=2, shuffle=True, augment=True):
        self.dim = dim
        self.batch_size = batch_size
        self.list_IDs = list_IDs
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.augment = augment
        self.on_epoch_end()
        
        if augment:
            self.data_augmenter = DataAugmenter(
                rotation_range=20,
                brightness_range=[-0.05, 0.05],
                dtype='float32'
            )
            self.label_augmenter = DataAugmenter(
                rotation_range=20,
                brightness_range=[-0.05, 0.05],
                dtype='float32'
            )
        
    def __len__(self):
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Find list of IDs
        list_IDs_temp = [self.list_IDs[k] for k in indexes]

        # Generate data
        X, y = self.__data_generation(list_IDs_temp)

        return X, y

    def on_epoch_end(self):
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __data_generation(self, list_IDs_temp):
        # X : (n_samples, *dim, n_channels)
        # Initialization
        X = np.empty((self.batch_size, *self.dim, self.n_channels))
        y = np.empty((self.batch_size, *self.dim))

        # Generate data
        for i, ID in enumerate(list_IDs_temp):
            # Store sample
            # Add data augmentation here
            if self.augment:
                seed = int(time.time())
                X[i, ] = self.data_augmenter.random_transform(x=np.load(ID), seed=seed)
            else:
                X[i,] = np.load(ID)

            # Store segmentation map
            if self.augment:
                img_2_dim = np.expand_dims(np.load(ID[:-8] +'_seg.npy'), -1)
                y[i, ] = np.squeeze(self.label_augmenter.random_transform(x=img_2_dim, seed=seed))
            else:
                y[i] = np.load(ID[:-8] +'_seg.npy')

        return X, y

In [131]:
train_generator = DataGenerator(train_list)
validation_generator = DataGenerator(val_list)
IMG_SIZE = (240,240)
RANDOM_SEED = 100

In [84]:
def dice_score(y_true, y_pred, smooth=1):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

In [85]:
from keras_unet_collection import models

strategy = tf.distribute.MirroredStrategy()
print("Number of devices: {}".format(strategy.num_replicas_in_sync))

with strategy.scope():
#     model = models.unet_plus_2d((240, 240, 3), [64, 128, 256, 512, 1024], n_labels=1,
#                                 stack_num_down=2, stack_num_up=2,
#                                 activation='ReLU', output_activation='Sigmoid', 
#                                 batch_norm=True, pool='max', unpool=False, deep_supervision=True, name='unet2plus')

    model = models.unet_2d((240, 240, 3), [64, 128, 256, 512, 1024], n_labels=1,
                      stack_num_down=2, stack_num_up=1,
                      backbone="VGG19", weights='pretrain/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5',
                      activation='ReLU', output_activation='Sigmoid',
                      batch_norm=True, pool='max', unpool=False, name='unet')
    
#     model = models.unet_3plus_2d((240, 240, 3), n_labels=1, filter_num_down=[64, 128, 256, 512, 1024],
#                             activation='ReLU', output_activation='Sigmoid', batch_norm=True,
#                             backbone="VGG19", weights='pretrain/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5', 
#                             deep_supervision=True, name='unet3plus')
    

    model.compile(optimizer = Adam(lr = 1e-4), loss =['binary_crossentropy']*5, 
                  loss_weights = [0.25,0.25,0.25,0.25,1.00], metrics = [dice_score])

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')
Number of devices: 2



Backbone VGG19 does not use batch norm, but other layers received batch_norm=True


In [86]:
checkpoint = tf.keras.callbacks.ModelCheckpoint(monitor="val_dice_score",
                                                filepath="seg_tmp/weights-{epoch:02d}-{val_dice_score:04f}.hdf5",
                                                save_best_only=True,
                                                verbose=1,
                                                mode="max",
                                                save_weights_only=False)

# earlystopping = tf.keras.callbacks.EarlyStopping(monitor="val_dice_score",
#                                                  min_delta=0.03,
#                                                  patience=1,
#                                                  verbose=1,
#                                                  mode="auto",
#                                                  restore_best_weights=False)

tensorboard = tf.keras.callbacks.TensorBoard(log_dir='/output/logs')

In [53]:
model = load_model("seg_tmp/unet3plus_vgg19_pretrain_epoch_60_9351_8016.h5", {'dice_score': dice_score})

In [None]:
history = model.fit(
    train_generator,
    epochs=30,
    validation_data=validation_generator,
    callbacks=[checkpoint, tensorboard]
)

Epoch 1/30
INFO:tensorflow:batch_all_reduce: 42 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:batch_all_reduce: 42 all-reduces with algorithm = nccl, num_packs = 1

In [62]:
model.save("seg_tmp/unet3plus_vgg19_pretrain_epoch_210_9196_.h5")

In [79]:
test_dir = 'ValSeg/'
#load your model here
dependencies = {
    'dice_score': dice_score
}


model = load_model('seg_tmp/weights-69-0.919960.hdf5', custom_objects=dependencies)

test_list = []

CLASS = 'Yes'
all_files = os.listdir(test_dir + CLASS)
files = [item for item in all_files if "img" in item]
for file_name in files:
    test_list.append(test_dir + CLASS + '/' + file_name)        
test_generator = DataGenerator(test_list[:100], batch_size=1)

predictions = []
x_test = []
y_test = []
accuracy = []
for i in range(test_generator.__len__()):
    x, y = test_generator.__getitem__(i)
    x_test.append(x)
    y_test.append(y[0])
    prediction = model.predict(x)
    prediction = prediction[0]
    prediction[prediction>0.5] = 1
    prediction[prediction<=0.5] = 0
    predictions.append(prediction[0])
    accuracy.append(dice_score(y[0], prediction[0].astype('float64')))
print('Test Score = %.4f' % np.mean(accuracy))

Test Score = 0.7996
