In [None]:
# matplotlib plots within notebook
%matplotlib inline

import platform
print("python: "+platform.python_version())


import numpy as np
import matplotlib.pyplot as plt

import time

import os, shutil, sys


sys.path.insert(0, '.')
from DeepAttCorr_lib import GAN_3D_lib as GAN
from DeepAttCorr_lib import data_handling as DH
from DeepAttCorr_lib import file_manage_utils as File_mng

import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.backend as K
print('Using TensorFlow version: '+tf.__version__)

from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())

# Paths and Definitions

### Basics

In [None]:
# Network name
NETWORK_NAME = 'DeepAttCorr_GAN_Network'

# Dataset location
DATASET_PATH = './datasets/'

# Checkpoint location
CHECKPOINT_PATH = "./Outputs/"+NETWORK_NAME+"/"

# Path to tensorboard desired output
TENSORBOARD_BASE_PATH = "./TensorBoard_output"
TENSORBOARD_OUT_PATH = os.path.join(TENSORBOARD_BASE_PATH,NETWORK_NAME)

# Clear outputs before running
CLEAR_OUTS = True

### Network Parameters

In [None]:
# Pre-Train: If True the generator is loaded from a previously 
# trained generator, trained with the "./Train_Standard_Model.ipynb" notebook.
# The checkpoint is loaded from "PRE_TRAINED_GENERATOR_PATH" using name "PRE_TRAINED_GENERATOR_NAME"
Pre_Train = True
PRE_TRAINED_GENERATOR_PATH = './trained_models/Sup_Pre_Train_No_GAN_Loss/'
PRE_TRAINED_GENERATOR_NAME = 'Sup_Pre_Train_No_GAN_Loss'

# If True the adversarial gradient is restricted to the conditional generator network
RestrictedGrad = True
# If True the segementator networks losses (DICE and L2) are used to regularize training
SupervisedLoss = True



# Imput volume size
voxels_X = 128
voxels_Y = 128
voxels_Z = 32
input_size = (voxels_X,voxels_Y,voxels_Z)

# --------------------------------------------------------------------------------
# ---------------------- GENERATOR / SEGMENTATION NET ----------------------------
# --------------------------------------------------------------------------------
# Network convolutional channels by resolution level
USE_GEN_net_conv_Channels = [10, 20, 40, 80, 160]
# Convolutional layers by resolution level
USE_GEN_net_conv_Layers = 2
# Hyperbolic tangent or Sigmoid output for generator
USE_GEN_TANH_OUT = False
USE_GEN_SIGMOID_OUT = True
# Use or not segmentation path
USE_GEN_SEGMENTATION = True
# Number of objective clases
USE_GEN_OBJECTIVE_SEGMENTATION_CLASES = 4
# Number of fully connected segmentation layers
USE_GEN_OBJECTIVE_SEGMENTATION_LAYERS = 2
# Number of convolutional segmentation layers
USE_GEN_OBJECTIVE_CONV_SEGMENTATION_LAYERS = 4
# Segmentation kernel size
USE_GEN_SEGM_KERNEL_SIZE = 3
# Use batch normalization for generator training
USE_GEN_BATCH_NORM = False
# Use pixel normalization for generator training
USE_GEN_PIXEL_NORM = True
# Use He scalling of weights for generator
USE_GEN_HE_SCALLING = True
# Wheight initialization standard deviation
USE_GEN_INI_STD = 1.0


# --------------------------------------------------------------------------------
# ---------------------- CONDITIONAL GENERATOR -----------------------------------
# --------------------------------------------------------------------------------

# Network convolutional channels by resolution level
USE_COND_GEN_net_conv_Channels = [8,8,8,8,8]
# Convolutional layers
USE_COND_GEN_net_conv_Layers = 5
# Hyperbolic tangent or Sigmoid output for generator
USE_COND_GEN_TANH_OUT = False
USE_COND_GEN_SIGMOID_OUT = False
# Use batch normalization for generator training
USE_COND_GEN_BATCH_NORM = False
# Use pixel normalization for generator training
USE_COND_GEN_PIXEL_NORM = True
# Use He scalling of weights for generator
USE_COND_GEN_HE_SCALLING = True
# Wheight initialization standard deviation
USE_COND_GEN_INI_STD = 1.0


