In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=False)
%cd 'drive/My Drive'
%cd 'Vol2SegGAN' ## project path 


Mounted at /content/drive
/content/drive/My Drive
/content/drive/My Drive/IBSR/3DUnet_GAN_MIPS_Patch/15_Github_Vol2SegGAN


In [None]:

!pip install tensorflow-gpu==2.1
!pip install tensorflow-addons==0.9.1
!pip install time
!pip install SimpleITK
!pip install natsort

In [None]:

import tensorflow.compat.v1 as tf
from random import shuffle
from natsort import natsorted, ns
from data_load import DataGenerator
import glob
import os
import config
FLAGS = tf.app.flags.FLAGS


def data(data_selected):

    path = FLAGS.data_dir+data_selected+"/"
        
    data_list = natsorted(os.listdir(path), alg=ns.PATH | ns.IGNORECASE)
    shuffle(data_list)
    
    path=path+'*/'
    t1_list = sorted(glob.glob(path+'*t1_strip_registration.nii.gz'))
    seg_list = sorted(glob.glob(path+'*segm_registration_round_class.nii.gz'))#sorted(glob.glob(path+'*segm.nii.gz'))
    
    print(t1_list)
    veri=[]
    for i in data_list:
        i=int(i)-1
        veri.append([t1_list[i], seg_list[i]])
    return veri

sets = {'train': [], 'valid': [], 'test': []}
sets['train']=data(data_selected='train')
train_gen = DataGenerator(sets['train'],batch_size=FLAGS.batch_size)


In [None]:
from tensorflow.keras.models import Model,load_model,model_from_json
from tensorflow.keras.layers import Input, Activation
from tensorflow.keras.optimizers import Adam
#import tensorflow.keras.backend as K
import os
import numpy as np
from utils2 import  get_dsc
import tensorflow.compat.v1 as tf
import time
from model import Vol2SegGAN
from random import randint

class_weights = np.load('class_weights.npy')

