In [None]:
# matplotlib plots within notebook
%matplotlib inline

import platform
print("python: "+platform.python_version())


import numpy as np
import matplotlib.pyplot as plt

import time

import os, shutil, sys

sys.path.insert(0, '.')
from DeepAttCorr_lib import GAN_3D_lib as GAN
from DeepAttCorr_lib import data_handling as DH
from DeepAttCorr_lib import file_manage_utils as File_mng

import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.backend as K
print('Using TensorFlow version: '+tf.__version__)

from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())

# Paths and Definitions

### Basics

In [None]:
# Network name
NETWORK_NAME = 'DeepAttCorr_ProGAN_Network'

# Dataset location
DATASET_PATH = './datasets/'

# Checkpoint location
CHECKPOINT_PATH = "./Outputs/"+NETWORK_NAME+"/"

# Path to tensorboard desired output
TENSORBOARD_BASE_PATH = "./TensorBoard_output"
TENSORBOARD_OUT_PATH = os.path.join(TENSORBOARD_BASE_PATH,NETWORK_NAME)
TENSORBOARD_OUT_PATH_TRAIN = os.path.join(TENSORBOARD_OUT_PATH,"train")
TENSORBOARD_OUT_PATH_VALIDATION = os.path.join(TENSORBOARD_OUT_PATH,"test")

# Clear outputs before running
CLEAR_OUTS = True

### Network Parameters

In [None]:
# Imput volume size
voxels_X = 128
voxels_Y = 128
voxels_Z = 32
input_size = (voxels_X,voxels_Y,voxels_Z)

# Number of resulotion compressions to be applies. 
# The lowest resolution (and start of the progressive GAN)
# is at full_resolution/2**LOW_RES_POW
LOW_RES_POW =4



# --------------------------------------------------------------------------------
# ---------------------- GENERATOR -----------------------------------------------
# --------------------------------------------------------------------------------
# Network convolutional channels by resolution level
USE_GEN_net_conv_Channels = [10, 20, 40, 80, 160]
# Convolutional layers by resolution level
USE_GEN_net_conv_Layers = 2
# Use batch normalization for generator training
USE_GEN_BATCH_NORM = False
# Use pixel normalization for generator training
USE_GEN_PIXEL_NORM = True
# Use He scalling of weights for generator
USE_GEN_HE_SCALLING = True
# Wheight initialization standard deviation
USE_GEN_INI_STD = 1.0


# --------------------------------------------------------------------------------
# ---------------------- DISCRIMNATOR --------------------------------------------
# --------------------------------------------------------------------------------
# If true the discriminator recieves the generator output AND input latent space
IS_CONDITIONAL_DISC = True
# Network convolutional channels by resolution level
USE_DISC_net_conv_Channels = [16, 32, 64, 128, 256]
# Use weight norm constraint
USE_DISC_NORM_CONSTRAINT_SCALE = False
# Use mini batch std
USE_DISC_MINI_BATCH_STD = False



### Training

In [None]:
# Dataset full size (without slicing)
DATASET_X_size = 128
DATASET_Y_size = 128
DATASET_Z_size = 256

# Uniform or custom sampling of the input FOV
# If True the input sample is sliced with uniform probability
# If False, the Cumulative Density Function in CDF_PATH will control the sampling
UNIFORM_SAMPLING = False
CDF_PATH = "./DeepAttCorr_lib/cdf_coef.npy"


# Volume mini-batch size
BATCH_SIZE_TRAIN = 4
BUFFER_SIZE_TRAIN = 4
BATCH_SIZE_VALIDATION = 4
BUFFER_SIZE_VALIDATION = 4


# Initial step size
step_size_gen = 0.0001
step_size_disc = 0.0005

# Set training functions
TRAINING_FUNCTION_DISC = GAN.train_support.train_step_discriminator_conditional_3D_GAN_tf
TRAINING_FUNCTION_GEN = GAN.train_support.train_step_generator_conditional_3D_GAN_tf

