In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = str(2)


import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import *

from tqdm import tqdm
import matplotlib.pyplot as plt
import glob
from sklearn.utils import shuffle

from numpy import expand_dims
from numpy import zeros
from numpy import ones
from numpy.random import randn
from numpy.random import randint

# example of loading the generator model and generating images
from numpy import asarray
from numpy.random import randn
from numpy.random import randint
from tensorflow.keras.models import load_model
from matplotlib import pyplot


import tensorflow as tf


from tensorflow.keras.datasets.fashion_mnist import load_data
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Reshape
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Conv2DTranspose
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import Embedding
from tensorflow.keras.layers import Concatenate
from tensorflow.keras.layers import BatchNormalization

from tensorflow.keras import initializers

from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import glob
from sklearn.utils import shuffle

In [2]:
IMG_SHAPE = (64, 64, 3)
BATCH_SIZE = 128

# Size of the noise vector
noise_dim = 128

train_images_path = []
train_images_path += glob.glob('../data/faces_50k/*.jpg')
train_images_path += glob.glob('../data/second_dataset_69k/images/*.jpg')
#train_images_path += glob.glob('../data/third_data_21k/*.jpg')

In [3]:
def preprocess_images(images):
  images = (images - 127.5) / 127.5
  return images.astype('float32')

def generator_img(path_list: list):
    counter = 0
    max_counter = len(path_list)
    while True:
        single_path = path_list[counter]
        image_s = preprocess_images(
            transform.resize(np.asarray(io.imread(single_path), dtype=np.float32),
            (IMG_SHAPE[0], IMG_SHAPE[1]),
            anti_aliasing=True                 
        )).astype(np.float32)
        yield image_s
        # yield np.ones((336, 336, 3))
        counter += 1

        if counter == max_counter:
            counter = 0
            path_list = shuffle(path_list)

def train_gen():
    return generator_img(train_images_path)

In [4]:
dataset = (
    tf.data.Dataset.from_generator(
        train_gen, 
        output_signature=(
            tf.TensorSpec(shape=IMG_SHAPE, dtype=np.float32)
        )
    )
    .shuffle(BATCH_SIZE * 10).batch(BATCH_SIZE)
)


In [5]:
train_size = len(train_images_path)

print(f'train: {train_size}')

train: 114788


In [6]:
def conv_block(
    x,
    filters,
    activation,
    kernel_size=(3, 3),
    strides=(1, 1),
    padding="same",
    use_bias=True,
    use_bn=False,
    use_dropout=False,
    drop_value=0.5,
):
    x = layers.Conv2D(
        filters, kernel_size, strides=strides, padding=padding, use_bias=use_bias
    )(x)
    if use_bn:
        x = layers.BatchNormalization()(x)
    x = activation(x)
    if use_dropout:
        x = layers.Dropout(drop_value)(x)
    return x


def get_discriminator_model():
    img_input = layers.Input(shape=IMG_SHAPE)
    x = conv_block(
        img_input,
        64,
        kernel_size=(5, 5),
        strides=(2, 2),
        use_bn=False,
        use_bias=True,
        activation=layers.LeakyReLU(0.2),
        use_dropout=False,
        drop_value=0.3,
    ) # 32
    x = conv_block(
        x,
        128,
        kernel_size=(5, 5),
        strides=(2, 2),
        use_bn=False,
        activation=layers.LeakyReLU(0.2),
        use_bias=True,
        use_dropout=True,
        drop_value=0.3,
    ) # 16
    x = conv_block(
        x,
        256,
        kernel_size=(5, 5),
        strides=(2, 2),
        use_bn=False,
        activation=layers.LeakyReLU(0.2),
        use_bias=True,
        use_dropout=True,
        drop_value=0.3,
    ) # 8
    x = conv_block(
        x,
        512,
        kernel_size=(5, 5),
        strides=(2, 2),
        use_bn=False,
        activation=layers.LeakyReLU(0.2),
        use_bias=True,
        use_dropout=False,
        drop_value=0.3,
    )  # 4

    x = layers.Flatten()(x)
    x = layers.Dropout(0.3)(x)
    x = layers.Dense(1)(x)

    d_model = keras.models.Model(img_input, x, name="discriminator")
    return d_model


d_model = get_discriminator_model()
d_model.summary()

