In [1]:
'''Mounting Google Drive on the Colab notebook'''
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [0]:
#file_image = '/content/gdrive/My Drive/Brain_Tumour_segmentation/Train_image.hdf5'
import h5py
#dataset has data of 484 patients. (155 images of each patient)
#data is extracted using 4 different techniques
#size of data of 1 patient is [240,240,155,4]
#for 2D segmentation we stack in 3rd dimension (axis=2)
#train_image
image_store = h5py.File("/content/gdrive/My Drive/Brain_Tumour_segmentation/Train_image.hdf5", "r")
#train_labels
label_store = h5py.File("/content/gdrive/My Drive/Brain_Tumour_segmentation/Train_label.hdf5", "r")
train_images = image_store["image"]
train_labels = label_store["label"]
#print('hi')

In [0]:
'''IMPORTING LIBRARIES'''
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import tensorflow as tf
import time
import math
'''Clearing tesorflow computation graph'''
tf.reset_default_graph()

In [0]:
'''DEFINING VARIABLES'''

batch_size=1              #batch size taken at a time
n_class = 4                #number of classes in the label

'''PLACEHOLDER for input and output of UNET'''
'''Here we crop the 155x240x240 image to  160x160x192 by repeating the last layer 5 times to convert 155 to 160'''
X = tf.placeholder(shape=[None,160,160,192,4], dtype=tf.float32, name='input_image')
y = tf.placeholder(shape=[None,160,160,192,1], dtype=tf.int64, name='hot_encode_label')


'''Batch variable exraction from h5py file (used by functions 'rnadom_h5py_batch' and 'test_batch')'''
out_img = np.empty((240,240,batch_size*155,4),dtype=np.float32)
out_label = np.empty((240,240,batch_size*155,1),dtype=np.int64)

In [0]:
'''Functions for batch Extraction and pre processing'''

def normalizing_input():
    '''normalization of each input channels'''
    global out_img
    '''CHANNEL INFO'''
    # maximum value found using function called "Finding_maximum_to_normalise "
    #'''max value for dimension 4 is 5337.0'''
    #'''max value for dimension 3 is 11737.0'''
    #'''max value for dimension 2 is 9751.0'''
    #'''max value for dimension 1 is 6476.0'''
    out_img[:,:,:,0] = out_img[:,:,:,0]/6476.0
    out_img[:,:,:,1] = out_img[:,:,:,1]/9751.0
    out_img[:,:,:,2] = out_img[:,:,:,2]/11737.0
    out_img[:,:,:,3] = out_img[:,:,:,3]/5337.0

def crop_image_fit_brain(in_image):
    '''cropping size was found using the code finding_brain.ipynb'''
    left = 19
    right= 210
    top  = 38
    bot  = 199
    out_image = in_image[top:(bot-1),left:(right+1),:,:]
    return (out_image)


def Pre_processing_3D(a):
    '''Function to roll axis to convert array into[depth,width,height,channels] and the divide it in batches'''
    #print(np.shape(a))
    b = np.rollaxis(a,2, 0)
    #print(np.shape(b))
    #image shape
    out_arr = np.empty(shape=[batch_size,160,np.shape(b)[1],np.shape(b)[2],np.shape(b)[3]])
    for i in range(batch_size):
        start = i*155
        end = start+155
        out_arr[i,0:155,:,:,:] = b[start:end,:,:,:]
    
    a = [out_arr[:,154:155,:,:,:]]*5
    out_arr[:,155:160,:,:,:] = a[0]
    
    '''clippig data from front of each batch'''
    '''to fit the model we remove first 3 slices of each batch'''
    '''shape of a is [batch_size,depth,width,height,channels]'''
    out_send = out_arr[:,:,:,:,:]
    #print("arr",np.shape(out_arr),"send",np.shape(out_send))
    return (out_send)




def random_h5py_batch(current_batch_no,permute_mat):
    '''Function to take batches randomly'''
    global out_img
    global out_label
    
    '''training info'''
    train_info = 300  #100 patients with 155 images each

    if current_batch_no == 0:
        no_of_batches = train_info//batch_size  
        permute_mat = np.random.permutation(no_of_batches)
    
    start = permute_mat[current_batch_no]*batch_size*155
    end = start + (batch_size*155)
    train_images.read_direct(out_img,np.s_[:,:,start:end,:])
    train_labels.read_direct(out_label,np.s_[:,:,start:end,:])
    current_batch_no += 1
    #print(len(out_img))
    '''Input normalization'''
    normalizing_input()
    '''normalization oof labels'''
    #out_label = out_label
    '''converting multi class to dual class'''
    #out_label = convert_dual_class(out_label)
    '''cropping image and labels'''
    crop_out_image = crop_image_fit_brain(out_img)
    crop_out_label = crop_image_fit_brain(out_label)
    '''Rolling axes'''
    #out_img_send = np.rollaxis(crop_out_image,2, 0)
    '''hot encoding'''
    #out_label_send = crop_out_label
    '''3D conversion'''
    out_img_send = Pre_processing_3D(crop_out_image)
    out_label_send = Pre_processing_3D(crop_out_label)
     
    last=0
    if current_batch_no == len(permute_mat):
        last=1
    
    return (out_img_send,out_label_send,current_batch_no,permute_mat,last)

