In [None]:
# matplotlib plots within notebook
%matplotlib inline

import platform
print("python: "+platform.python_version())


import numpy as np
import matplotlib.pyplot as plt

import time

import os, shutil, sys


sys.path.insert(0, '.')
from DeepAttCorr_lib import GAN_3D_lib as GAN
from DeepAttCorr_lib import data_handling as DH
from DeepAttCorr_lib import file_manage_utils as File_mng

import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.backend as K
print('Using TensorFlow version: '+tf.__version__)

from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())

# Paths and Definitions

### Basics

In [None]:
# Network name
NETWORK_NAME = 'DeepAttCorr_3D_Unet_Network'

# Dataset location
DATASET_PATH = './datasets/'

# Checkpoint location
CHECKPOINT_PATH = "./Outputs/"+NETWORK_NAME+"/"

# Clear outputs before running
CLEAR_OUTS = True

### Network Parameters

In [None]:
# Imput volume size
voxels_X = 128
voxels_Y = 128
voxels_Z = 32
input_size = (voxels_X,voxels_Y,voxels_Z)

# Network convolutional channels by resolution level
USE_GEN_net_conv_Channels = [10, 20, 40, 80, 160]
# Convolutional layers by resolution level
USE_GEN_net_conv_Layers = 2

# Hyperbolic tangent or Sigmoid output for generator
USE_GEN_TANH_OUT = False
USE_GEN_SIGMOID_OUT = True
# Use or not segmentation path
USE_GEN_SEGMENTATION = True
# Number of objective clases
USE_GEN_OBJECTIVE_SEGMENTATION_CLASES = 4
# Number of fully connected segmentation layers
USE_GEN_OBJECTIVE_SEGMENTATION_LAYERS = 2
# Number of convolutional segmentation layers
USE_GEN_OBJECTIVE_CONV_SEGMENTATION_LAYERS = 4
# Segmentation kernel size
USE_GEN_SEGM_KERNEL_SIZE = 3
# Use batch normalization for generator training
USE_GEN_BATCH_NORM = False
# Use pixel normalization for generator training
USE_GEN_PIXEL_NORM = True
# Use He scalling of weights for generator
USE_GEN_HE_SCALLING = True
# Wheight initialization standard deviation
USE_GEN_INI_STD = 1.0

### Training

In [None]:
# Dataset full size (without slicing)
DATASET_X_size = 128
DATASET_Y_size = 128
DATASET_Z_size = 256

# Train cicles and checkpoints
CICLES_TRAIN = 100
CICLES_PER_SAVE = 5
EPOCHS_PER_PLOTS = 10
STEPS_PER_EPOCH = 100


# Mini-batch size
BATCH_SIZE_TRAIN = 4
BUFFER_SIZE_TRAIN = 4
BATCH_SIZE_VALIDATION = 4
BUFFER_SIZE_VALIDATION = 4

# Initial step size
step_size = 0.001

# Uniform or custom sampling of the input FOV
# If True the input sample is sliced with uniform probability
# If False, the Cumulative Density Function in CDF_PATH will control the sampling
UNIFORM_SAMPLING = False
CDF_PATH = "./DeepAttCorr_lib/cdf_coef.npy"

# Set-up

In [None]:
File_mng.check_create_path('CHECKPOINT_PATH', CHECKPOINT_PATH, clear_folder=CLEAR_OUTS)

In [None]:
cdf_coef = [1.0]
if not UNIFORM_SAMPLING:
    cdf_coef = np.load(CDF_PATH)

### Dataset Reader

In [None]:
# Create dataset list
shape_X = int(voxels_X)
shape_Y = int(voxels_Y)
shape_Z = int(voxels_Z)

data_size = np.array([int(DATASET_X_size), 
                      int(DATASET_Y_size), 
                      int(DATASET_Z_size)])

input_size_this = (shape_X,shape_Y,shape_Z)

# Get dataset name
train_dataset_name = 'Train_Dataset_%dx%dx%d.tfrecord'%(data_size[0],data_size[1],data_size[2])
validation_dataset_name = 'Validation_Dataset_%dx%dx%d.tfrecord'%(data_size[0],data_size[1],data_size[2])

# Create dataset reading pipelines
PATH_TFRECORD_TRAIN = os.path.join(DATASET_PATH, train_dataset_name)
PATH_TFRECORD_VALIDATION = os.path.join(DATASET_PATH, validation_dataset_name)