# --------------------------------------------------------------------------------
# ---------------------- DISCRIMNATOR --------------------------------------------
# --------------------------------------------------------------------------------
# If true the discriminator recieves the generator output AND input latent space
IS_CONDITIONAL_DISC = True
# Network convolutional channels by resolution level
USE_DISC_net_conv_Channels = [4, 8, 16, 32, 64, 128, 256]
# Use weight norm constraint
USE_DISC_NORM_CONSTRAINT_SCALE = False
# Use mini batch std
USE_DISC_MINI_BATCH_STD = False



### Training

In [None]:
# Dataset full size (without slicing)
DATASET_X_size = 128
DATASET_Y_size = 128
DATASET_Z_size = 256

# Volume mini-batch size
BATCH_SIZE_TRAIN = 4
BUFFER_SIZE_TRAIN = 4
BATCH_SIZE_VALIDATION = 4
BUFFER_SIZE_VALIDATION = 4

# Initial step size
step_size_gen = 0.0001 
step_size_disc = 0.0005 
step_size_segm = 0.001 

# Set training steps for each network
TRAINING_FUNCTION_DISC = GAN.train_support.train_step_discriminator_conditional_3D_GAN_tf
TRAINING_FUNCTION_GEN = GAN.train_support.train_step_generator_and_segmentator_conditional_3D_GAN_tf

# Uniform or custom sampling of the input FOV
# If True the input sample is sliced with uniform probability
# If False, the Cumulative Density Function in CDF_PATH will control the sampling
UNIFORM_SAMPLING = False
CDF_PATH = "./DeepAttCorr_lib/cdf_coef.npy"


# Total number of training steps to perform
STEPS_RUN = 100001
# Number of steps per information print
STEPS_PER_PRINT = 100
# Number of steps per plotting of progress
STEPS_PER_PLOTS = 500
# Number of steps per checkpoint
STEPS_PER_SAVE = 500
# Initial number of discriminator training steps
steps_disc_per_gen_ini = 500
# Number of discriminator training steps per generator training steps
steps_disc_per_gen_loop = 5



# Set-up

In [None]:
File_mng.check_create_path('CHECKPOINT_PATH', CHECKPOINT_PATH, clear_folder=CLEAR_OUTS)
File_mng.check_create_path('TENSORBOARD_BASE_PATH', TENSORBOARD_BASE_PATH)
File_mng.check_create_path('TENSORBOARD_OUT_PATH', TENSORBOARD_OUT_PATH, clear_folder=CLEAR_OUTS)

In [None]:
# Load sampling CDF
cdf_coef = [1.0]
if not UNIFORM_SAMPLING:
    cdf_coef = np.load(CDF_PATH)

In [None]:
# Set multip-GPU mirror strategy
strategy = tf.distribute.MirroredStrategy()

### Dataset reader

In [None]:
# Create dataset list

shape_X = int(voxels_X)
shape_Y = int(voxels_Y)
shape_Z = int(voxels_Z)

data_size = np.array([DATASET_X_size, 
                      DATASET_Y_size, 
                      DATASET_Z_size])

input_size_this = (shape_X,shape_Y,shape_Z)

# Get dataset name
train_dataset_name = 'Train_Dataset_%dx%dx%d.tfrecord'%(data_size[0],data_size[1],data_size[2])
validation_dataset_name = 'Validation_Dataset_%dx%dx%d.tfrecord'%(data_size[0],data_size[1],data_size[2])

# Create dataset reading pipelines
PATH_TFRECORD_TRAIN = os.path.join(DATASET_PATH, train_dataset_name)
PATH_TFRECORD_VALIDATION = os.path.join(DATASET_PATH, validation_dataset_name)

dataset_train_GAN = tf.data.TFRecordDataset(PATH_TFRECORD_TRAIN)
dataset_validation_GAN = tf.data.TFRecordDataset(PATH_TFRECORD_VALIDATION)

if shape_X <= 32: 
    dataset_train_GAN = dataset_train_GAN.cache()
    dataset_validation_GAN = dataset_validation_GAN.cache()
    print('Using cache for dataset: %s'%train_dataset_name)

