In [1]:
import os
import numpy as np
import cv2
from glob import glob
from tqdm import tqdm
import tensorflow as tf
from sklearn.model_selection import train_test_split
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger, ReduceLROnPlateau
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import Recall, Precision
from metrics import dice_loss, dice_coef, iou

In [2]:
def iou(y_true, y_pred):
    def f(y_true, y_pred):
        intersection = (y_true * y_pred).sum()
        union = y_true.sum() + y_pred.sum() - intersection
        x = (intersection + 1e-15) / (union + 1e-15)
        x = x.astype(np.float32)
        return x
    return tf.numpy_function(f, [y_true, y_pred], tf.float32)

smooth = 1e-15


def dice_coef(y_true, y_pred):
    y_true = tf.keras.layers.Flatten()(y_true)
    y_pred = tf.keras.layers.Flatten()(y_pred)
    intersection = tf.reduce_sum(y_true * y_pred)
    return (2. * intersection + smooth) / (tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) + smooth)

def dice_loss(y_true, y_pred):
    return 1.0 - dice_coef(y_true, y_pred)

In [3]:
""" Global parameters """
H = 512
W = 512

In [4]:
def create_dir(path):
    """ Create a directory. """
    if not os.path.exists(path):
        os.makedirs(path)

In [5]:
def load_data(path, split=0.1):
    images = sorted(glob(os.path.join(path, "CXR_png", "*.png")))
    masks = sorted(glob(os.path.join(path, "masks", "*.png")))


    split_size = int(len(images) * split)

    train_x, valid_x = train_test_split(images, test_size=split_size, random_state=42)
    train_y, valid_y = train_test_split(masks, test_size=split_size, random_state=42)

    train_x, test_x = train_test_split(train_x, test_size=split_size, random_state=42)
    train_y, test_y = train_test_split(train_y, test_size=split_size, random_state=42)


    return (train_x, train_y), (valid_x, valid_y), (test_x, test_y)

In [6]:
def read_image(path):
    x = cv2.imread(path, cv2.IMREAD_COLOR)
    x = cv2.resize(x, (W, H))
    x = x/255.0
    x = x.astype(np.float32)
    return x

In [7]:
def read_mask(path):
    x = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    x = cv2.resize(x, (W, H))
    x = x/np.max(x)
    x = x > 0.5
    x = x.astype(np.float32)
    x = np.expand_dims(x, axis=-1)
    return x

In [8]:
def tf_parse(x, y):
    def _parse(x, y):
        x = x.decode()
        y = y.decode()

        x = read_image(x)
        y = read_mask(y)
        return x, y

    x, y = tf.numpy_function(_parse, [x, y], [tf.float32, tf.float32])
    x.set_shape([H, W, 3])
    y.set_shape([H, W, 1])
    return x, y

In [9]:
def tf_dataset(X, Y, batch=8):
    dataset = tf.data.Dataset.from_tensor_slices((X, Y))
    dataset = dataset.shuffle(buffer_size=200)
    dataset = dataset.map(tf_parse)
    dataset = dataset.batch(batch)
    dataset = dataset.prefetch(4)
    return dataset

In [10]:
""" Seeding """
np.random.seed(42)
tf.random.set_seed(42)

In [11]:
""" Directory for storing files """
create_dir("MaskUPTfiles")

In [12]:
""" Hyperparameters """
batch_size = 2
lr = 1e-5
num_epochs = 10
model_path = os.path.join("MaskUPTfiles", "model.h5")
csv_path = os.path.join("MaskUPTfiles", "data.csv")

In [13]:
""" Dataset """
dataset_path = "LungSegmentation"
(train_x, train_y), (valid_x, valid_y), (test_x, test_y) = load_data(dataset_path)

print(f"Train: {len(train_x)} - {len(train_y)} ")
print(f"Valid: {len(valid_x)} - {len(valid_y)} ")
print(f"Test: {len(test_x)} - {len(test_y)} ")


Train: 640 - 544 
Valid: 80 - 80 
Test: 80 - 80 


In [14]:
train_x = train_x[:544]

In [15]:
train_dataset = tf_dataset(train_x, train_y, batch=batch_size)
valid_dataset = tf_dataset(valid_x, valid_y, batch=batch_size)

2022-09-22 16:29:49.936068: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-09-22 16:29:50.452511: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 30616 MB memory:  -> device: 0, name: Quadro GV100, pci bus id: 0000:65:00.0, compute capability: 7.0


## Bulid Model

In [16]:
len(train_dataset), len(valid_dataset)

(272, 40)

In [17]:

