In [None]:
# import libs
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import cv2

from random import randint
from PIL import Image
from itertools import product
import os
import pydicom

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.utils import Sequence
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.metrics import BinaryAccuracy
from tensorflow.keras.losses import Loss
from tensorflow.keras.regularizers import l2
from keras.layers import Concatenate, Conv2D, Conv2DTranspose, MaxPool2D, Input, Activation, BatchNormalization, Dropout, concatenate
from keras.models import Model

In [None]:
# Init global variables
X_TRAIN_DIR = os.path.abspath('x_train')
Y_TRAIN_DIR = os.path.abspath('y_train')

X_VAL_DIR = os.path.abspath('x_val')
Y_VAL_DIR = os.path.abspath('y_val')

SAVED_WEIGHTS = os.path.abspath('saved_weights')

In [None]:
# Building data generators
class Data_train_generator(Sequence):
    def __init__(self, data_dir, y_true_dir, batch_size, new_img_size: tuple[int, int]) -> None:
        self.data = os.listdir(data_dir)
        self.labels = os.listdir(y_true_dir)
        self.batch_size = batch_size

        self.data_dir = data_dir
        self.y_true_dir = y_true_dir

        self.new_img_size = new_img_size

    def __len__(self):
        return int(np.ceil(len(self.data) / float(self.batch_size)))

    def __open_png_y(self, file_path):
        img = Image.open(os.path.join(self.y_true_dir, file_path))
        data = np.array(img, dtype="float32")
        data = cv2.resize(data[:4608, :1920], dsize=self.new_img_size[::-1], interpolation=cv2.INTER_AREA)
        img.close()
        return data.reshape((*self.new_img_size, 1))

    def __open_dcm_x(self, file_path):
        dcm = pydicom.dcmread(os.path.join(self.data_dir, file_path))

        d = dcm.pixel_array.astype("float32") / 255
        d = cv2.resize(d[:4608, :1920], dsize=self.new_img_size[::-1], interpolation=cv2.INTER_AREA)

        return d.reshape((*self.new_img_size, 1))

    def __getitem__(self, index):
        batch_x = np.array(list(map(self.__open_dcm_x, self.data[index * self.batch_size: (index + 1) * self.batch_size])))
        batch_y = np.array(list(map(self.__open_png_y, self.labels[index * self.batch_size: (index + 1) * self.batch_size])))

        return batch_x, batch_y

    def on_epoch_end(self):
        for i in range(len(self.data)):
            ind_from, ind_to = randint(0, len(self.data) - 1), randint(0, len(self.data) - 1)
            self.data[ind_from], self.data[ind_to] = self.data[ind_to], self.data[ind_from]
            self.labels[ind_from], self.labels[ind_to] = self.labels[ind_to], self.labels[ind_from]

class Test_train_generator(Sequence):
    def __init__(self, data_file, y_file, data_dir, y_true_dir, batch_size, new_img_size: tuple[int, int]) -> None:
        self.data = [data_file for i in range(500)]
        self.labels = [y_file for i in range(500)]
        self.batch_size = batch_size

        self.data_dir = data_dir
        self.y_true_dir = y_true_dir

        self.new_img_size = new_img_size

    def __len__(self):
        return int(np.ceil(len(self.data) / float(self.batch_size)))

    def __open_png_y(self, file_path):
        img = Image.open(os.path.join(self.y_true_dir, file_path))
        data = np.array(img, dtype="float32")
        data = cv2.resize(data[:4608, :1920], dsize=self.new_img_size[::-1], interpolation=cv2.INTER_AREA)
        img.close()
        return data.reshape((*self.new_img_size, 1))

    def __open_dcm_x(self, file_path):
        dcm = pydicom.dcmread(os.path.join(self.data_dir, file_path))

        d = dcm.pixel_array.astype("float32") / 255
        d = cv2.resize(d[:4608, :1920], dsize=self.new_img_size[::-1], interpolation=cv2.INTER_AREA)

        return d.reshape((*self.new_img_size, 1))

    def __getitem__(self, index):
        batch_x = np.array(list(map(self.__open_dcm_x, self.data[index * self.batch_size: (index + 1) * self.batch_size])))
        batch_y = np.array(list(map(self.__open_png_y, self.labels[index * self.batch_size: (index + 1) * self.batch_size])))

        return batch_x, batch_y

    def on_epoch_end(self):
        for i in range(len(self.data)):
            ind_from, ind_to = randint(0, len(self.data) - 1), randint(0, len(self.data) - 1)
            self.data[ind_from], self.data[ind_to] = self.data[ind_to], self.data[ind_from]
            self.labels[ind_from], self.labels[ind_to] = self.labels[ind_to], self.labels[ind_from]

