In [2]:

import tensorflow as tf
# from tensorflow_probability  import distributions as tfd
from tensorflow.keras import layers, Model
from tensorflow.keras.layers import Layer, Input, Conv2D, Dense, Flatten, Reshape, Lambda, Dropout
from tensorflow.keras.layers import Conv2DTranspose, MaxPooling2D, UpSampling2D, LeakyReLU, BatchNormalization
from tensorflow.keras.activations import relu
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.preprocessing.image import ImageDataGenerator

import tensorflow_datasets as tfds

import cv2
import numpy as np
import matplotlib.pyplot as plt
import datetime, os
import warnings
warnings.filterwarnings('ignore')

from packaging.version import parse as parse_version
assert parse_version(tf.__version__) < parse_version("2.4.0"), \
    f"Please install TensorFlow version 2.3.1 or older. Your current version is {tf.__version__}."



In [4]:
(ds_train, ds_test_), ds_info = tfds.load('celeb_a', 
                              split=['train', 'test'], 
                              shuffle_files=True,
                              with_info=True, data_dir='/data/')

In [6]:


batch_size = 128

def preprocess(sample):
    image = sample['image']
    image = tf.image.resize(image, [112,112])
    image = tf.cast(image, tf.float32)/255.
    return image, image

ds_train = ds_train.map(preprocess)
ds_train = ds_train.shuffle(batch_size*4)
ds_train = ds_train.batch(batch_size).prefetch(batch_size)

ds_test = ds_test_.map(preprocess).batch(batch_size).prefetch(batch_size)

train_num = ds_info.splits['train'].num_examples
test_num = ds_info.splits['test'].num_examples



In [7]:
class GaussianSampling(Layer):        
    def call(self, inputs):
        means, logvar = inputs
        epsilon = tf.random.normal(shape=tf.shape(means), mean=0., stddev=1.)
        samples = means + tf.exp(0.5*logvar)*epsilon

        return samples
    
class DownConvBlock(Layer):
    count = 0
    def __init__(self, filters, kernel_size=(3,3), strides=1, padding='same'):
        super(DownConvBlock, self).__init__(name=f"DownConvBlock_{DownConvBlock.count}")
        DownConvBlock.count+=1
        self.forward = Sequential([Conv2D(filters, kernel_size, strides, padding)])
        self.forward.add(BatchNormalization())
        self.forward.add(layers.LeakyReLU(0.2))
        
    def call(self, inputs):
        return self.forward(inputs)

class UpConvBlock(Layer):
    count = 0
    def __init__(self, filters, kernel_size=(3,3), padding='same'):
        super(UpConvBlock, self).__init__(name=f"UpConvBlock_{UpConvBlock.count}")
        UpConvBlock.count += 1
        self.forward = Sequential([Conv2D(filters, kernel_size, 1, padding),])
        self.forward.add(layers.LeakyReLU(0.2))
        self.forward.add(UpSampling2D((2,2)))
        
    def call(self, inputs):
        return self.forward(inputs)
    
class Encoder(Layer):
    def __init__(self, z_dim, name='encoder'):
        super(Encoder, self).__init__(name=name)
        
        self.features_extract = Sequential([
            DownConvBlock(filters = 32, kernel_size=(3,3), strides=2),
            DownConvBlock(filters = 32, kernel_size=(3,3), strides=2),
            DownConvBlock(filters = 64, kernel_size=(3,3), strides=2),
            DownConvBlock(filters = 64, kernel_size=(3,3), strides=2),
            Flatten()])
        
        self.dense_mean = Dense(z_dim, name='mean')
        self.dense_logvar = Dense(z_dim, name='logvar')
        self.sampler = GaussianSampling()
        
    def call(self, inputs):
        x = self.features_extract(inputs)
        mean = self.dense_mean(x)
        logvar = self.dense_logvar(x)
        z = self.sampler([mean, logvar])
        return z, mean, logvar

class Decoder(Layer):
    def __init__(self, z_dim, name='decoder'):
        super(Decoder, self).__init__(name=name)
            
        self.forward = Sequential([
                        Dense(7*7*64, activation='relu'),
                        Reshape((7,7,64)),
                        UpConvBlock(filters=64, kernel_size=(3,3)),
                        UpConvBlock(filters=64, kernel_size=(3,3)),
                        UpConvBlock(filters=32, kernel_size=(3,3)),
                        UpConvBlock(filters=32, kernel_size=(3,3)),
                        Conv2D(filters=3, kernel_size=(3,3), strides=1, padding='same', activation='sigmoid'),
                
        ])

    def call(self, inputs):
        return self.forward(inputs)

    
class VAE(Model):
    def __init__(self, z_dim, name='VAE'):
        super(VAE, self).__init__(name=name)
        self.encoder = Encoder(z_dim)
        self.decoder = Decoder(z_dim)
        self.mean = None
        self.logvar = None
        
    def call(self, inputs):
        z, self.mean, self.logvar = self.encoder(inputs)
        out = self.decoder(z)           
        return out

In [8]:
vae = VAE(z_dim=200)

In [9]:
def vae_kl_loss(y_true, y_pred):
    kl_loss =  - 0.5 * tf.reduce_mean(1 + vae.logvar - tf.square(vae.mean) - tf.exp(vae.logvar))
    return kl_loss    

def vae_rc_loss(y_true, y_pred):
    #rc_loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
    rc_loss = tf.keras.losses.MSE(y_true, y_pred)
    return rc_loss

def vae_loss(y_true, y_pred):
    kl_loss = vae_kl_loss(y_true, y_pred)
    rc_loss = vae_rc_loss(y_true, y_pred)
    kl_weight_const = 0.01
    return kl_weight_const*kl_loss + rc_loss

In [10]:
model_path = "./models/my_vae_celeb_a.h5"

checkpoint = ModelCheckpoint(model_path, 
                             monitor= "vae_rc_loss", 
                             verbose=1, 
                             save_best_only=True, 
                             mode= "auto", 
                             save_weights_only = True)

early = EarlyStopping(monitor= "vae_rc_loss", 
                      mode= "auto", 
                      patience = 3)

callbacks_list = [checkpoint, early]

initial_learning_rate = 1e-3
steps_per_epoch = int(np.round(train_num/batch_size))
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate,
    decay_steps=steps_per_epoch,
    decay_rate=0.96,
    staircase=True)

vae.compile(
    loss = [vae_loss],
    optimizer=tf.keras.optimizers.RMSprop(learning_rate=3e-3),
    metrics=[vae_kl_loss,vae_rc_loss])


history = vae.fit(ds_train, validation_data=ds_test,
                epochs = 20, callbacks = callbacks_list)

Epoch 1/20


UnknownError:  Failed to get convolution algorithm. This is probably because cuDNN failed to initialize, so try looking to see if a warning log message was printed above.
	 [[node VAE/encoder/sequential_4/DownConvBlock_0/sequential/conv2d/Conv2D (defined at <ipython-input-7-b392740aaf94>:19) ]] [Op:__inference_train_function_4107]

Errors may have originated from an input operation.
Input Source operations connected to node VAE/encoder/sequential_4/DownConvBlock_0/sequential/conv2d/Conv2D:
 IteratorGetNext (defined at <ipython-input-10-967ff380b52c>:31)

Function call stack:
train_function