Model: "discriminator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 64, 64, 3)]       0         
_________________________________________________________________
conv2d (Conv2D)              (None, 32, 32, 64)        4864      
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 32, 32, 64)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 16, 16, 128)       204928    
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 16, 16, 128)       0         
_________________________________________________________________
dropout (Dropout)            (None, 16, 16, 128)       0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 8, 8, 256)       

In [7]:
def upsample_block(
    x,
    filters,
    activation,
    kernel_size=(3, 3),
    strides=(1, 1),
    up_size=(2, 2),
    padding="same",
    use_bn=False,
    use_bias=True,
    use_dropout=False,
    drop_value=0.3,
):
    x = layers.UpSampling2D(up_size)(x)
    x = layers.Conv2D(
        filters, kernel_size, strides=strides, padding=padding, use_bias=use_bias
    )(x)

    if use_bn:
        x = layers.BatchNormalization()(x)

    if activation:
        x = activation(x)
    if use_dropout:
        x = layers.Dropout(drop_value)(x)
    return x


def get_generator_model():
    noise = layers.Input(shape=(noise_dim,))
    x = layers.Dense(4 * 4 * 256, use_bias=False)(noise)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Reshape((4, 4, 256))(x)
    x = upsample_block(
        x,
        128,
        layers.LeakyReLU(0.2),
        strides=(1, 1),
        use_bias=False,
        use_bn=True,
        padding="same",
        use_dropout=False,
    ) # 8
    x = upsample_block(
        x,
        128,
        layers.LeakyReLU(0.2),
        strides=(1, 1),
        use_bias=False,
        use_bn=True,
        padding="same",
        use_dropout=False,
    ) # 16
    x = upsample_block(
        x,
        64,
        layers.LeakyReLU(0.2),
        strides=(1, 1),
        use_bias=False,
        use_bn=True,
        padding="same",
        use_dropout=False,
    ) # 32
    x = upsample_block(
        x, IMG_SHAPE[-1], layers.Activation("tanh"), strides=(1, 1), use_bias=False, use_bn=False
    ) # 64

    g_model = keras.models.Model(noise, x, name="generator")
    return g_model


g_model = get_generator_model()
g_model.summary()

Model: "generator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 128)]             0         
_________________________________________________________________
dense_1 (Dense)              (None, 4096)              524288    
_________________________________________________________________
batch_normalization (BatchNo (None, 4096)              16384     
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 4096)              0         
_________________________________________________________________
reshape (Reshape)            (None, 4, 4, 256)         0         
_________________________________________________________________
up_sampling2d (UpSampling2D) (None, 8, 8, 256)         0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 8, 8, 128)         29