# Create train dataset with transformations
dataset_train_GAN = dataset_train_GAN.map(lambda x: DH.tf_read_sample_file(x, 
                                                                              data_size, 
                                                                              input_size_this, 
                                                                              not_transformed = True,
                                                                              cdf_sampler_coef=cdf_coef))
# Create validation dataset, whole image
dataset_validation_GAN = dataset_validation_GAN.map(lambda x: DH.tf_read_raw_sample(x, data_size))

# Shuffle the train dataset
dataset_train_GAN = dataset_train_GAN.shuffle(buffer_size=BUFFER_SIZE_TRAIN, reshuffle_each_iteration=True).repeat(-1)


# Set batch size
dataset_train_GAN = dataset_train_GAN.batch(batch_size=BATCH_SIZE_TRAIN)
dataset_validation_GAN = dataset_validation_GAN.batch(batch_size=BATCH_SIZE_VALIDATION)

# Create distributed datasets
dist_dataset_train_GAN = strategy.experimental_distribute_dataset(dataset_train_GAN)
dist_dataset_validation_GAN = strategy.experimental_distribute_dataset(dataset_validation_GAN)



# Model Creation -- Keras API

### Segmentation V-Net

In [None]:
with strategy.scope():

    if Pre_Train:
        custom_layers_dict = {'HeScale': GAN.layers.HeScale,
                              'BiasLayer': GAN.layers.BiasLayer,
                              'PixelNormalization': GAN.layers.PixelNormalization}
        segm_model = GAN.train_support.load_model(PRE_TRAINED_GENERATOR_PATH, 
                                                  PRE_TRAINED_GENERATOR_NAME, 
                                                  custom_obj_dict = custom_layers_dict)
        
    else:
        param_segm = GAN.topologies.Gen_param_structure()

        param_segm.block_conv_layers = USE_GEN_net_conv_Layers
        param_segm.block_conv_channels = USE_GEN_net_conv_Channels
        param_segm.n_blocks = len(param_segm.block_conv_channels)
        param_segm.latent_dim = (voxels_X,voxels_Y,voxels_Z,1)

        param_segm.use_tanh_out = USE_GEN_TANH_OUT
        param_segm.use_sigmoid_out = USE_GEN_SIGMOID_OUT
        param_segm.segmentation_output = USE_GEN_SEGMENTATION
        param_segm.segmentation_classes = USE_GEN_OBJECTIVE_SEGMENTATION_CLASES
        param_segm.segmentation_layers = USE_GEN_OBJECTIVE_SEGMENTATION_LAYERS
        param_segm.conv_segmentation_channels = USE_GEN_OBJECTIVE_CONV_SEGMENTATION_LAYERS
        param_segm.segmentation_kernel_size = USE_GEN_SEGM_KERNEL_SIZE

        param_segm.use_BatchNorm = USE_GEN_BATCH_NORM
        param_segm.use_PixelNorm = USE_GEN_PIXEL_NORM
        param_segm.use_He_scale = USE_GEN_HE_SCALLING
        param_segm.initializer_std = USE_GEN_INI_STD

        # Crea una instancia del modelo
        segm_model = GAN.topologies.define_3D_Vnet_generator(param_segm)


In [None]:
segm_model.summary()

In [None]:
tf.keras.utils.plot_model(segm_model, to_file=os.path.join(CHECKPOINT_PATH,'segmentator_model.png'), show_shapes=True, show_layer_names=True)

### Conditional Generator Net