dataset_train = tf.data.TFRecordDataset(PATH_TFRECORD_TRAIN)
dataset_validation = tf.data.TFRecordDataset(PATH_TFRECORD_VALIDATION)

if shape_X <= 32: 
    dataset_train = dataset_train.cache()
    dataset_validation = dataset_validation.cache()
    print('Using cache for dataset: %s'%train_dataset_name)

# Create train dataset with transformations
dataset_train = dataset_train.map(lambda x: DH.tf_get_keras_sample(x,
                                                                   data_size, 
                                                                   input_size_this, 
                                                                   not_transformed = True,
                                                                   cdf_sampler_coef=cdf_coef))

# Create validation dataset, unmodified
dataset_validation = dataset_validation.map(lambda x: DH.tf_get_keras_sample(x,
                                                                             data_size, 
                                                                             input_size_this, 
                                                                             not_transformed = True,
                                                                             cdf_sampler_coef=cdf_coef))


# Shuffle the train dataset
dataset_train= dataset_train.shuffle(buffer_size=BUFFER_SIZE_TRAIN, reshuffle_each_iteration=True).repeat(-1)


# Set batch size
dataset_train = dataset_train.batch(batch_size=BATCH_SIZE_TRAIN)
dataset_validation = dataset_validation.batch(batch_size=BATCH_SIZE_VALIDATION)
 

# Model Creation -- Keras API

In [None]:
strategy = tf.distribute.MirroredStrategy()

### Generator

In [None]:
with strategy.scope():

    param_gen = GAN.topologies.Gen_param_structure()

    param_gen.block_conv_channels = USE_GEN_net_conv_Channels
    param_gen.block_conv_layers = USE_GEN_net_conv_Layers
    param_gen.n_blocks = len(param_gen.block_conv_channels)
    param_gen.latent_dim = (voxels_X,voxels_Y,voxels_Z,1)

    param_gen.use_tanh_out = USE_GEN_TANH_OUT
    param_gen.use_sigmoid_out = USE_GEN_SIGMOID_OUT
    param_gen.segmentation_output = USE_GEN_SEGMENTATION
    param_gen.segmentation_classes = USE_GEN_OBJECTIVE_SEGMENTATION_CLASES
    param_gen.segmentation_layers = USE_GEN_OBJECTIVE_SEGMENTATION_LAYERS
    param_gen.conv_segmentation_channels = USE_GEN_OBJECTIVE_CONV_SEGMENTATION_LAYERS
    param_gen.segmentation_kernel_size = USE_GEN_SEGM_KERNEL_SIZE

    param_gen.use_BatchNorm = USE_GEN_BATCH_NORM
    param_gen.use_PixelNorm = USE_GEN_PIXEL_NORM
    param_gen.use_He_scale = USE_GEN_HE_SCALLING
    param_gen.initializer_std = USE_GEN_INI_STD

    # Crea una instancia del modelo
    gen_model = GAN.topologies.define_3D_Vnet_generator(param_gen)

    

In [None]:
gen_model.summary()

In [None]:
tf.keras.utils.plot_model(gen_model, to_file=os.path.join(CHECKPOINT_PATH,'generator_model.png'), show_shapes=True, show_layer_names=True)

### Compile model

In [None]:
with strategy.scope():

    gen_model.compile(optimizer=keras.optimizers.Adam(lr=step_size, beta_1=0.9, beta_2=0.99, epsilon=1.0e-8),
                  loss=GAN.losses.Vnet_compound_loss,
                  metrics=[keras.metrics.mse])


# Train!

In [None]:
for idx_cicle in range(CICLES_TRAIN):
    
    gen_model.fit(dataset_train, steps_per_epoch=STEPS_PER_EPOCH, epochs=EPOCHS_PER_PLOTS)
    
    show_images(dataset_validation)
    
    plt.close('all')
    print('Cicle %d/%d'%((idx_cicle+1), CICLES_TRAIN))

    if idx_cicle%CICLES_PER_SAVE == 0:
        FILENAME_SAVE_USE = NETWORK_NAME+'_%dx%dx%d'%(voxels_X,voxels_Y,voxels_Z)
        
        GAN.train_support.save_model(gen_model, CHECKPOINT_PATH, FILENAME_SAVE_USE)
    