In [None]:
import re
import cv2

import nrrd
import random
import os, glob
import numpy as np
from collections import Counter
import matplotlib.pyplot as plt
from patchify import patchify, unpatchify

import keras
import tensorflow as tf
from keras import backend as K
from keras.models import Model
from keras.layers import Input, Conv3D, MaxPooling3D, UpSampling3D, Conv3DTranspose, BatchNormalization, Dropout, Lambda
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint, CSVLogger
from keras.layers import Activation, Concatenate

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"]="2"

In [None]:
def conv_block(input, num_filters):
    x = Conv3D(num_filters, 3, padding="same")(input)
    x = BatchNormalization()(x)   #Not in the original network. 
    x = Activation("relu")(x)

    x = Conv3D(num_filters, 3, padding="same")(x)
    x = BatchNormalization()(x)  #Not in the original network
    x = Activation("relu")(x)

    return x

#Encoder block: Conv block followed by maxpooling

def encoder_block(input, num_filters):
    x = conv_block(input, num_filters)
    p = MaxPooling3D((2, 2, 2))(x)
    return x, p   

#Decoder block
#skip features gets input from encoder for concatenation

def decoder_block(input, skip_features, num_filters):
    x = Conv3DTranspose(num_filters, (2, 2, 2), strides=2, padding="same")(input)
    x = Concatenate()([x, skip_features])
    x = conv_block(x, num_filters)
    return x

#Build Unet using the blocks
def build_unet(input_shape, n_classes):
    inputs = Input(input_shape)

    s1, p1 = encoder_block(inputs, 64)
    s2, p2 = encoder_block(p1, 128)
    s3, p3 = encoder_block(p2, 256)
    s4, p4 = encoder_block(p3, 512)

    b1 = conv_block(p4, 1024) #Bridge

    d1 = decoder_block(b1, s4, 512)
    d2 = decoder_block(d1, s3, 256)
    d3 = decoder_block(d2, s2, 128)
    d4 = decoder_block(d3, s1, 64)

    if n_classes == 1:  #Binary
      activation = 'sigmoid'
    else:
      activation = 'softmax'

    outputs = Conv3D(n_classes, 1, padding="same", activation=activation)(d4)  #Change the activation based on n_classes
    print(activation)

    model = Model(inputs, outputs, name="U-Net")
    return model

In [None]:
data_dir = '/home/tester/jianhoong/jh_fyp_work/ct_scans_data/raw_data/'

z_train = os.path.join(data_dir, 'training_data_z')
z_train_image = os.path.join(z_train, 'training_images/training_images')
z_train_mask = os.path.join(z_train, 'training_masks/training_masks')

z_valid = os.path.join(data_dir, 'valid_data_z')
z_valid_image = os.path.join(z_valid, 'valid_images/valid_images')
z_valid_mask = os.path.join(z_valid, 'valid_masks/valid_masks')

z_test = os.path.join(data_dir, 'testing_data_z')
z_test_image = os.path.join(z_test, 'testing_images/testing_images')
z_test_mask = os.path.join(z_test, 'testing_masks/testing_masks')

In [None]:
def read_nrrd_file(filepath):
    '''read and load volume'''
    pixelData, header = nrrd.read(filepath)
    return pixelData[:,:,:96]

def normalize(volume):
    min = -1000 # min value of our data : -1000
    max = 5000 # max value of our data : 5013
    range = max - min
    volume[volume < min] = min
    volume[volume > max] = max
    volume = (volume - min) / range
    volume = volume.astype("float32")
    return volume

def process_scan(path):
    volume = read_nrrd_file(path)
    volume = normalize(volume)
    # volume = resize_volume(volume)
    return volume

def sorted_alnum(l):
    convert = lambda text: int(text) if text.isdigit() else text 
    alphanum_key = lambda key : [convert(c) for c in re.split('([0-9]+)', key)]
    return sorted(l, key = alphanum_key)

In [None]:
train_data = [4,5,9,10,15,16,18,19,20,21,22,23,28,29,33,34,35,37,44,51,52,58,64,68,69,72,73,76,77,79,82,83,90,91,92,93,94,95,97,103,104,116,123,126,129,131,134,136,139,142,149,151,155,157,159,162,163,164,169,170,171,179,180,186,188,191,192,198,200,201,206,207,208,209]
valid_data = [213,214,215,217,219,220,221,224,226,227,228,231,237,238,240,243,248,249,250,254,255,256,257,264,266,270]
test_data = [271,272,274,277,280,283,286,287,288,293,294,296,297,298,300,303,305,306,314,315,317,320,322,340,342,345,347]

