In [1]:
import os
import numpy as np
import cv2
from glob import glob
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import Recall, Precision

In [2]:
""" Function to read the links of the images in followed by a train-test-split into train, test 
 and validation datasets respectively."""
def load_data(path, split=0.1):
    images = sorted(glob(os.path.join(path, "CXR_png", "*.png")))
    masks_l = sorted(glob(os.path.join(path, "ManualMask", "leftMask", "*.png")))
    masks_r = sorted(glob(os.path.join(path, "ManualMask", "rightMask", "*.png")))
    split_size = int(len(images) * split) # a 90-10 split in the data
    train_x, val_x = train_test_split(images, test_size=split_size, random_state=42)
    train_y_l, val_y_l = train_test_split(masks_l, test_size=split_size, random_state=42)
    train_y_r, val_y_r = train_test_split(masks_r, test_size=split_size, random_state=42)
    train_x, test_x = train_test_split(train_x, test_size=split_size, random_state=42)
    train_y_l, test_y_l = train_test_split(train_y_l, test_size=split_size, random_state=42)
    train_y_r, test_y_r = train_test_split(train_y_r, test_size=split_size, random_state=42)

    return (train_x, train_y_l, train_y_r), (val_x, val_y_l, val_y_r), (test_x, test_y_l, test_y_r)

In [3]:
"""To read in the images"""
def imageread(path,width=512,height=512):
    x = cv2.imread(path, cv2.IMREAD_COLOR)
    x = cv2.resize(x, (width, height))
    x = x/255.0
    x = x.astype(np.float32)
    return x

In [4]:
""" To read in the masks"""
def maskread(path_l, path_r,width=512,height=512):
    x_l = cv2.imread(path_l, cv2.IMREAD_GRAYSCALE)
    x_r = cv2.imread(path_r, cv2.IMREAD_GRAYSCALE)
    x = x_l + x_r
    x = cv2.resize(x, (width, height))
    x = x/np.max(x)
    x = x > 0.5
    x = x.astype(np.float32)
    x = np.expand_dims(x, axis=-1)
    return x

In [5]:
def tf_parse(x, y_l, y_r):
    def _parse(x, y_l, y_r):
        x = x.decode()
        y_l = y_l.decode()
        y_r = y_r.decode()
        x = imageread(x)
        y = maskread(y_l, y_r)
        return x, y
    x, y = tf.numpy_function(_parse, [x, y_l, y_r], [tf.float32, tf.float32])
    x.set_shape([512, 512, 3])
    y.set_shape([512, 512, 1])
    return x, y

In [6]:
def tf_dataset(X, Y_l, Y_r, batch=8):
    dataset = tf.data.Dataset.from_tensor_slices((X, Y_l, Y_r))
    dataset = dataset.shuffle(buffer_size=200)
    dataset = dataset.map(tf_parse)
    dataset = dataset.batch(batch)
    dataset = dataset.prefetch(4)
    return dataset

In [7]:
""" Hyperparameters """
batch_size = 2
lr = 1e-5
epochs = 30
model_path = "/content/model.h5"

In [8]:
""" Dataset """
dataset_path = '/content/drive/MyDrive/NLM-MontgomeryCXRSet/MontgomerySet'
(train_x, train_y_l, train_y_r), (val_x, val_y_l, val_y_r), (test_x, test_y_l, test_y_r) = load_data(dataset_path)

In [9]:
train_dataset = tf_dataset(train_x, train_y_l, train_y_r, batch=batch_size)
val_dataset = tf_dataset(val_x, val_y_l, val_y_r, batch=batch_size)

In [10]:
def iou(y_true, y_pred):
    def f(y_true, y_pred):
        intersection = (y_true * y_pred).sum()
        union = y_true.sum() + y_pred.sum() - intersection
        x = (intersection + 1e-15) / (union + 1e-15)
        x = x.astype(np.float32)
        return x
    return tf.numpy_function(f, [y_true, y_pred], tf.float32)

smooth = 1e-15
def dice_coef(y_true, y_pred):
    y_true = tf.keras.layers.Flatten()(y_true)
    y_pred = tf.keras.layers.Flatten()(y_pred)
    intersection = tf.reduce_sum(y_true * y_pred)
    return (2. * intersection + smooth) / (tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) + smooth)

def dice_loss(y_true, y_pred):
    return 1.0 - dice_coef(y_true, y_pred)

In [11]:
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input
from tensorflow.keras.models import Model

