# preparations

In [None]:
# interactive gpu session
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

config = ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.8
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

In [None]:
# seed
from tensorflow.random import set_seed; 
from numpy.random import seed;
set_seed(0)
seed(0)

# imports
from matplotlib import pyplot as plt
from keras.utils.vis_utils import plot_model
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.activations import tanh, sigmoid
from tensorflow.keras.optimizers import Adam

from tensorflow.keras.models import Sequential, Model, save_model
from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose
from tensorflow.keras.layers import Activation, BatchNormalization
from tensorflow.keras.layers import Concatenate, UpSampling2D, Add
from tensorflow.keras.layers import LeakyReLU, Dense, ZeroPadding2D
from tensorflow.keras.layers import Layer #, Dropout
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.image import sobel_edges
from tensorflow import squeeze, concat, stack, Variable
import tensorflow.keras.backend as K
import tensorflow as tf

from glob import glob
import numpy as np
import pywt

# import other notebooks
import import_ipynb
from data_feeder import DataFeeder
from utilities import visualize_training, get_start_time, get_formatted_elapsed_time

importing Jupyter notebook from data_feeder.ipynb
importing Jupyter notebook from utilities.ipynb


In [None]:
# general (de)convolution block
def conv_block(input, nr_filters, kernel_size=3, convolution=True, strides=1, batch_normalization=True, activation=True, use_bias=True):
  init = RandomNormal(stddev=0.02)
  if convolution: out = Conv2D(nr_filters, kernel_size, padding='same', strides=strides, kernel_initializer=init, use_bias=use_bias)(input)
  else: out = Conv2DTranspose(nr_filters, kernel_size, padding='same', strides=strides, kernel_initializer=init, use_bias=use_bias)(input)
  if batch_normalization: out = BatchNormalization()(out)
  if activation: out = LeakyReLU()(out)
  return out

# U-Net generator

In [None]:
# U-Net generator blocks
def decoder_block(input, nr_filters):
  # Deconv, BN, LReLU, Conv, BN, LReLU
  out = conv_block(input, nr_filters, 4, convolution=False, strides=2)
  out = conv_block(out, nr_filters, 3)
  return out

def multi_scale_block(input, nr_filters):
  # Conv, BN, LReLU, Conv, BN, LReLU, Conv
  out = conv_block(input, nr_filters, 3)
  out = conv_block(out, nr_filters, 3)
  out = conv_block(out, nr_filters, 3, batch_normalization=False, activation=False)
  return out