In [8]:
class WGAN(keras.Model):
    def __init__(
        self,
        discriminator,
        generator,
        latent_dim,
        discriminator_extra_steps=3,
        gp_weight=10.0,
    ):
        super(WGAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.d_steps = discriminator_extra_steps
        self.gp_weight = gp_weight

    def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
        super(WGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.d_loss_fn = d_loss_fn
        self.g_loss_fn = g_loss_fn

    def gradient_penalty(self, batch_size, real_images, fake_images):
        """ Calculates the gradient penalty.

        This loss is calculated on an interpolated image
        and added to the discriminator loss.
        """
        # Get the interpolated image
        alpha = tf.random.normal([batch_size, 1, 1, 1], 0.0, 1.0)
        diff = fake_images - real_images
        interpolated = real_images + alpha * diff

        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            # 1. Get the discriminator output for this interpolated image.
            pred = self.discriminator(interpolated, training=True)

        # 2. Calculate the gradients w.r.t to this interpolated image.
        grads = gp_tape.gradient(pred, [interpolated])[0]
        # 3. Calculate the norm of the gradients.
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
        gp = tf.reduce_mean((norm - 1.0) ** 2)
        return gp

    def train_step(self, real_images):
        if isinstance(real_images, tuple):
            real_images = real_images[0]

        # Get the batch size
        batch_size = tf.shape(real_images)[0]

        # For each batch, we are going to perform the
        # following steps as laid out in the original paper:
        # 1. Train the generator and get the generator loss
        # 2. Train the discriminator and get the discriminator loss
        # 3. Calculate the gradient penalty
        # 4. Multiply this gradient penalty with a constant weight factor
        # 5. Add the gradient penalty to the discriminator loss
        # 6. Return the generator and discriminator losses as a loss dictionary

        # Train the discriminator first. The original paper recommends training
        # the discriminator for `x` more steps (typically 5) as compared to
        # one step of the generator. Here we will train it for 3 extra steps
        # as compared to 5 to reduce the training time.
        for i in range(self.d_steps):
            # Get the latent vector
            random_latent_vectors = tf.random.normal(
                shape=(batch_size, self.latent_dim)
            )
            with tf.GradientTape() as tape:
                # Generate fake images from the latent vector
                fake_images = self.generator(random_latent_vectors, training=True)
                # Get the logits for the fake images
                fake_logits = self.discriminator(fake_images, training=True)
                # Get the logits for the real images
                real_logits = self.discriminator(real_images, training=True)

                # Calculate the discriminator loss using the fake and real image logits
                d_cost = self.d_loss_fn(real_img=real_logits, fake_img=fake_logits)
                # Calculate the gradient penalty
                gp = self.gradient_penalty(batch_size, real_images, fake_images)
                # Add the gradient penalty to the original discriminator loss
                d_loss = d_cost + gp * self.gp_weight

            # Get the gradients w.r.t the discriminator loss
            d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables)
            # Update the weights of the discriminator using the discriminator optimizer
            self.d_optimizer.apply_gradients(
                zip(d_gradient, self.discriminator.trainable_variables)
            )

        # Train the generator
        # Get the latent vector
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        with tf.GradientTape() as tape:
            # Generate fake images using the generator
            generated_images = self.generator(random_latent_vectors, training=True)
            # Get the discriminator logits for fake images
            gen_img_logits = self.discriminator(generated_images, training=True)
            # Calculate the generator loss
            g_loss = self.g_loss_fn(gen_img_logits)

        # Get the gradients w.r.t the generator loss
        gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
        # Update the weights of the generator using the generator optimizer
        self.g_optimizer.apply_gradients(
            zip(gen_gradient, self.generator.trainable_variables)
        )
        return {"d_loss": d_loss, "g_loss": g_loss}

In [9]:
class GANMonitor():
    def __init__(self, model, num_img=100, latent_dim=128):
        self.num_img = num_img
        self.latent_dim = latent_dim
        self.model = model

    def on_epoch_end(self, epoch, logs=None):
        n = int(np.sqrt(self.num_img))
        random_latent_vectors = np.random.normal(size=(self.num_img, self.latent_dim))
        generated_images = self.model.predict(random_latent_vectors)
        # scale from [-1,1] to [0,1]
        generated_images = (generated_images + 1) / 2.0
        self._generate_plot(generated_images, n, f'{epoch}')
    
    def _generate_plot(self, examples, n, prefix):
        # plot images
        fig = plt.figure(figsize=(12,12))
        for i in range(n * n):
            # define subplot
            plt.subplot(n, n, 1 + i)
            # turn off axis
            plt.axis('off')
            # plot raw pixel data
            plt.imshow(examples[i])
        #pyplot.show()
        fig.savefig(f'{prefix}_image.png')
        plt.close('all')

In [10]:
# Instantiate the optimizer for both networks
# (learning_rate=0.0002, beta_1=0.5 are recommended)
generator_optimizer = keras.optimizers.Adam(
    learning_rate=0.0002, beta_1=0.5, beta_2=0.9
)
discriminator_optimizer = keras.optimizers.Adam(
    learning_rate=0.0002, beta_1=0.5, beta_2=0.9
)

# Define the loss functions for the discriminator,
# which should be (fake_loss - real_loss).
# We will add the gradient penalty later to this loss function.
def discriminator_loss(real_img, fake_img):
    real_loss = tf.reduce_mean(real_img)
    fake_loss = tf.reduce_mean(fake_img)
    return fake_loss - real_loss


# Define the loss functions for the generator.
def generator_loss(fake_img):
    return -tf.reduce_mean(fake_img)


# Set the number of epochs for trainining.
epochs = 20

# Instantiate the customer `GANMonitor` Keras callback.
cbk = GANMonitor(g_model, num_img=100, latent_dim=noise_dim)

# Instantiate the WGAN model.
wgan = WGAN(
    discriminator=d_model,
    generator=g_model,
    latent_dim=noise_dim,
    discriminator_extra_steps=3,
)

# Compile the WGAN model.
wgan.compile(
    d_optimizer=discriminator_optimizer,
    g_optimizer=generator_optimizer,
    g_loss_fn=generator_loss,
    d_loss_fn=discriminator_loss,
)


In [None]:
for ep in range(epochs):
    iteration = train_size // BATCH_SIZE
    for i in range(iteration):
        real_images = list(dataset.take(1))[0]
        data_losses = wgan.train_step(real_images)
        print('>%d, %d/%d, d=%.3f, g=%.3f' %
            (ep+1, i+1, iteration, data_losses['d_loss'], data_losses['g_loss']))
        
        if i % 10 == 0:
            cbk.on_epoch_end(f'i_{i}_ep_{ep}')

>1, 1/896, d=6.602, g=0.232
>1, 2/896, d=-7.032, g=-1.593
>1, 3/896, d=-27.527, g=8.962
>1, 4/896, d=-38.136, g=22.169
>1, 5/896, d=-45.867, g=46.973
>1, 6/896, d=-49.926, g=50.742
>1, 7/896, d=-56.552, g=45.699
>1, 8/896, d=-49.996, g=54.152
>1, 9/896, d=-44.751, g=47.128
>1, 10/896, d=-35.395, g=27.163
>1, 11/896, d=-26.905, g=10.751
>1, 12/896, d=-20.776, g=12.277
>1, 13/896, d=-22.946, g=18.212
>1, 14/896, d=-22.069, g=18.200
>1, 15/896, d=-21.580, g=34.634
>1, 16/896, d=-18.946, g=34.606
>1, 17/896, d=-20.772, g=44.959
>1, 18/896, d=-17.753, g=35.492
>1, 19/896, d=-17.600, g=34.683
>1, 20/896, d=-17.184, g=18.918
>1, 21/896, d=-14.760, g=24.090
>1, 22/896, d=-16.237, g=34.982
>1, 23/896, d=-16.484, g=36.275
>1, 24/896, d=-15.641, g=42.963
>1, 25/896, d=-15.813, g=29.029
>1, 26/896, d=-14.901, g=32.694
>1, 27/896, d=-15.728, g=46.054
>1, 28/896, d=-15.954, g=47.990
>1, 29/896, d=-14.532, g=36.887
>1, 30/896, d=-17.459, g=19.748
>1, 31/896, d=-14.750, g=13.214
>1, 32/896, d=-14.443,

>1, 255/896, d=-9.698, g=11.758
>1, 256/896, d=-12.269, g=55.500
>1, 257/896, d=-7.910, g=25.165
>1, 258/896, d=-5.355, g=16.239
>1, 259/896, d=-7.824, g=9.414
>1, 260/896, d=-5.747, g=-5.106
>1, 261/896, d=-8.570, g=-5.959
>1, 262/896, d=-11.708, g=-12.223
>1, 263/896, d=-10.384, g=-16.965
>1, 264/896, d=-7.827, g=9.525
>1, 265/896, d=-7.076, g=37.940
>1, 266/896, d=-6.182, g=41.520
>1, 267/896, d=-7.494, g=35.055
>1, 268/896, d=-7.539, g=-7.262
>1, 269/896, d=-9.765, g=-29.285
>1, 270/896, d=-5.477, g=-24.901
>1, 271/896, d=-8.841, g=-23.082
>1, 272/896, d=-9.628, g=-7.856
>1, 273/896, d=-10.272, g=0.601
>1, 274/896, d=-5.251, g=4.307
>1, 275/896, d=-6.957, g=-4.214
>1, 276/896, d=-9.065, g=-18.075
>1, 277/896, d=-6.577, g=-24.266
>1, 278/896, d=-7.766, g=-14.705
>1, 279/896, d=-7.421, g=-1.416
>1, 280/896, d=-5.735, g=11.716
>1, 281/896, d=-8.487, g=24.515
>1, 282/896, d=-8.359, g=31.979
>1, 283/896, d=-5.184, g=20.311
>1, 284/896, d=-8.298, g=-16.185
>1, 285/896, d=-7.474, g=-26.94

>1, 508/896, d=-7.589, g=-11.659
>1, 509/896, d=-6.459, g=-18.197
>1, 510/896, d=-9.947, g=-13.948
>1, 511/896, d=-5.858, g=3.052
>1, 512/896, d=-8.436, g=22.217
>1, 513/896, d=-9.625, g=28.009
>1, 514/896, d=-8.235, g=16.970
>1, 515/896, d=-7.794, g=-18.831
>1, 516/896, d=-8.060, g=-78.310
>1, 517/896, d=-9.220, g=-20.942
>1, 518/896, d=-7.345, g=2.369
>1, 519/896, d=-5.788, g=23.602
>1, 520/896, d=-8.427, g=30.921
>1, 521/896, d=-7.600, g=45.165
>1, 522/896, d=-8.573, g=47.064
>1, 523/896, d=-5.900, g=16.592
>1, 524/896, d=-5.981, g=-1.475
>1, 525/896, d=-8.050, g=-18.118
>1, 526/896, d=-7.581, g=-34.497
>1, 527/896, d=-6.897, g=-31.811
>1, 528/896, d=-5.828, g=-17.939
>1, 529/896, d=-4.200, g=5.553
>1, 530/896, d=-9.651, g=20.722
>1, 531/896, d=-10.204, g=40.297
>1, 532/896, d=-6.385, g=47.403
>1, 533/896, d=-4.887, g=42.362
>1, 534/896, d=-5.748, g=34.909
>1, 535/896, d=-9.429, g=24.690
>1, 536/896, d=-7.376, g=4.182
>1, 537/896, d=-6.518, g=-3.398
>1, 538/896, d=-8.623, g=-7.718
>

>1, 763/896, d=-6.842, g=16.061
>1, 764/896, d=-9.778, g=25.688
>1, 765/896, d=-10.453, g=43.540
>1, 766/896, d=-8.912, g=50.100
>1, 767/896, d=-1.064, g=36.174
>1, 768/896, d=-5.239, g=39.866
>1, 769/896, d=-5.006, g=40.385
>1, 770/896, d=-8.230, g=41.306
>1, 771/896, d=-8.043, g=48.623
>1, 772/896, d=-8.852, g=39.591
>1, 773/896, d=-3.556, g=26.758
>1, 774/896, d=-8.762, g=8.667
>1, 775/896, d=-5.758, g=4.601
>1, 776/896, d=-4.557, g=3.919
>1, 777/896, d=-6.884, g=9.670
>1, 778/896, d=-6.313, g=9.993
>1, 779/896, d=-6.266, g=6.167
>1, 780/896, d=-9.290, g=15.108
>1, 781/896, d=-6.830, g=13.499
>1, 782/896, d=-4.607, g=6.315
>1, 783/896, d=-5.439, g=1.497
>1, 784/896, d=-7.687, g=-2.308
>1, 785/896, d=-9.847, g=-12.327
>1, 786/896, d=-7.541, g=-16.292
>1, 787/896, d=-8.501, g=-18.010
>1, 788/896, d=-2.339, g=-29.571
>1, 789/896, d=-4.466, g=-36.829
>1, 790/896, d=-7.037, g=-29.710
>1, 791/896, d=-9.700, g=-24.855
>1, 792/896, d=-8.427, g=-32.101
>1, 793/896, d=-4.068, g=26.588
>1, 794

>2, 126/896, d=-8.592, g=-14.299
>2, 127/896, d=-1.286, g=-14.685
>2, 128/896, d=-5.183, g=-14.207
>2, 129/896, d=-4.963, g=-6.508
>2, 130/896, d=-7.064, g=-3.602
>2, 131/896, d=-7.833, g=-3.342
>2, 132/896, d=-5.339, g=-6.239
>2, 133/896, d=-12.867, g=-11.007
>2, 134/896, d=-3.195, g=-10.691
>2, 135/896, d=-8.836, g=-15.770
>2, 136/896, d=-10.323, g=-9.791
>2, 137/896, d=-6.621, g=-3.407
>2, 138/896, d=-3.736, g=8.127
>2, 139/896, d=-7.052, g=23.571
>2, 140/896, d=-4.626, g=19.700
>2, 141/896, d=-8.009, g=22.493
>2, 142/896, d=-2.640, g=30.081
>2, 143/896, d=-11.041, g=28.901
>2, 144/896, d=-7.595, g=30.303
>2, 145/896, d=-6.251, g=12.285
>2, 146/896, d=-9.701, g=-7.824
>2, 147/896, d=-8.933, g=-31.844
>2, 148/896, d=-11.046, g=-52.888
>2, 149/896, d=-5.844, g=-49.829
>2, 150/896, d=-2.350, g=-36.704
>2, 151/896, d=-7.009, g=-40.067
>2, 152/896, d=-6.151, g=-31.048
>2, 153/896, d=-3.188, g=-59.746
>2, 154/896, d=-2.729, g=-47.457
>2, 155/896, d=-9.527, g=-51.062
>2, 156/896, d=-4.820,

In [None]:
wgan.save_weights('wgan_generator.h5')