def conv_block(input, num_filters):
    x = Conv2D(num_filters, 3, padding="same")(input)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    x = Conv2D(num_filters, 3, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    return x

def encoder_block(input, num_filters):
    x = conv_block(input, num_filters)
    p = MaxPool2D((2, 2))(x)
    return x, p

def decoder_block(input, skip_features, num_filters):
    x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(input)
    x = Concatenate()([x, skip_features])
    x = conv_block(x, num_filters)
    return x

def build_unet(input_shape):
    inputs = Input(input_shape)

    s1, p1 = encoder_block(inputs, 64)
    s2, p2 = encoder_block(p1, 128)
    s3, p3 = encoder_block(p2, 256)
    s4, p4 = encoder_block(p3, 512)

    b1 = conv_block(p4, 1024)

    d1 = decoder_block(b1, s4, 512)
    d2 = decoder_block(d1, s3, 256)
    d3 = decoder_block(d2, s2, 128)
    d4 = decoder_block(d3, s1, 64)

    outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d4)

    model = Model(inputs, outputs, name="U-Net")
    return model

In [18]:
input_shape = (512, 512, 3)
model = build_unet(input_shape)
model.summary()

Model: "U-Net"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 512, 512, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 512, 512, 64  1792        ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 batch_normalization (BatchNorm  (None, 512, 512, 64  256        ['conv2d[0][0]']                 
 alization)                     )                                                             

                                                                                                  
 conv2d_8 (Conv2D)              (None, 32, 32, 1024  4719616     ['max_pooling2d_3[0][0]']        
                                )                                                                 
                                                                                                  
 batch_normalization_8 (BatchNo  (None, 32, 32, 1024  4096       ['conv2d_8[0][0]']               
 rmalization)                   )                                                                 
                                                                                                  
 activation_8 (Activation)      (None, 32, 32, 1024  0           ['batch_normalization_8[0][0]']  
                                )                                                                 
                                                                                                  
 conv2d_9 

                                                                                                  
 activation_15 (Activation)     (None, 256, 256, 12  0           ['batch_normalization_15[0][0]'] 
                                8)                                                                
                                                                                                  
 conv2d_transpose_3 (Conv2DTran  (None, 512, 512, 64  32832      ['activation_15[0][0]']          
 spose)                         )                                                                 
                                                                                                  
 concatenate_3 (Concatenate)    (None, 512, 512, 12  0           ['conv2d_transpose_3[0][0]',     
                                8)                                'activation_1[0][0]']           
                                                                                                  
 conv2d_16

In [19]:
from tensorflow.keras.utils import CustomObjectScope

In [20]:
#with CustomObjectScope({'iou': iou, 'dice_coef': dice_coef, 'dice_loss': dice_loss}):
model = build_unet((H, W, 3))
metrics = [dice_coef, iou, Recall(), Precision()]

In [21]:
mse_loss = tf.keras.losses.MeanSquaredError()

In [22]:
model.compile(loss=dice_loss, optimizer=Adam(lr), metrics=metrics)


In [23]:
callbacks = [
    ModelCheckpoint(model_path, verbose=1, save_best_only=True),
    ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, min_lr=1e-7, verbose=1),
    CSVLogger(csv_path)
]


In [24]:
model.fit(
    train_dataset,
    epochs=num_epochs,
    validation_data=valid_dataset,
    callbacks=callbacks
)

Epoch 1/10


2022-09-22 16:29:55.583624: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8204


Epoch 1: val_loss improved from inf to 0.51495, saving model to MaskUPTfiles/model.h5
Epoch 2/10
Epoch 2: val_loss improved from 0.51495 to 0.33298, saving model to MaskUPTfiles/model.h5
Epoch 3/10
Epoch 3: val_loss improved from 0.33298 to 0.30733, saving model to MaskUPTfiles/model.h5
Epoch 4/10
Epoch 4: val_loss improved from 0.30733 to 0.29369, saving model to MaskUPTfiles/model.h5
Epoch 5/10
Epoch 5: val_loss improved from 0.29369 to 0.29051, saving model to MaskUPTfiles/model.h5
Epoch 6/10
Epoch 6: val_loss improved from 0.29051 to 0.28065, saving model to MaskUPTfiles/model.h5
Epoch 7/10
Epoch 7: val_loss improved from 0.28065 to 0.27984, saving model to MaskUPTfiles/model.h5
Epoch 8/10
Epoch 8: val_loss did not improve from 0.27984
Epoch 9/10
Epoch 9: val_loss did not improve from 0.27984
Epoch 10/10
Epoch 10: val_loss did not improve from 0.27984