def test_batch():
    '''Function to take next test batch''' 
    global out_img
    global out_label
    
    '''training and testing info'''
    train_info = 380  #100 patients with 155 images each
    test_info = 484-train_info  #100 patients with 155 images each
    
    no_of_batches = test_info//batch_size  
    permute_mat = np.random.permutation(no_of_batches)
    start = (permute_mat[0]*batch_size*155) +(train_info*155)
    end = start + (155*batch_size)
    train_images.read_direct(out_img,np.s_[:,:,start:end,:])
    train_labels.read_direct(out_label,np.s_[:,:,start:end,:])
    
    '''normalization'''
    normalizing_input()
    '''croping images and labels'''
    crop_out_image = crop_image_fit_brain(out_img)
    crop_out_label = crop_image_fit_brain(out_label) 
    '''3D processing'''
    out_img_send = Pre_processing_3D(crop_out_image)
    out_label_send = Pre_processing_3D(crop_out_label)
    
    return (out_img_send,out_label_send)



In [0]:
'''COMPUTATION GRAPH Function Definitions'''


def left_filter_def(ker_size,in_chan,out_chan,name='left_filter'):
    '''Defining a filter variable to perform convolution'''
    return (tf.Variable(tf.random_normal([ker_size,ker_size,ker_size,in_chan,out_chan],stddev=0.05),name=name))


def right_filter_def(ker_size,in_chan,out_chan,name='right_filter'):
    '''Defining a filter variable to perform transpose convolution'''
    return (tf.Variable(tf.random_normal([ker_size,ker_size,ker_size,out_chan,in_chan],stddev=0.05),name=name))


def Conv_layer(input_im,filter_mask,stride,activation='None',name='conv'):
    '''Function to perform Convolution and apply activation filter'''
    '''Convolution'''
    conv = tf.nn.conv3d(input_im,filter_mask,strides = [1,stride,stride,stride,1], padding = "SAME",name=name)
    #norm_conv = tf.layers.batch_normalization(conv, training=training, momentum=0.9)
    '''Activation'''
    if activation == 'relu':
        return(tf.nn.relu(conv))
    elif activation == 'softmax':
        return(tf.nn.softmax(conv,axis=-1))
    elif activation == 'elu':
        return(tf.nn.elu(conv))
    else:
        #activation == 'None'
        return(conv)
    

def Deconv_layer(input_im,filter_mask,stride,activation='None',name='De_conv'):
    '''Function to perform Transpose Convolution and apply activation filter'''
    '''Transpose Convolution'''
    inp_shape = np.shape(input_im) #tf.shape()
    out_shape = [batch_size]+[int(inp_shape[1].value*2), int(inp_shape[2].value*2),int(inp_shape[3].value*2), int(inp_shape[4].value/2)]
    
    conv = tf.nn.conv3d_transpose(input_im, filter_mask, out_shape, strides = [1,stride,stride,stride,1], padding = "SAME",name=name)
    #norm_conv = tf.layers.batch_normalization(conv, training=training, momentum=0.9)
    '''Activation'''
    if activation == 'relu':
        return(tf.nn.relu(conv))
    elif activation == 'softmax':
        return(tf.nn.softmax(conv,axis=-1))
    elif activation == 'elu':
        return(tf.nn.elu(conv))
    else:
        #activation == 'None'
        return(conv)




In [0]:
'''loss function definition'''


###############################################################################################################################
#WEIGHTED MULTICLASS DICE LOSS
###############################################################################################################################
def dice_coeff(y_true, y_pred):
    '''Finding dice coefficient for one class'''
    intersection = tf.math.reduce_sum(y_true*y_pred)
    union = ((tf.math.reduce_sum(y_true*y_true))+(tf.math.reduce_sum(y_pred*y_pred)))
    return(intersection,union)

  
def dice_coef_multilabel(y_true, y_pred, numLabels = n_class):
    '''Finding dice loss for each class'''
    dice = denominator = numerator = 0
    Ncl = y_pred.shape[-1]
    w = np.zeros(shape=(Ncl,))
    #print(np.shape(y_true))
    w = tf.reduce_sum(y_true, axis=[0,1,2,3]) + 1
    #w = 1/((w**2))
    #w = np.sum(y_true, axis=(0,1,2))
    weight = np.zeros(shape=(Ncl,))
    weight = 1/w
      
    for index in range(numLabels):
        a = weight[index]/(tf.reduce_sum(weight))     #for removing the weight added in all
        '''Here, as of now we are neglecting the background class'''
        #dice -= (weight*dice_coeff(y_true[:,:,:,index,0],y_pred[:,:,:,index]))
        num,den = (dice_coeff(y_true[:,:,:,:,index,0],y_pred[:,:,:,:,index]))
        denominator += den
        numerator += a*(num)
    dice = -((2*numerator)/denominator)
    return (dice)
