In [None]:
import re
import cv2

import nrrd
import random
import os, glob
import numpy as np
from collections import Counter
import matplotlib.pyplot as plt
from patchify import patchify, unpatchify

import keras
import tensorflow as tf
from keras import backend as K
from keras.models import Model
from tensorflow.keras import layers
from keras.layers import Input, Conv3D, MaxPooling3D, UpSampling3D, concatenate, Conv3DTranspose, BatchNormalization, Dropout, Lambda
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint, CSVLogger
from keras.layers import Activation, MaxPool2D, Concatenate

from ipywidgets import interact, interactive, IntSlider, ToggleButtons

In [None]:
print(tf.keras.__version__)
print(tf.__version__)
physical_devices = tf.config.list_physical_devices('GPU')
print(physical_devices[3])
# tf.config.experimental.set_memory_growth(physical_devices[3], True)
os.environ["CUDA_VISIBLE_DEVICES"]="3"
# tf.config.experimental.set_memory_growth(physical_devices[1], True)


In [None]:
data_dir = '/home/tester/jianhoong/jh_fyp_work/ct_scans_data/raw_data/'

z_train = os.path.join(data_dir, 'training_data_z')
z_train_image = os.path.join(z_train, 'training_images/training_images')
z_train_mask = os.path.join(z_train, 'training_masks/training_masks')

z_valid = os.path.join(data_dir, 'valid_data_z')
z_valid_image = os.path.join(z_valid, 'valid_images/valid_images')
z_valid_mask = os.path.join(z_valid, 'valid_masks/valid_masks')

z_test = os.path.join(data_dir, 'testing_data_z')
z_test_image = os.path.join(z_test, 'testing_images/testing_images')
z_test_mask = os.path.join(z_test, 'testing_masks/testing_masks')

In [None]:
# Defining Helper Functions
def read_nrrd_file(filepath):
    '''read and load volume'''
    pixelData, header = nrrd.read(filepath)
    return pixelData[:,:,:96]

def normalize(volume):
    min = -1000 # min value of our data : -1000
    max = 5000 # max value of our data : 5013
    range = max - min
    volume[volume < min] = min
    volume[volume > max] = max
    volume = (volume - min) / range
    volume = volume.astype("float32")
    return volume

def process_scan(path):
    volume = read_nrrd_file(path)
    volume = normalize(volume)
    return volume

def sorted_alnum(l):
    convert = lambda text: int(text) if text.isdigit() else text 
    alphanum_key = lambda key : [convert(c) for c in re.split('([0-9]+)', key)]
    return sorted(l, key = alphanum_key)

def training_generator():
    train_dir = '/home/tester/jianhoong/jh_fyp_work/ct_scans_data/raw_data/training_data_z/training_images/training_images'
    train_mask_dir = '/home/tester/jianhoong/jh_fyp_work/ct_scans_data/raw_data/training_data_z/training_masks/training_masks'

    train_data_with_96 = [15, 16, 20, 21, 23, 33, 37, 52, 64, 69, 72, 76, 77, 83, 92, 94, 95, 97, 103, 116, 123, 129, 157, 159, 162, 163, 170, 171, 180, 191, 192, 198, 200, 207, 209]
    
    train_path = sorted_alnum([os.path.join(train_dir, file) for file in os.listdir(train_dir)  if int(re.findall(r'\d+', file)[0]) in train_data_with_96])
    train_mask_path = sorted_alnum([os.path.join(train_mask_dir, file) for file in os.listdir(train_mask_dir)  if int(re.findall(r'\d+', file)[0]) in train_data_with_96])
    
    for i in range(len(train_path)):
        x = process_scan(train_path[i])
        y = read_nrrd_file(train_mask_path[i])
        
        yield x,y
        
def validation_generator(valid_path, valid_mask):
    
    valid_dir = '/home/tester/jianhoong/jh_fyp_work/ct_scans_data/raw_data/valid_data_z/valid_images/valid_images'
    valid_mask_dir = '/home/tester/jianhoong/jh_fyp_work/ct_scans_data/raw_data/valid_data_z/valid_masks/valid_masks'

    valid_data_with_96 = [213, 219, 221, 227, 231, 238, 240, 243, 249, 250, 255, 256, 257, 270]

    valid_path = sorted_alnum([os.path.join(valid_dir, file) for file in os.listdir(valid_dir)  if int(re.findall(r'\d+', file)[0]) in valid_data_with_96])
    valid_mask_path = sorted_alnum([os.path.join(valid_mask_dir, file) for file in os.listdir(valid_mask_dir)  if int(re.findall(r'\d+', file)[0]) in valid_data_with_96])
    
    for i in range(len(valid_path)):
        x = process_scan(valid_path[i])
        y = read_nrrd_file(valid_mask_path[i])
        yield x, y
        

