In [None]:
%load_ext dotenv
%dotenv

In [None]:
import tensorflow as tf
import numpy as np
import time
import matplotlib.pyplot as plt
import cv2
from tensorflow.keras.utils import to_categorical  # Only for categorical one hot encoding
from tensorflow.keras import layers
from sklearn.metrics import accuracy_score
import tensorflow_datasets as tfds
from tensorflow.keras import backend as K
import tensorboard
import os
import tf_keras as tfk
import keras
from tensorflow_datasets.core.utils import gcs_utils
gcs_utils._is_gcs_disabled = True

In [None]:
!rm -rf ./logs/
MODEL_NAME = 'A01_PROGAN'
MODEL_PATH = os.path.join('models', MODEL_NAME)
TRAIN_LOGDIR = os.path.join("logs", "tensorflow", MODEL_NAME, 'train_data') # Sets up a log directory.
# Creates a file writer for the log directory.
file_writer = tf.summary.create_file_writer(TRAIN_LOGDIR)

In [None]:
celeba = tfds.load('celeb_a', split='train', shuffle_files=True)

In [None]:
lfw = tfds.load('lfw', split='train', shuffle_files=True)

In [None]:
len(lfw), len(celeba)

In [None]:
220*60

In [None]:
def plotImages(imgs):
    fig = plt.figure(figsize=(8, 8))

    for i in range(imgs.shape[0]):
      plt.subplot(8, 8, i+1)
      plt.imshow(tf.cast(imgs[i, :, :, :] * 127.5 + 127.5, tf.uint8))
      plt.axis('off')
    plt.show()

In [None]:
class EqualizeLearningRate(layers.Wrapper):
    """
    Reference from WeightNormalization implementation of TF Addons
    EqualizeLearningRate wrapper works for keras CNN and Dense (RNN not tested).
    ```python
      net = EqualizeLearningRate(
          layers.Conv2D(2, 2, activation='relu'),
          input_shape=(32, 32, 3),
          data_init=True)(x)
      net = EqualizeLearningRate(
          layers.Conv2D(16, 5, activation='relu'),
          data_init=True)(net)
      net = EqualizeLearningRate(
          layers.Dense(120, activation='relu'),
          data_init=True)(net)
      net = EqualizeLearningRate(
          layers.Dense(n_classes),
          data_init=True)(net)
    ```
    Arguments:
      layer: a layer instance.
    Raises:
      ValueError: If `Layer` does not contain a `kernel` of weights
    """

    def __init__(self, layer, **kwargs):
        super(EqualizeLearningRate, self).__init__(layer, **kwargs)
        self._track_trackable(layer, name='layer')
        self.is_rnn = isinstance(self.layer, layers.RNN)

    def build(self, input_shape):
        """Build `Layer`"""
        input_shape = tf.TensorShape(input_shape)
        self.input_spec = layers.InputSpec(
            shape=[None] + input_shape[1:])

        if not self.layer.built:
            self.layer.build(input_shape)

        kernel_layer = self.layer.cell if self.is_rnn else self.layer

        if not hasattr(kernel_layer, 'kernel'):
            raise ValueError('`EqualizeLearningRate` must wrap a layer that'
                             ' contains a `kernel` for weights')

        if self.is_rnn:
            kernel = kernel_layer.recurrent_kernel
        else:
            kernel = kernel_layer.kernel

        # He constant
        self.fan_in, self.fan_out= self._compute_fans(kernel.shape)
        self.he_constant = tf.Variable(1.0 / np.sqrt(self.fan_in), dtype=tf.float32, trainable=False)

        self.v = kernel
        self.built = True
    
    def call(self, inputs, training=True):
        """Call `Layer`"""
        # Multiply the kernel with the he constant.
        kernel = self.v #* self.he_constant
            
        if self.is_rnn:
            print(self.is_rnn)
            self.layer.cell.recurrent_kernel = kernel
            update_kernel = tf.identity(self.layer.cell.recurrent_kernel)
        else:
            self.layer.kernel = kernel
            # update_kernel = tf.identity(self.layer.kernel)

        # Ensure we calculate result after updating kernel.
        # with tf.control_dependencies([update_kernel]):
        outputs = self.layer(inputs)
        return outputs

    def compute_output_shape(self, input_shape):
        return tf.TensorShape(
            self.layer.compute_output_shape(input_shape).as_list())
    
    def _compute_fans(self, shape, data_format='channels_last'):
        """
        From Official Keras implementation
        Computes the number of input and output units for a weight shape.
        # Arguments
            shape: Integer shape tuple.
            data_format: Image data format to use for convolution kernels.
                Note that all kernels in Keras are standardized on the
                `channels_last` ordering (even when inputs are set
                to `channels_first`).
        # Returns
            A tuple of scalars, `(fan_in, fan_out)`.
        # Raises
            ValueError: in case of invalid `data_format` argument.
        """
        if len(shape) == 2:
            fan_in = shape[0]
            fan_out = shape[1]
        elif len(shape) in {3, 4, 5}:
            # Assuming convolution kernels (1D, 2D or 3D).
            # TH kernel shape: (depth, input_depth, ...)
            # TF kernel shape: (..., input_depth, depth)
            if data_format == 'channels_first':
                receptive_field_size = np.prod(shape[2:])
                fan_in = shape[1] * receptive_field_size
                fan_out = shape[0] * receptive_field_size
            elif data_format == 'channels_last':
                receptive_field_size = np.prod(shape[:-2])
                fan_in = shape[-2] * receptive_field_size
                fan_out = shape[-1] * receptive_field_size
            else:
                raise ValueError('Invalid data_format: ' + data_format)
        else:
            # No specific assumptions.
            fan_in = np.sqrt(np.prod(shape))
            fan_out = np.sqrt(np.prod(shape))
        return fan_in, fan_out

