In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import cv2
from keras.utils import to_categorical  # Only for categorical one hot encoding
from tensorflow.keras import layers
from sklearn.metrics import accuracy_score
import tensorflow_datasets as tfds
from tensorflow.keras import backend as K

In [None]:
lfw = tfds.load('lfw', split='train', shuffle_files=True)

In [None]:
def plotImages(imgs):
    fig = plt.figure(figsize=(8, 8))

    for i in range(imgs.shape[0]):
      plt.subplot(8, 8, i+1)
      plt.imshow(tf.cast(imgs[i, :, :, :] * 127.5 + 127.5, tf.uint8))
      plt.axis('off')
    plt.show()

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
def lerp(a, b, t):
    return a + (b - a) * t

def discriminator_loss(real_output, fake_output, smooth=1, wgan_target=1., wgan_lambda=10.):
    wgan_loss = fake_output - real_output
    return wgan_loss

def descriminator_WGANGPloss(reals, fakes, des, batch_size, smooth=1, wgan_target=1., wgan_lambda=10., wgan_epsilon=0.001):
    real_output = des(reals, training=True)
    fake_output = des(fakes, training=True)
    wgan_loss = discriminator_loss(real_output, fake_output)

    mixing_factors = tf.random.uniform([batch_size, 1, 1, 1], 0.0, 1.0, dtype=tf.float32)

    mixed_images = lerp(reals, fakes, mixing_factors)

    with tf.GradientTape() as gp_tape:
      gp_tape.watch(mixed_images)
      # 1. Get the discriminator output for this interpolated image.
      mixed_output = des(mixed_images, training=True)

    mixed_gradients = gp_tape.gradient(mixed_output, [mixed_images])[0]
    mixed_norms = tf.sqrt(tf.reduce_sum(tf.math.square(mixed_gradients), axis=[1,2,3]))
    gradient_penalty = tf.math.square(mixed_norms - wgan_target)
    total_loss = wgan_loss + (gradient_penalty * (wgan_lambda / (wgan_target**2))) + wgan_epsilon*real_output
    return tf.reduce_mean(total_loss)

def generator_loss(fake_output, smooth=1):
    total_loss = -tf.reduce_mean(fake_output)
    return total_loss

def generator_enc_loss(real, fake):
  # return tf.reduce_mean(tf.abs(real - fake))
  return tf.abs(real - fake)#tf.keras.losses.mean_absolute_error(real, fake)

In [None]:
class PixelNorm(layers.Layer):
  def __init__(self, epsilon=1e-8):
    super(PixelNorm, self).__init__()
    self.epsilon = epsilon

  def call(self, x):
    return x * tf.math.rsqrt(tf.reduce_mean(tf.math.square(x), axis=-1, keepdims=True) + self.epsilon)

class FadeAdd(layers.Layer):
  def __init__(self):
    super(FadeAdd, self).__init__()
    self.alpha = tf.Variable(initial_value=0., trainable=False)

  def incrementAlpha(self, step=0.1):
    self.alpha.assign(tf.minimum(self.alpha+step, 1.))
    # print("New Alpha: ", self.alpha)

  def call(self, input):
    new, old = input
    self.alpha.assign(tf.minimum(self.alpha, 1.))
    return (new*self.alpha) + (old*(1-self.alpha))

class MinibatchStddev(layers.Layer):
  def __init__(self, group_size=4):
    super(MinibatchStddev, self).__init__()
    self.group_size = group_size

  def call(self, layer):
    group_size = tf.minimum(self.group_size, tf.shape(layer)[0])
    shape = tf.shape(layer)
    minibatch = tf.reshape(layer,(group_size, -1, shape[1], shape[2], shape[3]))
    minibatch -= tf.reduce_mean(minibatch, axis=0, keepdims=True)
    minibatch = tf.reduce_mean(tf.math.square(minibatch), axis = 0)
    minibatch = tf.math.sqrt(minibatch + 1e8)
    minibatch = tf.reduce_mean(minibatch, axis=[1,2, 3], keepdims=True)
    minibatch = tf.tile(minibatch,[group_size, shape[1], shape[2], 1])
    return K.concatenate([layer, minibatch], axis=3)          # NHW1

