In [None]:
!pip install tensorflow-addons==0.8.3
!pip install elasticdeform
!pip install time
!pip install SimpleITK
!pip install natsort

In [None]:
import numpy as np
import tensorflow as tf
import nibabel as nib
import glob
import time
from tensorflow.keras.utils import to_categorical

from sys import stdout
import matplotlib.pyplot as plt
import matplotlib.image as mpim
import elasticdeform as ed
from scipy.ndimage.interpolation import affine_transform
import concurrent.futures

from tensorflow.keras.preprocessing.image import apply_affine_transform


from random import shuffle

import os

from natsort import natsorted, ns

Nclasses = 4
classes = np.arange(Nclasses)
root_data='../data/iseg-2017-20199_duzenlenmis/iSeg-2017+2019/'
path_file='iseg_data_2017+2019_t1+t2_Dice_Loss_Zskore_Norm'
path_cv = "results/cv5/"+str(path_file)
if os.path.exists(path_cv)==False:
    os.mkdir(path_cv)
path_cv = "test_Islemleri/data/Output/"+str(path_file)
if os.path.exists(path_cv)==False:
    os.mkdir(path_cv)
path_cv = "loss/"+str(path_file)
if os.path.exists(path_cv)==False:
    os.mkdir(path_cv)


def data(data_selected):
    path = root_data+data_selected+"/"
    data_list = natsorted(os.listdir(path), alg=ns.PATH | ns.IGNORECASE)
    shuffle(data_list)
    path=path+'*/'
    # images lists
    t1_list = sorted(glob.glob(path+'*t1.hdr')) #'*t1.nii.gz'
    t2_list = sorted(glob.glob(path+'*t2.hdr'))
    seg_list = sorted(glob.glob(path+'*segm.hdr'))
    veri=[]
    for i in data_list:
        i=int(i)-1
        veri.append([t1_list[i], t2_list[i], seg_list[i]])
    return veri

sets = {'train': [], 'valid': [], 'test': []}
sets['train']=data(data_selected='train')
print('Data Loaded...')

def MinMax_Normalize(modality):
    X = modality
    brain = X
    brain_norm=X
    if ((np.max(brain)-np.min(brain))!=0):
      brain_norm = (brain - np.min(brain))/(np.max(brain)-np.min(brain))
    return brain_norm

def Zskore_Normalize(modality):
    X = modality
    brain = X[X!=0]
    brain_norm = np.zeros_like(X) # background at -100
    if (np.std(brain)!=0):
      norm = (brain - np.mean(brain))/np.std(brain)
      brain_norm[X!=0] = norm
    else:
      brain_norm=X
    return brain_norm

def load_img(img_files):
    ''' Load one image and its target form file
    '''
    N = len(img_files)
    # target
    y = nib.load(img_files[N-1]).get_fdata(dtype='float32')
    y=y.reshape(y.shape[0], y.shape[1],y.shape[2])
    y[y==10]=1
    y[y==150]=2
    y[y==250]=3
    X_norm = np.empty((y.shape[0], y.shape[1],y.shape[2], 2))
    y = y[6:134,30:174,80:208]#[0:256,0:112,0:256]
    for channel in range(N-1):
        X = nib.load(img_files[channel]).get_fdata(dtype='float32')
        X=X.reshape(X.shape[0], X.shape[1],X.shape[2])
        X_norm[:,:,:,channel] = Zskore_Normalize(X)#brain_norm

    X_norm = X_norm[6:134,30:174,80:208,:]
    return X_norm, y

