In [None]:
%load_ext tensorboard

In [None]:
import os
import numpy as np

import matplotlib.pyplot as plt
import tensorflow as tf

import tensorflow_addons as tfa

from tensorflow.keras import Model

from IPython.display import clear_output
from tensorflow.keras.layers import Dense, Conv2D, Input, Reshape, LeakyReLU, UpSampling2D, Flatten, AveragePooling2D
import tensorflow.keras.initializers as initer

In [None]:
BATCH_SIZE = 32
LAMBDA = 10
EPOCHs = 200
NOISE_DIM = 128

CURRENT_EPOCH = 1
SAVE_EVERY_N_EPOCH = 5
N_CRITIC = 5

LOG_DIR = './results/logs/'
CKPT_DIR = './results/models_weight'

In [None]:
# preprocess data
import os 
import pathlib

def normalize(image):
    '''
        normalizing the images to [-1, 1]
    '''
    image = tf.cast(image, tf.float32)
    image = (image - 127.5) / 127.5
    return image

def preprocess_image(file_path, img_size=64):

    images = tf.io.read_file(file_path)

    images = tf.image.decode_jpeg(images, channels=3) 
    images = tf.image.resize(images, (img_size, img_size))
    images = normalize(images)

    return images

data_path = pathlib.Path("/Users/lubaixun/tensorflow-DL/data/faces")

file_path = [str(path) for path in data_path.glob('*.jpg')]

print(len(file_path))
train_data_path = tf.data.Dataset.from_tensor_slices(file_path)
train_data = train_data_path.map(preprocess_image).shuffle(100).batch(BATCH_SIZE)

imgs = next(iter(train_data))[0]
clear_output()

plt.imshow(imgs)
plt.show()

In [None]:
class AdaNorm(tf.keras.layers.Layer):
    def __init__(self, epsilon=1e-6):
        super().__init__()
        self.epsilon = epsilon
    
    def call(self, x):
        
        mean = tf.math.reduce_mean(x, axis=(1, 2), keepdims=True)
        variance = tf.reduce_mean(tf.math.square(x), axis=(1, 2), keepdims=True)
        
        x -= mean
        x *= tf.math.rsqrt(variance + self.epsilon)
        return x