# Total number of training steps to perform on each
# resolution phase for the full models:
STEPS_PER_PHASE = [1501, 3001, 6001, 8001, 18001]
# and the Fade-In models
STEPS_PER_PHASE_FADEIN = [1, 1001, 2001, 4001, 5001]
# Number of steps per information print
STEPS_PER_PRINT = 100
# Number of steps per plotting of progress
STEPS_PER_PLOTS = 500
# Number of steps per checkpoint
STEPS_PER_SAVE = 500
# Initial number of discriminator training steps
steps_disc_per_gen_ini = 500
# Number of discriminator training steps per generator training steps
steps_disc_per_gen_loop = 5

# Set-up

In [None]:
File_mng.check_create_path('CHECKPOINT_PATH', CHECKPOINT_PATH, clear_folder=CLEAR_OUTS)
File_mng.check_create_path('TENSORBOARD_BASE_PATH', TENSORBOARD_BASE_PATH)
File_mng.check_create_path('TENSORBOARD_OUT_PATH', TENSORBOARD_OUT_PATH, clear_folder=CLEAR_OUTS)

In [None]:
# Load sampling CDF
cdf_coef = [1.0]
if not UNIFORM_SAMPLING:
    cdf_coef = np.load(CDF_PATH)

In [None]:
# Set multip-GPU mirror strategy
strategy = tf.distribute.MirroredStrategy()

### Dataset reader

In [None]:
# Create dataset list
# This list will contain a dataset for each resolution
train_datasets_list = list()
validation_datasets_list = list()

train_datasets_dist_list = list()
validation_datasets_dist_list = list()

for idx_resol in range(LOW_RES_POW):

    voxels_X_low_res = int(voxels_X/(2**(idx_resol)))
    voxels_Y_low_res = int(voxels_Y/(2**(idx_resol)))
    voxels_Z_low_res = int(voxels_Z/(2**(idx_resol)))
    input_size_low_res = (voxels_X_low_res,voxels_Y_low_res,voxels_Z_low_res)
    print('Setting up dataset for resolution: (%d,%d,%d)'%(input_size_low_res[0],
                                                         input_size_low_res[1],
                                                         input_size_low_res[2]))
    
    

    # Dataset reading and agumentation function
    data_size = np.array([int(DATASET_X_size/((2**idx_resol))), 
                          int(DATASET_Y_size/((2**idx_resol))), 
                          int(DATASET_Z_size/((2**idx_resol)))])

    input_size_this = (voxels_X_low_res,voxels_Y_low_res,voxels_Z_low_res)

    train_dataset_name = 'Train_Dataset_%dx%dx%d.tfrecord'%(data_size[0],data_size[1],data_size[2])
    validation_dataset_name = 'Validation_Dataset_%dx%dx%d.tfrecord'%(data_size[0],data_size[1],data_size[2])
    print('\tLoading train dataset: %s'%train_dataset_name)
    print('\tLoading validation dataset: %s'%validation_dataset_name)


    PATH_TFRECORD_TRAIN = os.path.join(DATASET_PATH, train_dataset_name)
    PATH_TFRECORD_VALIDATION = os.path.join(DATASET_PATH, validation_dataset_name)

    dataset_train_GAN = tf.data.TFRecordDataset(PATH_TFRECORD_TRAIN)
    dataset_validation_GAN = tf.data.TFRecordDataset(PATH_TFRECORD_VALIDATION)

    if voxels_X_low_res <= 32: 
        dataset_train_GAN = dataset_train_GAN.cache()
        dataset_validation_GAN = dataset_validation_GAN.cache()
        print('\tUsing cache for dataset: %s'%train_dataset_name)

    # Create train dataset with transformations
    dataset_train_GAN = dataset_train_GAN.map(lambda x: DH.tf_read_sample_file(x, 
                                                                               data_size, 
                                                                               input_size_this, 
                                                                               not_transformed = True,
                                                                               cdf_sampler_coef=cdf_coef))

    # Create validation dataset, whole image
    dataset_validation_GAN = dataset_validation_GAN.map(lambda x: DH.tf_read_raw_sample(x, data_size))
    
    # Shuffle the train dataset
    dataset_train_GAN = dataset_train_GAN.shuffle(buffer_size=BUFFER_SIZE_TRAIN, reshuffle_each_iteration=True).repeat(-1)


    # Set batch size
    dataset_train_GAN = dataset_train_GAN.batch(batch_size=BATCH_SIZE_TRAIN)
    dataset_validation_GAN = dataset_validation_GAN.batch(batch_size=BATCH_SIZE_VALIDATION)
    
    train_datasets_list.append(dataset_train_GAN)
    validation_datasets_list.append(dataset_validation_GAN)

    # Distributed version
    dist_dataset_train_GAN = strategy.experimental_distribute_dataset(dataset_train_GAN)
    dist_dataset_validation_GAN = strategy.experimental_distribute_dataset(dataset_validation_GAN)
    
    train_datasets_dist_list.append(dist_dataset_train_GAN)
    validation_datasets_dist_list.append(dist_dataset_validation_GAN)
    
    

