# VAE implementation by tensorflow

In [None]:
import os

import numpy as np
import tensorflow as tf

from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline

np.random.seed(0)
tf.set_random_seed(0)

# utility関数

In [None]:
def save_metrics(metrics, epoch=None):
    """
    plt.figure(figsize=(10,8))
    plt.plot(metrics["dis_loss"], label="discriminative loss", color="b")
    plt.legend()
    plt.savefig(os.path.join("metrics", "dloss" + str(epoch) + ".png"))
    plt.close()
    """
    plt.figure(figsize=(10,8))
    plt.plot(metrics, label="generative loss", color="r")
    plt.legend()
    plt.savefig(os.path.join("metrics", "g_loss" + str(epoch) + ".png"))
    plt.close()

In [None]:
# noise[[examples, 100]]から生成した画像をplot_dim(例えば4x4)で表示
def save_imgs(images, plot_dim=(5,12), size=(12,5), epoch=None):
    #examples = plot_dim[0]*plot_dim[1]
    examples = 60
    #print(epoch)
    # 表示
    fig = plt.figure(figsize=size)
    for i in range(examples):
        plt.subplot(5, 12, i+1)
        img = images[i, :]
        img = img.reshape((96, 96, 3))
        plt.tight_layout()
        plt.imshow(img)
        plt.axis("off")
    plt.subplots_adjust(wspace=0.1, hspace=0.1)
    plt.savefig(os.path.join("generated_figures", str(epoch) + ".png"))
    plt.close()


# モデル

In [None]:
class VAE:
    def __init__(self):
        self.x_dim = 96*96*3
#        self.h_dim = config.h_dim
        self.z_dim = 100

        self.initializer = tf.random_normal_initializer(mean=0.0, stddev=0.01, dtype=tf.float32)

        print("init")

    def __call__(self, x):
        print("called")
        x_tensored = tf.convert_to_tensor(x)
        x_tensored = tf.reshape(x_tensored, [-1, 96*96*3])

        # Hindden Layer Encoder
        with tf.variable_scope("encoder"):
            weight = tf.get_variable("W3", shape=[96*96*3, 1000], initializer=self.initializer)
            bias = tf.get_variable("b3", shape=[1000], initializer=self.initializer)
            enc1 = tf.matmul(x_tensored, weight) + bias
            enc1 = tf.tanh(enc1)

            weight = tf.get_variable("W4", shape=[1000, 500], initializer=self.initializer)
            bias = tf.get_variable("b4", shape=[500], initializer=self.initializer)
            enc2 = tf.matmul(enc1, weight) + bias
            enc2 = tf.tanh(enc2)

            weight = tf.get_variable("W5", shape=[500, 250], initializer=self.initializer)
            bias = tf.get_variable("b5", shape=[250], initializer=self.initializer)
            enc3 = tf.matmul(enc2, weight) + bias
#            enc3 = tf.nn.relu(enc3)
        
        # Mu Encoder
        with tf.variable_scope("enc_mu"):
            weight = tf.get_variable("W", shape=[250,self.z_dim], initializer=self.initializer)
            bias = tf.get_variable("b", shape=[self.z_dim], initializer=self.initializer)
            enc_mu = tf.matmul(enc3, weight) + bias
            enc_mu = tf.nn.relu(enc_mu)
        
        # Sigma Encoder
        with tf.variable_scope("enc_logsg"):
            weight = tf.get_variable("W", shape=[250,self.z_dim], initializer=self.initializer)
            bias = tf.get_variable("b", shape=[self.z_dim], initializer=self.initializer)
            enc_logsd = tf.matmul(enc3, weight) + bias
            enc_logsd = tf.nn.relu(enc_logsd)
        
        # Sample Epsilon
        epsilon = tf.random_normal(tf.shape(enc_mu), name="epsilon")
        
        # Sample Latent Variable
        std_encoder = tf.exp(enc_logsd/2)
        
        # Compute KL divergence
        KLD = -0.5 * tf.reduce_sum(1 + enc_logsd - tf.pow(enc_mu, 2) - tf.exp(enc_logsd), reduction_indices=1)

        # Generate z
        # z = mu + (sigma * epsilon)
        z = enc_mu + tf.multiply(std_encoder, epsilon)

        # Hidden Layer decoder
        with tf.variable_scope("decoder"):
            weight = tf.get_variable("W1", shape=[self.z_dim,250], initializer=self.initializer)
            bias = tf.get_variable("b1", shape=[250], initializer=self.initializer)
            dec1 = tf.matmul(z, weight) + bias
            dec1 = tf.tanh(dec1)

            weight = tf.get_variable("W2", shape=[250,500], initializer=self.initializer)
            bias = tf.get_variable("b2", shape=[500], initializer=self.initializer)
            dec2 = tf.matmul(dec1, weight) + bias
            dec2 = tf.tanh(dec2)

            weight = tf.get_variable("W3", shape=[500,1000], initializer=self.initializer)
            bias = tf.get_variable("b3", shape=[1000], initializer=self.initializer)
            dec3 = tf.matmul(dec2, weight) + bias
            dec3 = tf.tanh(dec3)

            weight = tf.get_variable("W4", shape=[1000, 96*96*3], initializer=self.initializer)
            bias = tf.get_variable("b4", shape=[96*96*3], initializer=self.initializer)
            dec4 = tf.matmul(dec3, weight) + bias
            dec4 = tf.sigmoid(dec4)

        # Compute binary cross entropy(reconstruction loss)
        BCE = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=dec4, labels=x_tensored), reduction_indices=1)

        # Compute loss
        loss = tf.reduce_mean(KLD + BCE)

        return dec4, loss

    def train(self, loss):
        train_op = tf.train.AdamOptimizer(0.01).minimize(loss)
        return train_op

# 設定情報

In [None]:
import numpy as np

X_train = np.load("irasutoya_face_1813x96x96x3_jpg.npy")
X_train = X_train/255

batch_size = 64
epochs = 1000

metrics_save_epoch = 10
img_save_epoch = 2

param_save_epoch = 10000
losses = []




In [None]:
train_imgs = tf.placeholder(tf.float32, shape=[batch_size, 96, 96, 3])
#z = tf.placeholder(tf.float32, shape=[batch_size, z_dim])

AE = VAE()

pred, ae_loss = AE(train_imgs)
train_op = AE.train(ae_loss)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    
    for epoch in range(epochs):
        #print(epoch)

        # 訓練データを抜粋
        rand_index = np.random.randint(0, X_train.shape[0], size=batch_size)
        exImgs = X_train[rand_index, :].astype(np.float32)

        loss, _ = sess.run([ae_loss, train_op], feed_dict={train_imgs: exImgs})

        losses.append(loss)

        if epoch % metrics_save_epoch == 0:

            save_metrics(losses, epoch)
        
        if epoch % img_save_epoch == 0:
            print("epoch:" + str(epoch) + ", loss:" + str(loss))
            imgs = sess.run(pred, feed_dict={train_imgs: exImgs})
            #print(imgs.shape)
            imgs = imgs.reshape(imgs.shape[0], 96, 96, 3)
            #print(imgs.shape)
            save_imgs(imgs, epoch=epoch)
            save_imgs(imgs*255, epoch="_"+str(epoch))
        