In [None]:
train_path = sorted_alnum([os.path.join(z_train_image, file) for file in os.listdir(z_train_image)  if int(re.findall(r'\d+', file)[0]) in train_data])
train_mask_path = sorted_alnum([os.path.join(z_train_mask, file) for file in os.listdir(z_train_mask)  if int(re.findall(r'\d+', file)[0]) in train_data])

valid_path = sorted_alnum([os.path.join(z_valid_image, file) for file in os.listdir(z_valid_image)  if int(re.findall(r'\d+', file)[0]) in valid_data])
valid_mask_path = sorted_alnum([os.path.join(z_valid_mask, file) for file in os.listdir(z_valid_mask)  if int(re.findall(r'\d+', file)[0]) in valid_data])

test_path = sorted_alnum([os.path.join(z_test_image, file) for file in os.listdir(z_test_image)  if int(re.findall(r'\d+', file)[0]) in test_data])
test_mask_path = sorted_alnum([os.path.join(z_test_mask, file) for file in os.listdir(z_test_mask)  if int(re.findall(r'\d+', file)[0]) in test_data])

In [None]:
def __init__(self, list_IDs,labels,batch_size = 32, dim = (64,64,32), n_channels = 3, n_classes=1,shuffle=True):
    self.dim = dim
    self.batch = batch_size
    self.labels = labels
    self.list_IDs = list_IDs
    self.n_channels = n_channels
    self.shuffle = shuffle
    self.on_epoch_end() # Triggered at start & end of epoch

def on_epoch_end(self):
    self.indexes = np.arange(len(self.list_IDs))
    if self.shuffle == True: # Shuffling makes model more robust
        np.random.shuffle(self.indexes)    
        
def __data_generation(self,list_IDs_temp):
    X = np.empty((self.batch_size, *self.dim, self.n_channels))
    y = np.empty((self.batch_size), dtype = int)

In [None]:
def process_one(scan_paths, mask_paths,desired_size = 3000):
    scan_storage = list()
    mask_storage = list()
    patients_processed = list()

    while len(scan_storage) < desired_size:
        if len(scan_storage) >= desired_size:
            break
        random_idx = random.randint(0, len(scan_paths)-1)
        patient_idx = int(re.findall(r'\d+', scan_paths[random_idx][-14:-11])[0])
        patients_processed.append(patient_idx) # Extract numerical patient index from path string

        print(f'Processing Patient {patient_idx} data')
        scan_pixelData = process_scan(scan_paths[random_idx])
        mask_pixelData = read_nrrd_file(mask_paths[random_idx])

        scan_patch = patchify(scan_pixelData, (64,64,32), step = 32) # Yield 16 x 16 x 3 of size 32 x 32 x 32 cubes
        mask_patch = patchify(mask_pixelData, (64,64,32), step = 32)

        input_img = np.reshape(scan_patch, (-1, scan_patch.shape[3], scan_patch.shape[4], scan_patch.shape[5])) # Collapse 16 x 16 x 3 into 768 cubes x 32 x 32 x 32
        input_mask = np.reshape(mask_patch, (-1, mask_patch.shape[3], mask_patch.shape[4], mask_patch.shape[5]))

        for i in range(input_mask.shape[0]):
            if np.sum(input_mask[i]) > 0:
                print(f'Storing Patient {patient_idx} data, Cube Num: {i}')
                scan_storage.append(input_img[i])
                mask_storage.append(input_mask[i])
                print(f'Current Training Data: {len(scan_storage)}')

    scan_storage = np.array(scan_storage)
    mask_storage = np.array(mask_storage)
    
    processed_scan = np.stack((scan_storage,)*3, axis=-1)
    processed_mask = np.expand_dims(mask_storage, axis=4)
    
    return processed_scan, processed_mask, patients_processed

def check_patients_processed(processed_patients_list):
    x = Counter(processed_patients_list)
    top = 5
    if len(processed_patients_list) == len(set(processed_patients_list)):
        print(f"No duplicates of patients processed detected. {len(processed_patients_list)} patients processed")
    else:
        print(f'Total patients processed: {len(processed_patients_list)}')
        print(f"Unique Patients processed: {len(set(processed_patients_list))}")
        print(f"Top {top} occuring patients: ")
        for i in range(top):
            print(f"Patient {x.most_common()[i][0]} : {x.most_common()[i][1]} times ")

In [None]:
train_data, train_mask, train_patients = process_one(train_path, train_mask_path, 1000)

In [None]:
valid_data, valid_mask, valid_patients = process_one(valid_path, valid_mask_path, 200)

