In [None]:
# https://keras.io/examples/generative/wgan_gp/
#

import os
os.environ["CUDA_VISIBLE_DEVICES"] = str(2)


import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import *

from tqdm import tqdm
import matplotlib.pyplot as plt
import glob
from sklearn.utils import shuffle

from numpy import expand_dims
from numpy import zeros
from numpy import ones
from numpy.random import randn
from numpy.random import randint

# example of loading the generator model and generating images
from numpy import asarray
from numpy.random import randn
from numpy.random import randint
from tensorflow.keras.models import load_model
from matplotlib import pyplot


import tensorflow as tf


from tensorflow.keras.datasets.fashion_mnist import load_data
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Reshape
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Conv2DTranspose
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import Embedding
from tensorflow.keras.layers import Concatenate
from tensorflow.keras.layers import BatchNormalization

from tensorflow.keras import initializers

from skimage import io
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import glob
from sklearn.utils import shuffle

In [None]:
foldername2class = {
	'0.0':  0,
	'0.05': 0,
	'0.1':  1,
	'0.15': 1,
	'0.2':  2,
	'0.25': 2,
	'0.3':  3,
	'0.35': 3,
	'0.4':  4,
	'0.45': 4,
	'0.5':  5,
	'0.55': 5,
	'0.6':  6,
	'0.65': 6,
	'0.7':  7,
	'0.75': 7,
	'0.8':  8,
	'0.85': 8,
	'0.9':  9,
	'0.95': 9,
	'1.0':  9,
}

In [None]:
def preprocess_images(images):
  images = (images - 127.5) / 127.5
  return images.astype('float32')

def generator_img(path_list: list):
    counter = 0
    max_counter = len(path_list)
    while True:
        single_path = path_list[counter]
        label_s = foldername2class[single_path.split('/')[-2]]
        image_s = preprocess_images(np.asarray(io.imread(single_path), dtype=np.float32))[..., :3]
        yield image_s, label_s
        # yield np.ones((336, 336, 3))
        counter += 1

        if counter == max_counter:
            counter = 0
            path_list = shuffle(path_list)

def train_gen():
    return generator_img(train_images_path)

In [None]:
IMG_SHAPE = (336, 336, 3)
BATCH_SIZE = 16
N_CLASSES = 10
# Size of the noise vector
noise_dim = 256

PATH_DATA = '../../expand_double_modes'
SAVE_RESULT = 'exp_result'

train_images_path = []

iterator = tqdm(glob.glob(PATH_DATA + "/*"))
for single_folder in iterator:
    img_folder = shuffle(glob.glob(single_folder + '/*'))
    for indx, single_img_path in enumerate(img_folder):
        train_images_path.append(single_img_path)
iterator.close()

train_images_path = shuffle(train_images_path)

In [None]:
dataset = (
    tf.data.Dataset.from_generator(
        train_gen, 
        output_signature=(
            tf.TensorSpec(shape=IMG_SHAPE, dtype=np.float32),
            tf.TensorSpec(shape=(), dtype=np.int32),
        )
    )
    .shuffle(BATCH_SIZE * 500).batch(BATCH_SIZE).prefetch(6)
)


In [None]:
train_size = len(train_images_path)

print(f'train: {train_size}')

In [None]:
# generate points in latent space as input for the generator
def generate_latent_points(latent_dim, n_samples, n_classes=21):
    # generate points in the latent space
    x_input = randn(latent_dim * n_samples)
    # reshape into a batch of inputs for the network
    z_input = x_input.reshape(n_samples, latent_dim)
    # generate labels
    labels = randint(0, n_classes, n_samples)
    return [z_input, labels]

In [None]:
def init_weights():
    return initializers.RandomNormal(stddev=0.02)

def init_weights():
    return None

In [None]:
class BNInferenceMode(tf.Module):
    def __init__(self, dim, eps=1e-3):
        val = np.ones(dim, dtype='float32')
        self.gamma = tf.Variable(val, name='BN/gamma')
        val = np.zeros(dim, dtype='float32')
        self.beta = tf.Variable(val, name='BN/beta')
        self.eps = eps
    
    def __call__(self, x, training=False):
        mean, var = tf.nn.moments(x, axes=[0, 1, 2], keepdims=True)
        return tf.nn.batch_normalization(
            x=x,
            mean=mean,
            variance=var,
            offset=self.beta,
            scale=self.gamma,
            variance_epsilon=self.eps,
            name='CustomBN'
        )

