# Variational auto-encoders on anime faces

In [None]:
# Importing the libraries 

import tensorflow as tf
import tensorflow_datasets as tfds

import matplotlib.pyplot as plt
import numpy as np

import os 
import zipfile
import urllib.request
import random
from IPython import display

# Load and prepare dataset 

In [None]:
# parameters 

np.random.seed(51)

BATCH_SIZE = 2000

LATENT_DIM = 512

IMAGE_SIZE = 64

In [None]:
# Download the dataset 

dataurl = "https://storage.googleapis.com/learning-datasets/Resources/anime-faces.zip"

data_file_name = 'animefaces.zip'

download_dir = 'tmp/anime/'

urllib.request.urlretrieve(dataurl, data_file_name)

In [None]:
# Extract the Zip file 

zip_ref = zipfile.ZipFile(data_file_name , 'r')

zip_ref.extractall(download_dir)

zip_ref.close()

In [None]:
# Prepare the Dataset 

def get_dataset_slice_paths(image_dir):
    
    # Returns List of path to all image files
    
    image_file_list = os.listdir(image_dir)
    
    image_paths = [os.path.join(image_dir , fname) for fname in image_file_list]

    return image_paths

def map_images(image_filename):
    
    # Preprocess the images 
    
    img_raw = tf.io.read_file(image_filename)
    
    image = tf.image.decode_jpeg(img_raw)
    
    
    image = tf.cast(image , dtype= tf.float32)
    
    image = image.resize(image , (IMAGE_SIZE , IMAGE_SIZE))
    
    image = image / 255.0
    
    image = tf.reshape(image , shape = (IMAGE_SIZE , IMAGE_SIZE , 3 , ))
    
    return image

# Generate Data and validation sets

In [None]:
# Get lists conataining images paths

data_path = "/tmp/anime/images/"

paths = get_dataset_slice_paths(data_path)

# shuffle the paths

random.shuffle(paths)

# Split the path to training and validation dataset

path_len = len(paths)

train_path_len = int(path_len * 0.8)

train_paths = paths[: train_path_len]

val_paths = paths[train_path_len : ]

# load the training image paths into tensors, create batches and shuffle

training_dataset = tf.data.Dataset.from_tensor_slices((train_paths)) 

training_dataset = training_dataset.map(map_images)

training_dataset = training_dataset.shuffle(1000).batch(BATCH_SIZE)

# load the validation image paths into tensors and create batches


validation_dataset = tf.data.Dataset.from_tensor_slices((val_paths))

validation_dataset = validation_dataset.map(map_images)

validation_dataset = validation_dataset.batch(BATCH_SIZE)


print(f'number of batches in the training set: {len(training_dataset)}')

print(f'number of batches in the validation set: {len(validation_dataset)}')

# Display

In [None]:
def display_faces(dataset, size=9):
    
  '''Takes a sample from a dataset batch and plots it in a grid.'''
  dataset = dataset.unbatch().take(size)
  
  n_cols = 3
  
  n_rows = size//n_cols + 1
  
  plt.figure(figsize=(5, 5))
  
  i = 0
  
  for image in dataset:
      
    i += 1
    
    disp_img = np.reshape(image, (64,64,3))
    
    plt.subplot(n_rows, n_cols, i)
    
    plt.xticks([])
    
    plt.yticks([])
    
    plt.imshow(disp_img)


def display_one_row(disp_images, offset, shape=(28, 28)):
    
  '''Displays a row of images.'''
  
  for idx, image in enumerate(disp_images):
      
    plt.subplot(3, 10, offset + idx + 1)
    
    plt.xticks([])
    
    plt.yticks([])
    
    image = np.reshape(image, shape)
    
    plt.imshow(image)


def display_results(disp_input_images, disp_predicted):
    
  '''Displays input and predicted images.'''
  
  plt.figure(figsize=(15, 5))
  
  display_one_row(disp_input_images, 0, shape=(IMAGE_SIZE,IMAGE_SIZE,3))
  
  display_one_row(disp_predicted, 20, shape=(IMAGE_SIZE,IMAGE_SIZE,3))


In [None]:
display_faces(validation_dataset, size=12)

# Build the model 

In [None]:
class sampling(tf.keras.layers.layer):
    
    def call(self , inputs):
        
        mu , sigma = inputs
        
        batch = tf.shape(mu)[0]
        
        dim = tf.shape(mu)[1]
        
        epsilon = tf.keras.backend.random_normal(shape = (batch , dim))
        
        z = mu + tf.exp(0.5 * sigma) * epsilon 
        
        return z       

In [None]:
# Define encoder layers

