In [1]:
import os
import sys
import matplotlib.pyplot as plt

import cv2
# from skimage import morphology
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import *
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler
from tensorflow.keras.applications.vgg19 import VGG19

MODEL_DIR = os.getcwd()
ROOT_DIR = os.path.dirname(MODEL_DIR)
DATA_DIR = os.path.join(ROOT_DIR, "data")
sys.path.append(MODEL_DIR)
from block import res_path, res_block, decoder_block

In [2]:
def TransResUNet(input_size=(512, 512, 3)):
    """
    TransResUNet -- main architecture of TransResUNet
    
    Arguments:
    input_size {tuple} -- size of input image
    
    Returns:
    model {<class 'tensorflow.python.keras.engine.training.Model'>} -- final model
    """
    
    # Input 
    inputs = Input(input_size)
    inp = inputs
    input_shape = input_size
    
    # Handling input channels 
    # input with 1 channel will be converted to 3 channels to be compatible with VGG16 pretrained encoder 
    if input_size[-1] < 3:
        inp = Conv2D(3, 1)(inputs)                         
        input_shape = (input_size[0], input_size[0], 3)  
    else:
        inp = inputs
        input_shape = input_size

    # VGG16 with imagenet weights
    encoder = VGG19(include_top=False, weights='imagenet', input_shape=input_shape)
       
    # First encoder block
    enc1 = encoder.get_layer(name='block1_conv1')(inp)
    enc1 = encoder.get_layer(name='block1_conv2')(enc1)
    # Second encoder block
    enc2 = MaxPooling2D(pool_size=(2, 2))(enc1)
    enc2 = encoder.get_layer(name='block2_conv1')(enc2)
    enc2 = encoder.get_layer(name='block2_conv2')(enc2)
    # Third encoder block
    enc3 = MaxPooling2D(pool_size=(2, 2))(enc2)
    enc3 = encoder.get_layer(name='block3_conv1')(enc3)
    enc3 = encoder.get_layer(name='block3_conv2')(enc3)
    enc3 = encoder.get_layer(name='block3_conv3')(enc3)

    # Center block
    center = MaxPooling2D(pool_size=(2, 2))(enc3)
    center = decoder_block(center, 512, 256)

    # Decoder block corresponding to third encoder
    res_path3 = res_path(enc3,128,3)
    dec3 = concatenate([res_path3, center], axis=3)
    dec3 = decoder_block(dec3, 256, 64)
    # Decoder block corresponding to second encoder
    res_path2 = res_path(enc2,64,2)
    dec2 = concatenate([res_path2, dec3], axis=3)
    dec2 = decoder_block(dec2, 128, 64)
    # Final Block concatenation with first encoded feature 
    res_path1 = res_path(enc1,32,1)
    dec1 = concatenate([res_path1, dec2], axis=3)
    dec1 = Conv2D(32, 3, padding='same', kernel_initializer='he_normal')(dec1)
    dec1 = ReLU()(dec1)
    out = Conv2D(2, (1, 1), padding='same')(dec1)
    # Final model
    model = Model(inputs=[inputs], outputs=[out])
    
    return model

In [3]:
model = TransResUNet()
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
             loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
             metrics=['accuracy'], run_eagerly=True)

In [4]:
# image dataset
TRAIN_DIR = os.path.join(DATA_DIR, "train")
TEST_DIR = os.path.join(DATA_DIR, "test")
os.makedirs(TRAIN_DIR, exist_ok=True)
os.makedirs(os.path.join(DATA_DIR, "train", "img"), exist_ok=True)
os.makedirs(os.path.join(DATA_DIR, "train", "mask"), exist_ok=True)
os.makedirs(TEST_DIR, exist_ok=True)
os.makedirs(os.path.join(DATA_DIR, "test", "img"), exist_ok=True)
os.makedirs(os.path.join(DATA_DIR, "test", "mask"), exist_ok=True)

In [5]:
import cv2
import matplotlib.pyplot as plt

MASK_DIR = os.path.join(DATA_DIR, "JSRT", "scr", "masks")
MASK_HEART_DIR = os.path.join(MASK_DIR, "heart")
MASK_HEART_RESIZE_DIR = os.path.join(MASK_HEART_DIR, "heart_resize")

os.makedirs(MASK_HEART_DIR, exist_ok=True)
heart_mask = [os.path.join(MASK_HEART_DIR, gif) for gif in os.listdir(MASK_HEART_DIR)]
heart_mask_resize = [path.replace("heart", "heart_resized") for path in heart_mask]