In [None]:
# define the standalone discriminator model
def define_discriminator(in_shape, n_classes=21):
    # label input
    in_label = Input(shape=(1,))
    # embedding for categorical input
    li = Embedding(n_classes, 96)(in_label)
    # scale up to image dimensions with linear activation
    n_nodes = in_shape[0] * in_shape[1] * in_shape[2]
    li = Dense(n_nodes)(li)
    # reshape to additional channel
    li = Reshape((in_shape[0], in_shape[1], in_shape[2]))(li)
    # image input
    in_image = Input(shape=in_shape)
    # concat label as a channel
    merge = Concatenate()([in_image, li])                         # 336
    # downsample
    fe = Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init_weights())(merge) # 168
    fe = BatchNormalization()(fe)
    fe = LeakyReLU(alpha=0.2)(fe)
    # downsample
    fe = Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init_weights())(fe)    # 84  
    fe = BatchNormalization()(fe)
    fe = LeakyReLU(alpha=0.2)(fe)
    # downsample
    fe = Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init_weights())(fe)    # 42 
    fe = BatchNormalization()(fe)
    fe = LeakyReLU(alpha=0.2)(fe)
    # downsample
    fe = Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init_weights())(fe)    # 21 
    fe = BatchNormalization()(fe)
    fe = LeakyReLU(alpha=0.2)(fe)
    # downsample
    fe = Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init_weights())(fe)    # 10 
    fe = BatchNormalization()(fe)
    fe = LeakyReLU(alpha=0.2)(fe)
    # flatten feature maps
    fe = layers.Flatten()(fe)
    fe = layers.Dropout(0.3)(fe)
    out_layer = layers.Dense(1)(fe)
    # define model
    model = keras.models.Model([in_image, in_label], out_layer)
    # compile model
    #opt = Adam(lr=0.0002, beta_1=0.5) # Adam(lr=0.0002, beta_1=0.5)
    #model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
    return model

# define the standalone generator model
def define_generator(latent_dim, n_classes=21, h_low=21, w_low=21):
    # label input
    in_label = Input(shape=(1,))
    # embedding for categorical input
    li = Embedding(n_classes, 96)(in_label)
    # linear multiplication
    n_nodes = h_low * w_low
    li = Dense(n_nodes)(li)
    # reshape to additional channel
    li = Reshape((h_low, w_low, 1))(li)
    # image generator input
    in_lat = Input(shape=(latent_dim,))
    # foundation for h_low x w_low image
    n_nodes = 64 * h_low * w_low
    gen = Dense(n_nodes)(in_lat)
    gen = LeakyReLU(alpha=0.2)(gen)
    gen = Reshape((h_low, w_low, 64))(gen)
    # merge image gen and label input
    merge = Concatenate()([gen, li])
    # upsample to 42
    gen = Conv2DTranspose(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init_weights())(merge)  
    gen = BNInferenceMode(128)(gen)
    gen = LeakyReLU(alpha=0.2)(gen)
    # upsample to 84
    gen = Conv2DTranspose(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init_weights())(gen)
    gen = BNInferenceMode(128)(gen)
    gen = LeakyReLU(alpha=0.2)(gen)
    # upsample to 168
    gen = Conv2DTranspose(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init_weights())(gen)
    gen = BNInferenceMode(128)(gen)
    gen = LeakyReLU(alpha=0.2)(gen)
    # upsample to 336
    gen = Conv2DTranspose(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init_weights())(gen)
    gen = BNInferenceMode(128)(gen)
    gen = LeakyReLU(alpha=0.2)(gen)
    # output
    out_layer = Conv2D(IMG_SHAPE[-1], (3,3), activation='tanh', padding='same', kernel_initializer=init_weights())(gen)
    # define model
    model = keras.models.Model([in_lat, in_label], out_layer)
    return model

In [None]:
def conv_block(
    x,
    filters,
    activation,
    kernel_size=(3, 3),
    strides=(1, 1),
    padding="same",
    use_bias=True,
    use_bn=False,
    use_dropout=False,
    drop_value=0.5,
):
    x = layers.Conv2D(
        filters, kernel_size, strides=strides, padding=padding, use_bias=use_bias
    )(x)
    if use_bn:
        x = layers.BatchNormalization()(x)
    x = activation(x)
    if use_dropout:
        x = layers.Dropout(drop_value)(x)
    return x