In [None]:
with strategy.scope():
    param_gen = GAN.topologies.Gen_param_structure()

    param_gen.block_conv_layers = USE_COND_GEN_net_conv_Layers
    param_gen.block_conv_channels = USE_COND_GEN_net_conv_Channels
    param_gen.n_blocks = len(param_gen.block_conv_channels)
    
    if SupervisedLoss:
        # The input is the synth. CT and the segmentation output of previous network
        param_gen.latent_dim = (voxels_X,voxels_Y,voxels_Z,USE_GEN_OBJECTIVE_SEGMENTATION_CLASES+1)
    else:
        # The input is only the synth. CT output of previous network
        param_gen.latent_dim = (voxels_X,voxels_Y,voxels_Z,1)
        

    param_gen.use_tanh_out = USE_COND_GEN_TANH_OUT
    param_gen.use_sigmoid_out = USE_COND_GEN_SIGMOID_OUT

    param_gen.use_BatchNorm = USE_COND_GEN_BATCH_NORM
    param_gen.use_PixelNorm = USE_COND_GEN_PIXEL_NORM
    param_gen.use_He_scale = USE_COND_GEN_HE_SCALLING
    param_gen.initializer_std = USE_COND_GEN_INI_STD

    # Create model instance
    gen_model = GAN.topologies.define_3D_Convolutional_generator(param_gen, name_appendix='_Texturator')


In [None]:
gen_model.summary()

In [None]:
tf.keras.utils.plot_model(gen_model, to_file=os.path.join(CHECKPOINT_PATH,'texturator_model.png'), show_shapes=True, show_layer_names=True)

### Composed Generator Network

This model is just the union of the previous networks. This way is easier to control the gradient flow.

In [None]:
with strategy.scope():
    
    inputRefImage = keras.Input(shape=segm_model.input.shape[1:], name="Input_PET_GAN")

    # Set to segm range...
    t_nada = (inputRefImage/2.0)+0.5
    
    # Apply 3D U-Net
    out_v_net_segm = segm_model(t_nada)
    
    # Apply texture network
    if SupervisedLoss:
        out_texture_gan = gen_model(out_v_net_segm)
    else:
        out_texture, out_segm = tf.split(out_v_net_segm, [1,4], axis=-1)
        out_texture_gan = gen_model(out_texture)
        

    comp_gen_model = keras.Model(inputRefImage, out_texture_gan, name='Composed_Generator')


In [None]:
comp_gen_model.summary()

In [None]:
tf.keras.utils.plot_model(comp_gen_model, to_file=os.path.join(CHECKPOINT_PATH,'generator_model.png'), show_shapes=True, show_layer_names=True)

### Conditional Discriminator Net

In [None]:
with strategy.scope():
    
    
    
    param_disc = GAN.topologies.Disc_param_structure()
    param_disc.conditional = IS_CONDITIONAL_DISC
    param_disc.block_conv_channels = USE_DISC_net_conv_Channels
    
    # Nomber of pooling operations
    LOW_RES_POW = len(USE_DISC_net_conv_Channels)-3
    # Lowest resolution inside the convolutional discriminator befores using fully connected layers
    voxels_X_low_res = int(voxels_X/(2**LOW_RES_POW))
    voxels_Y_low_res = int(voxels_Y/(2**LOW_RES_POW))
    voxels_Z_low_res = int(voxels_Z/(2**LOW_RES_POW))
    # This is the input of the frst block. This is only to simplify contruction code.
    param_disc.input_shape = (voxels_X_low_res,voxels_Y_low_res,voxels_Z_low_res,1)
    
    param_disc.use_norm_contrain_scale = USE_DISC_NORM_CONSTRAINT_SCALE
    param_disc.use_minibatch_stdev = USE_DISC_MINI_BATCH_STD
    
    param_disc.n_blocks = LOW_RES_POW+1
    
    param_disc.downSampling_layers  = len(param_disc.block_conv_channels) 

    # Create a model instance
    disc_model_list = GAN.topologies.define_discriminator_3D(param_disc)
    # Keep the last model, since the previous function creates a list of
    # models for different resolutions (used in ProGAN, in an other notebook)
    disc_model = disc_model_list[-1][0]
    


In [None]:
disc_model.summary()

In [None]:
tf.keras.utils.plot_model(disc_model, to_file=os.path.join(CHECKPOINT_PATH,'discrimnator_model.png'), show_shapes=True, show_layer_names=True)

# Train!

