In [None]:
import sys

import os
import time
import random
import numpy as np
import scipy, multiprocessing
import tensorflow as tf
import tensorlayer as tl
from model import get_sr_G, get_dx_G, get_D, SRGAN_d2, get_patch_D, cycle_G
from config import config


import os
import dicom_to_numpy as dtn
import h5py
import numpy as np
import matplotlib.pyplot as plt

from gans import train

**1. Set training hyper-parameters**

In [None]:
###====================== HYPER-PARAMETERS ===========================###
## Adam
batch_size = config.TRAIN.batch_size
batch_size = 4
lr_init = config.TRAIN.lr_init
beta1 = config.TRAIN.beta1
## initialize G
n_epoch_init = config.TRAIN.n_epoch_init
## adversarial learning (SRGAN)
n_epoch = config.TRAIN.n_epoch
lr_decay = config.TRAIN.lr_decay
decay_every = config.TRAIN.decay_every
shuffle_buffer_size = 128


# create folders to save result images and trained models
save_dir = "samples"
tl.files.exists_or_mkdir(save_dir)
checkpoint_dir = "models"
tl.files.exists_or_mkdir(checkpoint_dir)

**2. Load and preprocess training DRRs and X-Rays**

In [None]:
load_dir = ''
drr_imgs = dtn.load_data(load_dir + 'drrs.hdf5', 'drrs')
xray_imgs = dtn.load_data(load_dir + 'xrays.hdf5', 'xrays')


drr_imgs = drr_imgs[:360, 128:, 128:-128]
xray_imgs = xray_imgs[:360, 128:, 128:-128]


c = 1/np.log(1 + np.max(drr_imgs, axis = (1,2)))
drr_imgs = np.log(drr_imgs+1)
drr_imgs = np.multiply(c[..., np.newaxis, np.newaxis], drr_imgs)
drr_imgs = (drr_imgs - np.min(drr_imgs))/(np.max(drr_imgs)-np.min(drr_imgs))
drr_imgs = drr_imgs*2 - 1

c = 1/np.log(1 + np.max(xray_imgs, axis = (1,2)))
xray_imgs = np.log(xray_imgs+1)
xray_imgs = np.multiply(c[..., np.newaxis, np.newaxis], xray_imgs)
xray_imgs = (xray_imgs - np.min(xray_imgs))/(np.max(xray_imgs)-np.min(xray_imgs))
xray_imgs = xray_imgs*2 - 1
xray_imgs = gamma_transform(xray_imgs, 5)



drr_imgs = drr_imgs[..., np.newaxis]
xray_imgs = xray_imgs[..., np.newaxis]

**3. Train SRGAN Generator alone for initialization.**

In [None]:

#---------------------------------INITIALIZE G_SR------------------------------------
           
    ## initialize learning G_sr
    G_sr.train() 
    
    n_step_epoch = round(n_epoch_init // batch_size)
    for epoch in range(n_epoch_init):       
        for step, (drr_lr_patchs, drr_hr_patchs, _, _) in enumerate(train_sr):           
            if drr_lr_patchs.shape[0] != batch_size: # if the remaining data in this epoch < batch_size
                break
            step_time = time.time()
            with tf.GradientTape(persistent=True) as tape:
                fake_drr_hr_patchs = G_sr(drr_lr_patchs)
                #fake_xray_hr_patchs = G_sr(xray_lr_patchs)
                drr_mse_loss_sr = tl.cost.absolute_difference_error(fake_drr_hr_patchs, drr_hr_patchs, is_mean=True)
                #xray_mse_loss_sr = tl.cost.absolute_difference_error(fake_xray_hr_patchs, xray_hr_patchs, is_mean=True)
                mse_loss_sr = drr_mse_loss_sr #+ xray_mse_loss_sr

            grad = tape.gradient(mse_loss_sr, G_sr.trainable_weights)
            g_sr_optimizer_init.apply_gradients(zip(grad, G_sr.trainable_weights))
            print("Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, mse_sr: {:.5f}".format(
                epoch, n_epoch_init, step, n_step_epoch, time.time() - step_time, mse_loss_sr))
        if (epoch != 0) and (epoch % 10 == 0):
            tl.vis.save_images(fake_drr_hr_patchs.numpy(), [1, 2], os.path.join(save_dir, 'train_g_sr_init_{}.png'.format(epoch)))
            G_sr.save_weights(os.path.join(checkpoint_dir, 'g_sr_init.h5'))


**4. Train SRGAN (G + D).**

In [None]:
#------------------------------ G_SR_____D_SR------------------------------------
    
    ## adversarial learning (G_sr, D_sr)
    if tl.files.file_exists(os.path.join(checkpoint_dir, 'd_sr.h5')):
        D_sr.load_weights(os.path.join(checkpoint_dir, 'd_sr.h5'))
    if tl.files.file_exists(os.path.join(checkpoint_dir, 'g_sr.h5')):
        G_sr.load_weights(os.path.join(checkpoint_dir, 'g_sr.h5'))    
   
    D_sr.train()  
    G_sr.train()

    n_step_epoch = round(n_epoch // batch_size)
    for epoch in range(n_epoch):
        for step, (drr_lr_patchs, drr_hr_patchs, _, _) in enumerate(train_sr):
            if drr_lr_patchs.shape[0] != batch_size: # if the remaining data in this epoch < batch_size
                break
            step_time = time.time()
            with tf.GradientTape(persistent=True) as tape:
                # generated/fake data
                fake_drr_hr_patchs = G_sr(drr_lr_patchs)
                
                sr_drr_logits_fake = D_sr(fake_drr_hr_patchs)
                sr_drr_feature_fake = VGG((fake_drr_hr_patchs+1)/2.)
                # ground-truth/real data
                sr_drr_logits_real = D_sr(drr_hr_patchs)
                sr_drr_feature_real = VGG((drr_hr_patchs+1)/2.)
                
                # D loss  
                d_sr_drr_loss1 = mae_criterion(sr_drr_logits_real[-1], tf.ones_like(sr_drr_logits_real[-1]))
                d_sr_drr_loss2 = mae_criterion(sr_drr_logits_fake[-1], tf.zeros_like(sr_drr_logits_fake[-1]))      
              
                d_loss = d_sr_drr_loss1 +  d_sr_drr_loss2
                         #+ d_sr_xray_loss1 +  d_sr_xray_loss2)/2.                           
               
                                          
                # G_sr super resolution loss
                g_sr_drr_gan_loss = 1e-2 * mae_criterion(sr_drr_logits_fake[-1], tf.ones_like(sr_drr_logits_fake[-1]))
                g_sr_drr_ade_loss = tl.cost.absolute_difference_error(fake_drr_hr_patchs, drr_hr_patchs, is_mean=True)
                g_sr_drr_vgg_loss = content_loss(sr_drr_feature_fake, sr_drr_feature_real, lamda=2e-6)

                g_sr_loss = g_sr_drr_gan_loss\
                            + g_sr_drr_ade_loss\
                            + g_sr_drr_vgg_loss
                                   
            
            grad = tape.gradient(d_loss, D_sr.trainable_weights)
            d_optimizer.apply_gradients(zip(grad, D_sr.trainable_weights))           
            grad = tape.gradient(g_sr_loss, G_sr.trainable_weights)
            g_sr_optimizer.apply_gradients(zip(grad, G_sr.trainable_weights))            
            
           
            print("Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, g_sr_loss(ade:{:.5f}, content:{:.5f}, adv:{:.5f}) d_loss:{:.5f}".format(
                epoch, n_epoch, step, n_step_epoch, time.time() - step_time,
                g_sr_drr_ade_loss, 
                g_sr_drr_vgg_loss, 
                g_sr_drr_gan_loss,
                d_loss,
                ))
         
        # update the learning rate
        if epoch > 100 and (epoch % decay_every == 0):
            new_lr_decay = lr_decay**(epoch // decay_every)
            lr_v.assign(lr_init * new_lr_decay)
            log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay)
            print(log)
            
        
        result = np.concatenate([drr_hr_patchs.numpy(),
                                 fake_drr_hr_patchs.numpy(),
                                 tf.image.resize(drr_lr_patchs, size=[size, size], method="bicubic").numpy()], axis=0)

        
        tl.vis.save_images(result, [3, 2], os.path.join(save_dir, 'train_g_sr_{}.png'.format(epoch)))
        G_sr.save_weights(os.path.join(checkpoint_dir, 'g_sr.h5'))
        D_sr.save_weights(os.path.join(checkpoint_dir, 'd_sr.h5'))
            
       
        

**5. Train CycleGAN Generators (G1 and G2) for initialization.**

In [None]:
   
#---------------------------------INITIALIZE G_DX------------------------------------
    
    ## initialize learning G_dx    
    G1.train() 
    G2.train()
    
    n_step_epoch = round(n_epoch_init)
    for epoch in range(n_epoch_init//5):       
        for step, (_, _, drr_mr_patchs, _, _, xray_mr_patchs) in enumerate(train_dx):           
            if drr_mr_patchs.shape[0] != batch_size: # if the remaining data in this epoch < batch_size
                break
            step_time = time.time()
            with tf.GradientTape(persistent=True) as tape:             
                # generated/fake data              
                fake_xray_patchs = G1(drr_mr_patchs) 
                #cycled_drr_patchs = G2(fake_xray_patchs)
                fake_drr_patchs = G2(xray_mr_patchs)  
                #cycled_xray_patchs = G1(fake_drr_patchs) 
               
                xray_ade_loss = tl.cost.mean_squared_error(fake_xray_patchs, xray_mr_patchs, is_mean=True)
                drr_ade_loss = tl.cost.mean_squared_error(fake_drr_patchs, drr_mr_patchs, is_mean=True)
                
                g1_loss = xray_ade_loss #+ cycle_loss
                g2_loss = drr_ade_loss #+ cycle_loss
            
            
            grad = tape.gradient(g1_loss, G1.trainable_weights)
            g1_optimizer_init.apply_gradients(zip(grad, G1.trainable_weights))
            grad = tape.gradient(g2_loss, G2.trainable_weights)
            g2_optimizer_init.apply_gradients(zip(grad, G2.trainable_weights))
            print("Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, g1_ade: {:.5f}, g2_ade: {:.5f}".format(
                epoch, n_epoch_init, step, n_step_epoch, time.time() - step_time,
                xray_ade_loss,
                drr_ade_loss,
                ))            
        
        tl.vis.save_images(fake_xray_patchs.numpy(), [2, 3], os.path.join(save_dir, 'train_g1_dx_init_{}.png'.format(epoch)))
        tl.vis.save_images(fake_drr_patchs.numpy(), [2, 3], os.path.join(save_dir, 'train_g2_dx_init_{}.png'.format(epoch)))
        G1.save_weights(os.path.join(checkpoint_dir, 'g1_dx_init.h5'))
        G2.save_weights(os.path.join(checkpoint_dir, 'g2_dx_init.h5'))

 
    

**6. Train CycleGAN (G1 + D1 + G2 + D2).**

In [None]:

#------------------------------ G_DX_____D_DX------------------------------------
               
    ## adversarial learning (G_dx, D_dx)
    if tl.files.file_exists(os.path.join(checkpoint_dir, 'd1_dx.h5')):
        D1.load_weights(os.path.join(checkpoint_dir, 'd1_dx.h5'))
    if tl.files.file_exists(os.path.join(checkpoint_dir, 'd2_dx.h5')):
        D2.load_weights(os.path.join(checkpoint_dir, 'd2_dx.h5'))
    if tl.files.file_exists(os.path.join(checkpoint_dir, 'g1_dx.h5')):
        G1.load_weights(os.path.join(checkpoint_dir, 'g1_dx.h5'))
    if tl.files.file_exists(os.path.join(checkpoint_dir, 'g2_dx.h5')):
        G2.load_weights(os.path.join(checkpoint_dir, 'g2_dx.h5'))
   
    D1.train()
    D2.train()
    G1.train()
    G2.train()
    
    
    n_step_epoch = round(n_epoch // batch_size)
    for epoch in range(n_epoch):  
        for step, (_, _, drr_mr_patchs, _, _, xray_mr_patchs) in enumerate(train_dx):
            if drr_mr_patchs.shape[0] != batch_size: # if the remaining data in this epoch < batch_size
                break
            step_time = time.time()
            
            with tf.GradientTape(persistent=True) as tape:
                # generated/fake data              
                fake_xray_patchs = G1(drr_mr_patchs) 
                cycled_drr_patchs = G2(fake_xray_patchs)               
                fake_drr_patchs = G2(xray_mr_patchs) 
                cycled_xray_patchs = G1(fake_drr_patchs)               
               
                vgg_fake_xray_logits = D1(fake_xray_patchs)
                vgg_fake_drr_logits = D2(fake_drr_patchs)   
                fake_xray_features = VGG((fake_xray_patchs+1)/2.)
                fake_drr_features = VGG((fake_drr_patchs+1)/2.)
                
                vgg_real_xray_logits = D1(xray_mr_patchs)
                vgg_real_drr_logits = D2(drr_mr_patchs)
                real_xray_features = VGG((xray_mr_patchs+1)/2.)
                real_drr_features = VGG((drr_mr_patchs+1)/2.)
                
                # cycle loss
                cycle_drr_loss = 10 * tl.cost.absolute_difference_error(cycled_drr_patchs, drr_mr_patchs, is_mean=True)
                cycle_xray_loss = 10 * tl.cost.absolute_difference_error(cycled_xray_patchs, xray_mr_patchs, is_mean=True)
                cycle_loss = cycle_drr_loss + cycle_xray_loss
                
              
                # D1 loss              
                d1_xray_loss1 = mae_criterion(vgg_real_xray_logits[-1:], tf.ones_like(vgg_real_xray_logits[-1:]))
                d1_xray_loss2 = mae_criterion(vgg_fake_xray_logits[-1:], tf.zeros_like(vgg_fake_xray_logits[-1:]))
                d1_loss = (d1_xray_loss1 + d1_xray_loss2)/2.
                         
                # D2 loss
                d2_drr_loss1 = mae_criterion(vgg_real_drr_logits[-1:], tf.ones_like(vgg_real_drr_logits[-1:]))
                d2_drr_loss2 = mae_criterion(vgg_fake_drr_logits[-1:], tf.zeros_like(vgg_fake_drr_logits[-1:]))
                d2_loss = (d2_drr_loss1 + d2_drr_loss2)/2.
                   
                
                # G1 loss
                g1_gan_loss = mae_criterion(vgg_fake_xray_logits[-1:], tf.ones_like(vgg_fake_xray_logits[-1:]))
                g1_xray_ade_loss = tl.cost.absolute_difference_error(fake_xray_patchs, xray_mr_patchs, is_mean=True)
                g1_content_vgg_loss = content_loss(fake_xray_features, real_xray_features, lamda=1e-6)
                g1_style_vgg_loss = style_loss(fake_xray_features, real_xray_features, lamda=1e-10)
                g1_perceptual_vgg_loss = perceptual_loss(vgg_fake_xray_logits[:-1], vgg_real_xray_logits[:-1], lamda=1e-1)
                g1_loss = g1_gan_loss\
                          + cycle_loss\
                          + g1_xray_ade_loss\
                          + g1_perceptual_vgg_loss\
                          + g1_style_vgg_loss\
                          + g1_content_vgg_loss\
                          
                # G2 loss
                g2_gan_loss = mae_criterion(vgg_fake_drr_logits[-1:], tf.ones_like(vgg_fake_drr_logits[-1:]))
                g2_drr_ade_loss = tl.cost.absolute_difference_error(fake_drr_patchs, drr_mr_patchs, is_mean=True)
                g2_content_vgg_loss = content_loss(fake_drr_features, real_drr_features, lamda=1e-6)
                g2_style_vgg_loss = style_loss(fake_drr_features, real_drr_features, lamda=1e-10)
                g2_perceptual_vgg_loss = perceptual_loss(vgg_fake_drr_logits[:-1], vgg_real_drr_logits[:-1], lamda=1e-1)
                g2_loss = g2_gan_loss\
                          + cycle_loss\
                          + g2_drr_ade_loss\
                          + g2_perceptual_vgg_loss\
                          + g2_style_vgg_loss\
                          + g2_content_vgg_loss\
                
                        
           
            grad = tape.gradient(g1_loss, G1.trainable_weights)
            g1_optimizer.apply_gradients(zip(grad, G1.trainable_weights))
            grad = tape.gradient(g2_loss, G2.trainable_weights)
            g2_optimizer.apply_gradients(zip(grad, G2.trainable_weights))            
            if step%1 == 0:
                grad = tape.gradient(d1_loss, D1.trainable_weights)
                d1_optimizer.apply_gradients(zip(grad, D1.trainable_weights))
                grad = tape.gradient(d2_loss, D2.trainable_weights)
                d2_optimizer.apply_gradients(zip(grad, D2.trainable_weights))            
                        
           
            print("Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, g1_loss(cycle:{:.5f}, ade:{:.5f}, content:{:.5f}, style:{:.5f}, perceptual:{:.5f}, adv:{:.5f}) d1_loss:{:.5f}".format(
                epoch, n_epoch, step, n_step_epoch, time.time() - step_time,               
                 cycle_drr_loss, 
                 g1_xray_ade_loss,
                 g1_content_vgg_loss,
                 g1_style_vgg_loss,
                 g1_perceptual_vgg_loss,
                 g1_gan_loss, 
                 d1_loss))
            print("Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, g2_loss(cycle:{:.5f}, ade:{:.5f}, content:{:.5f}, style:{:.5f}, perceptual:{:.5f}, adv:{:.5f}) d2_loss:{:.5f}".format(
                epoch, n_epoch, step, n_step_epoch, time.time() - step_time,               
                 cycle_xray_loss, 
                 g2_drr_ade_loss,
                 g2_content_vgg_loss,
                 g2_style_vgg_loss,
                 g2_perceptual_vgg_loss,
                 g2_gan_loss,
                 d2_loss))

        # update the learning rate
        if epoch > 100 and (epoch % (decay_every//2) == 0):
            new_lr_decay = lr_decay**(epoch // (decay_every//2))
            lr_v.assign(lr_init * new_lr_decay)
            log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay)
            print(log)

        if (epoch % 1 == 0):
            tl.vis.save_images(fake_xray_patchs.numpy(), [2, 2], os.path.join(save_dir, 'train_g1_xrays_{}.png'.format(epoch)))
            tl.vis.save_images(fake_drr_patchs.numpy(), [2, 2], os.path.join(save_dir, 'train_g2_drrs_{}.png'.format(epoch)))
            G1.save_weights(os.path.join(checkpoint_dir, 'g1_dx.h5'))
            D1.save_weights(os.path.join(checkpoint_dir, 'd1_dx.h5'))
            G2.save_weights(os.path.join(checkpoint_dir, 'g2_dx.h5'))
            D2.save_weights(os.path.join(checkpoint_dir, 'd2_dx.h5'))   
            
            
            
            