class DataGenerator(tf.keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, list_IDs, batch_size=1, dim=(128,144,128), n_channels=1, n_classes=4, shuffle=True,patch_size=64, n_patches=8):
        'Initialization'
        self.list_IDs = list_IDs
        self.batch_size = batch_size
        self.dim = dim
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.augmentation = augmentation
        self.patch_size = patch_size
        self.n_patches = n_patches
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.ceil(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Find list of IDs
        list_IDs_temp = [self.list_IDs[k] for k in indexes]

        # Generate data
        X, y = self.__data_generation(list_IDs_temp)

        if index == self.__len__()-1:
            self.on_epoch_end()

        return X, y

    def on_epoch_end(self):
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __data_generation(self, list_IDs_temp):
        'Generates data containing batch_size samples'
        X = np.empty((self.batch_size, *self.dim, self.n_channels))
        y = np.empty((self.batch_size, *self.dim))
        for i, IDs in enumerate(list_IDs_temp):
            X[i], y[i] = load_img(IDs)
        if self.augmentation == True:
            return X.astype('float32'), y
        else:
            return X.astype('float32'), to_categorical(y, self.n_classes)


train_gen = DataGenerator(sets['train'],batch_size=1, dim=(128,144,128), n_channels=2,n_classes=4,patch_size=1, n_patches=1)


['../data/iseg-2017-2019_duzenlenmis/iSeg-2017+2019/train/1/t1.hdr', '../data/iseg-2017-2019_duzenlenmis/iSeg-2017+2019/train/10/t1.hdr', '../data/iseg-2017-2019_duzenlenmis/iSeg-2017+2019/train/11/t1.hdr', '../data/iseg-2017-2019_duzenlenmis/iSeg-2017+2019/train/12/t1.hdr', '../data/iseg-2017-2019_duzenlenmis/iSeg-2017+2019/train/13/t1.hdr', '../data/iseg-2017-2019_duzenlenmis/iSeg-2017+2019/train/14/t1.hdr', '../data/iseg-2017-2019_duzenlenmis/iSeg-2017+2019/train/15/t1.hdr', '../data/iseg-2017-2019_duzenlenmis/iSeg-2017+2019/train/16/t1.hdr', '../data/iseg-2017-2019_duzenlenmis/iSeg-2017+2019/train/2/t1.hdr', '../data/iseg-2017-2019_duzenlenmis/iSeg-2017+2019/train/3/t1.hdr', '../data/iseg-2017-2019_duzenlenmis/iSeg-2017+2019/train/4/t1.hdr', '../data/iseg-2017-2019_duzenlenmis/iSeg-2017+2019/train/5/t1.hdr', '../data/iseg-2017-2019_duzenlenmis/iSeg-2017+2019/train/6/t1.hdr', '../data/iseg-2017-2019_duzenlenmis/iSeg-2017+2019/train/7/t1.hdr', '../data/iseg-2017-2019_duzenlenmis/iSeg

'\na=0\nfor Xbatch, Ybatch in train_gen:\n    print(Xbatch.shape)\n    print(Ybatch.shape)\n    a=a+1\nplt.imshow(Xbatch[1,:,:,30,0])    \nprint(a)\n'

In [None]:
from tensorflow.keras.models import Model,load_model,model_from_json
from tensorflow.keras.layers import Input, Conv3D,UpSampling3D, Conv3DTranspose, Dropout,Activation, ReLU, LeakyReLU, Concatenate,BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import backend as K
import os
from tensorflow_addons.layers import InstanceNormalization
import numpy as np
from utils2 import  get_dsc
import tensorflow.compat.v1 as tf
import config


from random import randint
root_1= ''
#root_data= ''

class iSegWNet():
    def __init__(self, img_shape, seg_shape, Nclasses=4, Nfilter_start=8, depth=3,LAMBDA=5):
        self.img_shape = img_shape
        self.seg_shape = seg_shape
        self.Nfilter_start = Nfilter_start
        self.depth = depth
        self.Nclasses = Nclasses
        self.LAMBDA = LAMBDA

        def diceLoss(y_true, y_pred):
            y_true = tf.convert_to_tensor(y_true, 'float32')
            y_pred = tf.convert_to_tensor(y_pred, y_true.dtype)
            num = tf.math.reduce_sum(tf.math.multiply(y_true, y_pred), axis=[0,1,2,3])
            den = tf.math.reduce_sum(tf.math.add(y_true, y_pred), axis=[0,1,2,3])+1e-5
            return 1-2*num/den



        self.path = root_1+"results/cv5/"+str(path_file) # './Results_mri2seg_128_aug_lambda{}'.format(self.LAMBDA)
        if os.path.exists(self.path)==False:
            os.mkdir(self.path)
        if (os.path.exists(self.path+'/generator.h5')):
          import json
          with open(self.path+"/generator.json", "r") as f:
            model_json = json.load(f)

          self.generator=model_from_json(model_json)
          self.generator.load_weights(self.path+"/generator.h5")
          self.generator.compile(loss=[diceLoss], optimizer=Adam(1e-4),metrics=['accuracy'])
        else:
          self.generator = self.WNet()
          self.generator.summary()
          self.generator.compile(loss=[diceLoss], optimizer=Adam(1e-4),metrics=['accuracy'])
    def WNet(self):
        inputs = Input(self.img_shape, name='input_image')
        def encoder_step(layer, Nf, inorm=True):
            Nf_1=Nf*2
            x = Conv3D(Nf, kernel_size=3, strides=2, kernel_initializer='he_normal', padding='same')(layer)
            #if inorm:
            x = InstanceNormalization()(x)
            x = LeakyReLU()(x)
            x=Dropout(0.2)(x)
            x = Conv3D(Nf_1, kernel_size=3,kernel_initializer='he_normal', padding='same')(x)
            x = InstanceNormalization()(x)
            x = LeakyReLU()(x)
            x=Dropout(0.2)(x)
            return x
        def bottlenek(layer, Nf):
            x = Conv3D(Nf, kernel_size=5, strides=2, kernel_initializer='he_normal', padding='same')(layer)
            x = InstanceNormalization()(x)
            x = LeakyReLU()(x)
            for i in range(4):
                y = Conv3D(Nf, kernel_size=5, strides=1, kernel_initializer='he_normal', padding='same')(x)
                x = InstanceNormalization()(y)
                x = LeakyReLU()(x)
                x = Concatenate()([x, y])
            return x

        def decoder_step(layer, layer_to_concatenate, Nf):
            x = Conv3DTranspose(Nf, kernel_size=5, strides=2, padding='same', kernel_initializer='he_normal')(layer)
            x = InstanceNormalization()(x)
            x = LeakyReLU()(x)
            x = Concatenate()([x, layer_to_concatenate])
            x = Dropout(0.2)(x)

            x = Conv3D(Nf, kernel_size=3,kernel_initializer='he_normal', padding='same')(x)
            x = InstanceNormalization()(x)
            x = LeakyReLU()(x)
            x=Dropout(0.2)(x)
            return x

        x = inputs
        # encoder
        for wnet in range(2):
          layers_to_concatenate = []
          for d in range(self.depth-2):
              if d==0:
                  x = encoder_step(x, self.Nfilter_start*np.power(2,d), False)
              else:
                  x = encoder_step(x, self.Nfilter_start*np.power(2,d))
              layers_to_concatenate.append(x)
          # bottlenek
          x = bottlenek(x, self.Nfilter_start*np.power(2,self.depth-2))
          # decoder
          for d in  range(self.depth-3, -1, -1):
              x = decoder_step(x, layers_to_concatenate.pop(), self.Nfilter_start*np.power(2,d))
          # classifier
          if (wnet==0):
            x = Conv3DTranspose(filters=32, kernel_size=3, strides=2, padding='same', kernel_initializer='he_normal')(x) #filters=32
            x = InstanceNormalization()(x)
            x = LeakyReLU()(x)
          else:
            last = Conv3DTranspose(filters=self.Nclasses, kernel_size=3, strides=2, padding='same', activation='softmax', name='output2')(x)
          # Create model
        return Model(inputs=inputs, outputs=last, name='3D_Wnet')


    def train_step(self, Xbatch, Ybatch, mp=True, n_workers=16):
        # Generetor output

        gen_loss = self.generator.fit(Xbatch, Ybatch,verbose=0)
        gen_output = self.generator.predict(Xbatch)
        dsc = get_dsc(labels=Ybatch,predictions=gen_output)
        return gen_loss,dsc

    def valid_step(self, Xbatch, Ybatch, mp=True, n_workers=16):
        gen_loss = self.generator.evaluate(Xbatch, Ybatch, verbose=0)
        return gen_loss
    def save_model(self,path,durum='',itr=1):
        import json
        model_generator= self.generator.to_json()
        with open((path+"/{}generator.json").format(durum), "w") as json_file:
            json.dump(model_generator, json_file)
        self.generator.save_weights(path+"/{}generator.h5".format(durum))
        with open((path+"/{}_generator.json").format(itr), "w") as json_file:
            json.dump(model_generator, json_file)
        self.generator.save_weights(path+"/{}_generator.h5".format(itr))
        return 5

    def run_toc(self,start_time):
      t_sec = round(time.time() - start_time)
      (t_min1, t_sec) = divmod(t_sec,60)
      (t_hour,t_min) = divmod(t_min1,60)
      return t_min1

    def train(self,patch_size,train_time,nEpochs):
        print('Training process:')
        trends_train = tf.keras.callbacks.History()
        trends_train.epoch = []

        trends_valid = tf.keras.callbacks.History()
        trends_valid.epoch = []

       start_step=int(epoch_start)
        PERIOD_OF_TIME = time.time()
        start_time=time.time()
        cnt=0
        max_steps=2000
        for itr in range(start_step+1, max_steps+1):#FLAGS.max_steps):
          train_gen = DataGenerator(sets['train'], augmentation=False,batch_size=1,n_channels=2, patch_size=1, n_patches=1)
          for Xbatch, Ybatch in train_gen:
              gan_losses,s_dsc = self.train_step(Xbatch, Ybatch)
              gan_losses.history['loss'][0] *= self.LAMBDA
              train_status = '- step: {}/{} : loss: {:0.4f} - bgr:{:0.3f} csf:{:0.3f} gm:{:0.3f} wm:{:0.3f} '#+ \
              txt=(train_status.format(itr, max_steps,gan_losses.history['loss'][0], s_dsc[0], s_dsc[1],s_dsc[2],s_dsc[3] #,s_dsc[4], #s_dsc[5], s_dsc[6], s_dsc[7], s_dsc[8], s_dsc[9], s_dsc[10],
                  ))
              print(txt)

          with open(root_1+"loss/"+path_file+"/train_loss.txt", "a+") as f:
            f.write(txt+"\n")

          gecen_sure=self.run_toc(PERIOD_OF_TIME)
          if (gecen_sure>=60):
            PERIOD_OF_TIME = time.time()
            cnt+=60
            dd=self.save_model(path=self.path,durum='',itr=cnt)
            with open(root_1+"loss/"+path_file+"/epoch.txt", "w") as f:
              f.write(str(itr))
            print("Save checkpoint")
            print(itr)
          if (cnt==train_time*60):
            break

        np.save(self.path + '/history_train', trends_train.history)
        np.save(self.path + '/history_valid', trends_valid.history)
""        return trends_train, trends_valid

patch_size =[128,144,128] #list(map(int, FLAGS.patch_size.split(",")))
pz = patch_size[0]
py = patch_size[1]
px = patch_size[2]

imShape = (pz, py, px, 2)
gtShape = (pz, py, px, 4)
iSegWNet = iSegWNet(imShape, gtShape,class_weights,Nfilter_start=64, depth=4)

trends_train, trends_valid = iSegWNet.train(patch_size,3,300)