<keras.callbacks.History at 0x7f7d4866e520>

In [26]:
""" Predicting the mask """
for x, y in tqdm(zip(test_x, test_y), total=len(test_x)):
    """ Extracing the image name. """
    image_name = x.split("/")[-1]

    """ Reading the image """
    ori_x = cv2.imread(x, cv2.IMREAD_COLOR)
    ori_x = cv2.resize(ori_x, (W, H))
    x = ori_x/255.0
    x = x.astype(np.float32)
    x = np.expand_dims(x, axis=0)

    """ Reading the mask """
    ori_y = cv2.imread(y, cv2.IMREAD_GRAYSCALE)
    ori_y = cv2.resize(ori_y, (W, H))
    ori_y = np.expand_dims(ori_y, axis=-1)  ## (512, 512, 1)
    ori_y = np.concatenate([ori_y, ori_y, ori_y], axis=-1)  ## (512, 512, 3)

    """ Predicting the mask. """
    y_pred = model.predict(x)[0] > 0.5
    y_pred = y_pred.astype(np.int32)

    """ Saving the predicted mask along with the image and GT """
    save_image_path = f"MaskUPTfiles/results/{image_name}"
    y_pred = np.concatenate([y_pred, y_pred, y_pred], axis=-1)

    sep_line = np.ones((H, 10, 3)) * 255

    cat_image = np.concatenate([ori_x, sep_line, ori_y, sep_line, y_pred*255], axis=1)
    cv2.imwrite(save_image_path, cat_image)

  0%|                                                    | 0/80 [00:00<?, ?it/s]



  1%|▌                                           | 1/80 [00:00<00:17,  4.56it/s]



  2%|█                                           | 2/80 [00:00<00:15,  4.93it/s]



  4%|█▋                                          | 3/80 [00:00<00:13,  5.55it/s]



  5%|██▏                                         | 4/80 [00:00<00:14,  5.22it/s]



  6%|██▊                                         | 5/80 [00:01<00:18,  4.00it/s]



  8%|███▎                                        | 6/80 [00:01<00:17,  4.29it/s]



  9%|███▊                                        | 7/80 [00:01<00:17,  4.17it/s]



 10%|████▍                                       | 8/80 [00:01<00:15,  4.52it/s]



 11%|████▉                                       | 9/80 [00:01<00:15,  4.60it/s]



 12%|█████▍                                     | 10/80 [00:02<00:15,  4.41it/s]



 14%|█████▉                                     | 11/80 [00:02<00:15,  4.43it/s]



 15%|██████▍                                    | 12/80 [00:02<00:14,  4.70it/s]



 16%|██████▉                                    | 13/80 [00:02<00:14,  4.74it/s]



 18%|███████▌                                   | 14/80 [00:03<00:13,  4.98it/s]



 19%|████████                                   | 15/80 [00:03<00:12,  5.04it/s]



 20%|████████▌                                  | 16/80 [00:03<00:12,  5.14it/s]



 21%|█████████▏                                 | 17/80 [00:03<00:12,  5.08it/s]



 22%|█████████▋                                 | 18/80 [00:03<00:12,  5.08it/s]



 24%|██████████▏                                | 19/80 [00:04<00:12,  4.91it/s]



 25%|██████████▊                                | 20/80 [00:04<00:11,  5.02it/s]



 26%|███████████▎                               | 21/80 [00:04<00:13,  4.41it/s]



 28%|███████████▊                               | 22/80 [00:04<00:13,  4.33it/s]



 29%|████████████▎                              | 23/80 [00:04<00:12,  4.47it/s]



 30%|████████████▉                              | 24/80 [00:05<00:12,  4.52it/s]



 31%|█████████████▍                             | 25/80 [00:05<00:11,  4.83it/s]



 32%|█████████████▉                             | 26/80 [00:05<00:11,  4.62it/s]



 34%|██████████████▌                            | 27/80 [00:05<00:12,  4.08it/s]



 35%|███████████████                            | 28/80 [00:06<00:11,  4.42it/s]



 36%|███████████████▌                           | 29/80 [00:06<00:10,  4.74it/s]



 38%|████████████████▏                          | 30/80 [00:06<00:10,  4.64it/s]



 39%|████████████████▋                          | 31/80 [00:06<00:10,  4.65it/s]



 40%|█████████████████▏                         | 32/80 [00:06<00:09,  4.87it/s]



 41%|█████████████████▋                         | 33/80 [00:07<00:11,  3.97it/s]



 42%|██████████████████▎                        | 34/80 [00:07<00:11,  3.95it/s]



 44%|██████████████████▊                        | 35/80 [00:07<00:11,  4.00it/s]



 45%|███████████████████▎                       | 36/80 [00:07<00:10,  4.21it/s]



 46%|███████████████████▉                       | 37/80 [00:08<00:09,  4.46it/s]



 48%|████████████████████▍                      | 38/80 [00:08<00:09,  4.34it/s]



 49%|████████████████████▉                      | 39/80 [00:08<00:09,  4.50it/s]



 50%|█████████████████████▌                     | 40/80 [00:08<00:10,  3.93it/s]



 51%|██████████████████████                     | 41/80 [00:09<00:09,  4.24it/s]



 52%|██████████████████████▌                    | 42/80 [00:09<00:08,  4.40it/s]



 54%|███████████████████████                    | 43/80 [00:09<00:09,  4.07it/s]



 55%|███████████████████████▋                   | 44/80 [00:09<00:08,  4.30it/s]



 56%|████████████████████████▏                  | 45/80 [00:09<00:07,  4.43it/s]



 57%|████████████████████████▋                  | 46/80 [00:10<00:08,  3.97it/s]



 59%|█████████████████████████▎                 | 47/80 [00:10<00:07,  4.21it/s]



 60%|█████████████████████████▊                 | 48/80 [00:10<00:07,  4.43it/s]



 61%|██████████████████████████▎                | 49/80 [00:10<00:06,  4.47it/s]



 62%|██████████████████████████▉                | 50/80 [00:11<00:07,  3.98it/s]



 64%|███████████████████████████▍               | 51/80 [00:11<00:06,  4.37it/s]



 65%|███████████████████████████▉               | 52/80 [00:11<00:06,  4.52it/s]



 66%|████████████████████████████▍              | 53/80 [00:11<00:05,  4.62it/s]



 68%|█████████████████████████████              | 54/80 [00:12<00:05,  4.64it/s]



 69%|█████████████████████████████▌             | 55/80 [00:12<00:05,  4.51it/s]



 70%|██████████████████████████████             | 56/80 [00:12<00:05,  4.56it/s]



 71%|██████████████████████████████▋            | 57/80 [00:12<00:06,  3.63it/s]



 72%|███████████████████████████████▏           | 58/80 [00:13<00:05,  3.90it/s]



 74%|███████████████████████████████▋           | 59/80 [00:13<00:05,  4.14it/s]



 75%|████████████████████████████████▎          | 60/80 [00:13<00:04,  4.05it/s]



 76%|████████████████████████████████▊          | 61/80 [00:13<00:04,  4.35it/s]



 78%|█████████████████████████████████▎         | 62/80 [00:14<00:04,  4.29it/s]



 79%|█████████████████████████████████▊         | 63/80 [00:14<00:03,  4.59it/s]



 80%|██████████████████████████████████▍        | 64/80 [00:14<00:03,  4.79it/s]



 81%|██████████████████████████████████▉        | 65/80 [00:14<00:03,  4.96it/s]



 82%|███████████████████████████████████▍       | 66/80 [00:14<00:03,  4.33it/s]



 84%|████████████████████████████████████       | 67/80 [00:15<00:02,  4.34it/s]



 85%|████████████████████████████████████▌      | 68/80 [00:15<00:02,  4.24it/s]



 86%|█████████████████████████████████████      | 69/80 [00:15<00:02,  4.49it/s]



 88%|█████████████████████████████████████▋     | 70/80 [00:15<00:02,  4.62it/s]



 89%|██████████████████████████████████████▏    | 71/80 [00:15<00:01,  4.74it/s]



 90%|██████████████████████████████████████▋    | 72/80 [00:16<00:01,  4.05it/s]



 91%|███████████████████████████████████████▏   | 73/80 [00:16<00:01,  4.39it/s]



 92%|███████████████████████████████████████▊   | 74/80 [00:16<00:01,  4.61it/s]



 94%|████████████████████████████████████████▎  | 75/80 [00:16<00:01,  3.85it/s]



 95%|████████████████████████████████████████▊  | 76/80 [00:17<00:00,  4.20it/s]



 96%|█████████████████████████████████████████▍ | 77/80 [00:17<00:00,  4.38it/s]



 98%|█████████████████████████████████████████▉ | 78/80 [00:17<00:00,  4.50it/s]



 99%|██████████████████████████████████████████▍| 79/80 [00:17<00:00,  4.02it/s]



100%|███████████████████████████████████████████| 80/80 [00:18<00:00,  4.42it/s]