#############################################################################################################################



In [0]:
def hot_encode(check_image,depth=n_class,name='hot_encode'):
    '''function for hot encoding images'''
    a = tf.one_hot(indices = check_image, depth=depth,name=name)
    #print(np.shape(a))
    b = tf.transpose(a,perm=[0,1,2,3,5,4])
    return b


In [9]:
'''MODEL1 Filter definition'''

'''LEFT'''

'''Input scaling'''
Scaling_filter0 = left_filter_def(3,4,8,name='preparing_input')
'''RESNET PART OF YNET'''

#res_filter1 = left_filter_def(3,4,8,name='res_filter1')
res_filter1 = left_filter_def(3,8,8,name='res_filter2')
res_filter2 = left_filter_def(3,8,8,name='res_filter3')
res_filter3 = left_filter_def(3,8,8,name='res_filter4')

res_filter4 = left_filter_def(3,8,16,name='res_filter5')
res_filter5 = left_filter_def(3,16,16,name='res_filter6')
res_filter6 = left_filter_def(3,16,16,name='res_filter7')
res_filter7 = left_filter_def(3,16,16,name='res_filter8')

res_filter8 = left_filter_def(3,16,32,name='res_filter9')
res_filter9 = left_filter_def(3,32,32,name='res_filter10')
res_filter10= left_filter_def(3,32,32,name='res_filter11')
res_filter11= left_filter_def(3,32,32,name='res_filter12')

res_filter12= left_filter_def(3,32,64,name='res_filter13')
res_filter13= left_filter_def(3,64,64,name='res_filter14')
res_filter14= left_filter_def(3,64,64,name='res_filter15')
res_filter15= left_filter_def(3,64,64,name='res_filter16')
'''preparing resnet to combine with inception network'''
res_prep_filter = left_filter_def(3,64,128,name='res_prep_filter')


'''INCEPTION PART OF YNET'''


'''1st inception block'''
inception_filter1 = left_filter_def(1,8,(2*8),name='inception_layer1_line1_filter1')
inception_filter2 = left_filter_def(3,(2*8),(4*8),name='inception_layer1_line1_filter2')
inception_filter3 = left_filter_def(1,8,(2*8),name='inception_layer1_line2_filter1')
inception_filter4 = left_filter_def(3,(2*8),(4*8),name='inception_layer1_line2_filter2')
inception_filter5 = left_filter_def(3,(4*8),(4*8),name='inception_layer1_line2_filter3')
inception_filter6 = left_filter_def(1,8,(2*8),name='inception_layer1_line3_filter1')
inception_filter7 = left_filter_def(3,(8*10),8,name='inception_layer1_final_filter1')  #in channels = (out cahnnens of all the three paths)
'''Scaling FIlter(used after pooling)'''
Scaling_filter1 = left_filter_def(3,8,16,name='inception_Sfilter1')

'''2nd inception block'''
inception_filter8 = left_filter_def(1,16,(2*16),name='inception_layer2_line1_filter1')
inception_filter9 = left_filter_def(3,(2*16),(4*16),name='inception_layer2_line1_filter2')
inception_filter10= left_filter_def(1,16,(2*16),name='inception_layer2_line2_filter1')
inception_filter11= left_filter_def(3,(2*16),(4*16),name='inception_layer2_line2_filter2')
inception_filter12= left_filter_def(3,(4*16),(4*16),name='inception_layer2_line2_filter3')
inception_filter13= left_filter_def(1,16,(2*16),name='inception_layer2_line3_filter1')
inception_filter14= left_filter_def(3,(16*10),16,name='inception_layer2_final_filter1')
'''Scaling FIlter(used after pooling)'''
Scaling_filter2 = left_filter_def(3,16,32,name='inception_Sfilter2')

'''3rd inception block'''
inception_filter15= left_filter_def(1,32,(2*32),name='inception_layer3_line1_filter1')
inception_filter16= left_filter_def(3,(2*32),(4*32),name='inception_layer3_line1_filter2')
inception_filter17= left_filter_def(1,(32),(2*32),name='inception_layer3_line2_filter1')
inception_filter18= left_filter_def(3,(2*32),(4*32),name='inception_layer3_line2_filter2')
inception_filter19= left_filter_def(3,(4*32),(4*32),name='inception_layer3_line2_filter3')
inception_filter20= left_filter_def(1,32,(2*32),name='inception_layer3_line3_filter1')
inception_filter21= left_filter_def(3,(10*32),32,name='inception_layer3_final_filter1')
'''Scaling FIlter(used after pooling)'''
Scaling_filter3 = left_filter_def(3,32,64,name='Sfilter3')