# U-Net generator class
class UNet:
  def __init__(self, input, depth=7, nr_filters=64, max_nr_filters=512, nr_add_scales=0):
    """
    inputs:
     input            input/output shapes (GAN training), DataFeeder object (U-Net training)
     depth            number of encoder blocks
     nr_filters       number of filters of the first block
     max_nr_filters   maximum number of filters
     nr_add_scales    number of multi-scale addition blocks

    outputs:
     UNet model based on input settings
    """

    # use generator as seperate network
    if isinstance(input, DataFeeder):
      self.data_feeder = input
      input_shape = np.load(self.data_feeder.single_energy_dirs[0]).shape
      output_shape = np.load(self.data_feeder.dual_energy_dirs[0]).shape

    # use as generator in gan
    else:
      self.data_feeder = None
      input_shape = input[0]
      output_shape = input[1]
    
    # number of filters array
    self.nr_filters_array = np.minimum(np.ones(depth) * nr_filters * 2 ** np.arange(depth), max_nr_filters).astype(np.int32)
    
    # reset seeds
    set_seed(0)
    seed(0)

    # input
    input = Input(shape=input_shape)

    # encoder
    encoder, encoder_list = self.encoder_path(input)

    # decoder
    decoder, decoder_list = self.decoder_path(encoder, encoder_list, output_shape[2])

    # mutli-scale output
    output = self.multi_scale_path(decoder, decoder_list, output_shape[2], nr_add_scales)

    # final model
    self.model = Model(inputs=input, outputs=output, name=f"unet_ms{nr_add_scales}")

    # adam optimizer for both U-Net and GAN training
    self.optimizer = Adam(learning_rate=2e-4, beta_1=0.9)

    # compile for cnn training
    if self.data_feeder != None:
      self.model.compile(loss='mae', optimizer=self.optimizer)

  def train(self, nr_epochs=100):
    # check if model can be trained as U-Net network
    if self.data_feeder == None:
      print("not trainable, initialize with DataFeeder as input to compile and allow training")
    else:
      # show training data examples
      self.data_feeder.show_data_examples()

      # save losses in list
      self.loss = []
      progress_marker = int(max([self.data_feeder.nr_batches / 30, 1]))

      # train
      train_start_time = get_start_time()
      print(f"training for {nr_epochs} epochs with {self.data_feeder.nr_batches} batches per epoch")
      for epoch_nr in range(1, nr_epochs + 1):
        print(f"epoch {epoch_nr} ", end="")
        batch_start_time = get_start_time()
        for batch_nr in range(self.data_feeder.nr_batches):
          # get source and target images
          se, de_real = self.data_feeder.load_augmented_batch()

          # train batch
          self.loss.append(self.model.train_on_batch(se, de_real))

          # update progress bar based on number of batches
          if batch_nr % progress_marker == 0: print("=", end="")

        # print results of epoch
        print(f" {self.loss[-1]:.3e} - {get_formatted_elapsed_time(batch_start_time)}")
        
        # visualize generator results every few epochs, with more visualization at the end 
        visualize_training(self.data_feeder, self.model, epoch_nr if epoch_nr != nr_epochs else 0)

      # loss
      self.loss = np.array(self.loss)
      print(f"total training time: {get_formatted_elapsed_time(train_start_time)}")

  def save(self, save_dir, name_suffix = ""):
    # save model
    model_sir = f"{save_dir}/{self.model.name}{name_suffix}"
    save_model(self.model, model_sir)

    # save loss
    np.save(f"{model_sir}/model_loss.npy", self.loss)

  def plot_loss(self):
    # plot training loss
    plt.figure()
    plt.plot(self.loss)

  def encoder_path(self, input):
    encoder = input
    encoder_list = []
    for i, nr_filters in enumerate(self.nr_filters_array):
      encoder = conv_block(encoder, nr_filters, 4, strides=2) # encoder block
      encoder_list.append(encoder)
    return encoder, encoder_list

  def decoder_path(self, input, encoder_list, nr_output_channels):
    decoder = input
    decoder_list = []
    for i in range(len(self.nr_filters_array) - 2, -1, -1):
      decoder = decoder_block(decoder, self.nr_filters_array[i]) # decoder block
      decoder_list.append(decoder)
      decoder = Concatenate(axis=3)([encoder_list[i], decoder]) # U-net skip connections
    decoder = conv_block(decoder, nr_output_channels, 4, convolution=False, strides=2, batch_normalization=False, activation=False)
    return decoder, decoder_list

  def multi_scale_path(self, input, decoder_list, nr_output_channels, nr_add_scales):
    multi_scale = input
    for depth in range(1, nr_add_scales + 1):
      scale = 2 ** depth
      scaled = multi_scale_block(decoder_list[-depth], nr_output_channels) # mutli-scale block
      scaled = UpSampling2D((scale, scale), interpolation="bilinear")(scaled)    
      multi_scale = Add()([multi_scale, scaled]) # multiscale elementwise addition
    out = Activation(tanh)(multi_scale)
    return out

# PatchGAN discriminator

In [None]:
class SobelGradient(Layer):
  def __init__(self, **kwargs):
    super(SobelGradient, self).__init__(**kwargs)

  def call(self, inputs):    
    sobel_inputs = sobel_edges(inputs) # get x and y sobel gradients from inputs

    # return concatenation of inputs and their x and y sobel gradients
    return concat([inputs, sobel_inputs[:,:,:,:,0], sobel_inputs[:,:,:,:,1]], axis=3)

