In [None]:
import numpy as np 
import matplotlib.pyplot as plt
import matplotlib
import tensorflow as tf
import os
import sys
#import nrrd
import scipy.ndimage
import scipy.misc
import pickle
import random
from tensorflow.python.framework import ops
import glob

In [None]:
# getting the directory of the data (overlapped patches and associated mask patches)

patches = './pre_processed dataset/patches/*.npy'
labels = './pre_processed dataset/labels_annotated/*.npy'   

#patches = 'D:/Brin Stroke Detection/Experiment/31.1.2021D/Dataset/Stroke3/Patches/*.npy'
#labels = 'D:/Brin Stroke Detection/Experiment/31.1.2021D/Dataset/Stroke3/Labels_Annotated/*.npy'      
    
patch_addrs = glob.glob(patches)
labels_addrs = glob.glob(labels)


# divide the data into 60% for train and 20% for validtion and 20% for testing

# training split
n_train = int(0.6*len(patch_addrs))
n_val = int(0.2*len(patch_addrs))
n_test = int(0.2*len(patch_addrs))

data_train_dir = patch_addrs[0:n_train]
anns_train_dir = labels_addrs[0:n_train]

# validation split
data_val_dir = patch_addrs[n_train + 1 : n_train + n_val]
anns_val_dir = labels_addrs[n_train+ 1 : n_train + n_val]

#testing split
data_test_dir = patch_addrs[n_train + n_val + 1 :]
anns_test_dir = labels_addrs[n_train + n_val + 1 :]

print(os.path.basename(data_test_dir[1]).split('.')[0])
os.path.basename(data_test_dir[1]).split('.')[0]

In [None]:
# create a function to load the patches and corresponding mask patches
def get_data(data_dir, anns_dir):
    length = len(data_dir)
    data_out = [] # for input patches
    for i in range(length):
        patch = np.load(data_dir[i])
        label = np.load(anns_dir[i])
        data_out.append((patch, label))
        
    return data_out

def get_data_ref(data_dir, anns_dir):
    length = len(data_dir)
    data_ref = [] # for references
    for i in range(length):
        patch_ref = (os.path.basename(data_dir[1]).split('.')[0])
        label_ref = (os.path.basename(anns_dir[1]).split('.')[0])
        data_ref.append((patch_ref, label_ref))
    return data_ref

In [None]:
# generate dataset and save as pickle files

train = get_data(data_train_dir, anns_train_dir)
val = get_data(data_val_dir, anns_val_dir)
test = get_data(data_test_dir, anns_test_dir)
test_ref = get_data_ref(data_test_dir, anns_test_dir) 

# save the data into pickle files
pickle.dump(file = open('./pickles/train.pkl', 'wb'), obj = train)
pickle.dump(file = open('./pickles/val.pkl', 'wb'), obj = val)
pickle.dump(file = open('./pickles/test.pkl', 'wb'), obj = test)


In [None]:
# define the parameters of input and output of u-net 

input_size = 64    # width and height of input 3d patch
input_depth = 64   # depth of input 3d patch
output_size = 64   # width and height of output
output_depth = 64  # depth of output 3d patch
output_classes = 2 # number of output classes (two classes,i.e background and infarct lesion)

# assign the hyperparameters to train u-net (need to tune to get higher performance)
learning_rate = 0.001    
num_epoches = 20       
batch_size = 16 

# location to save the weights, biases and 
save_path = "./tf/" 
logs_path = "./tf_logs/"

# check wherether there is previously trained model in save_path directory
load_model = True
if not os.path.exists(save_path):
    os.makedirs(save_path)
if not os.path.exists(logs_path):
    os.makedirs(logs_path)
model_name = 'model'     # can modify the model name to load


In [None]:
def get_data_raw_sample(data, size):
    x_y_data = random.sample(data, size)
    return [x[0] for x in x_y_data], [y[1] for y in x_y_data]