# Model Creation -- Keras API

### Progressive Pix2Pix 3D Generator

In [None]:
with strategy.scope():

    # Create parameter structure
    param_generator = GAN.topologies.ProGAN_param_structure()

    param_generator.block_conv_channels = USE_GEN_net_conv_Channels
    param_generator.n_blocks = len(param_generator.block_conv_channels) 
    param_generator.latent_dim = (voxels_X/(2**(LOW_RES_POW-1)),
                                  voxels_Y/(2**(LOW_RES_POW-1)),
                                  voxels_Z/(2**(LOW_RES_POW-1)),1)
    param_generator.block_conv_layers = USE_GEN_net_conv_Layers

    param_generator.use_BatchNorm = USE_GEN_BATCH_NORM
    param_generator.use_PixelNorm = USE_GEN_PIXEL_NORM
    param_generator.use_He_scale = USE_GEN_HE_SCALLING
    param_generator.initializer_std = USE_GEN_INI_STD



    # Create the list of progressive growing models
    generator_model_list = GAN.topologies.define_3D_prog_Vnet_generator(param_generator)


In [None]:
generator_model_list[1][1].summary()

In [None]:
# Save all models topologies
for i in range(LOW_RES_POW):
    tf.keras.utils.plot_model(generator_model_list[i][0], to_file=os.path.join(CHECKPOINT_PATH,'ProgPix2Pix3D_generator_model_dilat_%d.png'%i), show_shapes=True, show_layer_names=True)
    tf.keras.utils.plot_model(generator_model_list[i][1], to_file=os.path.join(CHECKPOINT_PATH,'ProgPix2Pix3D_generator_model_dilat_%d_FadeIn.png'%i), show_shapes=True, show_layer_names=True)

### Conditional Progrssive Critic

In [None]:
with strategy.scope():

    # Create parameter structure
    param_disc = GAN.topologies.Disc_param_structure()
    param_disc.conditional = IS_CONDITIONAL_DISC
    param_disc.block_conv_channels = USE_DISC_net_conv_Channels

    param_disc.input_shape = (voxels_X/(2**(LOW_RES_POW-1)),
                              voxels_Y/(2**(LOW_RES_POW-1)),
                              voxels_Z/(2**(LOW_RES_POW-1)),1)

    param_disc.use_norm_contrain_scale = USE_DISC_NORM_CONSTRAINT_SCALE
    param_disc.use_minibatch_stdev = USE_DISC_MINI_BATCH_STD

    param_disc.n_blocks = len(param_disc.block_conv_channels) 

    param_disc.downSampling_layers  = len(param_disc.block_conv_channels) 


    # Create the list of progressive growing models
    disc_model_list = GAN.topologies.define_discriminator_3D(param_disc)
    

In [None]:
disc_model_list[1][0].summary()

In [None]:
# Save all models topologies
for i in range(LOW_RES_POW):
    tf.keras.utils.plot_model(disc_model_list[i][0], to_file=os.path.join(CHECKPOINT_PATH,'ProgPix2Pix3D_discriminator_model_dilat_%d.png'%i), show_shapes=True, show_layer_names=True)
    tf.keras.utils.plot_model(disc_model_list[i][1], to_file=os.path.join(CHECKPOINT_PATH,'ProgPix2Pix3D_discriminator_model_dilat_%d_FadeIn.png'%i), show_shapes=True, show_layer_names=True)

# Train

In [None]:
# Create optimizers and metrics for training 