In [None]:
# Create optimizers and metrics for training 
with strategy.scope():
    
    optimizer_segm = keras.optimizers.Adam(lr=step_size_segm, beta_1=0.9, beta_2=0.99, epsilon=1.0e-8)    
    optimizer_gen = keras.optimizers.Adam(lr=step_size_gen, beta_1=0, beta_2=0.99, epsilon=1.0e-8)
    optimizer_disc = keras.optimizers.RMSprop(step_size_disc)
    
    train_d1_loss = keras.metrics.Mean(name='train_d1_loss')
    train_d2_loss = keras.metrics.Mean(name='train_d2_loss')
    train_dgrad_loss = keras.metrics.Mean(name='train_dgrad_loss')
    
    train_g_loss = keras.metrics.Mean(name='train_g_loss')
    train_g_comp_loss = keras.metrics.Mean(name='train_g_comp_loss')
    
    train_s1_loss = keras.metrics.Mean(name='train_s1_loss')
    train_s2_loss = keras.metrics.Mean(name='train_s2_loss')
    
loss_plot = np.zeros((STEPS_RUN*STEPS_PER_PRINT,6))      
loss_val_plot = np.zeros((STEPS_RUN*STEPS_PER_PLOTS,5))
loss_std_val_plot = np.zeros((STEPS_RUN*STEPS_PER_PLOTS,5))
axis_val_plot = np.zeros((STEPS_RUN*STEPS_PER_PLOTS,1))
idx_plot = 0

# train normal or straight-through models
train_step_gen_tf_this = tf.function(TRAINING_FUNCTION_GEN)
train_step_disc_tf_this = tf.function(TRAINING_FUNCTION_DISC)

# Plot titles... decoration...
titulos_use = ['disc. Real','disc. Fake','gen. Wasserstein','gen. MSE','Segm. DICE','Segm. L2']


step_mean_d1_loss = 0
step_mean_d2_loss = 0
step_mean_dgrad_loss = 0
step_mean_g_loss = 0
step_mean_g_comp_loss = 0
step_mean_s1_loss = 0
step_mean_s2_loss = 0

# Create tensorboard writer
tb_writer = tf.summary.create_file_writer(TENSORBOARD_OUT_PATH)

In [None]:
print('Discrimnator metrics:\n\t d1:\t  Real volume critic score.\n\t d2:\t  Fake volume critic score.\n\t dg:\t  Gradient penalty.\n\t d:\t  Total loss.')
print('Generator metrics:\n\t g:\t  Critic loss of generated sample (WGAN loss).\n\t g_comp:   L2 loss of generated sample.\n\t s1:\t  DICE segementation loss.\n\t s2:\t  L2 loss of segmentator net.')