# takes raw data (x, y) and scales to match desired input and output sizes to feed into tensorflow
# pads and normalises input and also moves axes around to orientation expected by tensorflow
def get_scaled_input(data, min_i = input_size, min_o = output_size, depth = input_depth, 
                    depth_out = output_depth, image_fill = 0, 
                    label_fill = 0, n_classes = output_classes, norm_max = 500):  
    
    input_scale_factor = min_i/data[0].shape[0]
    output_scale_factor = min_o/data[0].shape[0]

    vox_zoom = none
    lbl_zoom = none

    if not input_scale_factor == 1:
        vox_zoom = scipy.ndimage.interpolation.zoom(data[0], input_scale_factor, order = 1) 
        # order 1 is bilinear - fast and good enough
    else:
        vox_zoom = data[0]

    if not output_scale_factor == 1:
        lbl_zoom = scipy.ndimage.interpolation.zoom(data[1], output_scale_factor, order = 0) 
        # order 0 is nearest neighbours: very important as it ensures labels are scaled properly (and stay discrete)
    else:
        lbl_zoom = data[1]   

    lbl_pad = label_fill*np.ones((min_o, min_o, depth_out - lbl_zoom.shape[-1]))
    lbl_zoom = np.concatenate((lbl_zoom, lbl_pad), 2)
    lbl_zoom = lbl_zoom[np.newaxis, :, :, :]
    
    vox_pad = image_fill*np.ones((min_i, min_i, depth - vox_zoom.shape[-1]))
    vox_zoom = np.concatenate((vox_zoom, vox_pad), 2)
    
    max_val = np.max(vox_zoom)
    if not np.max(vox_zoom) == 0:
        vox_zoom = vox_zoom * norm_max/np.max(vox_zoom)
        
    vox_zoom = vox_zoom[np.newaxis, :, :, :]

    vox_zoom = np.swapaxes(vox_zoom, 0, -1)
    lbl_zoom = np.swapaxes(lbl_zoom, 0, -1)
    # swap axes
        
    return vox_zoom, lbl_zoom

def upscale_segmentation(lbl, shape_desired):
    # returns scaled up label for a given input label and desired shape. required for mean iou calculation
    
    scale_factor = shape_desired[0]/lbl.shape[0]
    lbl_upscale = scipy.ndimage.interpolation.zoom(lbl, scale_factor, order = 0)
    # order 0 even more important here
    lbl_upscale = lbl_upscale[:, :, :shape_desired[-1]]
    if lbl_upscale.shape[-1] < shape_desired[-1]:
        pad_zero = off_label_fill*np.zeros((shape_desired[0], shape_desired[1], shape_desired[2] - lbl_upscale.shape[-1]))
        lbl_upscale = np.concatenate((lbl_upscale, pad_zero), axis = -1)
    return lbl_upscale


In [None]:
# functions to calculate loss
def get_pred_iou(predictions, lbl_original, ret_full = False, reswap = False):
    # get mean_iou for full batch
    iou = []
    dic_tmp=[]
    for i in range(len(lbl_original)):
        pred_cur = np.squeeze(predictions[i])
        metric = get_mean_iou(pred_cur, lbl_original[i], ret_full = ret_full, reswap = reswap)
        iou.append(metric)
    if ret_full:
        return np.mean(iou, axis = 0)
    else:
        return np.mean(iou)
    
def get_label_accuracy(pred, lbl_original):
    # get pixel-wise labelling accuracy (demo metric)
    
    # swap axes back
    pred = swap_axes(pred)
    pred_upscale = upscale_segmentation(pred, np.shape(lbl_original))
    return 100*np.sum(np.equal(pred_upscale, lbl_original))/np.prod(lbl_original.shape)