for src_img_path, dst_img_path in zip(heart_mask, heart_mask_resize):
    img = cv2.resize(plt.imread(src_img_path) , dsize=(512, 512), interpolation=cv2.INTER_AREA)
    plt.imsave(dst_img_path, img, cmap="gray")

In [6]:
import shutil
NODULE_PNG_DIR = os.path.join(DATA_DIR, "JSRT", "nodules", "png")
NON_NODULE_PNG_DIR = os.path.join(DATA_DIR, "JSRT", "non_nodules", "png")
nodules_pngs = [os.path.join(NODULE_PNG_DIR, png) for png in os.listdir(NODULE_PNG_DIR)]
non_nodules_pngs = [os.path.join(NON_NODULE_PNG_DIR, png) for png in os.listdir(NON_NODULE_PNG_DIR)]
entire_img = nodules_pngs + non_nodules_pngs
heart_mask_resize

from sklearn.model_selection import train_test_split
samples = [os.path.basename(path).replace(".gif", "") for path in heart_mask_resize]
train_samples, test_samples = train_test_split(samples)

for train_sample in train_samples:
    for img in entire_img:
        if train_sample not in img:
            continue

        shutil.copy(img, os.path.join(TRAIN_DIR, "img", f"{train_sample}.png"))
    for mask in heart_mask_resize:
        if train_sample not in mask:
            continue

        shutil.copy(mask, os.path.join(TRAIN_DIR, "mask", f"{train_sample}.png"))

In [7]:
for test_sample in test_samples:
    for img in entire_img:
        if test_sample not in img:
            continue

        shutil.copy(img, os.path.join(TEST_DIR, "img", f"{test_sample}.png"))
    for mask in heart_mask_resize:
        if test_sample not in mask:
            continue

        shutil.copy(mask, os.path.join(TEST_DIR, "mask", f"{test_sample}.png"))

In [8]:
TRAIN_IMG_DIR = os.path.join(TRAIN_DIR, "img")
TRAIN_MASK_DIR = os.path.join(TRAIN_DIR, "mask")
train_imgs = [os.path.join(TRAIN_IMG_DIR, png) for png in os.listdir(TRAIN_IMG_DIR)]
train_masks = [os.path.join(TRAIN_MASK_DIR, png) for png in os.listdir(TRAIN_MASK_DIR)]

TEST_IMG_DIR = os.path.join(TEST_DIR, "img")
TEST_MASK_DIR = os.path.join(TEST_DIR, "mask")
test_imgs = [os.path.join(TEST_IMG_DIR, png) for png in os.listdir(TEST_IMG_DIR)]
test_masks = [os.path.join(TEST_MASK_DIR, png) for png in os.listdir(TEST_MASK_DIR)]

In [9]:
train_img_files = tf.constant(train_imgs)
train_mask_files = tf.constant(train_masks)
dataset = tf.data.Dataset.from_tensor_slices((train_img_files, train_mask_files))

In [10]:
def process_path(image_path, mask_path):
    img = tf.io.read_file(image_path)
    img = tf.image.decode_png(img, channels=3)
    img = tf.image.convert_image_dtype(img, tf.float32)

    mask = tf.io.read_file(mask_path)
    mask = tf.image.decode_png(mask, channels=3)
    mask = tf.math.reduce_max(mask, axis=-1, keepdims=True)
    return img, mask

def preprocess(image, mask):
    input_image = tf.image.resize(image, (512, 512), method='nearest')
    input_mask = tf.image.resize(mask, (512, 512), method='nearest')

    return input_image, input_mask/255

#Train dataset
image_ds = dataset.map(process_path)
train_processed_image_ds = image_ds.map(preprocess)


In [11]:
EPOCHS = 50
VAL_SUBSPLITS = 3
BUFFER_SIZE = 500
BATCH_SIZE = 2
train_processed_image_ds.batch(BATCH_SIZE)
train_dataset = train_processed_image_ds.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
print(train_processed_image_ds.element_spec)

(TensorSpec(shape=(512, 512, 3), dtype=tf.float32, name=None), TensorSpec(shape=(512, 512, 1), dtype=tf.float32, name=None))


In [12]:
model_history = model.fit(
    train_dataset,
    epochs=EPOCHS,
    verbose=True
)

Epoch 1/50