with tb_writer.as_default():
    for step in range(STEPS_RUN):


        # Train Discriminator
        idx_disc_train = 0
        disc_loss_d1_aux = 0
        disc_loss_d2_aux = 0
        disc_loss_dgrad_aux = 0

        if step == 0:
            steps_disc_per_gen = steps_disc_per_gen_ini
        else:
            steps_disc_per_gen = steps_disc_per_gen_loop

        # ---------------------------------------------------------------------------------
        # --------------------------- TRAIN DISCRIMINATOR ---------------------------------
        # ---------------------------------------------------------------------------------
        for in_NAC_PET, in_LABELS, in_CT in dist_dataset_train_GAN:

            train_step_disc_tf_this(in_NAC_PET,
                                    in_CT,
                                    comp_gen_model, 
                                    disc_model, 
                                    train_d1_loss,
                                    train_d2_loss,
                                    train_dgrad_loss,
                                    optimizer_disc,
                                    strategy,
                                    K_grad=10.0,
                                   cross_sample_loss = False)

            disc_loss_d1_aux += train_d1_loss.result()
            disc_loss_d2_aux += train_d2_loss.result()
            disc_loss_dgrad_aux += train_dgrad_loss.result()


            print_l_real = disc_loss_d1_aux/(idx_disc_train+1)
            print_l_fake = disc_loss_d2_aux/(idx_disc_train+1)
            print_l_grad = disc_loss_dgrad_aux/(idx_disc_train+1)
            print_l_tot = print_l_fake+print_l_real+print_l_grad
            print('(Steps: %d) Training Discriminator: %d/%d\tLoss: %.4f (Real:%.4f Fake:%.4f Grad.:%.4f)      '%(step,
                                                                                                                  idx_disc_train+1,
                                                                                                                  steps_disc_per_gen,
                                                                                                                  print_l_tot,
                                                                                                                  print_l_real,
                                                                                                                  print_l_fake,
                                                                                                                  print_l_grad), end='')
            print('', end='\r')

            idx_disc_train += 1
            if idx_disc_train == steps_disc_per_gen:
                last_fake_score_module = tf.abs(print_l_fake)
                break


        # ---------------------------------------------------------------------------------
        # --------------------------- TRAIN GENERATOR -------------------------------------
        # ---------------------------------------------------------------------------------
        for in_NAC_PET, in_LABELS, in_CT in dist_dataset_train_GAN:

            print('(Steps: %d) Training Generator...                                                                             '%(step), end='')
            print('', end='\r')

            train_step_gen_tf_this(in_NAC_PET,
                                   in_CT,
                                   in_LABELS,
                                   comp_gen_model, 
                                   disc_model, 
                                   segm_model,
                                   train_g_loss,
                                   train_g_comp_loss,
                                   train_s1_loss,
                                   train_s2_loss,
                                   optimizer_gen,
                                   strategy,
                                   K_comp_loss = last_fake_score_module*10.0,
                                   K_comp_segm_loss = GAN.losses.k_coupling, 
                                   norm_by_size = False,
                                   split_train = RestrictedGrad,
                                   train_segm = SupervisedLoss)



            break


        # ---------------------------------------------------------------------------------
        # --------------------------- TRAIN METRICS ---------------------------------------
        # ---------------------------------------------------------------------------------

        loss_plot[step,0] = disc_loss_d1_aux/steps_disc_per_gen
        loss_plot[step,1] = disc_loss_d2_aux/steps_disc_per_gen
        loss_plot[step,2] = train_g_loss.result()
        loss_plot[step,3] = train_g_comp_loss.result()    
        loss_plot[step,4] = train_s1_loss.result()
        loss_plot[step,5] = train_s2_loss.result()
        
        tf.summary.scalar("disc_loss_d1", loss_plot[step,0], step=step)
        tf.summary.scalar("disc_loss_d2", loss_plot[step,1], step=step)
        tf.summary.scalar("train_g_loss", loss_plot[step,2], step=step)
        tf.summary.scalar("train_g_comp_loss", loss_plot[step,3], step=step)
        tf.summary.scalar("train_s1_loss", loss_plot[step,4], step=step)
        tf.summary.scalar("train_s2_loss", loss_plot[step,5], step=step)

        train_d1_loss.reset_states()
        train_d2_loss.reset_states()
        train_dgrad_loss.reset_states()
        train_g_loss.reset_states()
        train_g_comp_loss.reset_states()
        train_s1_loss.reset_states()
        train_s2_loss.reset_states()

        step_mean_d1_loss += loss_plot[step,0]
        step_mean_d2_loss += loss_plot[step,1]
        step_mean_dgrad_loss += disc_loss_dgrad_aux/steps_disc_per_gen
        step_mean_g_loss += loss_plot[step,2]
        step_mean_g_comp_loss += loss_plot[step,3]
        step_mean_s1_loss += loss_plot[step,4]
        step_mean_s2_loss += loss_plot[step,5]

        if step%STEPS_PER_PRINT == 0:

            if step != 0:
                step_mean_d1_loss = step_mean_d1_loss/float(STEPS_PER_PRINT)
                step_mean_d2_loss = step_mean_d2_loss/float(STEPS_PER_PRINT)
                step_mean_g_loss = step_mean_g_loss/float(STEPS_PER_PRINT)
                step_mean_g_comp_loss = step_mean_g_comp_loss/float(STEPS_PER_PRINT)
                step_mean_s1_loss = step_mean_s1_loss/float(STEPS_PER_PRINT)
                step_mean_s2_loss = step_mean_s2_loss/float(STEPS_PER_PRINT)
                step_mean_dgrad_loss = step_mean_dgrad_loss/float(STEPS_PER_PRINT)

            step_mean_dtot_loss = step_mean_d1_loss+step_mean_d2_loss+step_mean_dgrad_loss

            print('>%d, d1=%.4f, d2=%.4f, dg=%.4f (d=%.4f) ; g=%.4f g_comp=%.4f s_1=%.4f s_2=%.4f ' % (step, 
                                                                                                       step_mean_d1_loss, 
                                                                                                       step_mean_d2_loss,
                                                                                                       step_mean_dgrad_loss,
                                                                                                       step_mean_dtot_loss,
                                                                                                       step_mean_g_loss,
                                                                                                       step_mean_g_comp_loss,
                                                                                                       step_mean_s1_loss,
                                                                                                       step_mean_s2_loss))

            step_mean_d1_loss = 0
            step_mean_d2_loss = 0
            step_mean_dgrad_loss = 0
            step_mean_g_loss = 0
            step_mean_g_comp_loss = 0
            step_mean_s1_loss = 0
            step_mean_s2_loss = 0

        if step%STEPS_PER_PLOTS == 0 and step > 0:

            GAN.train_support.plot_loss(loss_plot, step, forma='2_cols', dpi_use=100, titulos=titulos_use, skip_n_initial=10)


            (PET_INPUT_images, 
             CT_INPUT_images, 
             LABELS_INPUT_images, 
             CT_SYNTH_images, 
             SEGMENTED_images, 
             SCORE_images, 
             PSNR_images, 
             ME_images, 
             NMSE_images, 
             NCC_images) = GAN.train_support.validate_whole_volume(comp_gen_model, 
                                                                   disc_model, 
                                                                   dataset_validation_GAN, 
                                                                   [shape_X,shape_Y,shape_Z], 
                                                                   segm_net = SupervisedLoss, 
                                                                   s_model=segm_model,
                                                                   single_image = False)

            GAN.train_support.add_tensorboard_3Dimage(PET_INPUT_images, "PET_INPUT", 0, channel=0, min_val=-1.0, max_val=1.0)
            GAN.train_support.add_tensorboard_3Dimage(CT_INPUT_images, "CT_INPUT", 0, channel=0, min_val=-1.0, max_val=1.0)
            GAN.train_support.add_tensorboard_3Dimage(CT_SYNTH_images, "CT_SYNTH", 0, channel=0, min_val=-1.0, max_val=1.0)

            if SupervisedLoss:
                LABELS_INPUT = np.argmax(LABELS_INPUT_images, axis=-1)
                LABELS_INPUT = np.expand_dims(LABELS_INPUT, axis=-1)
                GAN.train_support.add_tensorboard_3Dimage(LABELS_INPUT, "LABELS", 0, channel=0, min_val=0.0, max_val=3.0)

                LABELS_SYNTH = np.argmax(SEGMENTED_images, axis=-1)
                SEGMENTED_images = np.expand_dims(SEGMENTED_images, axis=-1)
                GAN.train_support.add_tensorboard_3Dimage(SEGMENTED_images, "LABELS_SYNTH", 0, channel=0, min_val=-1.0, max_val=1.0)

                GAN.train_support.add_tensorboard_3Dimage(LABELS_INPUT_images, "LABELS_INPUT_0", 0, channel=0)
                GAN.train_support.add_tensorboard_3Dimage(LABELS_INPUT_images, "LABELS_INPUT_1", 0, channel=1)
                GAN.train_support.add_tensorboard_3Dimage(LABELS_INPUT_images, "LABELS_INPUT_2", 0, channel=2)
                GAN.train_support.add_tensorboard_3Dimage(LABELS_INPUT_images, "LABELS_INPUT_3", 0, channel=3)

                GAN.train_support.add_tensorboard_3Dimage(SEGMENTED_images, "SEGMENTED_0", 0, channel=0)
                GAN.train_support.add_tensorboard_3Dimage(SEGMENTED_images, "SEGMENTED_1", 0, channel=1)
                GAN.train_support.add_tensorboard_3Dimage(SEGMENTED_images, "SEGMENTED_2", 0, channel=2)
                GAN.train_support.add_tensorboard_3Dimage(SEGMENTED_images, "SEGMENTED_3", 0, channel=3)


            print('')
            GAN.train_support.plot_images3D_conditional(PET_INPUT_images,
                                                       CT_INPUT_images, 
                                                       CT_SYNTH_images, 
                                                       segm_net = SupervisedLoss, 
                                                       X_segm=SEGMENTED_images, 
                                                       X_real_segm=LABELS_INPUT_images,
                                                       dpi_use = 150)

            print('Metrics:                              ')
            print('\tCritic Score:\t mean=%0.5f ; std=%0.5f'%(np.array(SCORE_images).mean(),np.array(SCORE_images).std()))
            print('\tPSNR:\t\t mean=%0.5f ; std=%0.5f'%(np.array(PSNR_images).mean(),np.array(PSNR_images).std()))
            print('\tME:\t\t mean=%0.5f ; std=%0.5f'%(np.array(ME_images).mean(),np.array(ME_images).std()))
            print('\tNMSE:\t\t mean=%0.5f ; std=%0.5f'%(np.array(NMSE_images).mean(),np.array(NMSE_images).std()))
            print('\tNCC:\t\t mean=%0.5f ; std=%0.5f'%(np.array(NCC_images).mean(),np.array(NCC_images).std()))

            tf.summary.scalar("valid_disc_score", np.array(SCORE_images).mean(), step=0)
            tf.summary.scalar("valid_PSNR", np.array(PSNR_images).mean(), step=0)
            tf.summary.scalar("valid_ME", np.array(ME_images).mean(), step=0)
            tf.summary.scalar("valid_NMSE", np.array(NMSE_images).mean(), step=0)
            tf.summary.scalar("valid_NCC", np.array(NCC_images).mean(), step=0)


            loss_val_plot[idx_plot, 0] = np.array(SCORE_images).mean()
            loss_val_plot[idx_plot, 1] = np.array(PSNR_images).mean()
            loss_val_plot[idx_plot, 2] = np.array(ME_images).mean()
            loss_val_plot[idx_plot, 3] = np.array(NMSE_images).mean()
            loss_val_plot[idx_plot, 4] = np.array(NCC_images).mean()

            loss_std_val_plot[idx_plot, 0] = np.array(SCORE_images).std()
            loss_std_val_plot[idx_plot, 1] = np.array(PSNR_images).std()
            loss_std_val_plot[idx_plot, 2] = np.array(ME_images).std()
            loss_std_val_plot[idx_plot, 3] = np.array(NMSE_images).std()
            loss_std_val_plot[idx_plot, 4] = np.array(NCC_images).std()

            axis_val_plot[idx_plot, 0] = step
            idx_plot += 1


            plt.figure(dpi=100)
            plt.subplot(3,2,1)
            plt.errorbar(axis_val_plot[:idx_plot,0], loss_val_plot[:idx_plot,0], yerr=loss_std_val_plot[:idx_plot,0], fmt='-o')
            plt.grid('on')
            plt.title('Critic score')
            plt.subplot(3,2,2)
            plt.errorbar(axis_val_plot[:idx_plot,0], loss_val_plot[:idx_plot,1], yerr=loss_std_val_plot[:idx_plot,1], fmt='-o')
            plt.grid('on')
            plt.title('PSNR')
            plt.subplot(3,2,3)
            plt.errorbar(axis_val_plot[:idx_plot,0], loss_val_plot[:idx_plot,2], yerr=loss_std_val_plot[:idx_plot,2], fmt='-o')
            plt.grid('on')
            plt.title('ME')
            plt.subplot(3,2,4)
            plt.errorbar(axis_val_plot[:idx_plot,0], loss_val_plot[:idx_plot,3], yerr=loss_std_val_plot[:idx_plot,3], fmt='-o')
            plt.grid('on')
            plt.title('NMSE')
            plt.subplot(3,2,5)
            plt.errorbar(axis_val_plot[:idx_plot,0], loss_val_plot[:idx_plot,4], yerr=loss_std_val_plot[:idx_plot,4], fmt='-o')
            plt.grid('on')
            plt.title('NCC')
            plt.tight_layout()
            plt.show()


            plt.close('all')
            
            tb_writer.flush()

        if step%STEPS_PER_SAVE == 0 and step > 0:
            GAN.train_support.save_multiple_models([disc_model,comp_gen_model,segm_model], 
                                                         ['discrimnator','compound_generator','segmentatior'],
                                                         CHECKPOINT_PATH,
                                                         NETWORK_NAME,
                                                         name_prefix='_%d'%step)