In [None]:
train_loader = tf.data.Dataset.from_generator(training_generator,(tf.float32, tf.float32))
validation_loader = tf.data.Dataset.from_generator(validation_generator,(tf.float32, tf.float32))

ds = train_loader.batch(10)

train_dataset = train_loader.shuffle(4)
train_dataset = train_dataset.apply(tf.data.experimental.ignore_errors())
train_dataset = train_dataset.batch(2, drop_remainder=True).prefetch(8)

valid_dataset = validation_loader.shuffle(4)
valid_dataset = valid_dataset.apply(tf.data.experimental.ignore_errors())
valid_dataset = valid_dataset.batch(2, drop_remainder=True).prefetch(8)

In [None]:
def conv_block(input, num_filters):
    x = Conv3D(num_filters, 3, padding="same")(input)
    x = BatchNormalization()(x)   #Not in the original network. 
    x = Activation("relu")(x)

    x = Conv3D(num_filters, 3, padding="same")(x)
    x = BatchNormalization()(x)  #Not in the original network
    x = Activation("relu")(x)

    return x

#Encoder block: Conv block followed by maxpooling

def encoder_block(input, num_filters):
    x = conv_block(input, num_filters)
    p = MaxPooling3D((2, 2, 2))(x)
    return x, p   

#Decoder block
#skip features gets input from encoder for concatenation

def decoder_block(input, skip_features, num_filters):
    x = Conv3DTranspose(num_filters, (2, 2, 2), strides=2, padding="same")(input)
    x = Concatenate()([x, skip_features])
    x = conv_block(x, num_filters)
    return x

#Build Unet using the blocks
def build_unet(input_shape, n_classes):
    inputs = Input(input_shape)

    s1, p1 = encoder_block(inputs, 64)
    s2, p2 = encoder_block(p1, 128)
    s3, p3 = encoder_block(p2, 256)
    s4, p4 = encoder_block(p3, 512)

    b1 = conv_block(p4, 1024) #Bridge

    d1 = decoder_block(b1, s4, 512)
    d2 = decoder_block(d1, s3, 256)
    d3 = decoder_block(d2, s2, 128)
    d4 = decoder_block(d3, s1, 64)

    if n_classes == 1:  #Binary
      activation = 'sigmoid'
    else:
      activation = 'softmax'

    outputs = Conv3D(n_classes, 1, padding="same", activation=activation)(d4)  #Change the activation based on n_classes
    print(activation)

    model = Model(inputs, outputs, name="U-Net")
    return model

In [None]:
# Loss Function and coefficients to be used during training:
def dice_coefficient(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)
    smoothing_factor = 1.0
    flat_y_true = K.flatten(y_true)
    flat_y_pred = K.flatten(y_pred)
    return (2. * K.sum(flat_y_true * flat_y_pred) + smoothing_factor) / (K.sum(flat_y_true) + K.sum(flat_y_pred) + smoothing_factor)

def dice_coefficient_loss(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)
    return 1.0 - dice_coefficient(y_true, y_pred)

#Define parameters for our model.
n_classes = 1
patch_size = 32
channels=3

LR = 0.001
opt = tf.keras.optimizers.Nadam(LR)


model = build_unet((64,64,32,3), n_classes = 1)
model.compile(optimizer = opt, loss=dice_coefficient_loss, metrics=dice_coefficient)
print(model.summary())

In [None]:
csv_path = '/home/tester/jianhoong/jh_fyp_work/3D_UNet/trials/3DUNet_ModelCSVLogs/UNet_Approach3_v5.csv'
model_checkpoint_path = '/home/tester/jianhoong/jh_fyp_work/3D_UNet/ModelCheckpoints/Approach3_v5.hdf5'

my_callbacks = [
    ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=4),
    EarlyStopping(monitor='val_loss', patience=50, restore_best_weights=True),
    CSVLogger(csv_path, separator = ',', append = True),
    ModelCheckpoint(filepath = model_checkpoint_path,
    monitor = 'val_loss',
    mode = 'min',
    verbose = 1)
]

In [None]:
# data = train_dataset.take(1)
# images, labels = list(data)[0]
# images = images.numpy()
# image = images[0]
# print("Dimension of the CT scan is:", image.shape)
# plt.imshow(np.squeeze(image[:, :, 30]), cmap="gray")

In [None]:
print(model.input_shape)
# print(train_data.shape)
print(model.output_shape)
# print(train_mask.shape)
print("-------------------")
# print(train_data.max())  

In [None]:
#Fit the model
history = model.fit(
        train_dataset, 
        validation_data = valid_dataset,
        epochs=50,
        shuffle = True,
        verbose=2,
        callbacks = my_callbacks)