def get_mean_iou(pred, lbl_original, num_classes = output_classes, ret_full = False, reswap = False):
    # get mean iou between input predictions and target labels. note, method implicitly resizes as needed
    # ret_full - returns the full iou across all classes
    # reswap - if lbl_original is in tensorflow format, swap it back into the format expected by plotting tools (+ format of raw data)
    
    # swap axes back 
    pred = swap_axes(pred)
    if reswap:
        lbl_original = swap_axes(lbl_original)
    pred_upscale = upscale_segmentation(pred, np.shape(lbl_original))
    iou = [1]*num_classes
    for i in range(num_classes): 
        test_shape = np.zeros(np.shape(lbl_original))
        test_shape[pred_upscale == i] = 1
        test_shape[lbl_original == i] = 1
        full_sum = int(np.sum(test_shape))
        test_shape = -1*np.ones(np.shape(lbl_original))
        test_shape[lbl_original == i] = pred_upscale[lbl_original == i]
        t_p = int(np.sum(test_shape == i))
        if not full_sum == 0:
            iou[i] = t_p/full_sum
    if ret_full:
        return iou
    else: 
        return np.mean(iou)

def get_dice (pred, lbl_original):
    dice = np.sum(pred[lbl_original==1])*2.0 / (np.sum(pred) + np.sum(lbl_original))
    return dice

def get_dice_loss(y_true, y_pred):
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    y_pred_f = tf.cast(y_pred_f, dtype=tf.float32)
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    dice_loss = 1. - (2. * intersection + 1.) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + 1.)
    return dice_loss
    