'''4th inception block'''
inception_filter22= left_filter_def(1,64,(2*64),name='inception_layer4_line1_filter1')
inception_filter23= left_filter_def(3,(2*64),(4*64),name='inception_layer4_line1_filter2')
inception_filter24= left_filter_def(1,64,(2*64),name='inception_layer4_line2_filter1')
inception_filter25= left_filter_def(3,(2*64),(4*64),name='inception_layer4_line2_filter2')
inception_filter26= left_filter_def(3,(4*64),(4*64),name='inception_layer4_line2_filter3')
inception_filter27= left_filter_def(1,(64),(2*64),name='inception_layer4_line3_filter1')
inception_filter28= left_filter_def(3,(10*64),64,name='inception_layer4_final_filter1')
'''preparing inception net to combine with resnet'''
incident_prep_filter = left_filter_def(3,64,128,name='incident_prep_filter')


'''COMBINING '''
combo_filter1 = left_filter_def(1,(2*128),128,name='combo_filter1')             #concating ineption and resnet
combo_filter2 = left_filter_def(1,(3*64),64,name='combo_filter2')               #concating deconv, resnet and inception
combo_filter3 = left_filter_def(1,(3*32),32,name='combo_filter3')               #concating deconv, resnet and inception
combo_filter4 = left_filter_def(1,(3*16),16,name='combo_filter4')               #concating deconv, resnet and inception
combo_filter5 = left_filter_def(1,(3*8),8,name='combo_filter5')                 #concating deconv, resnet and inception
'''Right'''

UP_filter1 = right_filter_def(3,128,64,name='UP_filter11')
UP_filter2 = left_filter_def(1,(3*64),64,name='combo_filter2')
UP_filter3 = left_filter_def(3,64,64,name='UP_filter13')

UP_filter4 = right_filter_def(3,64,32,name='UP_filter14')
UP_filter5 = left_filter_def(1,(3*32),32,name='combo_filter3')
UP_filter6 = left_filter_def(3,32,32,name='UP_filter16')

UP_filter7 = right_filter_def(3,32,16,name='UP_filter17')
UP_filter8 = left_filter_def(1,(3*16),16,name='combo_filter4')
UP_filter9 = left_filter_def(3,16,16,name='UP_filter19')

UP_filter10 = right_filter_def(3,16,8,name='UP_filter20')
UP_filter11 = left_filter_def(1,(3*8),8,name='combo_filter5')
UP_filter12 = left_filter_def(3,8,n_class,name='UP_filter22')


Instructions for updating:
Colocations handled automatically by placer.