In [None]:
# Loss Function and coefficients to be used during training:
def dice_coefficient(y_true, y_pred):
    smoothing_factor = 1
    flat_y_true = K.flatten(y_true)
    flat_y_pred = K.flatten(y_pred)
    return (2. * K.sum(flat_y_true * flat_y_pred) + smoothing_factor) / (K.sum(flat_y_true) + K.sum(flat_y_pred) + smoothing_factor)

def dice_coefficient_loss(y_true, y_pred):
    return 1 - dice_coefficient(y_true, y_pred)

#Define parameters for our model.
n_classes = 1
patch_size = 32
channels=3

LR = 0.001
opt = tf.keras.optimizers.Nadam(LR)


model = build_unet((64,64,32,3), n_classes = 1)
model.compile(optimizer = opt, loss=dice_coefficient_loss, metrics=dice_coefficient)
print(model.summary())

In [None]:
print(model.input_shape)
print(train_data.shape)
print(model.output_shape)
print(train_mask.shape)
print("-------------------")
print(train_data.max())  

In [None]:
csv_path = '/home/tester/jianhoong/jh_fyp_work/3D_UNet/trials/3DUNet_ModelCSVLogs/UNet_Approach3_v3.csv'
model_checkpoint_path = '/home/tester/jianhoong/jh_fyp_work/3D_UNet/ModelCheckpoints/Approach3_v3.hdf5'

my_callbacks = [
    ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=4),
    EarlyStopping(monitor='val_loss', patience=8, restore_best_weights=True),
    CSVLogger(csv_path, separator = ',', append = True),
    ModelCheckpoint(filepath = model_checkpoint_path,
    monitor = 'val_loss',
    mode = 'min',
    verbose = 1)
]

In [None]:
#Fit the model
history = model.fit(train_data, 
        train_mask,
        batch_size=2, 
        epochs=50,
        verbose=1,
        validation_data=(valid_data, valid_mask),
        callbacks = my_callbacks)

In [None]:
#plot the training and validation IoU and loss at each epoch
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, len(loss) + 1)
plt.plot(epochs, loss, 'y', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

acc = history.history['dice_coefficient']
val_acc = history.history['val_dice_coefficient']

plt.plot(epochs, acc, 'y', label='Training Dice')
plt.plot(epochs, val_acc, 'r', label='Validation Dice')
plt.title('Training and validation Dice')
plt.xlabel('Epochs')
plt.ylabel('Dice')
plt.legend()
plt.show()

In [None]:
test_data, test_mask, test_patients = process_one(test_path, test_mask_path, 100)

In [None]:
import SimpleITK as sitk
from ipywidgets import interact, interactive, IntSlider, ToggleButtons
from tensorflow.keras.models import load_model

In [None]:
model = tf.keras.models.load_model(model_checkpoint_path, custom_objects = {'dice_coefficient_loss':dice_coefficient_loss,
                                                                                   'dice_coefficient':dice_coefficient})

In [None]:
test_img_number = random.randint(0, len(test_data))

test_img = test_data[test_img_number]
ground_truth = test_mask[test_img_number]

test_img_input = np.expand_dims(test_img,0)
test_pred = model.predict(test_img_input)

In [None]:
print(test_img.shape)
print(ground_truth.shape)
print(test_pred[0].shape)

In [None]:
# Reading immediately from .nrrd image instead of np array
@interact
def explore_prediction(layer = (0,31), view = ["axial", "sagittal","coronal"]):
    if view == 'axial':
        img_array_view = test_img[layer, :, :]
        msk_array_view = ground_truth[layer, :, :]
        pred_array_view = test_pred[0][layer, :, :]
    elif view == 'coronal':
        img_array_view = test_img[:,layer,:]
        msk_array_view = ground_truth[:,layer,:]
        pred_array_view = test_pred[0][:,layer,:]
    else:
        img_array_view = test_img[:,:,layer]
        msk_array_view = ground_truth[:,:,layer]
        pred_array_view = test_pred[0][:,:,layer]

    plt.figure(figsize=(18, 6))
    plt.subplot(1, 3, 1)
    plt.imshow(img_array_view, cmap='gray', aspect = "auto")
    plt.title('Img', fontsize=10)
    plt.subplot(1, 3, 2)
    plt.imshow(msk_array_view, cmap='gray', aspect = "auto")
    plt.title('Mask', fontsize=10)
    plt.subplot(1, 3, 3)
    plt.imshow(pred_array_view, cmap='gray', aspect = "auto")
    plt.title('Prediction', fontsize=10)
    plt.axis('off')
    plt.show()