In [None]:
# create a class for unet
class unetwork():
    
    # produces the conv_batch_relu combination as inf the paper
    def conv_batch_relu(self, tensor, filters, kernel = [3,3,3], stride = [1,1,1], is_training = True):
        padding = 'valid'
        if self.should_pad: padding = 'same'
    
        conv = tf.layers.conv3d(tensor, filters, kernel_size = kernel, strides = stride, padding = padding,
                                kernel_initializer = self.base_init, kernel_regularizer = self.reg_init)
        conv = tf.layers.batch_normalization(conv, training = is_training)
        conv = tf.nn.relu(conv) 
        return conv

    # upconvolution - two different implementations: the first is as suggested in the original unet paper and the second is a more recent version
    # needs to be determined if these do the same thing
    def upconvolve(self, tensor, filters, kernel = 2, stride = 2, scale = 4, activation = None):
        padding = 'valid'
        if self.should_pad: padding = 'same'
        # upsample_routine = tf.keras.layers.upsampling3d(size = (scale,scale,scale)) # uses tf.resize_images
        # tensor = upsample_routine(tensor)
        # conv = tf.layers.conv3d(tensor, filters, kernel, stride, padding = 'same',
        #                                 kernel_initializer = self.base_init, kernel_regularizer = self.reg_init)
        # use_bias = false is a tensorflow bug
        conv = tf.layers.conv3d_transpose(tensor, filters, kernel_size = kernel, strides = stride, padding = padding, use_bias= False, 
                                          kernel_initializer = self.base_init,  kernel_regularizer = self.reg_init)
        return conv

    def centre_crop_and_concat(self, prev_conv, up_conv):
        # if concatenating two different sized tensors, centre crop the first tensor to the right size and concat
        # needed if you don't have padding
        p_c_s = prev_conv.get_shape()
        u_c_s = up_conv.get_shape()
        offsets =  np.array([0, (p_c_s[1] - u_c_s[1]) // 2, (p_c_s[2] - u_c_s[2]) // 2, 
                             (p_c_s[3] - u_c_s[3]) // 2, 0], dtype = np.int32)
        size = np.array([-1, u_c_s[1], u_c_s[2], u_c_s[3], p_c_s[4]], np.int32)
        prev_conv_crop = tf.slice(prev_conv, offsets, size)
        up_concat = tf.concat((prev_conv_crop, up_conv), 4)
        return up_concat
        
    def __init__(self, base_filt = 8, in_depth = input_depth, out_depth = output_depth,
                 in_size = input_size, out_size = output_size, num_classes = output_classes,
                 learning_rate = learning_rate, print_shapes = True, drop = 0.2, should_pad = False):
        # initialise your model with the parameters defined above
        # print-shape is a debug shape printer for convenience
        # should_pad controls whether the model has padding or not
        # base_filt controls the number of base conv filters the model has. note deeper analysis paths have filters that are scaled by this value
        # drop specifies the proportion of dropped activations
        
        self.base_init = tf.compat.v1.truncated_normal_initializer(stddev=0.1) # initialise weights  tf.
        self.reg_init = tf.contrib.layers.l2_regularizer(scale=0.1) # initialise regularisation (was useful)
        
        self.should_pad = should_pad # to pad or not to pad, that is the question
        self.drop = drop # set dropout rate
        
        with tf.variable_scope('3dunet'):
            self.training = tf.placeholder(tf.bool)
            self.do_print = print_shapes
            self.model_input = tf.placeholder(tf.float32, shape = (None, in_depth, in_size, in_size, 1))  
            # define placeholders for feed_dict
            self.model_labels = tf.placeholder(tf.int32, shape = (None, out_depth, out_size, out_size, 1))
            labels_one_hot = tf.squeeze(tf.one_hot(self.model_labels, num_classes, axis = -1), axis = -2)
            
            if self.do_print: 
                print('input features shape', self.model_input.get_shape())
                print('labels shape', labels_one_hot.get_shape())
                
            # level zero
            conv_0_1 = self.conv_batch_relu(self.model_input, base_filt, is_training = self.training)
            conv_0_2 = self.conv_batch_relu(conv_0_1, base_filt*2, is_training = self.training)
            # level one
            max_1_1 = tf.layers.max_pooling3d(conv_0_2, [2,2,2], [2,2,2]) # stride, kernel previously [2,2,2]
            conv_1_1 = self.conv_batch_relu(max_1_1, base_filt*2, is_training = self.training)
            conv_1_2 = self.conv_batch_relu(conv_1_1, base_filt*4, is_training = self.training)
            conv_1_2 = tf.layers.dropout(conv_1_2, rate = self.drop, training = self.training)
            # level two
            max_2_1 = tf.layers.max_pooling3d(conv_1_2, [2,2,2], [2,2,2]) # stride, kernel previously [2,2,2]
            conv_2_1 = self.conv_batch_relu(max_2_1, base_filt*4, is_training = self.training)
            conv_2_2 = self.conv_batch_relu(conv_2_1, base_filt*8, is_training = self.training)
            conv_2_2 = tf.layers.dropout(conv_2_2, rate = self.drop, training = self.training)
            # level three
            max_3_1 = tf.layers.max_pooling3d(conv_2_2, [2,2,2], [2,2,2]) # stride, kernel previously [2,2,2]
            conv_3_1 = self.conv_batch_relu(max_3_1, base_filt*8, is_training = self.training)
            conv_3_2 = self.conv_batch_relu(conv_3_1, base_filt*16, is_training = self.training)
            conv_3_2 = tf.layers.dropout(conv_3_2, rate = self.drop, training = self.training)
            # level two
            up_conv_3_2 = self.upconvolve(conv_3_2, base_filt*16, kernel = 2, stride = [2,2,2]) # stride previously [2,2,2] 
            concat_2_1 = self.centre_crop_and_concat(conv_2_2, up_conv_3_2)
            conv_2_3 = self.conv_batch_relu(concat_2_1, base_filt*8, is_training = self.training)
            conv_2_4 = self.conv_batch_relu(conv_2_3, base_filt*8, is_training = self.training)
            conv_2_4 = tf.layers.dropout(conv_2_4, rate = self.drop, training = self.training)
            # level one
            up_conv_2_1 = self.upconvolve(conv_2_4, base_filt*8, kernel = 2, stride = [2,2,2]) # stride previously [2,2,2]
            concat_1_1 = self.centre_crop_and_concat(conv_1_2, up_conv_2_1)
            conv_1_3 = self.conv_batch_relu(concat_1_1, base_filt*4, is_training = self.training)
            conv_1_4 = self.conv_batch_relu(conv_1_3, base_filt*4, is_training = self.training)
            conv_1_4 = tf.layers.dropout(conv_1_4, rate = self.drop, training = self.training)
            # level zero
            up_conv_1_0 = self.upconvolve(conv_1_4, base_filt*4, kernel = 2, stride = [2,2,2])  # stride previously [2,2,2]
            concat_0_1 = self.centre_crop_and_concat(conv_0_2, up_conv_1_0)
            conv_0_3 = self.conv_batch_relu(concat_0_1, base_filt*2, is_training = self.training)
            conv_0_4 = self.conv_batch_relu(conv_0_3, base_filt*2, is_training = self.training)
            conv_0_4 = tf.layers.dropout(conv_0_4, rate = self.drop, training = self.training)
            conv_out = tf.layers.conv3d(conv_0_4, output_classes, [1,1,1], [1,1,1], padding = 'same')
            self.predictions = tf.expand_dims(tf.argmax(conv_out, axis = -1), -1)
            
            # note, this can be more easily visualised in a tool like tensorboard; follows exact same format as in paper.
            
            if self.do_print: 
                print('model convolution output shape', conv_out.get_shape())
                print('model argmax output shape', self.predictions.get_shape())
            
            
            y_pred = tf.squeeze([self.predictions])
            y_true = tf.squeeze(tf.one_hot(self.model_labels, num_classes, axis = -1), axis = -2) 
           
            
            
            # calculate dice _loss
            dice_loss = get_dice_loss(y_true, y_pred)                                          
            self.loss = tf.reduce_mean(dice_loss) # assign model loss as dice_loss            
            self.trainer = tf.train.adamoptimizer(learning_rate=learning_rate)           
            
            self.extra_update_ops = tf.get_collection(tf.graphkeys.update_ops) # ensure correct ordering for batch-norm to work
            with tf.control_dependencies(self.extra_update_ops):
                self.train_op = self.trainer.minimize(self.loss)


In [None]:
ops.reset_default_graph()
unet = unetwork(drop = 0.2, base_filt = 10, should_pad = True) # model definition 
init = tf.global_variables_initializer() 
saver = tf.train.saver(tf.global_variables())
config = tf.configproto()
with tf.session(config=config) as sess:
    writer = tf.summary.filewriter(logs_path, graph=tf.get_default_graph())
    if load_model:
        print('trying to load saved model...')
        try:
            print('loading from: ', save_path + 'model' + '.meta')
            restorer = tf.train.import_meta_graph(save_path +'/'+ model_name+ '.meta')
            restorer.restore(sess, tf.train.latest_checkpoint(save_path))
            print("model sucessfully restored")
        except ioerror:
            sess.run(init)
            print("no previous model found, running default init") 
    t_loss = []
    for i in range(num_epoches):
        print('current iter: ', i, end='\r')
#         x, y, orig_y = get_dataset_sample(train, batch_size, no_perturb = true) (used if data-aug at runtime)
        x, y = get_data_raw_sample(train_run, batch_size) # draw samples from batch
        train_dict = {
            unet.training: true,
            unet.model_input: x,
            unet.model_labels: y
        }
        _, loss = sess.run([unet.train_op, unet.loss], feed_dict = train_dict) # get loss
        t_loss.append(loss) # loss store
        if i % 400 == 0 and i > 0:
            print('saving model at iter: ', i) # save periodically
            saver.save(sess, save_path + model_name, global_step = i)
        if i  == 20 and i > 0:
            print('iteration', i, 'loss: ', np.mean(t_loss)) # get periodic progress reports
            t_loss = []
            iou_size = 5
            x, y, orig_y = get_dataset_sample(train, iou_size) #(used if data-aug at runtime)
            #x, y = get_data_raw_sample(train_run, batch_size) 
            train_dict = {
                unet.training: false,
                unet.model_input: x,
                unet.model_labels: y
            }
            preds = np.squeeze(sess.run([unet.predictions], feed_dict = train_dict))
            iou = get_pred_iou(preds, y, ret_full = true, reswap = true)
            print('train iou (on scaled anns): ', iou, 'mean: ', np.mean(iou[:output_classes-1]))
            
            # validation
            # get val mean iou over batch
            x, y, orig_y = get_dataset_sample(val, iou_size, no_perturb = true)            
            train_dict = {
                unet.training: false,
                unet.model_input: x,
                unet.model_labels: y
            }
            preds = np.squeeze(sess.run([unet.predictions], feed_dict = train_dict))
            iou = get_pred_iou(preds, orig_y, ret_full = true)
            print('validation iou (on original anns): ', iou, 'mean: ', np.mean(iou[:output_classes-1]))
            print('######################')            
          
    saver.save(sess,save_path + model_name, global_step = num_epoches) # final save[]


In [None]:
# Testing of the model

test = pickle.load(file = open('./pickles/test.pkl', 'rb'))
test_model_name = 'model.meta' # just for consistency - can ofc. be changed

config = tf.configproto()
test_predictions = []
with tf.session(config=config) as sess:
    print('loading saved model ...')
#   try
    restorer = tf.train.import_meta_graph(save_path + 'model-20'+ '.meta')
    restorer.restore(sess, tf.train.latest_checkpoint(save_path))
    print("model sucessfully restored")
    pred_out = []
    y_orig = []
    x_orig = []
    x_in = []
    y_in = []
    i = 0
    iou_out = []
    dice_out=[]

    while i < len(test):
        x_batch = []
        y_batch = []
        for j in range(i, min(len(test), i + batch_size)):
            y_orig.append(np.copy(test[j][1]))
            x_orig.append(np.copy(test[j][0]))
            x_cur, y_cur = get_scaled_input(test[j])
            x_batch.append(x_cur)
            y_batch.append(y_cur)
        if len(x_batch) == 0: break
        print('processing ', i)
        x_in = x_in + x_batch
        y_in = y_in + y_batch
        test_dict = {
            unet.training: false, # whether to perform batch-norm at inference (paper says this would be useful)
            unet.model_input: x_batch,
            unet.model_labels: y_batch
        }
        test_predictions = np.squeeze(sess.run([unet.predictions], feed_dict = test_dict))
        if len(x_batch) == 1:
            pred_out.append(test_predictions)
        else:
            pred_out.extend([np.squeeze(test_predictions[z, :, :, :]) for z in list(range(len(x_batch)))])
        i += batch_size
    dic_tmp= []
    for i in range(len(y_orig)):
        iou = get_mean_iou(pred_out[i], y_orig[i], ret_full = true)    
        print('test iou: ', iou, 'mean: ', np.mean(iou[:output_classes-1]))
        iou_out.append(np.mean(iou[:output_classes-1]))

        
        intersection = np.sum(y_orig[i] * pred_out[i])
        if (np.sum(y_orig[i])==0) and (np.sum(pred_out[i])==0):
            dic_tmp =1
        else: 
            dic_tmp = (2*intersection) / (np.sum(y_orig[i]) + np.sum(pred_out[i]))
        print('dice similarity: ', dic_tmp)
       

    print('mean test iou', np.mean(iou_out), 'var iou', np.var(iou_out))


In [None]:
# post processing 
final_prediction = np.zeros(256,256,256)
final_annotation =np.zeros(256,256,256)
exam_ref = []
pre_exam_ref = 'NA'

for i in range(len(test_ref))
    index =i
    current_test_ref = test_ref[i]     
    exam_ref = current_test_ref.split('_')[0] + current_test_ref.split('_')[1]
    
    if exam_ref == pre_exam_ref: # check patches are from the same case
        ref_number  =  current_test_ref.split('_')[2]

        row_ref = ref_number[0]
        col_ref = ref_number[1]
        dep_ref = ref_number[2]

        row_start = 64 * (row_ref-1)
        row_end = (row_start + 64)

        col_start = 64 * (col_ref-1)
        col_end = (col_start + 64)

        dep_start =64 * (dep_ref-1)
        dep_end = (dep_start + 64)

        final_prediction [row_start:row_end,col_start: col_end,dep_start:dep_end] = pred_out[i]
        final_annotation [row_start:row_end,col_start: col_end,dep_start:dep_end] = test[i][1][:]

        
        pre_exam_ref = exam_ref
        
    else:
        final_prediction = np.zeros(256,256,256)
        final_annotation =np.zeros(256,256,256)

    save_lesion('Test Result/Prediction_3/test_',exam_ref, '-', final_prediction)
    save_lesion('Test Result/Annotation_3/ann_',exam_ref,'_', final_annotation)