In [None]:
class Discriminator:
  def __init__(self, input_shape, depth=4, nr_filters=64, max_nr_filters=512, sobel_gradient=False):
    """
    inputs:
     input_shape      input shape
     depth            number of encoder blocks
     nr_filters       number of filters of the first block
     max_nr_filters   maximum number of filters
     sobel_gradient   Sobel gradient layer on inputs

    outputs:
     Discriminator model based on input settings
    """

    # number of filters array
    self.nr_filters_array = np.int32(np.minimum(np.ones(depth) * nr_filters * 2 ** np.arange(depth), max_nr_filters))
    
    # reset seeds
    set_seed(0)
    seed(0)

    # input
    se_input = Input(shape=input_shape)
    de_input = Input(shape=input_shape)

    # output
    output = self.discriminator_path([se_input, de_input], sobel_gradient)

    # build and compile
    self.model = Model(inputs=[se_input, de_input], outputs=output, name=f"dis_sg{int(sobel_gradient)}")
    self.optimizer = Adam(learning_rate=2e-6, beta_1=0.9)

  def discriminator_path(self, inputs, sobel_gradient):
    out = Concatenate(axis=3)(inputs)
    if sobel_gradient: 
      out = SobelGradient()(out)
    for i, nr_filters in enumerate(self.nr_filters_array):
      out = conv_block(out, nr_filters, 4, strides=2, use_bias=True) # encoder block

    # final dense layer
    out = Dense(1)(out)
    return out

# GAN