def get_discriminator_model():
    img_input = layers.Input(shape=IMG_SHAPE)
    # Zero pad the input to make the input images size to (32, 32, 1).
    x = layers.ZeroPadding2D((2, 2))(img_input)
    x = conv_block(
        x,
        64,
        kernel_size=(5, 5),
        strides=(2, 2),
        use_bn=False,
        use_bias=True,
        activation=layers.LeakyReLU(0.2),
        use_dropout=False,
        drop_value=0.3,
    )
    x = conv_block(
        x,
        128,
        kernel_size=(5, 5),
        strides=(2, 2),
        use_bn=False,
        activation=layers.LeakyReLU(0.2),
        use_bias=True,
        use_dropout=True,
        drop_value=0.3,
    )
    x = conv_block(
        x,
        256,
        kernel_size=(5, 5),
        strides=(2, 2),
        use_bn=False,
        activation=layers.LeakyReLU(0.2),
        use_bias=True,
        use_dropout=True,
        drop_value=0.3,
    )
    x = conv_block(
        x,
        512,
        kernel_size=(5, 5),
        strides=(2, 2),
        use_bn=False,
        activation=layers.LeakyReLU(0.2),
        use_bias=True,
        use_dropout=False,
        drop_value=0.3,
    )

    x = layers.Flatten()(x)
    x = layers.Dropout(0.2)(x)
    x = layers.Dense(1)(x)

    d_model = keras.models.Model(img_input, x, name="discriminator")
    return d_model


d_model = define_discriminator(IMG_SHAPE, n_classes=N_CLASSES)
d_model.summary()

In [None]:
def upsample_block(
    x,
    filters,
    activation,
    kernel_size=(3, 3),
    strides=(1, 1),
    up_size=(2, 2),
    padding="same",
    use_bn=False,
    use_bias=True,
    use_dropout=False,
    drop_value=0.3,
):
    x = layers.UpSampling2D(up_size)(x)
    x = layers.Conv2D(
        filters, kernel_size, strides=strides, padding=padding, use_bias=use_bias
    )(x)

    if use_bn:
        x = layers.BatchNormalization()(x)

    if activation:
        x = activation(x)
    if use_dropout:
        x = layers.Dropout(drop_value)(x)
    return x


def get_generator_model():
    noise = layers.Input(shape=(noise_dim,))
    x = layers.Dense(4 * 4 * 256, use_bias=False)(noise)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Reshape((4, 4, 256))(x)
    x = upsample_block(
        x,
        128,
        layers.LeakyReLU(0.2),
        strides=(1, 1),
        use_bias=False,
        use_bn=True,
        padding="same",
        use_dropout=False,
    )
    x = upsample_block(
        x,
        64,
        layers.LeakyReLU(0.2),
        strides=(1, 1),
        use_bias=False,
        use_bn=True,
        padding="same",
        use_dropout=False,
    )
    x = upsample_block(
        x, 1, layers.Activation("tanh"), strides=(1, 1), use_bias=False, use_bn=True
    )
    # At this point, we have an output which has the same shape as the input, (32, 32, 1).
    # We will use a Cropping2D layer to make it (28, 28, 1).
    x = layers.Cropping2D((2, 2))(x)

    g_model = keras.models.Model(noise, x, name="generator")
    return g_model


g_model = define_generator(noise_dim, n_classes=N_CLASSES)
g_model.summary()