# pixel-wise feature vector normalization layer
class PixelNormalization(layers.Layer):
    # initialize the layer
    def __init__(self, **kwargs):
        super(PixelNormalization, self).__init__(**kwargs)
 
    # perform the operation
    def call(self, inputs):
        # computing pixel values
        values = inputs**2.0
        mean_values = K.mean(values, axis=-1, keepdims=True)
        mean_values += 1.0e-8
        l2 = K.sqrt(mean_values)
        normalized = inputs / l2
        return normalized
 
    # define the output shape of the layer
    def compute_output_shape(self, input_shape):
        return input_shape

class PixelNorm(layers.Layer):
  def __init__(self, epsilon=1e-8, name=None):
    super(PixelNorm, self).__init__(name=name)
    self.epsilon = epsilon

  def call(self, x):
    return x * tf.math.rsqrt(tf.reduce_mean(tf.math.square(x), axis=-1, keepdims=True) + self.epsilon)

class FadeAdd(layers.Layer):
  def __init__(self, alpha : tf.Variable = None):
    super(FadeAdd, self).__init__()
    if alpha is None:
      self.alpha = tf.Variable(initial_value=0., trainable=False)
    else :
      self.alpha = alpha

  def incrementAlpha(self, step=0.1):
    self.alpha.assign(tf.minimum(self.alpha+step, 1.))
    # print("New Alpha: ", self.alpha)

  def call(self, input):
    new, old = input
    self.alpha.assign(tf.minimum(self.alpha, 1.))
    return (new*self.alpha) + (old*(1-self.alpha))

class MinibatchStddev(layers.Layer):
  def __init__(self, group_size=4, name=None):
    super(MinibatchStddev, self).__init__(name=name)
    self.group_size = group_size

  def call(self, inputs):
    group_size = tf.minimum(self.group_size, tf.shape(inputs)[0])
    shape = tf.shape(inputs)
    minibatch = tf.reshape(inputs, (group_size, -1, shape[1], shape[2], shape[3]))
    stddev = tf.sqrt(tf.reduce_mean(tf.square(minibatch - tf.reduce_mean(minibatch, axis=0)), axis=0) + 1e-8)
    stddev = tf.reduce_mean(stddev, axis=[1, 2, 3], keepdims=True)
    stddev = tf.tile(stddev, [group_size, shape[1], shape[2], 1])
    return tf.concat([inputs, stddev], axis=-1)

class SelfAttention(layers.Layer):
  def __init__(self, channelReduce=1, name=None):
    super(SelfAttention, self).__init__(name=name)
    self.channelReduce = channelReduce

  def get_config(self):
    config = {'name': self.name}
    base_config = super(SelfAttention, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

  def build(self, input_shape):
    self.channels = input_shape[-1]
    self.filters_f_g = self.channels // self.channelReduce
    self.filters_h = self.channels

    kernel_shape_f_g = (1, 1) + (self.channels, self.filters_f_g)
    kernel_shape_h = (1, 1) + (self.channels, self.filters_h)

    # Create a trainable weight variable for this layer:
    self.gamma = self.add_weight(name='gamma', shape=[1], initializer='zeros', trainable=True)
    self.kernel_f = self.add_weight(shape=kernel_shape_f_g,
                                        initializer='glorot_uniform',
                                        name='kernel_f',
                                        trainable=True)
    self.kernel_g = self.add_weight(shape=kernel_shape_f_g,
                                        initializer='glorot_uniform',
                                        name='kernel_g',
                                        trainable=True)
    self.kernel_h = self.add_weight(shape=kernel_shape_h,
                                        initializer='glorot_uniform',
                                        name='kernel_h',
                                        trainable=True)

    super(SelfAttention, self).build(input_shape)
    self.built = True

  def call(self, input):
    def hw_flatten(x):
      inp_shape = tf.shape(x)
      # inp_shape = x.shape
      shape = [inp_shape[0], inp_shape[1]*inp_shape[2], inp_shape[3]]
      return tf.reshape(x, shape=shape)

    # input = [NHWC]

    f_x =  K.conv2d(input,
                     kernel=self.kernel_f,
                     strides=(1, 1), padding='same')
    g_x =  K.conv2d(input,
                     kernel=self.kernel_g,
                     strides=(1, 1), padding='same')
    h_x =  K.conv2d(input,
                     kernel=self.kernel_h,
                     strides=(1, 1), padding='same')


    f_x_flat = hw_flatten(f_x) # [N(HW)C]
    g_x_flat = hw_flatten(g_x) # [N(HW)C]

    s = K.batch_dot(g_x_flat, K.permute_dimensions(f_x_flat, (0, 2, 1)))

    beta = K.softmax(s, axis=-1)
    o = K.batch_dot(beta, hw_flatten(h_x))

    o = tf.reshape(o, shape=tf.shape(input))  # [bs, h, w, C]
    x = self.gamma * o + input

    return x

def layer_init_stddev(shape, gain=np.sqrt(2)):
  """Get the He initialization scaling term."""
  fan_in = np.prod(shape[:-1])
  return gain / np.sqrt(fan_in)


In [None]:
fmap_base = 8192
fmap_max = 512
fmap_decay = 1.
def nf(stage): return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)

