In [1]:
# https://keras.io/examples/generative/wgan_gp/
#

import os
os.environ["CUDA_VISIBLE_DEVICES"] = str(2)
import time

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import *

from tqdm import tqdm
import matplotlib.pyplot as plt
import glob
from sklearn.utils import shuffle

from numpy import expand_dims
from numpy import zeros
from numpy import ones
from numpy.random import randn
from numpy.random import randint

# example of loading the generator model and generating images
from numpy import asarray
from numpy.random import randn
from numpy.random import randint
from tensorflow.keras.models import load_model
from matplotlib import pyplot


import tensorflow as tf


from tensorflow.keras.datasets.fashion_mnist import load_data
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Reshape
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Conv2DTranspose
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import Embedding
from tensorflow.keras.layers import Concatenate
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras import layers as L
from tensorflow.keras import initializers

from skimage import io
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import glob
from sklearn.utils import shuffle

In [2]:
foldername2class = {
	'0.0':  0,
	'0.05': 0,
	'0.1':  1,
	'0.15': 1,
	'0.2':  2,
	'0.25': 2,
	'0.3':  3,
	'0.35': 3,
	'0.4':  4,
	'0.45': 4,
	'0.5':  5,
	'0.55': 5,
	'0.6':  6,
	'0.65': 6,
	'0.7':  7,
	'0.75': 7,
	'0.8':  8,
	'0.85': 8,
	'0.9':  9,
	'0.95': 9,
	'1.0':  9,
}

In [3]:
def preprocess_images(images):
  images = (images - 127.5) / 127.5
  return images.astype('float32')

def generator_img(path_list: list):
    counter = 0
    max_counter = len(path_list)
    while True:
        single_path = path_list[counter]
        label_s = foldername2class[single_path.split('/')[-2]]
        image_s = preprocess_images(np.asarray(io.imread(single_path), dtype=np.float32))[..., :3]
        yield image_s, label_s
        # yield np.ones((336, 336, 3))
        counter += 1

        if counter == max_counter:
            counter = 0
            path_list = shuffle(path_list)

def train_gen():
    return generator_img(train_images_path)

In [4]:
IMG_SHAPE = (336, 336, 3)
BATCH_SIZE = 8
N_CLASSES = 10
# Size of the noise vector
noise_dim = 256

PATH_DATA = '../../expand_double_modes'
SAVE_RESULT = 'exp_result_new_ideas'

train_images_path = []

iterator = tqdm(glob.glob(PATH_DATA + "/*"))
for single_folder in iterator:
    img_folder = shuffle(glob.glob(single_folder + '/*'))
    for indx, single_img_path in enumerate(img_folder):
        train_images_path.append(single_img_path)
iterator.close()

train_images_path = shuffle(train_images_path)

100%|██████████| 21/21 [00:00<00:00, 121.55it/s]


In [5]:
dataset = (
    tf.data.Dataset.from_generator(
        train_gen, 
        output_signature=(
            tf.TensorSpec(shape=IMG_SHAPE, dtype=np.float32),
            tf.TensorSpec(shape=(), dtype=np.int32),
        )
    )
    .shuffle(BATCH_SIZE * 500).batch(BATCH_SIZE).prefetch(6)
)


In [6]:
train_size = len(train_images_path)

print(f'train: {train_size}')

train: 37800


In [7]:
# generate points in latent space as input for the generator
def generate_latent_points(latent_dim, n_samples, n_classes=21):
    # generate points in the latent space
    x_input = randn(latent_dim * n_samples)
    # reshape into a batch of inputs for the network
    z_input = x_input.reshape(n_samples, latent_dim)
    # generate labels
    labels = randint(0, n_classes, n_samples)
    return [z_input, labels]

In [8]:
def init_weights():
    return initializers.RandomNormal(stddev=0.02)

def init_weights():
    return None