with strategy.scope():
    
   
    optimizer_gen = keras.optimizers.Adam(lr=step_size_gen, beta_1=0, beta_2=0.99, epsilon=1.0e-8)
    optimizer_disc = keras.optimizers.RMSprop(step_size_disc)
    
    train_d1_loss = keras.metrics.Mean(name='train_d1_loss')
    train_d2_loss = keras.metrics.Mean(name='train_d2_loss')
    train_dgrad_loss = keras.metrics.Mean(name='train_dgrad_loss')
    
    train_g_loss = keras.metrics.Mean(name='train_g_loss')
    train_g_comp_loss = keras.metrics.Mean(name='train_g_comp_loss')
    

# Plot titles... decoration...
titulos_use = ['disc. Real','disc. Fake','gen. Wasserstein','gen. MSE']


step_mean_d1_loss = 0
step_mean_d2_loss = 0
step_mean_dgrad_loss = 0
step_mean_g_loss = 0
step_mean_g_comp_loss = 0

RESOL_INI = 0

# Create tensorboard writer
tb_writer = tf.summary.create_file_writer(TENSORBOARD_OUT_PATH)

In [None]:

for idx_resol in range(RESOL_INI,LOW_RES_POW):
    
    # Set current resolution shape
    current_shape = [int(input_size[0]/(2**(LOW_RES_POW-idx_resol-1))),
                 int(input_size[1]/(2**(LOW_RES_POW-idx_resol-1))),
                 int(input_size[2]/(2**(LOW_RES_POW-idx_resol-1)))]
    current_shape_string = '%dx%dx%d'%(current_shape[0],current_shape[1],current_shape[2])

    
    # Select Dataset
    dist_dataset_train_GAN = train_datasets_dist_list[LOW_RES_POW-idx_resol-1]
    dist_dataset_validation_GAN = validation_datasets_dist_list[LOW_RES_POW-idx_resol-1]
    
    dataset_train_GAN = train_datasets_list[LOW_RES_POW-idx_resol-1]
    dataset_validation_GAN = validation_datasets_list[LOW_RES_POW-idx_resol-1]


    for idx_type_model in range(2):
        
        # First model requires no Fade-In
        if idx_resol == RESOL_INI and idx_type_model == 0:
            continue
        FADE_IN = False
        if idx_type_model == 0:
            FADE_IN = True
        # Choose full or fade-in, we start with fade-in models
        disc_model = disc_model_list[idx_resol][1-idx_type_model]
        gen_model = generator_model_list[idx_resol][1-idx_type_model]
        
        # Get current number of steps
        if FADE_IN:
            STEPS_THIS_PHASE = STEPS_PER_PHASE_FADEIN[idx_resol]
        else:
            STEPS_THIS_PHASE = STEPS_PER_PHASE[idx_resol]
        
         
        
        print('---------------------------------------------------------------------')
        print('---------------------------------------------------------------------')
        print('\t\t Resolution %s'%current_shape_string)
        if idx_type_model == 0:
            print('\t\t\t --- FADEIN ---')
        print('\t\t Trainig for %d steps'%STEPS_THIS_PHASE)
        print('---------------------------------------------------------------------')
        print('---------------------------------------------------------------------')
        
                
        # Set trainiers and plots for this resolution
        loss_plot = np.zeros((STEPS_THIS_PHASE*STEPS_PER_PRINT,4))      
        loss_val_plot = np.zeros((STEPS_THIS_PHASE*STEPS_PER_PLOTS,5))
        loss_std_val_plot = np.zeros((STEPS_THIS_PHASE*STEPS_PER_PLOTS,5))
        axis_val_plot = np.zeros((STEPS_THIS_PHASE*STEPS_PER_PLOTS,1))
        idx_plot = 0
        # train normal or straight-through models
        train_step_gen_tf_this = tf.function(TRAINING_FUNCTION_GEN)
        train_step_disc_tf_this = tf.function(TRAINING_FUNCTION_DISC)
        
        # Set tensorboard names
        tb_plots_namespace = 'Resolution_'+current_shape_string+'_cuts'
        tb_validation_metrics_namespace = 'Resolution_'+current_shape_string+'_Validation_metrics'
        tb_training_metrics_namespace = 'Resolution_'+current_shape_string+'_Training_metrics'
        if idx_type_model == 0:
            tb_plots_namespace = tb_plots_namespace+'_fadeIn'
            tb_validation_metrics_namespace = tb_validation_metrics_namespace+'_fadeIn'
            tb_training_metrics_namespace = tb_training_metrics_namespace+'_fadeIn'
        
        
        with strategy.scope():
            
            train_d1_loss = keras.metrics.Mean(name='train_d1_loss')
            train_d2_loss = keras.metrics.Mean(name='train_d2_loss')
            train_dgrad_loss = keras.metrics.Mean(name='train_dgrad_loss')

            train_g_loss = keras.metrics.Mean(name='train_g_loss')
            train_g_comp_loss = keras.metrics.Mean(name='train_g_comp_loss')

        
        with tb_writer.as_default():
            
            step_ini = 0
            
            for step in range(step_ini, STEPS_THIS_PHASE):
                
                if step == 0:
                    steps_disc_per_gen = steps_disc_per_gen_ini
                else:
                    steps_disc_per_gen = steps_disc_per_gen_loop

                # ---------------------------------------------------------------------------------
                # --------------------------- TRAIN DISCRIMINATOR ---------------------------------
                # ---------------------------------------------------------------------------------
                idx_disc_train = 0
                disc_loss_d1_aux = 0
                disc_loss_d2_aux = 0
                disc_loss_dgrad_aux = 0
                for in_NAC_PET, in_LABELS, in_CT in dist_dataset_train_GAN:

                    train_step_disc_tf_this(in_NAC_PET,
                                            in_CT,
                                            gen_model, 
                                            disc_model, 
                                            train_d1_loss,
                                            train_d2_loss,
                                            train_dgrad_loss,
                                            optimizer_disc,
                                            strategy,
                                            K_grad=10.0,
                                            cross_sample_loss = False)

                    disc_loss_d1_aux += train_d1_loss.result()
                    disc_loss_d2_aux += train_d2_loss.result()
                    disc_loss_dgrad_aux += train_dgrad_loss.result()


                    print_l_real = disc_loss_d1_aux/(idx_disc_train+1)
                    print_l_fake = disc_loss_d2_aux/(idx_disc_train+1)
                    print_l_grad = disc_loss_dgrad_aux/(idx_disc_train+1)
                    print_l_tot = print_l_fake+print_l_real+print_l_grad
                    print('(Step: %d) Training Discriminator: %d/%d\tLoss: %.4f (Real:%.4f Fake:%.4f Grad.:%.4f)      '%(step,
                                                                                                                          idx_disc_train+1,
                                                                                                                          steps_disc_per_gen,
                                                                                                                          print_l_tot,
                                                                                                                          print_l_real,
                                                                                                                          print_l_fake,
                                                                                                                          print_l_grad), end='')
                    print('', end='\r')


                    idx_disc_train += 1
                    if idx_disc_train == steps_disc_per_gen:
                        last_fake_score_module = tf.abs(print_l_fake)
                        break

                with tf.name_scope(tb_training_metrics_namespace):
                    tf.summary.scalar("Disc. Real", print_l_real, step=step)
                    tf.summary.scalar("Disc. Fake", print_l_fake, step=step)
                    tf.summary.scalar("Disc. Grad", print_l_grad, step=step)
                    tf.summary.scalar("Disc. Full", print_l_tot, step=step)

                # ---------------------------------------------------------------------------------
                # --------------------------- TRAIN GENERATOR -------------------------------------
                # ---------------------------------------------------------------------------------
                for in_NAC_PET, in_LABELS, in_CT in dist_dataset_train_GAN:

                    print('(Step: %d) Training Generator...                                                                             '%(step), end='')
                    print('', end='\r')

                    train_step_gen_tf_this(in_NAC_PET,
                                           in_CT,
                                           gen_model, 
                                           disc_model, 
                                           train_g_loss,
                                           train_g_comp_loss,
                                           optimizer_gen,
                                           strategy,
                                           K_comp_loss = 1.0/100.0,
                                           norm_by_size = False)



                    break


                loss_plot[step,0] = disc_loss_d1_aux/steps_disc_per_gen
                loss_plot[step,1] = disc_loss_d2_aux/steps_disc_per_gen
                loss_plot[step,2] = train_g_loss.result()
                loss_plot[step,3] = train_g_comp_loss.result()    


                with tf.name_scope(tb_training_metrics_namespace):
                    tf.summary.scalar("Gen. WGAN loss", train_g_loss.result(), step=step)
                    tf.summary.scalar("Gen. L2 loss", train_g_comp_loss.result(), step=step)



                train_d1_loss.reset_states()
                train_d2_loss.reset_states()
                train_dgrad_loss.reset_states()
                train_g_loss.reset_states()
                train_g_comp_loss.reset_states()

                step_mean_d1_loss += loss_plot[step,0]
                step_mean_d2_loss += loss_plot[step,1]
                step_mean_dgrad_loss += disc_loss_dgrad_aux/steps_disc_per_gen
                step_mean_g_loss += loss_plot[step,2]
                step_mean_g_comp_loss += loss_plot[step,3]
                
                
                # ---------------------------------------------------------------------------------
                # --------------------------- TRAIN METRICS ---------------------------------------
                # ---------------------------------------------------------------------------------

                if step%STEPS_PER_PRINT == 0:

                    if step != 0:
                        step_mean_d1_loss = step_mean_d1_loss/float(STEPS_PER_PRINT)
                        step_mean_d2_loss = step_mean_d2_loss/float(STEPS_PER_PRINT)
                        step_mean_g_loss = step_mean_g_loss/float(STEPS_PER_PRINT)
                        step_mean_g_comp_loss = step_mean_g_comp_loss/float(STEPS_PER_PRINT)
                        step_mean_dgrad_loss = step_mean_dgrad_loss/float(STEPS_PER_PRINT)

                    step_mean_dtot_loss = step_mean_d1_loss+step_mean_d2_loss+step_mean_dgrad_loss

                    print('>%d, d1=%.4f, d2=%.4f, dg=%.4f (d=%.4f) ; g=%.4f g_comp=%.4f ' % (step, 
                                                                                             step_mean_d1_loss, 
                                                                                             step_mean_d2_loss,
                                                                                             step_mean_dgrad_loss,
                                                                                             step_mean_dtot_loss,
                                                                                             step_mean_g_loss,
                                                                                             step_mean_g_comp_loss))

                    step_mean_d1_loss = 0
                    step_mean_d2_loss = 0
                    step_mean_dgrad_loss = 0
                    step_mean_g_loss = 0
                    step_mean_g_comp_loss = 0


                if step%STEPS_PER_PLOTS == 0 and step > 0:

                    GAN.train_support.plot_loss(loss_plot, step, forma='2_cols', dpi_use=100, titulos=titulos_use, skip_n_initial=10)



                    (PET_INPUT_images, 
                     CT_INPUT_images, 
                     LABELS_INPUT_images, 
                     CT_SYNTH_images, 
                     SEGMENTED_images, 
                     SCORE_images, 
                     PSNR_images, 
                     ME_images, 
                     NMSE_images, 
                     NCC_images) = GAN.train_support.validate_whole_volume(gen_model, 
                                                                           disc_model, 
                                                                           dataset_validation_GAN, 
                                                                           current_shape, 
                                                                           segm_net = False)
                    print('')
                    with tf.name_scope(tb_plots_namespace):
                        GAN.train_support.plot_images3D_conditional(PET_INPUT_images,
                                                                    CT_INPUT_images, 
                                                                    CT_SYNTH_images, 
                                                                    segm_net = False,
                                                                    dpi_use = 150,
                                                                    add_tensorboard=True, 
                                                                    epoch=step)

                    print('Metrics:                              ')
                    print('\tCritic Score:\t mean=%0.5f ; std=%0.5f'%(np.array(SCORE_images).mean(),np.array(SCORE_images).std()))
                    print('\tPSNR:\t\t mean=%0.5f ; std=%0.5f'%(np.array(PSNR_images).mean(),np.array(PSNR_images).std()))
                    print('\tME:\t\t mean=%0.5f ; std=%0.5f'%(np.array(ME_images).mean(),np.array(ME_images).std()))
                    print('\tNMSE:\t\t mean=%0.5f ; std=%0.5f'%(np.array(NMSE_images).mean(),np.array(NMSE_images).std()))
                    print('\tNCC:\t\t mean=%0.5f ; std=%0.5f'%(np.array(NCC_images).mean(),np.array(NCC_images).std()))


                    with tf.name_scope(tb_validation_metrics_namespace):
                        tf.summary.scalar("Critic_Score", np.array(SCORE_images).mean(), step=step)
                        tf.summary.scalar("PSNR", np.array(PSNR_images).mean(), step=step)
                        tf.summary.scalar("ME", np.array(ME_images).mean(), step=step)
                        tf.summary.scalar("NMSE", np.array(NMSE_images).mean(), step=step)
                        tf.summary.scalar("NCC", np.array(NCC_images).mean(), step=step)


                    loss_val_plot[idx_plot, 0] = np.array(SCORE_images).mean()
                    loss_val_plot[idx_plot, 1] = np.array(PSNR_images).mean()
                    loss_val_plot[idx_plot, 2] = np.array(ME_images).mean()
                    loss_val_plot[idx_plot, 3] = np.array(NMSE_images).mean()
                    loss_val_plot[idx_plot, 4] = np.array(NCC_images).mean()

                    loss_std_val_plot[idx_plot, 0] = np.array(SCORE_images).std()
                    loss_std_val_plot[idx_plot, 1] = np.array(PSNR_images).std()
                    loss_std_val_plot[idx_plot, 2] = np.array(ME_images).std()
                    loss_std_val_plot[idx_plot, 3] = np.array(NMSE_images).std()
                    loss_std_val_plot[idx_plot, 4] = np.array(NCC_images).std()

                    axis_val_plot[idx_plot, 0] = step
                    idx_plot += 1


                    plt.figure(dpi=100)
                    plt.subplot(3,2,1)
                    plt.errorbar(axis_val_plot[:idx_plot,0], loss_val_plot[:idx_plot,0], yerr=loss_std_val_plot[:idx_plot,0], fmt='-o')
                    plt.grid('on')
                    plt.title('Critic score')
                    plt.subplot(3,2,2)
                    plt.errorbar(axis_val_plot[:idx_plot,0], loss_val_plot[:idx_plot,1], yerr=loss_std_val_plot[:idx_plot,1], fmt='-o')
                    plt.grid('on')
                    plt.title('PSNR')
                    plt.subplot(3,2,3)
                    plt.errorbar(axis_val_plot[:idx_plot,0], loss_val_plot[:idx_plot,2], yerr=loss_std_val_plot[:idx_plot,2], fmt='-o')
                    plt.grid('on')
                    plt.title('ME')
                    plt.subplot(3,2,4)
                    plt.errorbar(axis_val_plot[:idx_plot,0], loss_val_plot[:idx_plot,3], yerr=loss_std_val_plot[:idx_plot,3], fmt='-o')
                    plt.grid('on')
                    plt.title('NMSE')
                    plt.subplot(3,2,5)
                    plt.errorbar(axis_val_plot[:idx_plot,0], loss_val_plot[:idx_plot,4], yerr=loss_std_val_plot[:idx_plot,4], fmt='-o')
                    plt.grid('on')
                    plt.title('NCC')
                    plt.tight_layout()
                    plt.show()

                   
                    plt.close('all')


                tb_writer.flush()


                if step%STEPS_PER_SAVE == 0 and step > 0:
                    CHECKPOINT_PATH_NET_RESOLUTION = os.path.join(CHECKPOINT_PATH, 'Res_'+current_shape_string)
                    File_mng.check_create_path('CHECKPOINT_PATH_NET_RESOLUTION', CHECKPOINT_PATH_NET_RESOLUTION, clear_folder=False) 

                    GAN.train_support.save_progressive_model(generator_model_list, 
                                                             disc_model_list,
                                                             input_size, 
                                                             NETWORK_NAME, 
                                                             CHECKPOINT_PATH_NET_RESOLUTION, 
                                                             0,
                                                             save_limit=idx_resol)