class AdaIN(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()
        
    def call(self, inp):
        x, w = inp
        y = tf.reshape(self.dense(w), [-1, 2, 1, 1, self.c])
        out = y[:, 0] * x + y[:, 1]
        return out
    
    def build(self, input_shapes):
        x_shape, w_shape = input_shapes
        self.c = x_shape[-1]
        self.dense = Dense(self.c*2)

        
class AddNoise(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()
        
    def call(self, inp):
        x, noise = inp
        noise = noise[:, :x.shape[1], :x.shape[2], :]
        return self.b * noise + x
    
    def build(self, input_shape):
        n, h, w, c = input_shape[0]
        initializer = tf.keras.initializers.RandomNormal(mean=0.0, stddev=1.0)
        self.b = self.add_weight(
            shape=[1, 1, 1, c], initializer=initializer, trainable=True, name="noise_weights"
        )
    
        
    

In [None]:
def LatentMapping(x, latent_dim, num_layers=5):
    x = Dense(latent_dim, kernel_initializer=initer.HeNormal())(x)

    for _ in range(num_layers):
        x = LeakyReLU(0.2)(x)
        x = Dense(latent_dim, kernel_initializer=initer.HeNormal())(x)

    return x

def StyleBlock(inp, filter, up_sample=True):
    x, w, noise = inp
    
    x = AdaIN()((x, w))
    if up_sample:
        x = UpSampling2D((2, 2), interpolation="bilinear")(x)

    x = Conv2D(filter, 3, 1, padding="same", kernel_initializer=initer.HeNormal())(x)
    x = LeakyReLU(0.2)(x)
    x = AddNoise()((x, noise))
    x = AdaNorm()(x)

    return x


In [None]:
def get_generator(img_shape, latent_dim=NOISE_DIM):
    n_style_block = 0
    inp_size = res = 4
    while res <= img_shape[1]:
        n_style_block += 1
        res *= 2
    
    inp = Input((1,), name="input_tensor")
    z = Input((n_style_block, latent_dim,), name="z")
    noise = Input((img_shape[0], img_shape[1], 1), name="noise")

    w = LatentMapping(z, latent_dim)

    input_tensor = Dense(inp_size * inp_size * 256, kernel_initializer=initer.HeNormal())(inp)
    input_tensor = Reshape((inp_size, inp_size, -1))(input_tensor)


    x = AddNoise()((input_tensor, noise))
    x = AdaNorm()(x)
    for i in range(n_style_block):
        x = StyleBlock((x, w[:, i], noise), 256, up_sample=False if i == 0 else True)

    out = Conv2D(3, 7, 1, padding="same", activation='tanh')(x)

    model = Model([inp, z, noise], out, name="generator")

    return model, n_style_block


#model, n = get_generator((64, 64, 3))
#model.summary()

'''
tf.keras.utils.plot_model(
    model,
    to_file='model.png',
    show_shapes=True)
'''


In [None]:
def get_discriminator(img_dim, filter):

    inp = Input(img_dim)

    x = Conv2D(filter, 3, 1, padding='same')(inp)
    x = LeakyReLU(0.2)(x)

    x = Conv2D(filter*2, 3, 1, padding='same')(x)
    x = tfa.layers.InstanceNormalization()(x)
    x = LeakyReLU(0.2)(x)
    x = AveragePooling2D()(x)
    

    x = Conv2D(filter*4, 3, 1, padding='same')(x)
    x = tfa.layers.InstanceNormalization()(x)
    x = LeakyReLU(0.2)(x)
    x = AveragePooling2D()(x)

    x = Conv2D(filter*8, 3, 1, padding='same')(x)
    x = tfa.layers.InstanceNormalization()(x)
    x = LeakyReLU(0.2)(x)
    x = AveragePooling2D()(x)

    x = Conv2D(filter*16, 3, 1, padding='same')(x)
    x = tfa.layers.InstanceNormalization()(x)
    x = LeakyReLU(0.2)(x)
    x = AveragePooling2D()(x)

    x = Flatten()(x)
    x = Dense(128)(x)
    x = LeakyReLU(0.2)(x)
    out = Dense(1)(x)

    model = Model(inp, out)
    return model

#model.summary()

In [None]:
generator, n_block = get_generator((64, 64, 3))
discriminator = get_discriminator((64, 64, 3), 32)

G_optimizer = tf.keras.optimizers.Adam(0.0001,beta_1=0.5)
D_optimizer = tf.keras.optimizers.Adam(0.0001,beta_1=0.5)

ckpt = tf.train.Checkpoint(models=[generator, discriminator])
                           
# save model

summary_Writer = tf.summary.create_file_writer(LOG_DIR)
ckpt_manager = tf.train.CheckpointManager(ckpt, CKPT_DIR, max_to_keep=5)

if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    latest_epoch = int(ckpt_manager.latest_checkpoint.split('-')[-1])
    CURRENT_EPOCH = latest_epoch * SAVE_EVERY_N_EPOCH + 1
    print ('Latest checkpoint of epoch {} restored!!'.format(CURRENT_EPOCH))

In [None]:
# define train step

def WGAN_GP_train_d_step(real_image, batch_size, n_style_block=n_block):

    epsilon = tf.random.uniform(shape=[batch_size, 1, 1, 1], minval=0, maxval=1)

    inp = tf.ones((batch_size, 1))
    z = tf.repeat(tf.random.normal((batch_size, 1, NOISE_DIM)), n_style_block, axis=1)
    noise = tf.random.normal((batch_size, real_image.shape[1], real_image.shape[2], 1))
    
    with tf.GradientTape(persistent=True) as d_tape:
        with tf.GradientTape() as gp_tape:

            fake_image = generator([inp, z, noise], training=True)
            fake_image_mixed = epsilon * tf.dtypes.cast(real_image, tf.float32) + ((1 - epsilon) * fake_image)
            fake_mixed_pred = discriminator([fake_image_mixed], training=True)
        
        grads = gp_tape.gradient(fake_mixed_pred, fake_image_mixed)
        grad_norms = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
        gradient_penalty = tf.reduce_mean(tf.square(grad_norms - 1))
        
        fake_pred = discriminator([fake_image], training=True)
        real_pred = discriminator([real_image], training=True)
        
        D_loss = tf.reduce_mean(fake_pred) - tf.reduce_mean(real_pred) + LAMBDA * gradient_penalty
    
    D_gradients = d_tape.gradient(D_loss, discriminator.trainable_variables)                                 
    D_optimizer.apply_gradients(zip(D_gradients, discriminator.trainable_variables))
    
    return D_loss
                                                

def WGAN_GP_train_g_step(real_image, batch_size, n_style_block=n_block):

    inp = tf.ones((batch_size, 1))
    z = tf.repeat(tf.random.normal((batch_size, 1, NOISE_DIM)), n_style_block, axis=1)
    noise = tf.random.normal((batch_size, real_image.shape[1], real_image.shape[2], 1))

    with tf.GradientTape() as g_tape:
        fake_image = generator([inp, z, noise], training=True)
        fake_pred = discriminator([fake_image], training=True)
        G_loss = -tf.reduce_mean(fake_pred)
        
    G_gradients = g_tape.gradient(G_loss, generator.trainable_variables)                                
    G_optimizer.apply_gradients(zip(G_gradients,generator.trainable_variables))
    
    return G_loss

In [None]:
OUTPUT_PATH = r'./results/out_imgs/'
if not os.path.exists(OUTPUT_PATH):
    os.mkdir(OUTPUT_PATH)
    
def generate_and_save_images(model, epoch, path=OUTPUT_PATH, num_sample=16, figure_size=(12, 12), subplot=(4,4), save=True):
    inp = tf.ones((num_sample, 1))
    z = tf.repeat(tf.random.normal((num_sample, 1, NOISE_DIM)), n_block, axis=1)
    noise = tf.random.normal((num_sample, 64, 64, 1))

    predictions = model.predict([inp, z, noise])
    
    for i in range(predictions.shape[0]):
        axs = plt.subplot(subplot[0], subplot[1], i+1)
        axs.imshow(predictions[i] * 0.5 + 0.5)
        plt.axis('off')
    if save:
        plt.savefig(os.path.join(path, 'image_at_epoch_{:04d}.png'.format(epoch)))   
    plt.show()

In [None]:
import time 
n_critic_count = 0

for epoch in range(CURRENT_EPOCH, EPOCHs+1):
    start = time.time()
    
    print('Start of epoch %d' % (epoch,))
    
    for step, imgs in enumerate(train_data):
        current_batch_size = imgs.shape[0]
        clear_output()
        # Train critic (discriminator)
        
        d_loss = WGAN_GP_train_d_step(imgs, batch_size=current_batch_size)
        n_critic_count += 1
        
        if n_critic_count >= N_CRITIC: 
            
            # Train generator
            g_loss = WGAN_GP_train_g_step(imgs, batch_size=current_batch_size)
            n_critic_count = 0
        
        if step % 100 == 0:
            print ('.', end='')
        

    print('\n Epoch %d finished ~ ~ ~ '  % (epoch,))
    
    
    with summary_Writer.as_default():
        tf.summary.scalar('g_loss', g_loss, step=epoch)
        tf.summary.scalar('d_loss', d_loss, step=epoch)
        
    
    if epoch % SAVE_EVERY_N_EPOCH == 0:

        clear_output(wait=True)
    
        #save model
        ckpt_save_path = ckpt_manager.save()
        print ('Saving checkpoint for epoch {} at {}'.format(epoch, ckpt_save_path))
                                                             
    print ('Time taken for epoch {} is {} sec\n'.format(epoch,time.time()-start))                                             
    generate_and_save_images(generator, epoch)



In [None]:
%tensorboard --logdir LOG_DIR