def encoder_layers(inputs , latent_dim):
    
    x = tf.keras.layers.Conv2D(filters = 32 , kernel_size = ( 3, 3) , strides = 2 , activation = ' relu' , padding = 'same' , name = 'encode_conv1')(inputs)
    
    x = tf.keras.layers.BatchNormalization()(x)
    
    x = tf.keras.layers.Conv2D(filters = 64 , kernel_size = ( 3, 3) , strides = 2 , activation = 'relu' , padding = 'same' , name = 'encode_conv2')(x)
    
    x = tf.keras.layers.BatchNormalization()(x)
    
    x = tf.keras.layers.Conv2D(filters = 64 , kernel_size = ( 3 , 3) , strides = 2 , activation = 'relu' , padding = 'same' , name = 'encode_conv3')(x)
    
    batch_3 = tf.keras.layers.BatchNormalization()(x)
    
    x = tf.keras.layers.Flatten(name = 'encode_flatten')(batch_3)
    
    x = tf.keras.layers.Dense(1024 , activation = 'relu' , name = 'encode_dense')(x)
    
    x = tf.keras.layers.BatchNormalization()(x)
    
    mu = tf.keras.layers.Dense(latent_dim , name = 'laten_mu')(x)
    
    sigma = tf.keras.layers.Dense(latent_dim , name = 'laten_sigma')(x)
    
    return mu , sigma , batch_3.shape

In [None]:
# Define Encoder model 

def encoder_model(input_shape):
    
    inputs = tf.keras.layers.Input(shape = (input_shape))
    
    mu , sigma , conv_shape = encoder_layers(inputs , latent_dim= LATENT_DIM)
    
    z = sampling()((mu , sigma))
    
    model = tf.keras.Model(inputs , outputs = [ mu , sigma , z])
    
    model.summary()
    
    return model , conv_shape

In [None]:
# Decoder layers

def decoder_layers(inputs , conv_shape):
    
    units = conv_shape[0] * conv_shape[1] * conv_shape[2]       # number of neurons
    
    x = tf.keras.layers.Dense(units , activation = 'relu' , name = 'Decode_dense1')(inputs)
    
    x = tf.keras.layers.BatchNormalization()(x)
    
    x = tf.keras.Reshape((conv_shape[0] , conv_shape[1], conv_shape[2]) , name = 'decode_shape')    # reshape the decode input
    
    # Upsample the image back to original dimentions
    
    x = tf.keras.layers.Conv2DTranspose(filters = 64 , kernel_size = (3 , 3) , strides = 2 , padding = 'same' , activation = 'relu' , name = 'decode_conv1')(x)
    
    x = tf.keras.layers.BatchNormalization()(x)
    
    x = tf.keras.layers.Conv2DTranspose(filters = 64 , kernel_size = (3 ,3 ) , strides = 2 , padding = 'same' , activation = 'relu' , name = 'decode_conv2')(x)
    
    x = tf.keras.layers.BatchNormalization()(x)
    
    x = tf.keras.layers.Conv2DTranspose(filters = 32 , kernel_size = (3 ,3 ) , strides = 2 , padding = 'same' , activation = 'relu' , name = 'decode_conv3')(x)
  
    x = tf.keras.layers.BatchNormalization()(x)
    
    x = tf.keras.layers.Conv2DTranspose(filters = 3 , kernel_size = (3 ,3 ) , strides = 2 , padding = 'same' , activation = 'relu' , name = 'decode_final')(x)
    
    return x


In [None]:
# Decoder model 

def decoder_model(latent_dim , conv_shape):
    
    inputs = tf.keras.layers.Input(shape = latent_dim)
    
    outputs = decoder_layers(inputs , conv_shape)
    
    model = tf.keras.Model(inputs , outputs)
    
    model.summary()
    
    return model

# Kullback–Leibler Divergence

In [None]:
def kl_reconstruciton_loss(mu , sigma):
    
    kl_loss = 1 + sigma - tf.square(mu) - tf.math.exp(sigma)
    
    return tf.reduce_mean(kl_loss) * -0.5    

In [None]:
# Putting all together

def vae_model(encoder , decoder , input_shape):
    
    inputs = tf.keras.model.Inputs(input_shape)
    
    mu , sigma , z = encoder(inputs)
    
    reconstructed = decoder(z)
    
    model = tf.keras.layers.Model(inputs = inputs , outputs = reconstructed)
    
    loss = kl_reconstruciton_loss(mu , sigma=sigma)
    
    model.add_loss(loss)

    return model    
      
    
    

In [None]:
def get_models(latent_dim , input_shape):
    
    encoder, conv_shape = encoder_model(latent_dim=LATENT_DIM, input_shape=input_shape) 
    
    decoder = decoder_model(latent_dim=latent_dim, conv_shape=conv_shape) 
    
    vae = vae_model(encoder, decoder, input_shape=input_shape)
    
    return encoder , decoder , vae

In [None]:
encoder, decoder, vae = get_models(input_shape=(64,64,3,), latent_dim=LATENT_DIM)