In [0]:
def predict_model1(X):
    
    '''Function to define the UNET model'''
    Prep_Conv = Conv_layer(X,Scaling_filter0,stride=1,activation='relu',name='CNN_preparing_input')
    #print ("Prep_Conv",np.shape(Prep_Conv))
    with tf.name_scope("INCEPTION_NET"):
        
        with tf.name_scope("INCEPTION_BLOCK1"):
            '''BLOCK1'''
            CNN1 = Conv_layer(Prep_Conv,inception_filter1,stride=1,activation='relu',name='layer1_line1_CNN1')
            #print ("CNN1",np.shape(CNN1))
            CNN2 = Conv_layer(CNN1,inception_filter2,stride=1,activation='relu',name='layer1_line1_CNN2')
            #print ("CNN2",np.shape(CNN2))
            CNN3 = Conv_layer(Prep_Conv,inception_filter3,stride=1,activation='relu',name='layer1_line2_CNN1')
            #print ("CNN3",np.shape(CNN3))
            CNN4 = Conv_layer(CNN3,inception_filter4,stride=1,activation='relu',name='layer1_line2_CNN2')
            #print ("CNN4",np.shape(CNN4))
            CNN5 = Conv_layer(CNN4,inception_filter5,stride=1,activation='relu',name='layer1_line2_CNN3')
            #print ("CNN5",np.shape(CNN5))
            CNN6 = Conv_layer(Prep_Conv,inception_filter6,stride=1,activation='relu',name='layer1_line3_CNN1')
            #print ("CNN6",np.shape(CNN6))
            combo1 = tf.concat([CNN2,CNN5,CNN6],axis=4,name='COMBO1')
            #print ("COMBO1",np.shape(combo1))
            CNN7 = Conv_layer(combo1,inception_filter7,stride=1,activation='relu',name='layer1_final_CNN1')
            #print ("CNN7",np.shape(CNN7))
            Inception_block1_out = tf.add(CNN7,Prep_Conv,name='inception1')
            #print ("Inception_block1_outPUT",np.shape(Inception_block1_out))
            #print("Inception_block1_out",np.shape(Inception_block1_out))
        Inception_pool1 = tf.nn.max_pool3d(Inception_block1_out,ksize=[1,2,2,2,1],strides=[1,2,2,2,1],padding='VALID',name='Inception_pool1')
        S_CNN1 = Conv_layer(Inception_pool1,Scaling_filter1,stride=1,activation='relu',name='Scaling_CNN1')
        #print ("S_NN1",np.shape(S_CNN1))
    
        with tf.name_scope("INCEPTION_BLOCK2"):
            '''BLOCK2'''
            CNN8 = Conv_layer(S_CNN1,inception_filter8,stride=1,activation='relu',name='layer2_line1_CNN1')
            #rint ("CNN1",np.shape(CNN1))
            CNN9 = Conv_layer(CNN8,inception_filter9,stride=1,activation='relu',name='layer2_line1_CNN2')
            #rint ("CNN1",np.shape(CNN1))
            CNN10= Conv_layer(S_CNN1,inception_filter10,stride=1,activation='relu',name='layer2_line2_CNN1')
            #rint ("CNN1",np.shape(CNN1))
            CNN11= Conv_layer(CNN10,inception_filter11,stride=1,activation='relu',name='layer2_line2_CNN2')
            #rint ("CNN1",np.shape(CNN1))
            CNN12= Conv_layer(CNN11,inception_filter12,stride=1,activation='relu',name='layer2_line2_CNN3')
            #rint ("CNN1",np.shape(CNN1))
            CNN13= Conv_layer(S_CNN1,inception_filter13,stride=1,activation='relu',name='layer2_line3_CNN1')
            #rint ("CNN1",np.shape(CNN1))
            combo2 = tf.concat([CNN9,CNN12,CNN13],axis=4,name='COMBO2')
            #rint ("Combo",np.shape(combo2))
            CNN14= Conv_layer(combo2,inception_filter14,stride=1,activation='relu',name='layer2_final_CNN1')
            #rint ("CNN1",np.shape(CNN1))
            Inception_block2_out = tf.add(CNN14,S_CNN1,name='inception2')
            #print("Inception_block2_out",np.shape(Inception_block2_out))
        
        Inception_pool2 = tf.nn.max_pool3d(Inception_block2_out,ksize=[1,2,2,2,1],strides=[1,2,2,2,1],padding='VALID',name='POOL1')
        S_CNN2 = Conv_layer(Inception_pool2,Scaling_filter2,stride=1,activation='relu',name='Scaling_CNN2')
        #print ("SCNN2",np.shape(S_CNN2))
        
        with tf.name_scope("INCEPTION_BLOCK3"):
            '''BLOCK3'''
            CNN15= Conv_layer(S_CNN2,inception_filter15,stride=1,activation='relu',name='layer3_line1_CNN1')
            #rint ("CNN1",np.shape(CNN1))
            CNN16= Conv_layer(CNN15,inception_filter16,stride=1,activation='relu',name='layer3_line1_CNN2')
            #rint ("CNN1",np.shape(CNN1))
            CNN17= Conv_layer(S_CNN2,inception_filter17,stride=1,activation='relu',name='layer3_line2_CNN1')
            #rint ("CNN1",np.shape(CNN1))
            CNN18= Conv_layer(CNN17,inception_filter18,stride=1,activation='relu',name='layer3_line2_CNN2')
            #rint ("CNN1",np.shape(CNN1))
            CNN19= Conv_layer(CNN18,inception_filter19,stride=1,activation='relu',name='layer3_line2_CNN3')
            #rint ("CNN1",np.shape(CNN1))
            CNN20= Conv_layer(S_CNN2,inception_filter20,stride=1,activation='relu',name='layer3_line3_CNN1')
            #rint ("CNN1",np.shape(CNN1))
            combo3 = tf.concat([CNN16,CNN19,CNN20],axis=4,name='COMBO3')
            #rint ("Combo",np.shape(combo2))
            CNN21= Conv_layer(combo3,inception_filter21,stride=1,activation='relu',name='layer3_final_CNN1')
            #rint ("CNN1",np.shape(CNN1))
            Inception_block3_out = tf.add(CNN21,S_CNN2,name='inception3')
            #print("Inception_block3_out",np.shape(Inception_block3_out))
        
        Inception_pool3 = tf.nn.max_pool3d(Inception_block3_out,ksize=[1,2,2,2,1],strides=[1,2,2,2,1],padding='VALID',name='Inception_pool3')
        S_CNN3 = Conv_layer(Inception_pool3,Scaling_filter3,stride=1,activation='relu',name='Scaling_CNN3')
        #print ("sCNN3",np.shape(S_CNN3))
    
        with tf.name_scope("INCEPTION_BLOCK4"):
            '''BLOCK4'''
            CNN22= Conv_layer(S_CNN3,inception_filter22,stride=1,activation='relu',name='layer4_line1_CNN1')
            #rint ("CNN1",np.shape(CNN1))
            CNN23= Conv_layer(CNN22,inception_filter23,stride=1,activation='relu',name='layer4_line1_CNN2')
            #rint ("CNN1",np.shape(CNN1))
            CNN24= Conv_layer(S_CNN3,inception_filter24,stride=1,activation='relu',name='layer4_line2_CNN1')
            #rint ("CNN1",np.shape(CNN1))
            CNN25= Conv_layer(CNN24,inception_filter25,stride=1,activation='relu',name='layer4_line2_CNN2')
            #rint ("CNN1",np.shape(CNN1))
            CNN26= Conv_layer(CNN25,inception_filter26,stride=1,activation='relu',name='layer4_line2_CNN3')
            #rint ("CNN1",np.shape(CNN1))
            CNN27= Conv_layer(S_CNN3,inception_filter27,stride=1,activation='relu',name='layer4_line3_CNN1')
            #rint ("CNN1",np.shape(CNN1))
            combo4 = tf.concat([CNN23,CNN26,CNN27],axis=4,name='COMBO4')
            #rint ("Combo",np.shape(combo2))
            CNN28= Conv_layer(combo4,inception_filter28,stride=1,activation='relu',name='layer4_final_CNN1')
            #rint ("CNN1",np.shape(CNN1))
            Inception_block4_out = tf.add(CNN28,S_CNN3,name='inception4')
            #print("Inception_block4_out",np.shape(Inception_block4_out))
             
        Inception_pool4 = tf.nn.max_pool3d(Inception_block4_out,ksize=[1,2,2,2,1],strides=[1,2,2,2,1],padding='VALID',name='Inception_pool4')
        '''preparing inception for resnet'''
        ineption_prep = Conv_layer(Inception_pool4,incident_prep_filter,stride=1,activation='relu',name='incident_prep')
        #print ("ineption_prep",np.shape(ineption_prep))
    
    with tf.name_scope("RESNET"):
        
        with tf.name_scope("RESNET_BLOCK1"):
            '''BLOCK1'''
            Res_CNN1 = Conv_layer(Prep_Conv,res_filter1,stride=1,activation='relu',name='CNN1')
            #rint ("CNN1",np.shape(CNN1))
            Res_CNN2 = Conv_layer(Res_CNN1,res_filter2,stride=1,activation='relu',name='CNN2')
            #rint ("CNN2",np.shape(CNN2))
            Res_CNNR1 = tf.add(Conv_layer(Res_CNN2,res_filter3,stride=1,activation='relu',name='CNNR1'), Prep_Conv, name='Resnet1')
            #print("res_CNN1",np.shape(Res_CNNR1))
            #rint ("CNNR1",np.shape(CNNR1))
            pool1 = tf.nn.max_pool3d(Res_CNNR1,ksize=[1,2,2,2,1],strides=[1,2,2,2,1],padding='VALID',name='POOL1')
            #print("res_pool1",np.shape(pool1))
        with tf.name_scope("RESNET_BLOCK2"):
            '''BLOCK2'''
            Res_CNN3 = Conv_layer(pool1,res_filter4,stride=1,activation='relu',name='CNN3')
            Res_CNN4 = Conv_layer(Res_CNN3,res_filter5,stride=1,activation='relu',name='CNN3')
            #rint ("CNN3",np.shape(CNN3))
            Res_CNN5 = Conv_layer(Res_CNN4,res_filter6,stride=1,activation='relu',name='CNN4')
            #rint ("CNN4",np.shape(CNN4))
            Res_CNNR2 = tf.add(Conv_layer(Res_CNN5,res_filter7,stride=1,activation='relu',name='CNNR2'), Res_CNN3, name='Resnet2')
            #print("res_CNN2",np.shape(Res_CNNR2))
            #rint ("CNNR2",np.shape(CNNR2))
            pool2 = tf.nn.max_pool3d(Res_CNNR2,ksize=[1,2,2,2,1],strides=[1,2,2,2,1],padding='VALID',name='POOL2')
            #print("res_pool2",np.shape(pool2))
        with tf.name_scope("RESNET_BLOCK3"):
            '''BLOCK3'''
            Res_CNN6 = Conv_layer(pool2,res_filter8,stride=1,activation='relu',name='CNN3')
            Res_CNN7 = Conv_layer(Res_CNN6,res_filter9,stride=1,activation='relu',name='CNN5')
            #rint ("CNN5",np.shape(CNN5))
            Res_CNN8 = Conv_layer(Res_CNN7,res_filter10,stride=1,activation='relu',name='CNN6')
            #rint ("CNN6",np.shape(CNN6))
            Res_CNNR3 = tf.add(Conv_layer(Res_CNN8,res_filter11,stride=1,activation='relu',name='CNNR3'), Res_CNN6, name='Resnet3')
            #print("res_CNN3",np.shape(Res_CNNR3))
            #rint ("CNNR3N",np.shape(CNNR3))
            pool3 = tf.nn.max_pool3d(Res_CNNR3,ksize=[1,2,2,2,1],strides=[1,2,2,2,1],padding='VALID',name='POOL3')
            #print("res_pool3",np.shape(pool3))
        with tf.name_scope("RESNET_BLOCK4"):
            '''BLOCK4'''
            Res_CNN9 = Conv_layer(pool3,res_filter12,stride=1,activation='relu',name='CNN3')
            Res_CNN10= Conv_layer(Res_CNN9,res_filter13,stride=1,activation='relu',name='CNN7')
            #rint ("CNN7",np.shape(CNN7))
            Res_CNN11= Conv_layer(Res_CNN10,res_filter14,stride=1,activation='relu',name='CNN8')
            #rint ("CNN8",np.shape(CNN8))
            Res_CNNR4 = tf.add(Conv_layer(Res_CNN11,res_filter15,stride=1,activation='relu',name='CNNR4'), Res_CNN9, name='Resnet4')
            #print("res_CNN4",np.shape(Res_CNNR4))
            #rint ("CNNR4",np.shape(CNNR4))
            pool4 = tf.nn.max_pool3d(Res_CNNR4,ksize=[1,2,2,2,1],strides=[1,2,2,2,1],padding='VALID',name='POOL4')
        '''preparing resnet for inception net'''
        resnet_prep = Conv_layer(pool4,res_prep_filter,stride=1,activation='relu',name='res_prep')
        #print("res_prep",np.shape(resnet_prep))
        

    with tf.name_scope("Flat_Layer"):
        concat_res_incep = tf.concat([resnet_prep,ineption_prep],axis=4,name='CONCAT1')
        capturing_depth1 = Conv_layer(concat_res_incep,combo_filter1,stride=1,activation='relu',name='capturing_depth1')
        #print("capturing_depth1",np.shape(capturing_depth1))
    
    with tf.name_scope("UP_LAYER"):
        with tf.name_scope("UP_block1"):
            DCNN1= Deconv_layer(capturing_depth1,UP_filter1,stride=2,activation='relu',name='DE_CONV1')
            #print ("DCNN1",np.shape(DCNN1))
            concat_res_incep_decov1 = tf.concat([Res_CNNR4,Inception_block4_out,DCNN1],axis=4,name='CONCAT2')
            UP_CNN1 = Conv_layer(concat_res_incep_decov1,UP_filter2,stride=1,activation='relu',name='upCNN1')
            #print ("CNN11",np.shape(CNN11))
            UP_CNN2 = Conv_layer(UP_CNN1,UP_filter3,stride=1,activation='relu',name='upCNN2')
            #print ("UP_block1",np.shape(UP_CNN2))
        with tf.name_scope("UP_block2"):
            DCNN2= Deconv_layer(UP_CNN2,UP_filter4,stride=2,activation='relu',name='DE_CONV2')
            #print ("DCNN1",np.shape(DCNN1))
            concat_res_incep_decov2 = tf.concat([Res_CNNR3,Inception_block3_out,DCNN2],axis=4,name='CONCAT3')
            UP_CNN3 = Conv_layer(concat_res_incep_decov2,UP_filter5,stride=1,activation='relu',name='upCNN3')
            #print ("CNN11",np.shape(CNN11))
            UP_CNN4 = Conv_layer(UP_CNN3,UP_filter6,stride=1,activation='relu',name='upCNN4')
            #print ("UP_block2",np.shape(UP_CNN4))
        with tf.name_scope("UP_block3"):
            DCNN3= Deconv_layer(UP_CNN4,UP_filter7,stride=2,activation='relu',name='DE_CONV3')
            #print ("DCNN1",np.shape(DCNN1))
            concat_res_incep_decov3 = tf.concat([Res_CNNR2,Inception_block2_out,DCNN3],axis=4,name='CONCAT4')
            UP_CNN5 = Conv_layer(concat_res_incep_decov3,UP_filter8,stride=1,activation='relu',name='upCNN5')
            #print ("CNN11",np.shape(CNN11))
            UP_CNN6 = Conv_layer(UP_CNN5,UP_filter9,stride=1,activation='relu',name='upCNN6')
            #print ("UP_block3",np.shape(UP_CNN6))
        with tf.name_scope("UP_block4"):
            DCNN4= Deconv_layer(UP_CNN6,UP_filter10,stride=2,activation='relu',name='DE_CONV4')
            #print ("DCNN1",np.shape(DCNN1))
            concat_res_incep_decov4 = tf.concat([Res_CNNR1,Inception_block1_out,DCNN4],axis=4,name='CONCAT5')
            UP_CNN7 = Conv_layer(concat_res_incep_decov4,UP_filter11,stride=1,activation='relu',name='upCNN7')
            #print ("CNN11",np.shape(CNN11))
            UP_CNN8 = Conv_layer(UP_CNN7,UP_filter12,stride=1,activation='softmax',name='upCNN8')
            #print ("UP_block4",np.shape(UP_CNN8))
    return(UP_CNN8)
        
        



