In [1]:
import os
import sys
import  time
import random
from itertools import chain
import numpy as np
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt
plt.style.use("ggplot")

import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Input
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import binary_crossentropy
from tensorflow.keras.models import Model
from tensorflow.keras.layers import BatchNormalization, Activation, Dropout, Dense
from tensorflow.keras.layers import Conv2D, Conv2DTranspose
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import concatenate
from tensorflow.keras.metrics import Recall, Precision

from skimage.io import imread
from skimage.transform import resize
from sklearn.model_selection import train_test_split
from skimage.io import imread, imshow

import h5py

In [2]:
im_width = 128
im_height = 128
n_epochs = 10
batch_size = 32
num_class = 2

def conv2d_block(input_tensor, n_filters, kernel_size = 3, batchnorm = True):
    """Function to add 2 convolutional layers with the parameters passed to it"""
    # first layer
    x = Conv2D(filters = n_filters, kernel_size = (kernel_size, kernel_size),\
              kernel_initializer = 'he_normal', padding = 'same')(input_tensor)
    if batchnorm:
        x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    # second layer
    x = Conv2D(filters = n_filters, kernel_size = (kernel_size, kernel_size),\
              kernel_initializer = 'he_normal', padding = 'same')(x)
    if batchnorm:
        x = BatchNormalization()(x)
    x = Activation('relu')(x)
    return x


def get_unet(input_img, n_filters = 64, dropout = 0.2, batchnorm = True):
    """Function to define the UNET Model"""
    # Contracting Path
    c1 = conv2d_block(input_img, n_filters * 1, kernel_size = 3, batchnorm = batchnorm)
    p1 = MaxPooling2D((2, 2))(c1)
    p1 = Dropout(dropout)(p1)
    
    c2 = conv2d_block(p1, n_filters * 2, kernel_size = 3, batchnorm = batchnorm)
    p2 = MaxPooling2D((2, 2))(c2)
    p2 = Dropout(dropout)(p2)
    
    c3 = conv2d_block(p2, n_filters * 4, kernel_size = 3, batchnorm = batchnorm)
    p3 = MaxPooling2D((2, 2))(c3)
    p3 = Dropout(dropout)(p3)
    
    c4 = conv2d_block(p3, n_filters * 8, kernel_size = 3, batchnorm = batchnorm)
    p4 = MaxPooling2D((2, 2))(c4)
    p4 = Dropout(dropout)(p4)
    
    c5 = conv2d_block(p4, n_filters = n_filters * 16, kernel_size = 3, batchnorm = batchnorm)
    
    # Expansive Path
    u6 = Conv2DTranspose(n_filters * 8, (3, 3), strides = (2, 2), padding = 'same')(c5)
    u6 = concatenate([u6, c4])
    u6 = Dropout(dropout)(u6)
    c6 = conv2d_block(u6, n_filters * 8, kernel_size = 3, batchnorm = batchnorm)
    
    u7 = Conv2DTranspose(n_filters * 4, (3, 3), strides = (2, 2), padding = 'same')(c6)
    u7 = concatenate([u7, c3])
    u7 = Dropout(dropout)(u7)
    c7 = conv2d_block(u7, n_filters * 4, kernel_size = 3, batchnorm = batchnorm)
    
    u8 = Conv2DTranspose(n_filters * 2, (3, 3), strides = (2, 2), padding = 'same')(c7)
    u8 = concatenate([u8, c2])
    u8 = Dropout(dropout)(u8)
    c8 = conv2d_block(u8, n_filters * 2, kernel_size = 3, batchnorm = batchnorm)
    
    u9 = Conv2DTranspose(n_filters * 1, (3, 3), strides = (2, 2), padding = 'same')(c8)
    u9 = concatenate([u9, c1])
    u9 = Dropout(dropout)(u9)
    c9 = conv2d_block(u9, n_filters * 1, kernel_size = 3, batchnorm = batchnorm)
    
    outputs = Conv2D(num_class, (1, 1), activation='sigmoid')(c9)
    model = Model(inputs=[input_img], outputs=[outputs])
    return model


def TV_bin_loss(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    bin_loss = binary_crossentropy(y_true_f, y_pred_f)
    images = y_pred[: : ,: ,1]
    value = tf.reduce_mean(tf.abs(images[:,1:,:] - images[:,:-1,:])) + tf.reduce_mean(tf.abs(images[:,:,1:] - images[:,:,:-1]))
    return 2.4e-7*value + bin_loss


def dice_coef(y_pred, y_true):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + 0.0001) / (K.sum(y_true_f) + K.sum(y_pred_f) + 0.0001)


