In [None]:
import os
import re
import glob 
import nrrd
import shutil
import json
import h5py 
import numpy as np
from scipy import ndimage
import matplotlib.pyplot as plt
from scipy.ndimage import zoom
from datetime import datetime

import tensorflow.keras
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import RMSprop, Adam
from tensorflow.keras.callbacks import ModelCheckpoint, Callback, ReduceLROnPlateau, LearningRateScheduler, EarlyStopping, TensorBoard
from tensorflow.keras.callbacks import LambdaCallback, CSVLogger
from tensorflow.keras import backend as K
from tensorflow.keras.utils import plot_model

import tensorflow as tf
import util

## for tensorboard
# %load_ext tensorboard

In [None]:
data_dir = '/home/tester/jianhoong/jh_fyp_work/ct_scans_data/raw_data/'

z_train = os.path.join(data_dir, 'training_data_z')
z_train_image = os.path.join(z_train, 'training_images/training_images')
z_train_mask = os.path.join(z_train, 'training_masks/training_masks')

In [None]:
def read_nrrd_file(filepath):
    '''read and load volume'''
    pixelData, header = nrrd.read(filepath)
    return pixelData

def normalize(volume):
    min = -1000 # min value of our data : -1000
    max = 5000 # max value of our data : 5013
    range = max - min
    volume[volume < min] = min
    volume[volume > max] = max
    volume = (volume - min) / range
    volume = volume.astype("float32")
    return volume

def resize_volume(img):
    '''resizing across z-axis'''
    desired_depth = 128
    desired_width = 256
    desired_height = 256

    current_depth = img.shape[-1]
    current_width = img.shape[0]
    current_height = img.shape[1]

    depth_factor = 1 / (current_depth / desired_depth)
    width_factor = 1 / (current_width / desired_width)
    height_factor = 1/ (current_height / desired_height)
    '''rotating image to fix orientation'''
    img = ndimage.rotate(img, 90, reshape = False)
    img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order = 1)
    return img

def process_scan(path):
    volume = read_nrrd_file(path)
    volume = normalize(volume)
    volume = resize_volume(volume)
    return volume

def sorted_alnum(l):
    convert = lambda text: int(text) if text.isdigit() else text 
    alphanum_key = lambda key : [convert(c) for c in re.split('([0-9]+)', key)]
    return sorted(l, key = alphanum_key)

In [None]:
train_5_dim = [1, 2, 7, 11, 17, 24, 25, 26, 27, 30, 32, 39, 40, 41, 42, 43, 45, 46, 50]
test_5_dim = [2]

In [None]:
train_path = sorted_alnum([os.path.join(z_train_image, file) for file in os.listdir(z_train_image)  if int(re.findall(r'\d+', file)[0]) in train_5_dim])
train_mask_path = sorted_alnum([os.path.join(z_train_mask, file) for file in os.listdir(z_train_mask)  if int(re.findall(r'\d+', file)[0]) in train_5_dim])

test_path = sorted_alnum([os.path.join(z_train_image, file) for file in os.listdir(z_train_image)  if int(re.findall(r'\d+', file)[0]) in test_5_dim])
test_mask_path = sorted_alnum([os.path.join(z_train_mask, file) for file in os.listdir(z_train_mask)  if int(re.findall(r'\d+', file)[0]) in test_5_dim])

In [None]:
for i, j in zip(train_path, train_mask_path):
    print(i, j)

In [None]:
def training_generator(train_paths, mask_paths):
    for scan_path, label_path in zip(train_paths, mask_paths):
        scan_pixels = process_scan(scan_path)
        mask_pixels = process_scan(label_path)
        yield scan_pixels, mask_pixels
    

In [None]:
volume = process_scan(train_path[0])

In [None]:
volume.shape

In [None]:
volume = tf.expand_dims(volume, axis = 0) ## 3 for channel last

In [None]:
train_loader = tf.data.Dataset.from_generator(training_generator, (tf.float32, tf.float32))

In [None]:
ds = train_loader.batch(10)

In [None]:
train_dataset = train_loader.shuffle(len(train_path))
train_dataset = train_dataset.apply(tf.data.experimental.ignore_errors())
train_dataset = train_dataset.batch(2, drop_remainder=True).prefetch(8)

In [None]:
from  tensorflow.keras import backend as K
K.set_image_data_format('channels_first')

model = util.unet_model_3d(loss_function=util.weighted_bce_dice_loss, metrics=[util.dice_coefficient])
model.summary()
plot_model(model, to_file='MyModel.png')