weight_init = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
# weight_init = tf.keras.initializers.HeNormal()
const = None # tf.keras.constraints.max_norm(1.0)

class FakeLayer(layers.Layer):
  def __init__(self, layer, name=None):
    super(FakeLayer, self).__init__(name=name)
    self.layer = layer
    # self.trainable = True

  def call(self, input):
    return self.layer(input)

  def get_config(self):
    config = {'name': self.name, 'layer':self.layer}
    base_config = super(FakeLayer, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

def generatorBase():
  Normalizer = FakeLayer

  resolution = 4
  res = int(np.log2(resolution))

  inputLayer = layers.Input((512,))
  x = inputLayer
  x = PixelNorm()(x)
  x = Normalizer(layers.Dense(4*4*nf(res-1), kernel_constraint=const, kernel_initializer=weight_init, use_bias=False))(x)
  x = layers.LeakyReLU(0.2)(x)
  x = layers.Reshape((4, 4, nf(res-1)))(x)
  x = PixelNorm()(x)
  
  x = Normalizer(layers.Conv2D(nf(res-1), kernel_size=(3, 3), padding='same',
                                         kernel_constraint=const, kernel_initializer=weight_init, use_bias=False))(x)
  x = layers.LeakyReLU(0.2)(x)
  x = PixelNorm(name='final_4')(x)

  out = Normalizer(layers.Conv2D(3, (1, 1), strides=(1, 1), padding='same',
                                           use_bias=False, activation='tanh', kernel_constraint=const,
                                           kernel_initializer=weight_init), name='out_4')(x)
  print(out.shape)
  model = tf.keras.models.Model(inputs=inputLayer, outputs=out)
  return model

def generatorAddStage(gen, resolution=0, freeze=False, alpha:tf.Variable=None):
  print("Current Shape: ", gen.output.shape)
  Normalizer = FakeLayer

  if resolution == 0:
    resolution = gen.output.shape[1] * 2

  res = int(np.log2(resolution))

  if freeze:
    print("Freezing")
    gen.trainable = False

  newDepth = nf(res - 1)
  x = gen.get_layer('final_'+str(resolution // 2)).output
  print("Choosing layer ", x)

  print("New Depth: ", newDepth)

  x = layers.UpSampling2D(size=(2,2), interpolation='bicubic')(x)
  # x = layers.Resizing(resolution, resolution, interpolation='nearest')(x)
  x = Normalizer(layers.Conv2D(newDepth, (3, 3), strides=(1, 1), padding='same',
                                         kernel_constraint=const, kernel_initializer=weight_init, use_bias=False))(x)
  x = layers.LeakyReLU(0.2)(x)
  x = PixelNorm()(x)

  x = Normalizer(layers.Conv2D(newDepth, (3, 3), strides=(1, 1), padding='same',
                                         kernel_constraint=const, kernel_initializer=weight_init, use_bias=False))(x)
  x = layers.LeakyReLU(0.2)(x)
  x = PixelNorm(name='final_'+str(resolution))(x)

  out = Normalizer(layers.Conv2D(3, (1, 1), strides=(1, 1), padding='same',
                                           activation='tanh', kernel_constraint=const, kernel_initializer=weight_init, use_bias=False), name='out_'+str(resolution))(x)
  print("New Shape: ", out.shape)

  # Add prev output
  lastOut = gen.get_layer('out_'+str(resolution // 2)).output
  up = layers.UpSampling2D((2,2), interpolation='bicubic')(lastOut)
  # up = layers.Resizing(resolution, resolution, interpolation='nearest')(lastOut)

  alpha = FadeAdd(alpha)
  out = alpha([out, up])

  inputLayer = gen.input
  model = tf.keras.models.Model(inputs=inputLayer, outputs=out)
  return model, alpha

def reBaseModel(layers, inpTensor):
  layer = inpTensor
  # print("Rebasing")
  for i in range(len(layers)):
    layer = layers[i](layer)
  # print("Done")
  return layer

def descriminatorBase():
  Normalizer = layers.SpectralNormalization
  resolution = 4
  res = int(np.log2(resolution))

  inputLayer = layers.Input((4, 4, 3))
  # x = layers.GaussianNoise(0.00)(inputLayer)
  x = inputLayer

  x = MinibatchStddev(name='sup_mbatch_4')(x)
  x = Normalizer(layers.Conv2D(nf(res-1), (1, 1), padding='same',
                                                  kernel_constraint=const, kernel_initializer=weight_init, use_bias=False), name='sup_conv_4')(x)
  x = layers.LeakyReLU(0.2, name='sup_act_4')(x)

  baseLayers = []
  baseLayers.append(
      Normalizer(
          layers.Conv2D(nf(res-1), kernel_size=(3, 3), strides=(1, 1), padding='same',
                        kernel_constraint=const, kernel_initializer=weight_init, use_bias=False),
                        name='depth_4'
          )
  )
  baseLayers.append(layers.LeakyReLU(0.2))

  # baseLayers.append(layers.BatchNormalization())
  baseLayers.append(layers.Flatten())
  baseLayers.append(layers.Dense(1, kernel_constraint=const, kernel_initializer=weight_init, use_bias=False))

  encOut = reBaseModel(baseLayers[:-1], x)
  desOut = reBaseModel(baseLayers, x)

  dis = tf.keras.models.Model(inputs=inputLayer, outputs=desOut)
  enc = tf.keras.models.Model(inputs=inputLayer, outputs=encOut)
  return dis, enc, baseLayers

def descriminatorAddStage(des, enc, baseLayers, resolution=0, freeze=False, alpha:tf.Variable=None):
  Normalizer = layers.SpectralNormalization
  print("Current Shape: ", enc.input.shape)

  if freeze:
    print("Freezing")
    des.trainable = False
    enc.trainable = False

  print("Previous Input layer ", enc.input)

  if resolution == 0:
    resolution = enc.input.shape[1] * 2

  newSize = resolution

  res = int(np.log2(resolution))

  print("New input ", newSize, newSize)

  inputLayer = layers.Input((newSize, newSize, 3))
  # inp = layers.GaussianNoise(0.00)(inputLayer)
  inp = inputLayer

  processingLayers = []
  processingLayers.append(MinibatchStddev(name='sup_mbatch_'+str(resolution)))
  processingLayers.append(Normalizer(layers.Conv2D(nf(res-1), (1, 1), padding='same',
                                                                      kernel_constraint=const, kernel_initializer=weight_init, use_bias=False), name='sup_conv_'+str(resolution))
  )
  processingLayers.append(layers.LeakyReLU(name='sup_act_'+str(resolution), alpha=0.2))

  x = reBaseModel(processingLayers, inp)

  newLayers = []
  newLayers.append(Normalizer(layers.Conv2D(nf(res-1), (3, 3), strides=(1, 1), padding='same',
                                                      kernel_constraint=const, kernel_initializer=weight_init, use_bias=False), name='depth_'+str(resolution)))
  # newLayers.append(layers.BatchNormalization(momentum=0.8))
  newLayers.append(layers.LeakyReLU(0.2))

  newLayers.append(Normalizer(layers.Conv2D(nf(res-2), (3, 3), strides=(1, 1), padding='same',
                                                      kernel_constraint=const, kernel_initializer=weight_init, use_bias=False), name='depth2_'+str(resolution)))
  # newLayers.append(layers.BatchNormalization(momentum=0.8))
  newLayers.append(layers.LeakyReLU(0.2))
  newLayers.append(layers.AveragePooling2D(2, 2))
  # newLayers.append(layers.Resizing(resolution // 2, resolution // 2, interpolation='nearest'))  

  newInp = reBaseModel(newLayers, x)
  # print(newInp, baseLayers)

  small = layers.AveragePooling2D((2, 2))(inp)
  # small = layers.Resizing(resolution // 2, resolution // 2, interpolation='nearest')(inp)
  sup = des.get_layer('sup_mbatch_'+str(resolution // 2))(small)
  sup = des.get_layer('sup_conv_'+str(resolution // 2))(sup)
  sup = des.get_layer('sup_act_'+str(resolution // 2))(sup)

  print("====>", sup)

  print(newInp.shape, sup.shape)
  beta = FadeAdd(alpha=alpha)
  out = beta([newInp, sup])

  print("==>", out)

  desOut = reBaseModel(baseLayers, out)
  encOut = reBaseModel(baseLayers[:-1], out)

  des = tf.keras.models.Model(inputs=inputLayer, outputs=desOut)
  enc = tf.keras.models.Model(inputs=inputLayer, outputs=encOut)

  baseLayers = newLayers + baseLayers
  return des, enc, baseLayers, beta

def generateBaseModels():
  gen = generatorBase()
  des, enc, baseLayers = descriminatorBase()
  return gen, des, enc, baseLayers

In [None]:
gen, des, enc, baseLayers = generateBaseModels()
print(gen.summary(), des.summary())

for i in range(5):
  print("Adding Stage ", i, " to generator")
  gen, alpha = generatorAddStage(gen, freeze=False)
  print("Adding Stage ", i, " to descriminator")
  des, enc, baseLayers, beta = descriminatorAddStage(des, enc, baseLayers, freeze=False)
  print(gen.summary(), des.summary())

In [None]:
print(gen.summary(), des.summary())

In [None]:
tf.keras.utils.plot_model(gen, to_file="generator.png", expand_nested=True, show_shapes=True, show_layer_names=True, dpi=96, show_trainable=True)

In [None]:
tf.keras.utils.plot_model(des, to_file="descriminator.png", expand_nested=True, show_shapes=True, show_layer_names=True, dpi=96, show_trainable=True)

In [None]:
def lerp(a, b, t):
    return a + (b - a) * t

def gradientPenalty(des, reals, fakes, batch_size):
    """Calculates the gradient penalty.

    This loss is calculated on an interpolated image
    and added to the discriminator loss.
    """
    # Get the interpolated image
    mixing_factors = tf.random.uniform([batch_size, 1, 1, 1], 0.0, 1.0, dtype=tf.float32)
    mixed_images = lerp(reals, fakes, mixing_factors)

    with tf.GradientTape() as gp_tape:
        gp_tape.watch(mixed_images)
        # 1. Get the discriminator output for this interpolated image.
        mixed_output = des(mixed_images, training=True)

    # 2. Calculate the gradients w.r.t to this interpolated image.
    mixed_gradients = gp_tape.gradient(mixed_output, [mixed_images])[0]
    # 3. Calculate the norm of the gradients.
    mixed_norms = tf.sqrt(tf.reduce_sum(tf.square(mixed_gradients), axis=[1, 2, 3]))
    gradient_penalty = tf.reduce_mean((mixed_norms - 1.0) ** 2)
    return gradient_penalty

def descriminator_WGANGPloss(reals, fakes, des, batch_size, smooth=1, wgan_target=1., wgan_lambda=10., wgan_epsilon=0.001):
    real_output = des(reals, training=True)
    fake_output = des(fakes, training=True)
    loss = discriminator_loss(real_output, fake_output)

    gradient_penalty = gradientPenalty(des, reals, fakes, batch_size)
    # gradient_penalty = tf.math.square(mixed_norms - wgan_target)
    total_loss = loss + (gradient_penalty * (wgan_lambda / (wgan_target**2))) + wgan_epsilon*real_output
    # return tf.reduce_mean(total_loss)
    return total_loss, real_output, fake_output

def discriminator_loss(real_output, fake_output):
    # wgan_loss = fake_output - real_output
    # return wgan_loss
    real_loss = tf.reduce_mean(real_output)
    fake_loss = tf.reduce_mean(fake_output)
    return fake_loss - real_loss

def generator_loss(fake_output):
    total_loss = -tf.reduce_mean(fake_output)
    return total_loss

def generator_enc_loss(real, fake):
  # return tf.reduce_mean(tf.abs(real - fake))
  return tf.abs(real - fake)

def generator_hinge_loss(fake_output):
  total_loss = -tf.reduce_mean(fake_output)
  return total_loss

def descriminator_hinge_loss(reals, fakes, des, batch_size, apply_penalty=True, wgan_target=1., penalty_lambda=10):
  with tf.GradientTape() as gp_tape:
    gp_tape.watch(reals)
    real_output = des(reals, training=True)
    fake_output = des(fakes, training=True)

  real_loss = tf.reduce_mean(tf.nn.relu(1.0 - real_output))
  fake_loss = tf.reduce_mean(tf.nn.relu(1.0 + fake_output))
  des_loss = real_loss + fake_loss

  if apply_penalty:
    gradient = gp_tape.gradient(real_output, [reals])[0]
    penalty = K.mean(K.sum(tf.math.square(gradient), axis=np.arange(1, len(gradient.shape)))) * penalty_lambda
    des_loss += penalty
  return des_loss, real_output, fake_output

In [None]:
# @tf.function
def trainGenEnc(gen, enc, real, batch_size, coeff=1, generator_optimizer=None, enc_optimizer=None):
  with tf.GradientTape() as enc_tape, tf.GradientTape() as gen_tape:
    real_enc = enc(real, training=True)
    enc_fake = gen(real_enc, training=True)
 
  gen_loss = generator_enc_loss(real, enc_fake) * coeff
  gradients_of_generator = gen_tape.gradient(gen_loss, gen.trainable_variables)
  generator_optimizer.apply_gradients(zip(gradients_of_generator, gen.trainable_variables))
 
  # gradients_of_enc = enc_tape.gradient(gen_loss, enc.trainable_variables)
  # enc_optimizer.apply_gradients(zip(gradients_of_enc, enc.trainable_variables))
  return gen_loss
 
# @tf.function
def trainDes(gen, des, real, batch_size, hinge=False, discriminator_optimizer=None):
  with tf.GradientTape() as disc_tape:
    noise = tf.random.normal([batch_size, 512])
 
    fake = gen(noise, training=True)
    
    if hinge:
      des_loss, real_output, fake_output = descriminator_hinge_loss(real, fake, des, batch_size)
    else:
      des_loss, real_output, fake_output = descriminator_WGANGPloss(real, fake, des, batch_size)

  gradients_of_discriminator = disc_tape.gradient(des_loss, des.trainable_variables)
  discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, des.trainable_variables))
  return des_loss
 
# @tf.function
def trainDesGen(gen, des, real, batch_size, hinge=False, generator_optimizer=None, discriminator_optimizer=None):
  with tf.GradientTape() as disc_tape, tf.GradientTape() as gen_tape:
    noise = tf.random.normal([batch_size, 512])
    fake = gen(noise, training=True)
    
    if hinge:
      des_loss, real_output, fake_output = descriminator_hinge_loss(real, fake, des, batch_size)
      gen_loss = generator_hinge_loss(fake_output)
    else:
      des_loss, real_output, fake_output = descriminator_WGANGPloss(real, fake, des, batch_size)
      gen_loss = generator_loss(fake_output)
 
  gradients_of_discriminator = disc_tape.gradient(des_loss, des.trainable_variables)
  discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, des.trainable_variables))
  
  gradients_of_generator = gen_tape.gradient(gen_loss, gen.trainable_variables)
  generator_optimizer.apply_gradients(zip(gradients_of_generator, gen.trainable_variables))
  return des_loss, gen_loss

# @tf.function
def trainGen(gen, des, batch_size, hinge=False, generator_optimizer=None):
  with tf.GradientTape() as gen_tape:
    noise = tf.random.normal([batch_size, 512])
 
    fake = gen(noise, training=True)
    fake_output = des(fake, training=True)
 
    if hinge:
      gen_loss = generator_hinge_loss(fake_output)
    else:
      gen_loss = generator_loss(fake_output)

    # print("Gen Loss: ", gen_loss, fake_output)
  gradients_of_generator = gen_tape.gradient(gen_loss, gen.trainable_variables)
  generator_optimizer.apply_gradients(zip(gradients_of_generator, gen.trainable_variables))
  return gen_loss
 
def evalGan(gen, des, data, batches, batch_size):
  desAcc = 0
  genLoss = 0
  for i in range(batches):
    real = data
    fake = gen(tf.random.normal([batch_size, 512]), training=False)

    real_output = des(real, training=False)
    fake_output = des(fake, training=False)
    
    output = tf.concat((fake_output, real_output), axis=0)

    labels = tf.reshape(tf.concat((tf.zeros_like(fake_output), tf.ones_like(real_output)), axis=0), [-1])
    output = tf.reshape(output, [-1])

    acc = tf.keras.metrics.binary_accuracy(labels, output, threshold=0.5)
    desAcc += acc.numpy()
    # print(acc)
    genLoss += tf.reduce_sum(generator_loss(fake_output)).numpy() / batch_size
  return desAcc / batches, genLoss / batches

from IPython.display import clear_output

def augmenter(size, alpha=None, method='area'):
  if alpha is not None:
    @tf.function
    def augment(sample):
      image = (tf.cast(sample['image'], tf.float32) - 127.5) / 127.5
      big = tf.image.resize(image, [size, size], method=method, antialias=True)
      small =  tf.image.resize(image, [size//2, size//2], method=method, antialias=True)
      small = tf.image.resize(small, [size, size], method='area')
      image = (big*alpha.alpha) + (small*(1-alpha.alpha))
      image = tf.image.random_flip_left_right(image)
      return {'image':image}
    return augment
  else:
    @tf.function
    def augment(sample):
      image = tf.image.resize(sample['image'], [size, size], method=method, antialias=True)
      image = (tf.cast(image, tf.float32) - 127.5) / 127.5
      image = tf.image.random_flip_left_right(image)
      return {'image':image}
    return augment
    
# @tf.function
def trainGan(
  genModel, desModel, encModel, desBaseLayers,
  data, 
  name=MODEL_NAME, 
  modeldir=MODEL_PATH, 
  trainingConf={
    4:{"epochs":7, "batch_size":64, "alpha_step":0.1, "alpha_delay":0, "alpha_multiplier":1.2}, 
    8:{"epochs":20, "batch_size":64, "alpha_step":0.07, "alpha_delay":0.2, "alpha_multiplier":1.2}, 
    16:{"epochs":26, "batch_size":64, "alpha_step":0.07, "alpha_delay":0.2, "alpha_multiplier":1.2}, 
    32:{"epochs":32, "batch_size":48, "alpha_step":0.07, "alpha_delay":0.2, "alpha_multiplier":1.2}, 
    64:{"epochs":40, "batch_size":32, "alpha_step":0.07, "alpha_delay":0.3, "alpha_multiplier":1.2}, 
    128:{"epochs":50, "batch_size":32, "alpha_step":0.07, "alpha_delay":0.3, "alpha_multiplier":1.2}, 
    256:{"epochs":55, "batch_size":32, "alpha_step":0.07, "alpha_delay":0.3, "alpha_multiplier":1.2}
  }, 
  des_steps=2, 
  gen_steps=1,
):
  realData = data
  # print(realData.shape)
  noise = tf.random.normal([64, 512])
  results = []
  gen_alpha, des_alpha = None, None
  initialCoeff = 1.
  scaleSizes = sorted(trainingConf.keys())
  
  globalIter = tf.Variable(0, dtype=tf.int64)
  
  # gen, des, enc, baseLayers = genModel, desModel, encModel, desBaseLayers
  
  modelMap = {
    4 : {"gen":genModel, "des":desModel, "enc":encModel, "gen_alpha":None, "des_alpha":None, "baseLayers":desBaseLayers},
  }
  
  for size in scaleSizes[1:]:
    print("Adding Stage ", size, " to generator")
    lastModel = modelMap[size // 2]
    des, enc, baseLayers, des_alpha = descriminatorAddStage(des=lastModel["des"], enc=lastModel["enc"], baseLayers=lastModel["baseLayers"], freeze=False)
    gen, gen_alpha = generatorAddStage(gen=lastModel["gen"], freeze=False)
    modelMap[size] = {"gen":gen, "des":des, "enc":enc, "gen_alpha":gen_alpha, "des_alpha":des_alpha, "baseLayers":baseLayers}
    
  for size in scaleSizes:
    conf = trainingConf[size]
    epochs = conf['epochs']
    batch_size = conf['batch_size']
    alphaStep = conf['alpha_step']
    alphaDelay = conf['alpha_delay']
    alphaMultiplier = conf['alpha_multiplier']
    
    alphaDelay = int(alphaDelay * epochs)
    shouldApplyAlpha = tf.Variable(False, dtype=tf.bool)
    alphaStep = tf.Variable(alphaStep, dtype=tf.float32)
    
    models = modelMap[size]
    gen, des, enc, gen_alpha, des_alpha, desBaseLayers = models["gen"], models["des"], models["enc"], models["gen_alpha"], models["des_alpha"], models["baseLayers"]

    ideal_batch_size = int((nf(int(np.log2(size)) - 2) / 512) * batch_size)
    print("Current Batch Size", batch_size, "ideal batch size", ideal_batch_size, "alpha delay", alphaDelay)
    coeff = initialCoeff
    print("Input shape: ",des.input.shape, "Current scale: ", size)
    currentData = realData.map(augmenter(size, alpha=None)).shuffle(4096).batch(batch_size, drop_remainder=True).repeat().prefetch(tf.data.experimental.AUTOTUNE)
    iterData = iter(currentData)
    REAL = next(iterData)['image']
    
    steps = len(realData) // batch_size

    generator_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4, beta_1 = 0, beta_2 = 0.999)
    discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4, beta_1 = 0, beta_2 = 0.999)
    enc_optimizer = tf.keras.optimizers.Adam(learning_rate=3e-4, beta_1 = 0, beta_2 = 0.999)
    def getTrainers():
      def _trainDes(real):
        return trainDes(gen=gen, des=des, real=real, batch_size=batch_size, discriminator_optimizer=discriminator_optimizer, hinge=False)

      def _trainGen():
        return trainGen(gen=gen, des=des, batch_size=batch_size, hinge=False, generator_optimizer=generator_optimizer)
      
      def _trainDesGen(real):
        return trainDesGen(gen=gen, des=des, real=real, batch_size=batch_size, 
                    generator_optimizer=generator_optimizer, discriminator_optimizer=discriminator_optimizer, hinge=True)

      def _trainGenEnc(real, coeff):
        return trainGenEnc(gen=gen, des=des, real=real, batch_size=batch_size, coeff=coeff,
                    discriminator_optimizer=discriminator_optimizer, hinge=False, 
                    generator_optimizer=generator_optimizer, enc_optimizer=enc_optimizer)

      return tf.function(_trainDes), tf.function(_trainGen), tf.function(_trainGenEnc), tf.function(_trainDesGen)

    _trainDes, _trainGen, _trainGenEnc, _trainDesGen = getTrainers()
    
    currentStep = tf.Variable(0, dtype=tf.int64)
    
    @tf.function
    def trainStep(dataIterator): 
      globalIter.assign_add(1)
      for _ in range(des_steps):
        batch = next(dataIterator)
        real = batch['image']
        des_loss = _trainDes(real)
        currentStep.assign_add(1)
      
      for _ in range(gen_steps):
        gen_loss = _trainGen()
        # gen_loss = trainGen(gen=gen, des=des, batch_size=batch_size, hinge=True, generator_optimizer=generator_optimizer)

      # batch = next(dataIterator)
      # real = batch['image']
      # des_loss, gen_loss = _trainDesGen(real)

      if des_alpha != None and gen_alpha != None and shouldApplyAlpha.read_value():
        alphaIncr = float(alphaStep.read_value() / float(steps))
        # Slow down alpha increment
        des_alpha.incrementAlpha(alphaIncr)
        gen_alpha.incrementAlpha(alphaIncr)
        print("Alpha, Beta: ", gen_alpha.alpha, des_alpha.alpha)
        
      return des_loss, gen_loss
    
    def trainSteps(dataIterator, steps):
      for _ in range(steps):
        try:
          des_loss, gen_loss = trainStep(dataIterator)
          if globalIter % 10 == 0:
            # print("==>Des Loss: ", des_loss, "Gen Loss: ", gen_loss)
            with file_writer.as_default():
                tf.summary.scalar('Generator Loss', gen_loss, step=globalIter)
                tf.summary.scalar('Descriminator Loss', des_loss, step=globalIter)
          # if currentStep > steps:
          #   print("Step: ", currentStep)
          #   break
        except Exception as e:
          print(e)
          break

    fake = gen(noise, training=False)
    real = REAL
    print("Evaluating at the start of epoch:")
    desAcc, genLoss = evalGan(gen, des, real, 10, batch_size)
    print("Real: ")
    plotImages(real)

    print("Fake: ")
    plotImages(fake)
    
    for epoch in range(epochs):
      print("Running epoch ", epoch)
      t = time.time()
      
      shouldApplyAlpha.assign(epoch > alphaDelay)
      
      trainSteps(iterData, steps)
          
      alphaStep.assign(alphaStep * alphaMultiplier)
      print("Alpha step: ", alphaStep.read_value())
      coeff *= 0.9
      fake = gen(noise, training=False)
      real = REAL

      print("Evaluating:", time.time() - t)
      desAcc, genLoss = evalGan(gen, des, real, 10, batch_size)
      results.append({'desAcc':desAcc, 'genLoss':genLoss})
      print("Epoch ", epoch, "Descriminator Accuracy ", desAcc, "Generator Loss ", genLoss, "of ", epochs, "Epochs")
      with file_writer.as_default():
        tf.summary.scalar('Descriminator Accuracy at Epoch', desAcc, step=epoch)
        tf.summary.scalar('Generator Loss at Epoch', genLoss, step=epoch)
        # Share images of fake data generated
        tf.summary.image('Fake Images', fake, step=epoch)

      print("Real: ")
      plotImages(real)

      print("Fake: ")
      plotImages(fake)

      if desAcc > 0.8:
        coeff *= 2
        coeff = min(coeff, 1.)

      if des_alpha != None and gen_alpha != None:
        print("Final Alpha, Beta: ", gen_alpha.alpha, des_alpha.alpha)

      
    # _ = input("Press Enter to continue")
    gen.save(modeldir + name + '_' + str(size) + '_gen.keras')
    des.save(modeldir + name + '_' + str(size)  + '_des.keras')
    enc.save(modeldir + name + '_' + str(size)  + '_enc.keras')
    # des, enc, baseLayers, des_alpha = descriminatorAddStage(des, enc, baseLayers, freeze=False)
    # gen, gen_alpha = generatorAddStage(gen, freeze=False)
    epochs *= 1.2
    epochs = int(epochs)
    initialCoeff *= 0.6
    clear_output(wait=True)

In [None]:
gen, des, enc, baseLayers = generateBaseModels()
trainGan(gen, des, enc, baseLayers, data=lfw, des_steps=3, gen_steps=1, name=MODEL_NAME, 
  trainingConf={
    4:{"epochs":7, "batch_size":64, "alpha_step":0.1, "alpha_delay":0, "alpha_multiplier":1.2}, 
    8:{"epochs":23, "batch_size":64, "alpha_step":0.015, "alpha_delay":0.15, "alpha_multiplier":1.15}, 
    16:{"epochs":30, "batch_size":64, "alpha_step":0.010, "alpha_delay":0.15, "alpha_multiplier":1.1}, 
    32:{"epochs":30, "batch_size":64, "alpha_step":0.010, "alpha_delay":0.15, "alpha_multiplier":1.1}, 
    64:{"epochs":30, "batch_size":64, "alpha_step":0.01, "alpha_delay":0.15, "alpha_multiplier":1.1}, 
    128:{"epochs":30, "batch_size":64, "alpha_step":0.01, "alpha_delay":0.15, "alpha_multiplier":1.1}, 
    256:{"epochs":30, "batch_size":64, "alpha_step":0.01, "alpha_delay":0.15, "alpha_multiplier":1.1}
  }, )