In [None]:
class Gan:
  def __init__(self, data_feeder, gen_nr_add_scales=0, dis_sobel_gradient=False):
    """
    inputs:
     data_feeder          DataFeeder object
     gen_nr_add_scales    number of multi-scale blocks for the generator
     dis_sobel_gradient   Sobel gradient layer on discriminator inputs

    outputs:
     GAN model based on input settings
    """

    self.data_feeder = data_feeder
    input_shape = np.load(self.data_feeder.single_energy_dirs[0]).shape
    output_shape = np.load(self.data_feeder.dual_energy_dirs[0]).shape
    self.generator = UNet((input_shape, output_shape), nr_add_scales=gen_nr_add_scales)
    self.discriminator = Discriminator(output_shape, sobel_gradient=dis_sobel_gradient)
    self.name = f"gan({self.generator.model.name}+{self.discriminator.model.name})"
    
    # training initialization
    self.loss_object = BinaryCrossentropy(from_logits=True)
    self.gen_lambda = Variable(200.0)

  def train(self, nr_epochs=100):
    # show training data examples
    self.data_feeder.show_data_examples()

    # initalize/initial calculations
    progress_marker = int(max([self.data_feeder.nr_batches / 30, 1]))
    self.losses = []

    # train
    train_start_time = get_start_time()
    print(f"training for {nr_epochs} epochs with {self.data_feeder.nr_batches} batches per epoch")
    for epoch_nr in range(1, nr_epochs + 1):      
      print(f"epoch {epoch_nr} ", end="")
      batch_start_time = get_start_time()
      for batch_nr in range(self.data_feeder.nr_batches):
        # train with gradient tape step and save losses
        self.train_step_tape()

        # update progress bar based on number of batches
        if batch_nr % progress_marker == 0: print("=", end="")

      # print results of epoch
      print(f" - {self.losses[-1][0]:.3f} {self.losses[-1][1]:.3f} {self.losses[-1][2]:.3f} | {self.losses[-1][3]:.4f} {self.losses[-1][4]:.3f} {self.losses[-1][5]:.2e} - {get_formatted_elapsed_time(batch_start_time)}")

      # visualize generator results every 10 epochs, with more visualization every 50 and at the end 
      visualize_training(self.data_feeder, self.generator.model, epoch_nr if epoch_nr != nr_epochs else 0)

    # loss lists to numpy arrays
    self.losses = np.array(self.losses).T
    print(f"total training time: {get_formatted_elapsed_time(train_start_time)}")

  def train_step_tape(self):
    # get real images
    se, de_real = self.data_feeder.load_augmented_batch()

    # train 
    losses = self.gradient_tape(se, de_real)

    # save losses
    self.losses.append(np.array([loss.numpy() for loss in losses]))

  def get_gen_loss(self, dis_fake_output, de_fake, de_real):
    # adversarial loss
    adv_loss = self.loss_object(tf.ones_like(dis_fake_output), dis_fake_output)

    # L1 loss/mean absolute error
    l1_loss = tf.reduce_mean(tf.abs(de_fake - de_real))

    # adversarial + lambda * L1
    total_gen_loss = adv_loss + self.gen_lambda * l1_loss

    return total_gen_loss, adv_loss, l1_loss
  
  def get_dis_loss(self, dis_real_output, dis_fake_output):
    # real and fake loss
    real_loss = self.loss_object(tf.ones_like(dis_real_output), dis_real_output)
    fake_loss = self.loss_object(tf.zeros_like(dis_fake_output), dis_fake_output)

    # real + fake loss
    total_dis_loss = real_loss + fake_loss

    return total_dis_loss, real_loss, fake_loss
        
  @tf.function # tensorflow function for speed
  def gradient_tape(self, se, de_real):
    # use gradient tape
    with tf.GradientTape() as gen_tape, tf.GradientTape() as dis_tape:
      # get fake images
      de_fake = self.generator.model(se, training=True)

      # # discrminator output on real and fake images
      dis_real_output = self.discriminator.model([se, de_real], training=True)
      dis_fake_output = self.discriminator.model([se, de_fake], training=True)

      # get losses
      gen_tot_loss, gen_adv_loss, gen_l1_loss = self.get_gen_loss(dis_fake_output, de_fake, de_real)
      dis_tot_loss, dis_real_loss, dis_fake_loss = self.get_dis_loss(dis_real_output, dis_fake_output)

    # get gradients
    gen_gradients = gen_tape.gradient(gen_tot_loss, self.generator.model.trainable_variables)
    dis_gradients = dis_tape.gradient(dis_tot_loss, self.discriminator.model.trainable_variables)

    # apply gradients
    self.generator.optimizer.apply_gradients(zip(gen_gradients, self.generator.model.trainable_variables))
    self.discriminator.optimizer.apply_gradients(zip(dis_gradients, self.discriminator.model.trainable_variables))

    return dis_tot_loss, dis_real_loss, dis_fake_loss, gen_tot_loss, gen_adv_loss, gen_l1_loss
    
  def save(self, save_dir, name_suffix = ""):
    # save model
    model_sir = f"{save_dir}/{self.name}{name_suffix}"
    save_model(self.generator.model, model_sir)

    # save losses
    np.save(f"{model_sir}/model_loss.npy", self.losses)

  def plot_losses(self):
    # plot all training losses
    titles = ["total discriminator loss", "discriminator real loss", "discriminator fake loss", "total generator loss", "generator adverserial loss", "generator L1/MAE loss"]
    fix, axs = plt.subplots(2, 3, figsize=(30, 20))
    for i in range(6):
      axs[i // 3, i % 3].plot(self.losses[i]); axs[i // 3, i % 3].set_title(titles[i], fontsize=24)

# visualization

In [None]:
def visualize(plot_save_dir):
  # define example input and output shapes
  input_shape = (512, 512, 1) # (256, 256, 4)
  output_shape = (512, 512, 1) # (256, 256, 4)

  # get UNet generators and PatchGAN discriminators with and without mutli-scale or sobel gradient inputs
  unet_default = UNet((input_shape, output_shape))
  unet_multi_scale = UNet((input_shape, output_shape), nr_add_scales=3)
  discriminator_default = Discriminator(input_shape)
  discriminator_sobel_filter =  Discriminator(input_shape, sobel_gradient=True)

  # print model summaries and save model plots
  models = [unet_default, unet_multi_scale, discriminator_default, discriminator_sobel_filter]
  for model in models:
    model.model.summary()
    plot_model(model.model, f"{plot_save_dir}/{model.model.name}_plot.png", show_shapes=True, expand_nested=True)

In [None]:
def main():
  visualize("models/plots")

if __name__ == "__main__":
  main() # only run when used as main notebook, not when imported

Model: "unet_ms0"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 512, 512, 1) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 256, 256, 64) 1088        input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 256, 256, 64) 256         conv2d[0][0]                     
__________________________________________________________________________________________________
leaky_re_lu (LeakyReLU)         (None, 256, 256, 64) 0           batch_normalization[0][0]        
___________________________________________________________________________________________