In [None]:
import os
import tensorflow as tf
from constants import __CHUNK_SAMPLE__, __DATASET_DENOISING_PATH__,__PREPROC_DATA_DENOISING_PATH__, __RESULT_DIR__
from dataloader import get_dataset
from models import cycleGAN, Inception_Unet,  Discriminator
from callbacks import setup_callbacks
import matplotlib.pyplot as plt

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"  
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"

In [None]:
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    gpu_name = gpus[0]  

In [None]:
__TRAIN_PCT__ = [0.80, 0.10, 0.10]
__BATCH_SIZE__ = 32
__N_EPOCHS__ = 100
__INITIAL_EPOCHS__ = 0
__SHOW_MODELS_PLOT__ = True
__N_CHANNEL__ = 3

gen_n_layers = 3
gen_filters = 16
gen_inception_layers = 0
gen_inception_receptive_fields = [5, 9, 13]
gen_traditional_block_kernel = 9
gen_recurrent_skip = False
gen_recurrent_exit = False
gen_exit_recurrent_units = None
variable_kernel_size_g = False

disc_n_layers = 5
disc_filters = 8
disc_kernel_size = 7
variable_kernel_size_d = False

lambda_cycle = 10.0
lambda_identity = 5.0
lambda_adv = 1.0

In [None]:
train_ds,val_ds,test_ds = get_dataset(data_dir=__DATASET_DENOISING_PATH__,
                                          percentages=__TRAIN_PCT__,
                                          batch_size=__BATCH_SIZE__,
                                          )

In [None]:
genB = Inception_Unet(signal_len=__CHUNK_SAMPLE__,
                             n_channels=__N_CHANNEL__,
                             n_layers=gen_n_layers,
                             filters=gen_filters,
                             inception_layers=gen_inception_layers,
                             inception_receptive_fields=gen_inception_receptive_fields,
                             traditional_block_kernel=gen_traditional_block_kernel,
                             recurrent_skip=gen_recurrent_skip,
                             recurrent_exit=gen_recurrent_exit,
                             exit_recurrent_units=gen_exit_recurrent_units,
                             variable_kernel_size = variable_kernel_size_g,
                             )

genA = Inception_Unet(signal_len=__CHUNK_SAMPLE__,
                             n_channels=__N_CHANNEL__,
                             n_layers=gen_n_layers,
                             filters=gen_filters,
                             inception_layers=gen_inception_layers,
                             inception_receptive_fields=gen_inception_receptive_fields,
                             traditional_block_kernel=gen_traditional_block_kernel,
                             recurrent_skip=gen_recurrent_skip,
                             recurrent_exit=gen_recurrent_exit,
                             exit_recurrent_units=gen_exit_recurrent_units,
                             variable_kernel_size = variable_kernel_size_g,
                             )

discX = Discriminator(signal_len=__CHUNK_SAMPLE__,
                              n_channels=__N_CHANNEL__,
                              n_layers=disc_n_layers,
                              filters=disc_filters,
                              kernel_size=disc_kernel_size,
                              variable_kernel_size = variable_kernel_size_d)

discY = Discriminator(signal_len=__CHUNK_SAMPLE__,
                              n_channels=__N_CHANNEL__,
                              n_layers=disc_n_layers,
                              filters=disc_filters,
                              kernel_size=disc_kernel_size,
                              variable_kernel_size = variable_kernel_size_d)

GANmodel = cycleGAN(generator_A=genA,
                    generator_B=genB,
                    discriminator_X=discX,
                    discriminator_Y=discY,
                    lambda_cycle=lambda_cycle,
                    lambda_identity=lambda_identity,
                    lambda_adv=lambda_adv
                    )

GANmodel.compile()

if not (__SHOW_MODELS_PLOT__):
    plt.ioff()
tf.keras.utils.plot_model(model=discX,
                          to_file="model_discriminator.png",
                          show_shapes=True,
                          show_layer_names=True,
                          expand_nested=True,
                          show_layer_activations=True,
                          show_trainable=False
                          )

tf.keras.utils.plot_model(model=genA,
                          to_file="model_generator.png",
                          show_shapes=True,
                          show_layer_names=True,
                          expand_nested=False,
                          show_layer_activations=True,
                          show_trainable=False
                          )
if not (__SHOW_MODELS_PLOT__):
    plt.ion()

In [None]:
callbacks = setup_callbacks(model=GANmodel,
                            save_logs=True,
                            save_checkpoints=True,
                            save_plots_denoising=True,
                            logs_dir=os.path.join(__RESULT_DIR__, "logs"),
                            ckpts_dir=os.path.join(__RESULT_DIR__, "models"),
                            plots_dir=os.path.join(__RESULT_DIR__, "plots"),
                            validation_data=val_ds,
                            ckpt_monitor="val_loss_A", 
                            ckpt_mode="always", 
                            plots_num=1,
                            plots_freq=1,
                            clear_every=5,
                            plot_labels=["Chest",
                                         "Finger",
                                         "Pred"],
                            ckpt_warmup= 95,
                            )

In [None]:

history = GANmodel.fit(train_ds,
                       validation_data=val_ds, 
                       batch_size=__BATCH_SIZE__,
                       epochs=__N_EPOCHS__ + __INITIAL_EPOCHS__,
                       initial_epoch=__INITIAL_EPOCHS__,
                       callbacks=callbacks
                       )