In [0]:
def train_unet(learning_rate =0.0001,n_epochs = 100):
    '''Function to Train U-Net'''
    prediction = predict_model1(X)#logits
    '''----------------------------------------------------------------------------------------------------------------------------------------------------------------'''
    with tf.name_scope("LOSS_FUNCTION"):
        '''using multi dimensional dice'''
        hot_y = hot_encode(y)
        dice = 1+ dice_coef_multilabel(hot_y,prediction)     #dice loss for verison 1
        #dice = 1 + generalized_dice_coeff(hot_y[:,:,:,:,:,0], prediction)
    '''----------------------------------------------------------------------------------------------------------------------------------------------------------------'''
    with tf.name_scope("COST_FUNCTION"):
        '''Cost function''''''Remember to change max to min min to mx depending on loss function'''
        loss = tf.reduce_mean(dice, name="loss")

    with tf.name_scope("OPTIMIZER"):
        '''Optimizer'''
        optimize = tf.train.AdamOptimizer(learning_rate = learning_rate)
        training_output = optimize.minimize(loss)


    '''initializing'''
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()

    '''Timing'''
    start = time.time()
    '''Session'''
    with tf.Session() as sess:
        init.run()
        for epoch in range(n_epochs):
            current_batch_no = 0
            permute_mat = 0
            iteration = 0
            while(1):
                epoch_x,epoch_y,current_batch_no,permute_mat,last = random_h5py_batch(current_batch_no,permute_mat)
                dice_val,sess_results = sess.run([dice,training_output], feed_dict={X: epoch_x, y: epoch_y})
                #print ("epoch",epoch+1,"batch",iteration+1)#,"Cost",sess_results[0])
            
                '''DICE Coefficient for iteration'''
                if iteration%10==0:
                    acc_train = 1-(dice.eval(feed_dict={X: epoch_x, y: epoch_y}))
                    test_images, test_labels = test_batch()
                    acc_test = 1-(dice.eval(feed_dict={X: test_images, y: test_labels}))
                    print("Minibatch at","Epoch", epoch+1,"batch",iteration+1, "Train accuracy:", acc_train, "Test accuracy:", acc_test)
                if last ==1:
                    break
                iteration +=1
            test_images, test_labels = test_batch()
            acc_test = 1-dice.eval(feed_dict={X: test_images, y: test_labels})
            print("-------------------------------------------------------------------------------------------------------")
            print("After Epoch", epoch+1, "Test accuracy:", acc_test)
            print("-------------------------------------------------------------------------------------------------------")
        
        
            if epoch % 1 == 0:
                test_example =   test_images
                test_example_gt = test_labels#np.rollaxis(test_labels,2,0)
                sess_results = sess.run(prediction,feed_dict={X:test_example})

                sess_results = sess_results[0,100,:,:,1] + (2*sess_results[0,100,:,:,2]) + (3*sess_results[0,100,:,:,3])
                test_example = test_example[0,100,:,:,3]
                test_example_gt = test_example_gt[0,100,:,:,:]
                
                plt.figure()
                plt.imshow(np.squeeze(test_example),cmap='gray')
                plt.axis('off')
                plt.title('Original Image')
                plt.savefig('/content/gdrive/My Drive/Brain_Tumour_segmentation/Ynet/YNET_without_FOCAL_LOSS/'+str(epoch)+"a_Original_Image.png")
                 
                plt.figure()
                plt.imshow(np.squeeze(test_example_gt),cmap='gray')
                plt.axis('off')
                plt.title('Ground Truth Mask')
                plt.savefig('/content/gdrive/My Drive/Brain_Tumour_segmentation/Ynet/YNET_without_FOCAL_LOSS/'+str(epoch)+"b_Original_Mask.png")

                plt.figure()
                plt.imshow(np.squeeze(sess_results),cmap='gray')
                plt.axis('off')
                plt.title('Generated Mask')
                plt.savefig('/content/gdrive/My Drive/Brain_Tumour_segmentation/Ynet/YNET_without_FOCAL_LOSS/'+str(epoch)+"c_Generated_Mask.png")

                plt.close('all')

        '''Saving the graph'''
        save_path = saver.save(sess, "/content/gdrive/My Drive/Brain_Tumour_segmentation/Ynet/final_madel_graph_model1")
    end = time.time()
    total_time = end-start
    return (total_time)

In [0]:
train_unet()