class SelfAttention(layers.Layer):
  def __init__(self, channelReduce=1, name=''):
    super(SelfAttention, self).__init__(name=name)
    self.channelReduce = channelReduce

  def get_config(self):
    config = {'name': self.name}
    base_config = super(SelfAttention, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

  def build(self, input_shape):
    self.channels = input_shape[-1]
    self.filters_f_g = self.channels // self.channelReduce
    self.filters_h = self.channels

    kernel_shape_f_g = (1, 1) + (self.channels, self.filters_f_g)
    kernel_shape_h = (1, 1) + (self.channels, self.filters_h)

    # Create a trainable weight variable for this layer:
    self.gamma = self.add_weight(name='gamma', shape=[1], initializer='zeros', trainable=True)
    self.kernel_f = self.add_weight(shape=kernel_shape_f_g,
                                        initializer='glorot_uniform',
                                        name='kernel_f',
                                        trainable=True)
    self.kernel_g = self.add_weight(shape=kernel_shape_f_g,
                                        initializer='glorot_uniform',
                                        name='kernel_g',
                                        trainable=True)
    self.kernel_h = self.add_weight(shape=kernel_shape_h,
                                        initializer='glorot_uniform',
                                        name='kernel_h',
                                        trainable=True)

    super(SelfAttention, self).build(input_shape)
    self.built = True

  def call(self, input):
    def hw_flatten(x):
      inp_shape = tf.shape(x)
      # inp_shape = x.shape
      shape = [inp_shape[0], inp_shape[1]*inp_shape[2], inp_shape[3]]
      return tf.reshape(x, shape=shape)

    # input = [NHWC]

    f_x =  K.conv2d(input,
                     kernel=self.kernel_f,
                     strides=(1, 1), padding='same')
    g_x =  K.conv2d(input,
                     kernel=self.kernel_g,
                     strides=(1, 1), padding='same')
    h_x =  K.conv2d(input,
                     kernel=self.kernel_h,
                     strides=(1, 1), padding='same')


    f_x_flat = hw_flatten(f_x) # [N(HW)C]
    g_x_flat = hw_flatten(g_x) # [N(HW)C]

    s = K.batch_dot(g_x_flat, K.permute_dimensions(f_x_flat, (0, 2, 1)))

    beta = K.softmax(s, axis=-1)
    o = K.batch_dot(beta, hw_flatten(h_x))

    o = tf.reshape(o, shape=tf.shape(input))  # [bs, h, w, C]
    x = self.gamma * o + input

    return x


def generatorBase():
  inputLayer = layers.Input((512))
  x = layers.Dense(4*4*128, use_bias=False)(inputLayer)
  x = layers.BatchNormalization(momentum=0.8)(x)
  x = layers.LeakyReLU(alpha=0.2)(x)
  x = layers.Reshape((4, 4, 128))(x)
  # x = PixelNorm()(x)

  # x = tf.keras.layers.UpSampling2D()(x)
  x = tf.keras.layers.GaussianNoise(0.5)(x)

  # x = layers.Conv2D(256, kernel_size=(3, 3), padding='same')(x)
  # x = layers.BatchNormalization(momentum=0.8)(x)
  # x = layers.LeakyReLU()(x)

  x = layers.Conv2D(256, kernel_size=(3, 3), padding='same')(x)
  x = layers.BatchNormalization(momentum=0.8)(x)
  x = layers.LeakyReLU()(x)

  x = SelfAttention(name='final_4')(x)
  # x = PixelNorm()(x)

  # x = tf.keras.layers.GaussianNoise(0.1)(x)

  out = layers.Conv2D(3, (4, 4), strides=(1, 1), padding='same', use_bias=False, activation='tanh', name='out_4')(x)
  print(out.shape)
  model = tf.keras.models.Model(inputs=inputLayer, outputs=out)
  return model

def generatorAddStage(gen, newDepth=0, freeze=False, initialAlpha=0):
  print("Current Shape: ", gen.output.shape)

  if freeze:
    print("Freezing")
    gen.trainable = False

  lastSize = gen.output.shape[1]
  x = gen.get_layer('final_'+str(lastSize)).output
  print("Choosing layer ", x)

  shape = x.shape

  if newDepth == 0:
    newDepth = shape[3] // 2

  print("New Depth: ", newDepth)

  # x = layers.Conv2DTranspose(newDepth, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)
  x = layers.UpSampling2D(size=(2,2), interpolation='nearest')(x)

  x = layers.Conv2D(newDepth, (3, 3), strides=(1, 1), padding='same', use_bias=False)(x)
  x = layers.BatchNormalization(momentum=0.8)(x)
  x = layers.LeakyReLU()(x)
  # x = PixelNorm()(x)
  x = SelfAttention()(x)

  x = layers.Conv2D(newDepth, (3, 3), strides=(1, 1), padding='same', use_bias=False)(x)
  x = layers.BatchNormalization(momentum=0.8)(x)
  x = layers.LeakyReLU()(x)
  # x = PixelNorm()(x)
  x = SelfAttention(name='final_'+str(lastSize*2))(x)

  # x = tf.keras.layers.GaussianNoise(0.3)(x)

  # out = layers.Conv2DTranspose(3, (7, 7), strides=(2, 2), padding='same', use_bias=False, activation='tanh')(out)
  out = layers.Conv2D(3, (4, 4), strides=(1, 1), padding='same', use_bias=False, activation='tanh', name='out_'+str(lastSize*2))(x)
  print("New Shape: ", out.shape)

  # Add prev output
  lastOut = gen.get_layer('out_'+str(lastSize)).output
  up = layers.UpSampling2D((2,2), interpolation='nearest')(lastOut)

  alpha = FadeAdd()
  out = alpha([out, up])

  inputLayer = gen.input
  model = tf.keras.models.Model(inputs=inputLayer, outputs=out)
  return model, alpha

def changeGenAlpha(gen, alpha, step=0.1):
  newAlpha = alpha + step
  out = gen.layers[-1]

def reBaseModel(layers, inpTensor):
  layer = inpTensor
  # print("Rebasing")
  for i in range(len(layers)):
    # print(layer)
    layer = layers[i](layer)
  # print("Done")
  return layer

def descriminatorBase():
  inputLayer = layers.Input((4, 4, 3))
  x = tf.keras.layers.GaussianNoise(0.07)(inputLayer)

  x = tf.keras.layers.Conv2D(256, (1, 1), padding='same', name='sup_conv_256')(x)
  x = tf.keras.layers.BatchNormalization(name='sup_bn_256')(x)
  x = tf.keras.layers.LeakyReLU(name='sup_act_256')(x)

  # x = SelfAttention()(x)
  # x = reBaseModel(processingLayers, x)

  baseLayers = []
  baseLayers.append(MinibatchStddev())
  # baseLayers.append(layers.Conv2D(512, (3, 3), strides=(1, 1), padding='same'))
  baseLayers.append(layers.Conv2D(512, kernel_size=(3, 3), strides=(1, 1), padding='same', name='depth_512'))
  # baseLayers.append(layers.Conv2D(512, kernel_size=(3, 3), strides=(1, 1), padding='same', name='depth_512'))
  # baseLayers.append(layers.DepthwiseConv2D(depth_multiplier=2, kernel_size=(4, 4), strides=(1, 1), padding='same', name='depth_512'))
  baseLayers.append(layers.BatchNormalization(momentum=0.8))
  baseLayers.append(layers.LeakyReLU())
  baseLayers.append(SelfAttention())

  baseLayers.append(layers.GlobalAveragePooling2D())
  baseLayers.append(layers.Dense(1))

  encOut = reBaseModel(baseLayers[:-1], x)
  desOut = reBaseModel(baseLayers, x)

  dis = tf.keras.models.Model(inputs=inputLayer, outputs=desOut)
  enc = tf.keras.models.Model(inputs=inputLayer, outputs=encOut)
  return dis, enc, baseLayers

def descriminatorAddStage(des, enc, baseLayers, newDepth=0, newSize=0, freeze=False):
  print("Current Shape: ", enc.input.shape)

  if freeze:
    print("Freezing")
    des.trainable = False
    enc.trainable = False

  print("Previous Input layer ", enc.input)

  if newDepth == 0:
    # newDepth = enc.layers[5].output.shape[3] // 2
    for layer in enc.layers:
      if 'depth_' in layer.name:
        newDepth = layer.output.shape[3]//2
        print("New Depth: ", newDepth)
        break

  if newSize == 0:
    newSize = enc.input.shape[2] * 2

  print("New input ", newSize, newSize)

  inputLayer = layers.Input((newSize, newSize, 3))
  inp = tf.keras.layers.GaussianNoise(0.07)(inputLayer)

  processingLayers = []
  processingLayers.append(tf.keras.layers.Conv2D(newDepth // 2, (1, 1), padding='same', name='sup_conv_'+str(newDepth//2)))
  processingLayers.append(tf.keras.layers.BatchNormalization(name='sup_bn_'+str(newDepth//2)))
  processingLayers.append(tf.keras.layers.LeakyReLU(name='sup_act_'+str(newDepth//2)))
  # processingLayers.append(SelfAttention())

  x = reBaseModel(processingLayers, inp)

  newLayers = []
  newLayers.append(layers.Conv2D(newDepth, (3, 3), strides=(1, 1), padding='same', name='depth_'+str(newDepth)))
  newLayers.append(layers.BatchNormalization(momentum=0.8))
  newLayers.append(layers.LeakyReLU())
  # newLayers.append(SelfAttention())
  newLayers.append(layers.Conv2D(newDepth, (3, 3), strides=(1, 1), padding='same', name='depth2_'+str(newDepth)))
  newLayers.append(layers.BatchNormalization(momentum=0.8))
  newLayers.append(layers.LeakyReLU())
  newLayers.append(SelfAttention())
  newLayers.append(layers.AveragePooling2D())

  newInp = reBaseModel(newLayers, x)
  # print(newInp, baseLayers)

  small = layers.AveragePooling2D((2, 2))(inp)
  sup = des.get_layer('sup_conv_'+str(newDepth))(small)
  sup = des.get_layer('sup_bn_'+str(newDepth))(sup)
  sup = des.get_layer('sup_act_'+str(newDepth))(sup)

  print("====>", sup)

  print(newInp.shape, sup.shape)
  beta = FadeAdd()
  out = beta([newInp, sup])

  print("==>", out)

  desOut = reBaseModel(baseLayers, out)
  encOut = reBaseModel(baseLayers[:-1], out)

  des = tf.keras.models.Model(inputs=inputLayer, outputs=desOut)
  enc = tf.keras.models.Model(inputs=inputLayer, outputs=encOut)

  baseLayers = newLayers + baseLayers
  return des, enc, baseLayers, beta

def generateBaseModels():
  gen = generatorBase()
  des, enc, baseLayers = descriminatorBase()
  return gen, des, enc, baseLayers

In [None]:
gen, des, enc, baseLayers = generateBaseModels()

In [None]:
gen.summary()

In [None]:
gen, alpha = generatorAddStage(gen, False)

In [None]:
des, enc, baseLayers, beta = descriminatorAddStage(des, enc, baseLayers)

In [None]:
def trainGenEnc(gen, enc, real, batch_size, coeff=1, generator_optimizer=None, enc_optimizer=None):
  with tf.GradientTape() as enc_tape, tf.GradientTape() as gen_tape:
    real_enc = enc(real, training=True)
    enc_fake = gen(real_enc, training=True)

    gen_loss = generator_enc_loss(real, enc_fake) * coeff
    gradients_of_generator = gen_tape.gradient(gen_loss, gen.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, gen.trainable_variables))

    gradients_of_enc = enc_tape.gradient(gen_loss, enc.trainable_variables)
    enc_optimizer.apply_gradients(zip(gradients_of_enc, enc.trainable_variables))


def trainDes(gen, des, real, batch_size, smooth, discriminator_optimizer):
  with tf.GradientTape() as disc_tape:
    noise = tf.random.normal([batch_size, 512])

    fake = gen(noise, training=False)

    des_loss = descriminator_WGANGPloss(real, fake, des, batch_size)

    gradients_of_discriminator = disc_tape.gradient(des_loss, des.trainable_variables)
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, des.trainable_variables))

def trainDesGen(gen, des, real, batch_size, smooth):
  with tf.GradientTape() as disc_tape, tf.GradientTape() as gen_tape:
    noise = tf.random.normal([batch_size, 512])
    fake = gen(noise, training=True)

    fake_output = des(fake, training=True)

    des_loss = descriminator_WGANGPloss(real, fake, des, batch_size)
    gen_loss = generator_loss(fake_output, smooth)

    gradients_of_discriminator = disc_tape.gradient(des_loss, des.trainable_variables)
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, des.trainable_variables))

    gradients_of_generator = gen_tape.gradient(gen_loss, gen.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, gen.trainable_variables))

def trainGen(gen, des, batch_size, smooth, generator_optimizer=None):
  with tf.GradientTape() as gen_tape:
    noise = tf.random.normal([batch_size, 512])

    fake = gen(noise, training=True)
    fake_output = des(fake, training=False)

    gen_loss = generator_loss(fake_output, smooth)

    gradients_of_generator = gen_tape.gradient(gen_loss, gen.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, gen.trainable_variables))

def evalGan(gen, des, data, batches, batch_size):
  desAcc = 0
  genLoss = 0
  for i in range(batches):
    real = data
    fake = gen(tf.random.normal([batch_size, 512]), training=False)
    X = tf.concat((fake, real), axis=0)

    output = des(X, training=False)

    real_output = output[batch_size:]
    fake_output = output[:batch_size]

    labels = tf.reshape(tf.concat((tf.zeros_like(fake_output), tf.ones_like(real_output)), axis=0), [-1])
    output = tf.reshape(output, [-1])

    acc = tf.keras.metrics.binary_accuracy(labels, output, threshold=0.5)
    desAcc += acc.numpy()
    # print(acc)
    genLoss += tf.reduce_sum(generator_loss(fake_output, 1)).numpy() / batch_size
  return desAcc / batches, genLoss / batches

def augmenter(size):
  def augment(sample):
    sample['image'] = tf.image.resize(sample['image'], [size, size], method='nearest', antialias=True)
    return sample
  return augment

from IPython.display import clear_output

# @tf.function
def trainGan(data, name='A03', modeldir='/content/gdrive/My Drive/AI Research/GANs/models/', epochs=10, batch_size=5, loss='mse', smooth=1., sizes=[4, 8, 16, 32, 64, 128, 256], des_steps=2, gen_steps=1, iters=206):
  global gen, des, enc, baseLayers
  realData = data
  # print(realData.shape)
  noise = tf.random.normal([64, 512])
  results = []
  gen_alpha, des_alpha = None, None
  initialCoeff = 1.
  for size in sizes:
    coeff = initialCoeff
    print("Input shape: ",des.input.shape)
    currentData = realData.map(augmenter(size)).shuffle(4096).batch(batch_size, drop_remainder=True).repeat().prefetch(tf.data.experimental.AUTOTUNE)
    REAL = next(iter(realData.map(augmenter(size)).batch(64)))['image']
    REAL = (tf.cast(REAL, tf.float32) - 127.5) / 127.5
    iterData = iter(currentData)

    generator_optimizer = tfa.optimizers.Lookahead(tf.keras.optimizers.Adam(1e-4))
    discriminator_optimizer = tfa.optimizers.Lookahead(tf.keras.optimizers.Adam(1e-4))
    enc_optimizer = tfa.optimizers.Lookahead(tf.keras.optimizers.Adam(3e-4))

    def getTrainers():
      def _trainDes(real):
        trainDes(gen, des, real, batch_size, smooth, discriminator_optimizer)

      def _trainGen():
        trainGen(gen, des, batch_size, smooth, generator_optimizer)

      def _trainGenEnc(real, coeff):
        trainGenEnc(gen, enc, real, batch_size, coeff, generator_optimizer, enc_optimizer)

      return tf.function(_trainDes), tf.function(_trainGen), tf.function(_trainGenEnc)

    _trainDes, _trainGen, _trainGenEnc = getTrainers()

    # @tf.function
    # def runEpoch(epoch):
    @tf.function
    def trainStep(iterData):
        if des_alpha != None and gen_alpha != None:
          des_alpha.incrementAlpha(0.1 / ((220*60)/batch_size))
          gen_alpha.incrementAlpha(0.1 / ((220*60)/batch_size))

        for des_iter in range(des_steps):
          batch = next(iterData)
          real = batch['image']
          real = (tf.cast(real, tf.float32) - 127.5) / 127.5
          # trainDes(gen, des, real, batch_size, smooth)
          # strategy.run(_trainDes, (real, ))
          _trainDes(real)

        for gen_iter in range(gen_steps):
          # trainGen(gen, des, batch_size, smooth)
          # strategy.run(_trainGen, ())
          _trainGen()

        # trainDesGen(gen, des, real, batch_size, smooth)

        if coeff >= 0.15:
          # strategy.run(_trainGenEnc, (real, coeff, ))
          _trainGenEnc(real, coeff)
          # trainGenEnc(gen, enc, real, batch_size, coeff)
        else:
          # strategy.run(_trainGenEnc, (real, 0.15, ))
          _trainGenEnc(real, 0.15)
          # trainGenEnc(gen, enc, real, batch_size, 0.15)

    for epoch in range(epochs):
      print("Running epoch ", epoch)
      t = time.time()
      for _ in range(iters):
        # print("Itering")
        try:
          trainStep(iterData)
        except Exception as e:
          print("Error in epoch ", epoch, e)
          _trainDes, _trainGen, _trainGenEnc = getTrainers()
      coeff *= 0.9

      fake = gen(noise, training=False)
      real = REAL

      print("Evaluating:", time.time() - t)
      desAcc, genLoss = evalGan(gen, des, real, 10, batch_size)
      results.append({'desAcc':desAcc, 'genLoss':genLoss})
      print("Epoch ", epoch, desAcc, genLoss, "of ", epochs, "Epochs")

      print("Real: ")
      plotImages(real)

      print("Fake: ")
      plotImages(fake)

      if desAcc > 0.8:
        coeff *= 2
        coeff = min(coeff, 1.)

      if des_alpha != None and gen_alpha != None:
        print("Alpha, Beta: ", gen_alpha.alpha, des_alpha.alpha)

    gen.save(modeldir + name + '_' + str(size) + '_gen.h5')
    des.save_weights(modeldir + name + '_' + str(size)  + '_des.h5')
    enc.save_weights(modeldir + name + '_' + str(size)  + '_enc.h5')
    des, enc, baseLayers, des_alpha = descriminatorAddStage(des, enc, baseLayers, freeze=False)
    gen, gen_alpha = generatorAddStage(gen, freeze=False)
    epochs *= 1.6
    epochs = int(epochs)
    initialCoeff *= 0.6
    clear_output(wait=True)

In [None]:
global gen, des, enc, baseLayers
gen, des, enc, baseLayers = generateBaseModels()
trainGan(lfw, epochs=10, batch_size=64, smooth=1, des_steps=1, gen_steps=1, name='A03')

In [None]:
import time