In [9]:
class BNInferenceMode(tf.Module):
    def __init__(self, dim, eps=1e-3):
        val = np.ones(dim, dtype='float32')
        self.gamma = tf.Variable(val, name='BN/gamma')
        val = np.zeros(dim, dtype='float32')
        self.beta = tf.Variable(val, name='BN/beta')
        self.eps = eps
    
    def __call__(self, x, training=False):
        mean, var = tf.nn.moments(x, axes=[0, 1, 2], keepdims=True)
        return tf.nn.batch_normalization(
            x=x,
            mean=mean,
            variance=var,
            offset=self.beta,
            scale=self.gamma,
            variance_epsilon=self.eps,
            name='CustomBN'
        )

In [10]:
class GeneratorModel(tf.Module):
    def __init__(self, out_dim, n_classes=21, h_low=5, w_low=5):
        super().__init__()
        # Labels inputs
        self.label_layers_l = self._init_label_input_branch(n_classes, h_low, w_low)
        # Noise inputs
        self.noise_layers_l = self._init_noise_input_branch(h_low, w_low)
        # Merge layer (concat)
        self.merge = Concatenate()
        # Model layers 
        self.model_layers_l = self._init_model_branch(out_dim)
        # Final layer
        self.final_layer = self.model_layers_l[-1]

    def _init_label_input_branch(self, n_classes, h_low, w_low):
        n_nodes = h_low * w_low
        label_layers_l = [
            Embedding(n_classes, 64),
            Dense(n_nodes),
            Reshape((h_low, w_low, 1))
        ]
        return label_layers_l

    def _init_noise_input_branch(self, h_low, w_low):
        # foundation for h_low x w_low image
        n_nodes = 64 * h_low * w_low

        noise_layers_l = [
            Dense(n_nodes),
            LeakyReLU(alpha=0.2),
            Reshape((h_low, w_low, 64))
        ]
        return noise_layers_l

    def _init_model_branch(self, out_dim):
        model_layers_l = [
            Conv2DTranspose(
                128, (3,3), 
                strides=(2,2), padding='same',
                kernel_initializer=init_weights()
            ), # 10
            #BNInferenceMode(128),
            LeakyReLU(alpha=0.2),
            
            Conv2DTranspose(
                128, (3,3), 
                strides=(2,2), padding='same',
                kernel_initializer=init_weights()
            ), # 20
            L.ZeroPadding2D(((1, 0), (1, 0))), # 21
            #BNInferenceMode(128),
            LeakyReLU(alpha=0.2),
            
            Conv2DTranspose(
                256, (3,3), 
                strides=(2,2), padding='same',
                kernel_initializer=init_weights()
            ), # 42
            #BNInferenceMode(256),
            LeakyReLU(alpha=0.2),

            Conv2DTranspose(
                128, (3,3), 
                strides=(2,2), padding='same',
                kernel_initializer=init_weights()
            ), # 84
            #BNInferenceMode(128),
            LeakyReLU(alpha=0.2),

            Conv2DTranspose(
                64, (3,3), 
                strides=(2,2), padding='same',
                kernel_initializer=init_weights()
            ), # 168
            #BNInferenceMode(64),
            LeakyReLU(alpha=0.2),

            Conv2DTranspose(
                64, (3,3), 
                strides=(2,2), padding='same',
                kernel_initializer=init_weights()
            ), # 336
            #BNInferenceMode(64),
            LeakyReLU(alpha=0.2),

            Conv2D(
                out_dim, (3,3), 
                activation='tanh', padding='same',
                kernel_initializer=init_weights()
            ), # 336
        ]
        return model_layers_l
    
    @tf.function
    def __call__(self, label_i, noise_i, training=False):
        # Label branch
        for layer_label_i in self.label_layers_l:
            label_i = layer_label_i(label_i, training=training)

        # Noise branch
        for layer_noise_i in self.noise_layers_l:
            noise_i = layer_noise_i(noise_i, training=training)

        x_t = self.merge([noise_i, label_i])
        # Model branch
        for layer_model_i in self.model_layers_l:
            x_t = layer_model_i(x_t, training=training)
        
        return x_t


