In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
import numpy as np
import tensorflow_datasets as tfds
import time

In [None]:
lfw = tfds.load('lfw', split='train', shuffle_files=True)

In [None]:
import matplotlib.pyplot as plt
def plotImages(imgs):
    fig = plt.figure(figsize=(8, 8))

    for i in range(imgs.shape[0]):
      plt.subplot(8, 8, i+1)
      plt.imshow(tf.cast(imgs[i, :, :, :] * 127.5 + 127.5, tf.uint8))
      plt.axis('off')
    plt.show()

In [None]:

# Constants
KERNEL_INIT = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

def lerp(a, b, t):
    return a + (b - a) * t

def gradient_penalty(discriminator, real_images, fake_images):
    batch_size = tf.shape(real_images)[0]
    alpha = tf.random.uniform([batch_size, 1, 1, 1], 0.0, 1.0)
    interpolated_images = lerp(real_images, fake_images, alpha)
    with tf.GradientTape() as tape:
        tape.watch(interpolated_images)
        pred = discriminator(interpolated_images, training=True)
    gradients = tape.gradient(pred, [interpolated_images])[0]
    norm = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3]))
    gp = tf.reduce_mean((norm - 1.0) ** 2)
    return gp

def discriminator_loss(real_output, fake_output):
    return tf.reduce_mean(fake_output) - tf.reduce_mean(real_output)

def generator_loss(fake_output):
    return -tf.reduce_mean(fake_output)

class PixelNormalization(layers.Layer):
    def call(self, inputs):
        return inputs * tf.math.rsqrt(tf.reduce_mean(tf.square(inputs), axis=-1, keepdims=True) + 1e-8)

class MinibatchStddev(layers.Layer):
    def call(self, inputs):
        group_size = tf.minimum(4, tf.shape(inputs)[0])
        shape = tf.shape(inputs)
        minibatch = tf.reshape(inputs, (group_size, -1, shape[1], shape[2], shape[3]))
        stddev = tf.sqrt(tf.reduce_mean(tf.square(minibatch - tf.reduce_mean(minibatch, axis=0)), axis=0) + 1e-8)
        stddev = tf.reduce_mean(stddev, axis=[1, 2, 3], keepdims=True)
        stddev = tf.tile(stddev, [group_size, shape[1], shape[2], 1])
        return tf.concat([inputs, stddev], axis=-1)

def build_generator():
    noise = layers.Input(shape=(512,))
    x = layers.Dense(4*4*512, use_bias=False, kernel_initializer=KERNEL_INIT)(noise)
    x = layers.Reshape((4, 4, 512))(x)
    x = PixelNormalization()(x)
    x = layers.Conv2D(128, (3, 3), padding='same', kernel_initializer=KERNEL_INIT)(x)
    x = PixelNormalization()(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.Conv2DTranspose(128, (3, 3), strides=(1, 1), padding='same', kernel_initializer=KERNEL_INIT)(x)
    x = PixelNormalization()(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.Conv2D(3, (1, 1), padding='same', activation='tanh', kernel_initializer=KERNEL_INIT)(x)
    model = models.Model(noise, x)
    return model

def build_discriminator():
    image = layers.Input(shape=(4, 4, 3))
    x = MinibatchStddev()(image)
    x = layers.Conv2D(128, (1, 1), padding='same', kernel_initializer=KERNEL_INIT)(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.Conv2D(128, (3, 3), padding='same', kernel_initializer=KERNEL_INIT)(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.Conv2D(128, (4, 4), padding='same', kernel_initializer=KERNEL_INIT)(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.Flatten()(x)
    x = layers.Dense(1, kernel_initializer=KERNEL_INIT)(x)
    model = models.Model(image, x)
    return model

def train_step(generator, discriminator, batch_size, generator_optimizer, discriminator_optimizer, data):
    noise = tf.random.normal([batch_size, 512])
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        real_images = next(data)['image']
        real_images = (tf.cast(real_images, tf.float32) - 127.5) / 127.5
        fake_images = generator(noise, training=True)

        real_output = discriminator(real_images, training=True)
        fake_output = discriminator(fake_images, training=True)

        gp = gradient_penalty(discriminator, real_images, fake_images)
        disc_loss = discriminator_loss(real_output, fake_output) + gp * 10.0
        gen_loss = generator_loss(fake_output)

    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)

    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))

    return gen_loss, disc_loss

def augmenter(size):
  def augment(sample):
    sample['image'] = tf.image.resize(sample['image'], [size, size], method='nearest', antialias=True)
    return sample
  return augment

def train(generator, discriminator, epochs, batch_size=16):
    generator_optimizer = optimizers.Adam(1e-4, beta_1=0.5, beta_2=0.9)
    discriminator_optimizer = optimizers.Adam(1e-4, beta_1=0.5, beta_2=0.9)
    currentData = lfw.map(augmenter(4)).shuffle(4096).batch(batch_size, drop_remainder=True).repeat().prefetch(tf.data.experimental.AUTOTUNE)
    iterData = iter(currentData)
    evalSample = next(iterData)['image']
    evalSample = (tf.cast(evalSample, tf.float32) - 127.5) / 127.5
    iters = len(lfw)//batch_size
    print(f'Iterations per epoch: {iters}')
    for epoch in range(epochs):
        start = time.time()
        for _ in range(iters):
            gen_loss, disc_loss = train_step(generator, discriminator, batch_size, generator_optimizer, discriminator_optimizer, iterData)
        print(f'Epoch {epoch+1}, Gen Loss: {gen_loss.numpy()}, Disc Loss: {disc_loss.numpy()}, Time: {time.time()-start}s')
        
        real = evalSample
        noise = tf.random.normal([64, 512])
        fake = generator(noise, training=False)
        
        print("Real: ")
        plotImages(real)

        print("Fake: ")
        plotImages(fake)


In [None]:

generator = build_generator()
discriminator = build_discriminator()
generator.summary(), discriminator.summary()
train(generator, discriminator, epochs=10, batch_size=64)
