In [None]:
import numpy as np
from keras_preprocessing.image import ImageDataGenerator
from keras.layers import Input, Conv2D, Flatten, Dense, Conv2DTranspose, Reshape, Lambda, Activation, BatchNormalization, ReLU, Dropout
from keras.models import Model
from keras import backend as K
from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint 
from keras.utils import plot_model
import tensorflow as tf
import cv2
import matplotlib.pyplot as plt

In [None]:
import os
from glob import glob

DATA_FOLDER = '../input/celeba-dataset/img_align_celeba/img_align_celeba/'

In [None]:
filenames = np.array(glob(os.path.join(DATA_FOLDER, '*.jpg')))
NUM_IMAGES = len(filenames)
print("Total number of images : " + str(NUM_IMAGES))
# prints : Total number of images : 202599
 
INPUT_DIM = (128,128,3) # Image dimension
BATCH_SIZE = 128
# train set num is 162770
 
data_flow = ImageDataGenerator(rescale=1./255).flow_from_directory('../input/celeba-dataset/img_align_celeba/', 
                                                                   target_size = (64,64),
                                                                   batch_size = BATCH_SIZE,
                                                                   shuffle = True,
                                                                   class_mode = 'input',
                                                                   subset = 'training'
                                                                   )

In [None]:
# create a sampling layer
from tensorflow.keras.layers import Layer
class reparameterize(Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""
    def call(self, inputs):
        mean, logvar = inputs
        batch = tf.shape(mean)[0]
        dim = tf.shape(mean)[1]
        eps = tf.keras.backend.random_normal(shape=(batch, dim))
        return eps * tf.exp(logvar * .5) + mean

In [None]:
# encoder
input_shape = (64, 64, 3)
kernel_size = 5
latent_dim = 512
image_size = 64 # square, original to be resized
crop_size = 148 # center crop to this size per dimension
batch_size = 128
gamma_init = tf.random_normal_initializer(1., 0.02)

# First build the Encoder Model part
inputs = Input(shape=input_shape, name='encoder_input')
x = Conv2D(filters=64, kernel_size=kernel_size, strides=2, padding='same', name='conv1')(inputs)
x = BatchNormalization(gamma_initializer=gamma_init, trainable=True)(x)
x = ReLU()(x)
x = Conv2D(filters=128, kernel_size=kernel_size, strides=2, padding='same', name='conv2')(x)
x = BatchNormalization(gamma_initializer=gamma_init, trainable=True)(x)
x = ReLU()(x)
x = Conv2D(filters=256, kernel_size=kernel_size, strides=2, padding='same', name='conv3')(x)
x = BatchNormalization(gamma_initializer=gamma_init, trainable=True)(x)
x = ReLU()(x)
x = Conv2D(filters=512, kernel_size=kernel_size, strides=2, padding='same', name='conv4')(x)
x = BatchNormalization(gamma_initializer=gamma_init, trainable=True)(x)
x = ReLU()(x)

shape = K.int_shape(x)

x = Flatten()(x)
mean = Dense(latent_dim, name="z_mean")(x)
logvar = Dense(latent_dim, name="z_log_var")(x)
z = reparameterize()([mean, logvar])

# Instantiate Encoder Model
encoder = Model(inputs, [z, mean, logvar], name='encoder')
encoder.summary()

In [None]:
# decoder
latent_inputs = Input(shape=(latent_dim,), name='decoder_input')
x = Dense(shape[1] * shape[2] * shape[3] * 2)(latent_inputs)
x = Reshape((shape[1]*2, shape[2]*2, shape[3]//2), name='h0_reshape')(x)
x = BatchNormalization(gamma_initializer=gamma_init, trainable=True)(x)
x = ReLU()(x)

x = Conv2DTranspose(filters=256, kernel_size=5, strides=2, padding='same')(x)
x = BatchNormalization(gamma_initializer=gamma_init, trainable=True)(x)
x = ReLU()(x)
x = Conv2DTranspose(filters=128, kernel_size=5, strides=2, padding='same')(x)
x = BatchNormalization(gamma_initializer=gamma_init, trainable=True)(x)
x = ReLU()(x)
x = Conv2DTranspose(filters=64, kernel_size=5, strides=2, padding='same')(x)
x = BatchNormalization(gamma_initializer=gamma_init, trainable=True)(x)
x = ReLU()(x)
x = Conv2DTranspose(filters=3, kernel_size=5, strides=1, padding='same')(x)
outputs = Activation('tanh', name='decoder_output')(x)

# Instantiate Decoder Model
decoder = Model(latent_inputs, outputs, name='decoder')
decoder.summary()

In [None]:
class VAE(Model):
    def __init__(self, encoder, latent_dim, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.cf_layer = Dense(units=latent_dim, kernel_initializer=tf.constant_initializer(np.eye(latent_dim)))
        # self.cf_layer = Sequential(
        #     [
        #     Input(shape=(latent_dim,)),
        #     Dense(32, activation="relu"),
        #     Dense(latent_dim),
        #     ]
        # )
        self.decoder = decoder

    def train_step(self, data):
        if isinstance(data, tuple):
            data = data[0]
        with tf.GradientTape() as tape:
            z, mean, logvar = encoder(data)
            reconstruction = decoder(z)
            reconstruction_loss = tf.reduce_mean(
                tf.keras.losses.binary_crossentropy(data, reconstruction)
            )
            reconstruction_loss *= 64 * 64
            kl_loss = 1 + logvar - tf.square(mean) - tf.exp(logvar)
            kl_loss = tf.reduce_mean(kl_loss)
#             kl_loss = tf.reduce_sum(kl_loss)
            kl_loss *= -0.5
#             LOSS_FACTOR = 10000
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        return {
            "loss": total_loss,
            "reconstruction_loss": reconstruction_loss,
            "kl_loss": kl_loss,
        }

    def call(self, inputs):
      z, mean, logvar = encoder(inputs)
      return decoder(z)

    def cf(self, inputs):
      z, mean, logvar = encoder(inputs)
      return decoder(self.cf_layer(z))

    @tf.function
    def sample(self, eps=None):
      if eps is None:
        eps = tf.random.normal(shape=(100, latent_dim))
      return tf.sigmoid(self.decoder(eps))

# Instantiate vae
vae = VAE(encoder, latent_dim, decoder)

In [None]:
latent_dim

In [None]:
vae.load_weights('../input/checkpoint-30/CELEBVAE_30.tf')

In [None]:
pic_size=64
pic_size

In [None]:
temp = cv2.imread('../input/celeba-dataset/img_align_celeba/img_align_celeba/000001.jpg')
temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB)
temp = cv2.resize(temp, (pic_size,pic_size)).astype('float32') / 255.
plt.imshow(temp)


In [None]:
temp = temp[np.newaxis, :]
p = vae(temp)

In [None]:
b = p[0]
b = K.eval(b)
plt.imshow(b)

In [None]:
vae.compile(optimizer='adam')

In [None]:
vae.fit_generator(data_flow, epochs=15)

In [None]:
vae.save_weights("cgen_pretrain_vae_400.tf")