In [11]:
class DiscModel(tf.Module):
    def __init__(self, in_shape, out_dim, n_classes=21):
        super().__init__()
        # Labels inputs
        self.label_layers_l = self._init_label_input_branch(in_shape, n_classes)
        # Noise inputs
        self.image_layers_l = self._init_image_input_branch()
        # Merge layer (concat)
        self.merge = Concatenate()
        # Model layers 
        self.model_layers_l = self._init_model_branch(out_dim)
        # Final layer
        self.final_layer = self.model_layers_l[-1]

    def _init_label_input_branch(self, in_shape, n_classes):
        n_nodes = in_shape[0] * in_shape[1] * in_shape[2]
        label_layers_l = [
            # embedding for categorical input
            Embedding(n_classes, 64),
            # scale up to image dimensions with linear activation
            Dense(n_nodes),
            Reshape((in_shape[0], in_shape[1], in_shape[2]))
        ]
        return label_layers_l

    def _init_image_input_branch(self):
        image_layers_l = []
        return image_layers_l

    def _init_model_branch(self, out_dim):
        model_layers_l = [
            Conv2D(
                128, (3,3), 
                strides=(2,2), padding='same', 
                kernel_initializer=init_weights()
            ), # 168
            #BatchNormalization(),
            LeakyReLU(alpha=0.2),

            Conv2D(
                256, (3,3), 
                strides=(2,2), padding='same', 
                kernel_initializer=init_weights()
            ), # 84  
            #BatchNormalization(),
            LeakyReLU(alpha=0.2),

            Conv2D(
                512, (3,3), 
                strides=(2,2), padding='same', 
                kernel_initializer=init_weights()
            ), # 42  
            #BatchNormalization(),
            LeakyReLU(alpha=0.2),

            Conv2D(
                512, (3,3), 
                strides=(2,2), padding='same', 
                kernel_initializer=init_weights()
            ), # 21  
            #BatchNormalization(),
            LeakyReLU(alpha=0.2),

            Conv2D(
                512, (3,3), 
                strides=(2,2), padding='same', 
                kernel_initializer=init_weights()
            ), # 10  
            #BatchNormalization(),
            LeakyReLU(alpha=0.2),
            
            Conv2D(
                512, (3,3), 
                strides=(2,2), padding='same', 
                kernel_initializer=init_weights()
            ), # 5  
            #BatchNormalization(),
            LeakyReLU(alpha=0.2),
            
            layers.Flatten(),
            layers.Dropout(0.2),
            layers.Dense(out_dim),
        ]
        return model_layers_l
    
    @tf.function
    def __call__(self, label_i, image_i, training=False):
        # Label branch
        for layer_label_i in self.label_layers_l:
            label_i = layer_label_i(label_i, training=training)
        
        # Image branch
        for layer_image_i in self.image_layers_l:
            image_i = layer_image_i(image_i, training=training)
        
        x_t = self.merge([image_i, label_i])
        # Model branch
        for layer_model_i in self.model_layers_l:
            x_t = layer_model_i(x_t, training=training)
        return x_t


In [12]:
d_model = DiscModel(in_shape=IMG_SHAPE, out_dim=1, n_classes=N_CLASSES)

In [13]:
g_model = GeneratorModel(IMG_SHAPE[-1], n_classes=N_CLASSES) 

In [14]:
#q = g_model.model_layers_l[0]
#q.get_config()

In [15]:
#g_model(np.asarray([0]*16), np.random.randn(16, 256)).shape