def dice_loss(y_true, y_pred):
    return 1 - dice_coef(y_true, y_pred)


def custom_loss(y_true, y_pred):
    layer_names=[layer.name for layer in model.layers]
    for l in layer_names:
        if l==layer_names[-1]:
            value = TV_bin_loss(y_true, y_pred)
        else:
            value = binary_crossentropy(K.flatten(y_true),K.flatten(y_pred))
    return value

In [3]:
combined_data = h5py.File("h5_datasets/combined_CT_datasets.h5", "r")

X_train = np.array(combined_data["X_train"])
X_valid = np.array(combined_data["X_valid"])
y_train = np.array(combined_data["y_train"])
y_valid = np.array(combined_data["y_valid"])

In [4]:
input_img = Input((im_height, im_width, 1), name='img')
model = get_unet(input_img, n_filters=64, dropout=0.2, batchnorm=True)
model.compile(optimizer=Adam(learning_rate=0.001), loss=[custom_loss], metrics=['accuracy', dice_loss, Recall(name='recall_1'), Precision(name='pre_1')])

callbacks = [
    EarlyStopping(patience=50, verbose=1),
    ReduceLROnPlateau(factor=0.1, patience=5, min_lr=0.00001, verbose=1),
    ModelCheckpoint('model-TV-UNet1.h5', verbose=1, save_best_only=True, save_weights_only=True)
]

results = model.fit(X_train, y_train, batch_size=batch_size, epochs=n_epochs, callbacks=callbacks, validation_data=(X_valid, y_valid))

Epoch 1/10
Epoch 00001: val_loss improved from inf to 7.07172, saving model to model-TV-UNet1.h5
Epoch 2/10
Epoch 00002: val_loss improved from 7.07172 to 2.19302, saving model to model-TV-UNet1.h5
Epoch 3/10
Epoch 00003: val_loss improved from 2.19302 to 0.27677, saving model to model-TV-UNet1.h5
Epoch 4/10
Epoch 00004: val_loss improved from 0.27677 to 0.23784, saving model to model-TV-UNet1.h5
Epoch 5/10
Epoch 00005: val_loss did not improve from 0.23784
Epoch 6/10
Epoch 00006: val_loss improved from 0.23784 to 0.13462, saving model to model-TV-UNet1.h5
Epoch 7/10
Epoch 00007: val_loss improved from 0.13462 to 0.11066, saving model to model-TV-UNet1.h5
Epoch 8/10
Epoch 00008: val_loss did not improve from 0.11066
Epoch 9/10
Epoch 00009: val_loss did not improve from 0.11066
Epoch 10/10
Epoch 00010: val_loss improved from 0.11066 to 0.05789, saving model to model-TV-UNet1.h5


In [5]:
## save model as ONNX

if not os.path.exists('tf_model'):
    os.makedirs('tf_model')

if not os.path.exists('onnx_trained_model'):
    os.makedirs('onnx_trained_model')

tf.saved_model.save(model, 'tf_model')

2022-04-07 10:21:21.662114: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.


INFO:tensorflow:Assets written to: tf_model/assets


In [6]:
%%bash

python -m tf2onnx.convert --saved-model tf_model --output onnx_trained_model/ct_seg_model.onnx --opset 14

2022-04-07 10:21:45,306 - INFO - Signatures found in model: [serving_default].
2022-04-07 10:21:45,307 - INFO - Output names: ['conv2d_18']
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
2022-04-07 10:21:51,788 - INFO - Using tensorflow=2.7.0, onnx=1.11.0, tf2onnx=1.9.3/1190aa
2022-04-07 10:21:51,788 - INFO - Using opset <onnx, 14>
2022-04-07 10:21:59,408 - INFO - Computed 0 values for constant folding
2022-04-07 10:22:05,274 - INFO - Optimizing ONNX model
2022-04-07 10:22:11,834 - INFO - After optimization: BatchNormalization -18 (18->0), Cast -4 (4->0), Concat -4 (8->4), Const -131 (178->47), Identity -13 (13->0), Reshape +1 (0->1), Shape -4 (4->0), Slice -4 (4->0), Squeeze -4 (4->0), Transpose -89 (90->1), Unsqueeze -16 (16->0)
2022-04-07 10:22:12,456 - INFO - 
2022-04-07 10:22:12,456 - INFO - Successfully converted TensorFlow model tf_model to ONNX
2022-04-07 10:22:12,456 - INFO -

In [7]:
# Release GPU memory
gpu = tf.config.list_physical_devices('GPU')
if len(gpu) > 0:
    from numba import cuda
    cuda.select_device(0)
    cuda.close()