In [None]:
class WGAN(keras.Model):
    def __init__(
        self,
        discriminator,
        generator,
        latent_dim,
        discriminator_extra_steps=3,
        gp_weight=10.0,
    ):
        super(WGAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.d_steps = discriminator_extra_steps
        self.gp_weight = gp_weight

    def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
        super(WGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.d_loss_fn = d_loss_fn
        self.g_loss_fn = g_loss_fn

    @tf.function
    def gradient_penalty(self, batch_size, real_images, fake_images, real_labels):
        """ Calculates the gradient penalty.

        This loss is calculated on an interpolated image
        and added to the discriminator loss.
        """
        # Get the interpolated image
        alpha = tf.random.normal([batch_size, 1, 1, 1], 0.0, 1.0)
        diff = fake_images - real_images
        interpolated = real_images + alpha * diff

        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            # 1. Get the discriminator output for this interpolated image.
            pred = self.discriminator([interpolated, real_labels], training=True)

        # 2. Calculate the gradients w.r.t to this interpolated image.
        grads = gp_tape.gradient(pred, [interpolated, real_labels])[0]
        # 3. Calculate the norm of the gradients.
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
        gp = tf.reduce_mean((norm - 1.0) ** 2)
        return gp

    def train_step(self, real_images, real_labels):
        # Get the batch size
        batch_size = tf.shape(real_images)[0]

        # For each batch, we are going to perform the
        # following steps as laid out in the original paper:
        # 1. Train the generator and get the generator loss
        # 2. Train the discriminator and get the discriminator loss
        # 3. Calculate the gradient penalty
        # 4. Multiply this gradient penalty with a constant weight factor
        # 5. Add the gradient penalty to the discriminator loss
        # 6. Return the generator and discriminator losses as a loss dictionary

        # Train the discriminator first. The original paper recommends training
        # the discriminator for `x` more steps (typically 5) as compared to
        # one step of the generator. Here we will train it for 3 extra steps
        # as compared to 5 to reduce the training time.
        for i in range(self.d_steps):
            d_loss = self._disc_train_step(real_images, real_labels)

        # Train the generator
        # Get the latent vector
        g_loss = self._generator_train_step(batch_size)

        return {"d_loss": d_loss, "g_loss": g_loss}

    @tf.function
    def _generator_train_step(self, batch_size):
        # Train the generator
        # Get the latent vector
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        random_labels = tf.random.uniform([batch_size], minval=0, maxval=N_CLASSES, dtype=tf.int32)
        with tf.GradientTape() as tape:
            # Generate fake images using the generator
            generated_images = self.generator([random_latent_vectors, random_labels], training=True)
            # Get the discriminator logits for fake images
            gen_img_logits = self.discriminator([generated_images, random_labels], training=True)
            # Calculate the generator loss
            g_loss = self.g_loss_fn(gen_img_logits)

        # Get the gradients w.r.t the generator loss
        gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
        # Update the weights of the generator using the generator optimizer
        self.g_optimizer.apply_gradients(
            zip(gen_gradient, self.generator.trainable_variables)
        )

        return g_loss

    @tf.function
    def _disc_train_step(self, real_images, real_labels):
        # Get the batch size
        batch_size = tf.shape(real_images)[0]

        random_latent_vectors = tf.random.normal(
            shape=(batch_size, self.latent_dim)
        )
        with tf.GradientTape() as tape:
            # Generate fake images from the latent vector
            fake_images = self.generator([random_latent_vectors, real_labels], training=True)
            # Get the logits for the fake images
            fake_logits = self.discriminator([fake_images, real_labels], training=True)
            # Get the logits for the real images
            real_logits = self.discriminator([real_images, real_labels], training=True)

            # Calculate the discriminator loss using the fake and real image logits
            d_cost = self.d_loss_fn(real_img=real_logits, fake_img=fake_logits)
            # Calculate the gradient penalty
            gp = self.gradient_penalty(batch_size, real_images, fake_images, real_labels)
            # Add the gradient penalty to the original discriminator loss
            d_loss = d_cost + gp * self.gp_weight

        # Get the gradients w.r.t the discriminator loss
        d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables)
        # Update the weights of the discriminator using the discriminator optimizer
        self.d_optimizer.apply_gradients(
            zip(d_gradient, self.discriminator.trainable_variables)
        )

        return d_loss



In [None]:
class GANMonitor():
    def __init__(self, model, num_img=100, latent_dim=128):
        self.num_img = num_img
        self.latent_dim = latent_dim
        self.model = model

    def on_epoch_end(self, epoch, logs=None, save_path=''):
        n = int(np.sqrt(self.num_img))
        random_latent_vectors = np.random.normal(size=(self.num_img, self.latent_dim))
        random_labels = np.asarray([min(x, N_CLASSES-1)  for _ in range(10) for x in range(10)])
        generated_images = self.model.predict([random_latent_vectors, random_labels])
        # scale from [-1,1] to [0,1]
        generated_images = (generated_images + 1) / 2.0
        self._generate_plot(generated_images, n, os.path.join(save_path, f'{epoch}'))
    
    def _generate_plot(self, examples, n, prefix):
        # plot images
        fig = plt.figure(figsize=(12,12))
        for i in range(n * n):
            # define subplot
            plt.subplot(n, n, 1 + i)
            # turn off axis
            plt.axis('off')
            # plot raw pixel data
            plt.imshow(examples[i])
        #pyplot.show()
        fig.savefig(f'{prefix}_image.png')
        plt.close('all')

In [None]:
import gc

class GCClearCallback:
    def on_epoch_end(self, epoch=0, logs=None):
        gc.collect()
        tf.keras.backend.clear_session()

In [None]:
# Instantiate the optimizer for both networks
# (learning_rate=0.0002, beta_1=0.5 are recommended)
generator_optimizer = keras.optimizers.Adam(
    learning_rate=0.0002, beta_1=0.5, beta_2=0.9
)
discriminator_optimizer = keras.optimizers.Adam(
    learning_rate=0.0002, beta_1=0.5, beta_2=0.9
)

# Define the loss functions for the discriminator,
# which should be (fake_loss - real_loss).
# We will add the gradient penalty later to this loss function.
def discriminator_loss(real_img, fake_img):
    real_loss = tf.reduce_mean(real_img)
    fake_loss = tf.reduce_mean(fake_img)
    return fake_loss - real_loss


# Define the loss functions for the generator.
def generator_loss(fake_img):
    return -tf.reduce_mean(fake_img)


# Set the number of epochs for trainining.
epochs = 20

# Instantiate the customer `GANMonitor` Keras callback.
cbk = GANMonitor(g_model, num_img=100, latent_dim=noise_dim)
gcclear_call = GCClearCallback
# Instantiate the WGAN model.
wgan = WGAN(
    discriminator=d_model,
    generator=g_model,
    latent_dim=noise_dim,
    discriminator_extra_steps=5, # was 3
)

# Compile the WGAN model.
wgan.compile(
    d_optimizer=discriminator_optimizer,
    g_optimizer=generator_optimizer,
    g_loss_fn=generator_loss,
    d_loss_fn=discriminator_loss,
)

# Start training the model.
#wgan.fit(train_images, batch_size=BATCH_SIZE, epochs=epochs, callbacks=[cbk])

In [None]:
import time

In [None]:
for ep in range(epochs):
    iteration = train_size // BATCH_SIZE
    save_path = os.path.join(SAVE_RESULT, f'ep_{ep}')
    os.makedirs(save_path, exist_ok=True)
    for i, (real_images, real_labels) in enumerate(dataset.take(iteration)):
        start = time.time()
        data_losses = wgan.train_step(real_images=real_images, real_labels=real_labels)
        print('>%d, %d/%d, d=%.3f, g=%.3f' %
            (ep+1, i+1, iteration, data_losses['d_loss'], data_losses['g_loss']))
        print('time left:', np.round(time.time() - start, 2))
        if i % 20 == 0:
            cbk.on_epoch_end(f'i_{i}_ep_{ep}', save_path=save_path)
    # Clear session
    # Keras iteself has some memory leaks
    # Isshue: https://github.com/tensorflow/tensorflow/issues/31312
    gcclear_call.on_epoch_end(ep)

In [None]:
for ep in range(epochs):
    iteration = train_size // BATCH_SIZE
    save_path = os.path.join(SAVE_RESULT, f'ep_{ep}')
    os.makedirs(save_path, exist_ok=True)
    for i in range(iteration):
        real_images, real_labels = list(dataset.take(1))[0]
        data_losses = wgan.train_step(real_images=real_images, real_labels=real_labels)
        print('>%d, %d/%d, d=%.3f, g=%.3f' %
            (ep+1, i+1, iteration, data_losses['d_loss'], data_losses['g_loss']))
        
        if i % 20 == 0:
            cbk.on_epoch_end(f'i_{i}_ep_{ep}', save_path=save_path)
    # Clear session
    # Keras iteself has some memory leaks
    # Isshue: https://github.com/tensorflow/tensorflow/issues/31312
    gcclear_call.on_epoch_end(ep)

In [None]:
wgan.save_weights('wgan_generator.h5')