class GAN_train():
    def __init__(self, img_shape, seg_shape, class_weights, Nfilter_start=32, depth=4, batch_size=1, LAMBDA=5):
        self.img_shape = img_shape
        self.seg_shape = seg_shape
        self.class_weights = class_weights
        self.Nfilter_start = Nfilter_start
        self.depth = depth
        self.batch_size = batch_size
        self.LAMBDA = LAMBDA
        self.path = FLAGS.save_model

        def diceLoss(y_true, y_pred, w=self.class_weights):
            y_true = tf.convert_to_tensor(y_true, 'float32')
            y_pred = tf.convert_to_tensor(y_pred, y_true.dtype)
            num = tf.math.reduce_sum(tf.math.multiply(w, tf.math.reduce_sum(tf.math.multiply(y_true, y_pred), axis=[0,1,2,3])))
            den = tf.math.reduce_sum(tf.math.multiply(w, tf.math.reduce_sum(tf.math.add(y_true, y_pred), axis=[0,1,2,3])))+1e-5
            return 1-2*num/den

         
        if os.path.exists(self.path)==False:
            os.mkdir(self.path)               

        if (os.path.exists(self.path+'generator.h5')):          

          import json
          #model_json= self.generator.to_json()
          with open(self.path+'generator.json', "r") as f:
            model_json = json.load(f)

          self.generator=model_from_json(model_json)
          self.generator.load_weights(self.path+'generator.h5')
                
          with open(self.path+'discriminator.json', "r") as f:
            model_discriminator = json.load(f)
            
          self.discriminator=model_from_json(model_discriminator)
          self.discriminator.load_weights(self.path+'discriminator.h5')
          self.discriminator.compile(loss='mse', optimizer=Adam(1e-4, beta_1=0.5), metrics=['accuracy'])
          
          with open(self.path+'Vol2SegGAN.json', "r") as f:
            model_combined = json.load(f)

          self.combined=model_from_json(model_combined)
          self.combined.load_weights(self.path+'Vol2SegGAN.h5')
          self.combined.compile(loss=['mse', diceLoss], optimizer=Adam(1e-4, beta_1=0.5))
          print('loaded')         

        else:
          # Build and compile the discriminator
          gan = Vol2SegGAN(self.img_shape, self.seg_shape, Nfilter_start=self.Nfilter_start,depth=self.depth)
          self.discriminator = gan.Discriminator()
          self.discriminator.compile(loss='mse', optimizer=Adam(1e-4, beta_1=0.5), metrics=['accuracy'])
          self.discriminator.summary()

          # Build the generator
          self.generator = gan.Generator()
          self.generator.summary()
          # Input images and their conditioning images
          seg = Input(shape=self.seg_shape)
          img = Input(shape=self.img_shape)

          # By conditioning on B generate a fake version of A
          seg_pred = self.generator(img)

          # For the combined model we will only train the generator
          self.discriminator.trainable = False

          # Discriminators determines validity of translated images / condition pairs
          # print(seg_pred.shape)
          valid = self.discriminator([seg_pred, img])

          self.combined = Model(inputs=[seg, img], outputs=[valid, seg_pred])
          self.combined.compile(loss=['mse', diceLoss], loss_weights=[1, self.LAMBDA], optimizer=Adam(1e-4, beta_1=0.5))
          print('New Model')

    def train_step(self, Xbatch, Ybatch):
        # Generetor output
        gen_output = self.generator.predict(Xbatch, use_multiprocessing=True, workers=16)
        
        # Discriminator output shape    
        disc_output_shape = self.discriminator.output_shape
        disc_output_shape = (gen_output.shape[0], *disc_output_shape[1:])
        
        # Train Discriminator
        disc_loss_real = self.discriminator.fit([Ybatch, Xbatch], tf.ones(disc_output_shape), verbose=0, use_multiprocessing=True, workers=16)
        disc_loss_fake = self.discriminator.fit([gen_output, Xbatch], tf.zeros(disc_output_shape), verbose=0, use_multiprocessing=True, workers=16)

        # Train Generator
        gen_loss = self.combined.fit([Ybatch, Xbatch], [tf.ones(disc_output_shape), Ybatch], verbose=0, use_multiprocessing=True, workers=16)        
        dsc = get_dsc(labels=Ybatch,predictions=gen_output)        
        return gen_loss,dsc
    

    def save_model(self,path):
        import json

        model_generator= self.generator.to_json()
        with open((path+"generator.json"), "w") as json_file:
            json.dump(model_generator, json_file)
        self.generator.save_weights(path+"generator.h5")
        #self.generator.save((path+"/generator.model")) 
        
        model_discriminator= self.discriminator.to_json()
        with open((path+"discriminator.json"), "w") as json_file:
            json.dump(model_discriminator, json_file)
        self.discriminator.save_weights(path+"discriminator.h5")
        #self.discriminator.save((path+"/discriminator.model"))  

        model_combined= self.combined.to_json()
        with open((path+"Vol2SegGAN.json"), "w") as json_file:
            json.dump(model_combined, json_file)
        self.combined.save_weights(path+"Vol2SegGAN.h5")        
        #self.combined.save((path+"/Vol2SegGAN.model")) 

    
    def tic(self):
        global _start_time
        _start_time = time.time()

    def toc(self):
        t_sec = round(time.time() - _start_time)
        (t_min, t_sec) = divmod(t_sec,60)
        (t_hour,t_min) = divmod(t_min,60)
        sure='Time: {}saat:{}dk:{}sn'.format(t_hour,t_min,t_sec)
        print(sure)   

    def run_toc(self,start_time):
      t_sec = round(time.time() - start_time)
      (t_min1, t_sec) = divmod(t_sec,60)
      (t_hour,t_min) = divmod(t_min1,60)
      return t_min1

    def train(self):
        print('Training process:')
        PERIOD_OF_TIME = time.time()
        start_time=time.time() 
     
        max_steps=FLAGS.max_steps
        self.tic()
        for itr in range(1,FLAGS.max_steps): 
          train_gen = DataGenerator(sets['train'], batch_size=self.batch_size)     
          for Xbatch, Ybatch in train_gen:
              gan_losses,s_dsc = self.train_step(Xbatch, Ybatch)
              train_status = '- step: {}/{} : loss: {:0.4f} - bgr:{:0.3f} csf:{:0.3f} gm:{:0.3f} wm:{:0.3f} '
              txt=(train_status.format(itr, FLAGS.max_steps,gan_losses.history['loss'][0], s_dsc[0], s_dsc[1],s_dsc[2],s_dsc[3]))      
              print(txt)
          gecen_sure=self.run_toc(PERIOD_OF_TIME)
          if (itr)%FLAGS.steps_to_save_checkpoint == 0: 
            ## Save_model
            self.save_model(path=self.path)
            print("Save checkpoint")     
            print(itr)
        self.toc()       
        return trends_train, trends_valid
patch_size = list(map(int, FLAGS.patch_size.split(",")))
px =patch_size[0]
py =patch_size[1]
pz =patch_size[2] 

  
imShape = (px, py, pz, 1) 
gtShape = (px, py, pz, 4)

gan = GAN_train(imShape, gtShape, class_weights, Nfilter_start=FLAGS.Nfilter_start,depth=FLAGS.depth, batch_size=FLAGS.batch_size, LAMBDA=FLAGS.LAMBDA)

trends_train, trends_valid = gan.train()