In [16]:
class WGAN(keras.Model):
    def __init__(
        self,
        discriminator,
        generator,
        latent_dim,
        discriminator_extra_steps=3,
        gp_weight=10.0,
    ):
        super(WGAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.d_steps = discriminator_extra_steps
        self.gp_weight = gp_weight

    def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
        super(WGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.d_loss_fn = d_loss_fn
        self.g_loss_fn = g_loss_fn

    @tf.function
    def gradient_penalty(self, batch_size, real_images, fake_images, real_labels):
        """ Calculates the gradient penalty.

        This loss is calculated on an interpolated image
        and added to the discriminator loss.
        """
        # Get the interpolated image
        alpha = tf.random.normal([batch_size, 1, 1, 1], 0.0, 1.0)
        diff = fake_images - real_images
        interpolated = real_images + alpha * diff

        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            # 1. Get the discriminator output for this interpolated image.
            pred = self.discriminator(label_i=real_labels, image_i=interpolated, training=True)

        # 2. Calculate the gradients w.r.t to this interpolated image.
        grads = gp_tape.gradient(pred, [interpolated, real_labels])[0]
        # 3. Calculate the norm of the gradients.
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
        gp = tf.reduce_mean((norm - 1.0) ** 2)
        return gp

    def train_step(self, real_images, real_labels):
        # Get the batch size
        batch_size = tf.shape(real_images)[0]

        # For each batch, we are going to perform the
        # following steps as laid out in the original paper:
        # 1. Train the generator and get the generator loss
        # 2. Train the discriminator and get the discriminator loss
        # 3. Calculate the gradient penalty
        # 4. Multiply this gradient penalty with a constant weight factor
        # 5. Add the gradient penalty to the discriminator loss
        # 6. Return the generator and discriminator losses as a loss dictionary

        # Train the discriminator first. The original paper recommends training
        # the discriminator for `x` more steps (typically 5) as compared to
        # one step of the generator. Here we will train it for 3 extra steps
        # as compared to 5 to reduce the training time.
        for i in range(self.d_steps):
            d_loss = self._disc_train_step(real_images, real_labels)

        # Train the generator
        # Get the latent vector
        g_loss = self._generator_train_step(batch_size)

        return {"d_loss": d_loss, "g_loss": g_loss}

    @tf.function
    def _generator_train_step(self, batch_size):
        # Train the generator
        # Get the latent vector
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        random_labels = tf.random.uniform([batch_size], minval=0, maxval=N_CLASSES, dtype=tf.int32)
        with tf.GradientTape() as tape:
            # Generate fake images using the generator
            generated_images = self.generator(label_i=random_labels, noise_i=random_latent_vectors, training=True)
            # Get the discriminator logits for fake images
            gen_img_logits = self.discriminator(label_i=random_labels, image_i=generated_images, training=True)
            # Calculate the generator loss
            g_loss = self.g_loss_fn(gen_img_logits)

        # Get the gradients w.r.t the generator loss
        gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
        # Update the weights of the generator using the generator optimizer
        self.g_optimizer.apply_gradients(
            zip(gen_gradient, self.generator.trainable_variables)
        )

        return g_loss

    @tf.function
    def _disc_train_step(self, real_images, real_labels):
        # Get the batch size
        batch_size = tf.shape(real_images)[0]

        random_latent_vectors = tf.random.normal(
            shape=(batch_size, self.latent_dim)
        )
        with tf.GradientTape() as tape:
            # Generate fake images from the latent vector
            fake_images = self.generator(label_i=real_labels, noise_i=random_latent_vectors, training=True)
            # Get the logits for the fake images
            fake_logits = self.discriminator(label_i=real_labels, image_i=fake_images, training=True)
            # Get the logits for the real images
            real_logits = self.discriminator(label_i=real_labels, image_i=real_images, training=True)

            # Calculate the discriminator loss using the fake and real image logits
            d_cost = self.d_loss_fn(real_img=real_logits, fake_img=fake_logits)
            # Calculate the gradient penalty
            gp = self.gradient_penalty(batch_size, real_images, fake_images, real_labels)
            # Add the gradient penalty to the original discriminator loss
            d_loss = d_cost + gp * self.gp_weight

        # Get the gradients w.r.t the discriminator loss
        d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables)
        # Update the weights of the discriminator using the discriminator optimizer
        self.d_optimizer.apply_gradients(
            zip(d_gradient, self.discriminator.trainable_variables)
        )

        return d_loss



In [17]:
class GANMonitor():
    def __init__(self, model, num_img=100, latent_dim=128):
        self.num_img = num_img
        self.latent_dim = latent_dim
        self.model = model

    def on_epoch_end(self, epoch, logs=None, save_path=''):
        n = int(np.sqrt(self.num_img))
        random_latent_vectors = np.random.normal(size=(self.num_img, self.latent_dim))
        random_labels = np.asarray([min(x, N_CLASSES-1)  for _ in range(10) for x in range(10)])
        generated_images = self.model(label_i=random_labels, noise_i=random_latent_vectors)
        # scale from [-1,1] to [0,1]
        generated_images = (generated_images + 1) / 2.0
        self._generate_plot(generated_images, n, os.path.join(save_path, f'{epoch}'))
    
    def _generate_plot(self, examples, n, prefix):
        # plot images
        fig = plt.figure(figsize=(12,12))
        for i in range(n * n):
            # define subplot
            plt.subplot(n, n, 1 + i)
            # turn off axis
            plt.axis('off')
            # plot raw pixel data
            plt.imshow(examples[i])
        #pyplot.show()
        fig.savefig(f'{prefix}_image.png')
        plt.close('all')

In [18]:
import gc

class GCClearCallback:
    def on_epoch_end(self, epoch=0, logs=None):
        gc.collect()
        tf.keras.backend.clear_session()

In [19]:
# Instantiate the optimizer for both networks
# (learning_rate=0.0002, beta_1=0.5 are recommended)
generator_optimizer = keras.optimizers.Adam(
    learning_rate=0.0002, beta_1=0.5, beta_2=0.9
)
discriminator_optimizer = keras.optimizers.Adam(
    learning_rate=0.0002, beta_1=0.5, beta_2=0.9
)

# Define the loss functions for the discriminator,
# which should be (fake_loss - real_loss).
# We will add the gradient penalty later to this loss function.
def discriminator_loss(real_img, fake_img):
    real_loss = tf.reduce_mean(real_img)
    fake_loss = tf.reduce_mean(fake_img)
    return fake_loss - real_loss


# Define the loss functions for the generator.
def generator_loss(fake_img):
    return -tf.reduce_mean(fake_img)


# Set the number of epochs for trainining.
epochs = 20

# Instantiate the customer `GANMonitor` Keras callback.
cbk = GANMonitor(g_model, num_img=100, latent_dim=noise_dim)
gcclear_call = GCClearCallback
# Instantiate the WGAN model.
wgan = WGAN(
    discriminator=d_model,
    generator=g_model,
    latent_dim=noise_dim,
    discriminator_extra_steps=5, # was 3
)

# Compile the WGAN model.
wgan.compile(
    d_optimizer=discriminator_optimizer,
    g_optimizer=generator_optimizer,
    g_loss_fn=generator_loss,
    d_loss_fn=discriminator_loss,
)


In [None]:
for ep in range(epochs):
    iteration = train_size // BATCH_SIZE
    save_path = os.path.join(SAVE_RESULT, f'ep_{ep}')
    os.makedirs(save_path, exist_ok=True)
    for i, (real_images, real_labels) in enumerate(dataset.take(iteration)):
        data_losses = wgan.train_step(real_images=real_images, real_labels=real_labels)
        print('>%d, %d/%d, d=%.3f, g=%.3f' %
            (ep+1, i+1, iteration, data_losses['d_loss'], data_losses['g_loss']))
        if i % 20 == 0:
            cbk.on_epoch_end(f'i_{i}_ep_{ep}', save_path=save_path)
    # Clear session
    # Keras iteself has some memory leaks
    # Isshue: https://github.com/tensorflow/tensorflow/issues/31312
    gcclear_call.on_epoch_end(ep)

>1, 1/4725, d=-22.993, g=0.113
>1, 2/4725, d=-3415.470, g=-8.007
>1, 3/4725, d=-5960.186, g=-29.973
>1, 4/4725, d=-7936.403, g=-2.702
>1, 5/4725, d=-7174.975, g=-10.224
>1, 6/4725, d=-8213.363, g=-4.798
>1, 7/4725, d=-7446.815, g=-16.325
>1, 8/4725, d=-7594.896, g=-68.659
>1, 9/4725, d=-6968.963, g=-90.616
>1, 10/4725, d=-7049.369, g=-58.834
>1, 11/4725, d=-8014.090, g=-126.112
>1, 12/4725, d=-6225.958, g=-206.271
>1, 13/4725, d=-6142.880, g=-624.695
>1, 14/4725, d=-6168.522, g=-1399.444
>1, 15/4725, d=-3910.729, g=-1174.829
>1, 16/4725, d=-4168.845, g=-1941.030
>1, 17/4725, d=-3033.071, g=-2327.098
>1, 18/4725, d=-2140.864, g=-1659.725
>1, 19/4725, d=-1884.702, g=-1448.867
>1, 20/4725, d=-1656.559, g=941.505
>1, 21/4725, d=-1705.674, g=76.563
>1, 22/4725, d=-2091.476, g=2608.628
>1, 23/4725, d=-2236.122, g=72.841
>1, 24/4725, d=-2070.244, g=1960.014
>1, 25/4725, d=-1781.867, g=1931.084
>1, 26/4725, d=-1865.317, g=1570.873
>1, 27/4725, d=-2421.663, g=2800.281
>1, 28/4725, d=-1904.571, 

>1, 217/4725, d=-733.298, g=4402.667
>1, 218/4725, d=-377.251, g=2700.246
>1, 219/4725, d=-672.097, g=1852.227
>1, 220/4725, d=-238.273, g=2487.607
>1, 221/4725, d=-485.766, g=1604.684
>1, 222/4725, d=-685.828, g=1945.636
>1, 223/4725, d=-607.330, g=1000.738
>1, 224/4725, d=-452.127, g=1724.221
>1, 225/4725, d=-678.103, g=890.691
>1, 226/4725, d=-272.012, g=1901.163
>1, 227/4725, d=-506.704, g=1768.922
>1, 228/4725, d=-903.338, g=1271.498
>1, 229/4725, d=-397.611, g=1811.984
>1, 230/4725, d=-548.158, g=3396.357
>1, 231/4725, d=-466.497, g=3806.985
>1, 232/4725, d=-277.330, g=1385.027
>1, 233/4725, d=13.209, g=2036.165
>1, 234/4725, d=-347.208, g=2701.500
>1, 235/4725, d=-623.282, g=2112.130
>1, 236/4725, d=-462.098, g=2701.736
>1, 237/4725, d=-830.545, g=1685.909
>1, 238/4725, d=-694.727, g=756.358
>1, 239/4725, d=-143.543, g=334.780
>1, 240/4725, d=-569.607, g=743.637
>1, 241/4725, d=69.929, g=198.571
>1, 242/4725, d=-933.394, g=-74.913
>1, 243/4725, d=-1051.108, g=1942.662
>1, 244/47

In [None]:
wgan.save_weights('wgan_generator.h5')

In [None]:
call = model.__call__.get_concrete_function(
    tf.TensorSpec((1, SEQ_LEN, 32), tf.float32, name='poses'), tf.TensorSpec((1, SEQ_LEN), tf.float32, name='mask')
)

In [None]:
tf.saved_model.save(model, 'models/corrector_customlayernorm/', signatures=call)

In [None]:
model = tf.saved_model.load('models/3d/')