In [None]:
import re
import nrrd
import zipfile
import os, glob
import numpy as np
from skimage import io
import SimpleITK as sitk
from scipy import ndimage
import matplotlib.pyplot as plt
from patchify import patchify, unpatchify
from ipywidgets import interact, interactive, IntSlider, ToggleButtons

import keras
import tensorflow as tf
from keras import backend as K
import segmentation_models_3D as sm
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint, CSVLogger

In [None]:
data_dir = '/home/tester/jianhoong/jh_fyp_work/ct_scans_data/raw_data/'

z_train = os.path.join(data_dir, 'training_data_z')
z_train_image = os.path.join(z_train, 'training_images/training_images')
z_train_mask = os.path.join(z_train, 'training_masks/training_masks')

# z_valid = os.path.join(data_dir, 'valid_data_z')
# z_valid_image = os.path.join(z_valid, 'valid_images/valid_images')
# z_valid_mask = os.path.join(z_valid, 'valid_masks/valid_masks')

# z_test = os.path.join(data_dir, 'testing_data_z')
# z_test_image = os.path.join(z_test, 'testing_images/testing_images')
# z_test_mask = os.path.join(z_test, 'testing_masks/testing_masks')

In [None]:
def read_nrrd_file(filepath):
    '''read and load volume'''
    pixelData, header = nrrd.read(filepath)
    return pixelData

def normalize(volume):
    min = -1000 # min value of our data : -1000
    max = 5000 # max value of our data : 5013
    range = max - min
    volume[volume < min] = min
    volume[volume > max] = max
    volume = (volume - min) / range
    volume = volume.astype("float32")
    return volume

def resize_volume(img):
    '''resizing across z-axis'''
    desired_depth = 128
    desired_width = 256
    desired_height = 256

    current_depth = img.shape[-1]
    current_width = img.shape[0]
    current_height = img.shape[1]

    depth_factor = 1 / (current_depth / desired_depth)
    width_factor = 1 / (current_width / desired_width)
    height_factor = 1/ (current_height / desired_height)
    '''rotating image to fix orientation'''
    img = ndimage.rotate(img, 90, reshape = False)
    img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order = 1)
    return img

def process_scan(path):
    volume = read_nrrd_file(path)
    volume = normalize(volume)
    volume = resize_volume(volume)
    return volume

def sorted_alnum(l):
    convert = lambda text: int(text) if text.isdigit() else text 
    alphanum_key = lambda key : [convert(c) for c in re.split('([0-9]+)', key)]
    return sorted(l, key = alphanum_key)

In [None]:
# train_5_dim = [1, 2, 7, 11, 17, 24, 25, 26, 27, 30, 32, 39, 40, 41, 42, 43, 45, 46, 50, 56, 57, 59, 61, 62, 63, 65, 66, 67, 70, 71, 74, 75, 78, 80, 84, 86, 98, 99, 100, 101, 102, 111, 113, 114, 115, 120, 121, 124, 125, 127, 128, 130, 133, 135, 137, 138, 141, 143, 146, 148, 150, 152, 153, 154, 158, 160, 161, 165, 166, 167, 168, 173, 174, 175, 176, 177, 178, 181, 182, 183, 187, 189, 194, 195, 196, 197, 203, 204]
train_5_dim = [1]
test_5_dim = [2]

In [None]:
train_path = sorted_alnum([os.path.join(z_train_image, file) for file in os.listdir(z_train_image)  if int(re.findall(r'\d+', file)[0]) in train_5_dim])
train_mask_path = sorted_alnum([os.path.join(z_train_mask, file) for file in os.listdir(z_train_mask)  if int(re.findall(r'\d+', file)[0]) in train_5_dim])

test_path = sorted_alnum([os.path.join(z_train_image, file) for file in os.listdir(z_train_image)  if int(re.findall(r'\d+', file)[0]) in test_5_dim])
test_mask_path = sorted_alnum([os.path.join(z_train_mask, file) for file in os.listdir(z_train_mask)  if int(re.findall(r'\d+', file)[0]) in test_5_dim])

In [None]:
len(train_path)

In [None]:
train_scans = np.array([process_scan(path) for path in train_path])
train_mask_scans = np.array([process_scan(path) for path in train_mask_path])

test_scans = np.array([process_scan(path) for path in test_path])
test_mask_scans = np.array([process_scan(path) for path in test_mask_path])

In [None]:
train_scans.shape

In [None]:
train_mask_scans.shape

### Setting up 3D UNet

In [None]:
BACKBONE = 'mobilenetv2'
preprocess_input = sm.get_preprocessing(BACKBONE)

In [None]:
train_scans = np.stack((train_scans,) * 3, axis = -1) # Stacking input img by itself , 3 times. To accomodate SM library requirements
train_msk = np.expand_dims(train_mask_scans, axis = 4) # Mask requires 1 channel for SM library

In [None]:
train_scans = np.stack((train_scans,) * 3, axis = -1) 
train_msk = np.expand_dims(train_mask_scans, axis = 4) 

In [None]:
train_scans.shape

In [None]:
train_scans = preprocess_input(train_scans)
test_scans = preprocess_input(test_scans)

In [None]:
LR = 0.0001
opt = tf.keras.optimizers.Nadam(LR)

dice_loss = sm.losses.DiceLoss()
CE_loss = sm.losses.BinaryCELoss()
total_loss = dice_loss + CE_loss 

metrics = [sm.metrics.IOUScore(threshold = 0.5), sm.metrics.FScore(threshold = 0.5)]

model = sm.Unet(
    BACKBONE, 
    classes = 1,
    input_shape = (256, 256, 128, 3),
    encoder_weights = 'imagenet',
    activation = 'sigmoid')

model.compile(optimizer = opt, loss = total_loss, metrics = metrics)
print(model.summary())

In [None]:
tf.config.run_functions_eagerly(True)

In [None]:
train_scans.shape

In [None]:
history = model.fit(
    train_scans,
    train_mask_scans,
    batch_size = 1,
    epochs = 25,
    verbose = 1,
    validation_data = (test_scans, test_mask_scans)
)