In [None]:
dropout_rate = 0.1
def standard_unit(input_tensor, stage, nb_filter, kernel_size=3):

    act = 'elu'

    x = Conv2D(nb_filter, (kernel_size, kernel_size), activation=act, name='conv'+stage+'_1', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(input_tensor)
    x = Dropout(dropout_rate, name='dp'+stage+'_1')(x)
    x = Conv2D(nb_filter, (kernel_size, kernel_size), activation=act, name='conv'+stage+'_2', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(x)
    x = Dropout(dropout_rate, name='dp'+stage+'_2')(x)

    return x

def Nest_Net(img_rows, img_cols, color_type=1, num_class=1, deep_supervision=False):

    nb_filter = [32,64,128,256,512]
    act = 'elu'

    bn_axis = 3
    img_input = Input(shape=(img_rows, img_cols, color_type), name='main_input')

    conv1_1 = standard_unit(img_input, stage='11', nb_filter=nb_filter[0])
    pool1 = MaxPool2D((2, 2), strides=(2, 2), name='pool1')(conv1_1)

    conv2_1 = standard_unit(pool1, stage='21', nb_filter=nb_filter[1])
    pool2 = MaxPool2D((2, 2), strides=(2, 2), name='pool2')(conv2_1)

    up1_2 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up12', padding='same')(conv2_1)
    conv1_2 = concatenate([up1_2, conv1_1], name='merge12', axis=bn_axis)
    conv1_2 = standard_unit(conv1_2, stage='12', nb_filter=nb_filter[0])

    conv3_1 = standard_unit(pool2, stage='31', nb_filter=nb_filter[2])
    pool3 = MaxPool2D((2, 2), strides=(2, 2), name='pool3')(conv3_1)

    up2_2 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up22', padding='same')(conv3_1)
    conv2_2 = concatenate([up2_2, conv2_1], name='merge22', axis=bn_axis)
    conv2_2 = standard_unit(conv2_2, stage='22', nb_filter=nb_filter[1])

    up1_3 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up13', padding='same')(conv2_2)
    conv1_3 = concatenate([up1_3, conv1_1, conv1_2], name='merge13', axis=bn_axis)
    conv1_3 = standard_unit(conv1_3, stage='13', nb_filter=nb_filter[0])

    conv4_1 = standard_unit(pool3, stage='41', nb_filter=nb_filter[3])
    pool4 = MaxPool2D((2, 2), strides=(2, 2), name='pool4')(conv4_1)

    up3_2 = Conv2DTranspose(nb_filter[2], (2, 2), strides=(2, 2), name='up32', padding='same')(conv4_1)
    conv3_2 = concatenate([up3_2, conv3_1], name='merge32', axis=bn_axis)
    conv3_2 = standard_unit(conv3_2, stage='32', nb_filter=nb_filter[2])

    up2_3 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up23', padding='same')(conv3_2)
    conv2_3 = concatenate([up2_3, conv2_1, conv2_2], name='merge23', axis=bn_axis)
    conv2_3 = standard_unit(conv2_3, stage='23', nb_filter=nb_filter[1])

    up1_4 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up14', padding='same')(conv2_3)
    conv1_4 = concatenate([up1_4, conv1_1, conv1_2, conv1_3], name='merge14', axis=bn_axis)
    conv1_4 = standard_unit(conv1_4, stage='14', nb_filter=nb_filter[0])

    conv5_1 = standard_unit(pool4, stage='51', nb_filter=nb_filter[4])

    up4_2 = Conv2DTranspose(nb_filter[3], (2, 2), strides=(2, 2), name='up42', padding='same')(conv5_1)
    conv4_2 = concatenate([up4_2, conv4_1], name='merge42', axis=bn_axis)
    conv4_2 = standard_unit(conv4_2, stage='42', nb_filter=nb_filter[3])

    up3_3 = Conv2DTranspose(nb_filter[2], (2, 2), strides=(2, 2), name='up33', padding='same')(conv4_2)
    conv3_3 = concatenate([up3_3, conv3_1, conv3_2], name='merge33', axis=bn_axis)
    conv3_3 = standard_unit(conv3_3, stage='33', nb_filter=nb_filter[2])
    up2_4 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up24', padding='same')(conv3_3)
    conv2_4 = concatenate([up2_4, conv2_1, conv2_2, conv2_3], name='merge24', axis=bn_axis)
    conv2_4 = standard_unit(conv2_4, stage='24', nb_filter=nb_filter[1])

    up1_5 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up15', padding='same')(conv2_4)
    conv1_5 = concatenate([up1_5, conv1_1, conv1_2, conv1_3, conv1_4], name='merge15', axis=bn_axis)
    conv1_5 = standard_unit(conv1_5, stage='15', nb_filter=nb_filter[0])

    nestnet_output_1 = Conv2D(num_class, (1, 1), activation='sigmoid', name='output_1', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_2)
    nestnet_output_2 = Conv2D(num_class, (1, 1), activation='sigmoid', name='output_2', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_3)
    nestnet_output_3 = Conv2D(num_class, (1, 1), activation='sigmoid', name='output_3', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_4)
    nestnet_output_4 = Conv2D(num_class, (1, 1), activation='sigmoid', name='output_4', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_5)

    if deep_supervision:
        model = Model(img_input, [nestnet_output_1,nestnet_output_2,nestnet_output_3,nestnet_output_4])
    else:
        model = Model(img_input, [nestnet_output_4])
    
    return model

In [None]:
# training
def make_dice_loss(smooth=1e-6, gama=2):
    def dice_loss(y_true, y_pred):
        y_true, y_pred = tf.cast(y_true, dtype=tf.float32), tf.cast(y_pred, tf.float32)
        nominator = 2 * tf.reduce_sum(tf.multiply(y_pred, y_true)) + smooth
        denominator = tf.reduce_sum(y_pred ** gama) + tf.reduce_sum(y_true ** gama) + smooth
        return 1 - tf.divide(nominator, denominator)
    return dice_loss

data_gen = Data_train_generator(X_TRAIN_DIR, Y_TRAIN_DIR, 16, (576, 240))
val_gen = Data_train_generator(X_VAL_DIR, Y_VAL_DIR, 16, (576, 240))

model_checkpoint = ModelCheckpoint(
    save_best_only=True,
    save_weights_only=True,
    monitor='val_loss',
    mode='min',
    filepath='./saved_weights/{val_loss:.3f}-weights_unetpp.hdf5'
)

model_unet = Nest_Net(576, 240, color_type=1, num_class=1, deep_supervision=True)

model_unet.compile(optimizer='Adam', loss={
    'output_1': make_dice_loss(),
    'output_2': make_dice_loss(),
    'output_3': make_dice_loss(),
    'output_4': make_dice_loss()
}, loss_weights=[1.0, 1.0, 1.0, 1.0])

history_unet = model_unet.fit(x=data_gen, epochs=20, validation_data=val_gen, callbacks=[model_checkpoint])

In [None]:
# testing
test_img, test_true_img = Test_train_generator('046_SD.dcm', '046_SD.png', X_VAL_DIR, Y_VAL_DIR, 16, (576, 240)).__getitem__(0)
test_img, test_true_img = test_img[3].reshape((1, 576, 240, 1)), test_true_img[0]

res = model_unet.predict(test_img, batch_size=1)
res = (res[0] + res[1] + res[2] + res[3]) / 4
res = res.reshape((576, 240))

fig, ax = plt.subplots(nrows=1, ncols=4)
fig.set_figwidth(12)
ax[3].set_title('mask')
ax[1].set_title('prediction')

sns.heatmap(res, ax=ax[0])
sns.heatmap(np.where(res > 0.5, 1, 0), ax=ax[1])
sns.heatmap(test_img.reshape((576, 240)), ax=ax[2])
sns.heatmap(test_true_img.reshape((576, 240)), ax=ax[3])
plt.show()