In [12]:
def conv_block(input, num_filters):
    x = Conv2D(num_filters, 3, padding="same")(input)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    x = Conv2D(num_filters, 3, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    return x

def encoder_block(input, num_filters):
    x = conv_block(input, num_filters)
    p = MaxPool2D((2, 2))(x)
    return x, p

def decoder_block(input, skip_features, num_filters):
    x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(input)
    x = Concatenate()([x, skip_features])
    x = conv_block(x, num_filters)
    return x

def build_unet(input_shape):
    inputs = Input(input_shape)

    s1, p1 = encoder_block(inputs, 64)
    s2, p2 = encoder_block(p1, 128)
    s3, p3 = encoder_block(p2, 256)
    s4, p4 = encoder_block(p3, 512)

    b1 = conv_block(p4, 1024)

    d1 = decoder_block(b1, s4, 512)
    d2 = decoder_block(d1, s3, 256)
    d3 = decoder_block(d2, s2, 128)
    d4 = decoder_block(d3, s1, 64)

    outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d4)

    model = Model(inputs, outputs, name="U-Net")
    return model

In [13]:
model = build_unet((512, 512, 3))
metrics = [dice_coef, iou, Recall(), Precision()]
model.compile(loss=dice_loss, optimizer=Adam(lr), metrics=metrics)

In [14]:
model.summary()

Model: "U-Net"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 512, 512, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 512, 512, 64  1792        ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 batch_normalization (BatchNorm  (None, 512, 512, 64  256        ['conv2d[0][0]']                 
 alization)                     )                                                             

In [15]:
callbacks = [
        ModelCheckpoint(model_path, verbose=1, save_best_only=True),
        ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, min_lr=1e-8, verbose=1)
        ]

In [16]:
history = model.fit(
        train_dataset,
        epochs=epochs,
        validation_data=val_dataset,
        callbacks=callbacks
    )

Epoch 1/30
Epoch 1: val_loss improved from inf to 0.65390, saving model to /content/model.h5
Epoch 2/30
Epoch 2: val_loss improved from 0.65390 to 0.62336, saving model to /content/model.h5
Epoch 3/30
Epoch 3: val_loss improved from 0.62336 to 0.61061, saving model to /content/model.h5
Epoch 4/30
Epoch 4: val_loss improved from 0.61061 to 0.60634, saving model to /content/model.h5
Epoch 5/30
Epoch 5: val_loss did not improve from 0.60634
Epoch 6/30
Epoch 6: val_loss did not improve from 0.60634
Epoch 7/30
Epoch 7: val_loss improved from 0.60634 to 0.47755, saving model to /content/model.h5
Epoch 8/30
Epoch 8: val_loss improved from 0.47755 to 0.25108, saving model to /content/model.h5
Epoch 9/30
Epoch 9: val_loss did not improve from 0.25108
Epoch 10/30
Epoch 10: val_loss did not improve from 0.25108
Epoch 11/30
Epoch 11: val_loss improved from 0.25108 to 0.14831, saving model to /content/model.h5
Epoch 12/30
Epoch 12: val_loss improved from 0.14831 to 0.13534, saving model to /content

In [17]:
from tensorflow.keras.utils import CustomObjectScope
with CustomObjectScope({'iou': iou, 'dice_coef': dice_coef, 'dice_loss': dice_loss}):
  model = tf.keras.models.load_model("/content/model.h5")

In [18]:
    """ Predicting the mask """
    from tqdm import tqdm
    import matplotlib.pyplot as plt
    ct=0
    for x, y_l, y_r in tqdm(zip(test_x, test_y_l, test_y_r), total=len(test_x)):
        """ Extracing the image name. """
        image_name = x.split("/")[-1]

        """ Reading the image """
        ori_x = cv2.imread(x, cv2.IMREAD_COLOR)
        ori_x = cv2.resize(ori_x, (512, 512))
        x = ori_x/255.0
        x = x.astype(np.float32)
        x = np.expand_dims(x, axis=0)

        """ Reading the mask """
        ori_y_l = cv2.imread(y_l, cv2.IMREAD_GRAYSCALE)
        ori_y_r = cv2.imread(y_r, cv2.IMREAD_GRAYSCALE)
        ori_y = ori_y_l + ori_y_r
        ori_y = cv2.resize(ori_y, (512, 512))
        ori_y = np.expand_dims(ori_y, axis=-1)  ## (512, 512, 1)
        ori_y = np.concatenate([ori_y, ori_y, ori_y], axis=-1)  ## (512, 512, 3)

        """ Predicting the mask. """
        y_pred = model.predict(x)[0] > 0.5
        y_pred = y_pred.astype(np.int32)
        #plt.imshow(y_pred)

        """ Saving the predicted mask along with the image and GT """
        save_image_path = "/content/"+str(ct)+".png"
        ct+=1
        y_pred = np.concatenate([y_pred, y_pred, y_pred], axis=-1)
        sep_line = np.ones((512, 10, 3)) * 255
        cat_image = np.concatenate([ori_x, sep_line, ori_y, sep_line, y_pred*255], axis=1)
        cv2.imwrite(save_image_path, cat_image)

100%|██████████| 13/13 [00:28<00:00,  2.17s/it]
