<a href="https://colab.research.google.com/github/TA-aiacademy/course_3.0/blob/GAI/08_GAI/GAI_Part1/1_VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Variational Auto Encoder

## Preparations

### Imports and Installs

In [None]:
%pip install plotly

In [None]:
%matplotlib inline

import matplotlib.pyplot as plt
from plotly import express as px

import numpy as np
import tensorflow as tf

## Data

本次使用MNIST數字資料集

In [None]:
from tensorflow import keras

(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
mnist_digits = np.concatenate([x_train, x_test], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255

## Model Construction

### Sampling Cell
我們在課堂上學到，Auto Encoder會將圖片壓縮成較低維的feature map再還原，藉此學習特徵的萃取或壓縮

<img src=https://hackmd.io/_uploads/S199Ud3hh.png height=300>

而VAE會進一步將雜訊加入encode完的feature map中增加這個萃取訓練的穩定性：

$$
z =  \mu + \epsilon , \epsilon \sim \mathcal{N}(0,\sigma)
$$

但因為Normal Distribution這種取樣的步驟不可微分，所以實際上的做法則是使用 NN encode出兩個latent變數$\mu$與$\sigma$，再使用一個sample來的權重$\epsilon$就可以等效為上面z的算式:

<img src=https://hackmd.io/_uploads/S1Y3Idnnn.png height=300>

下面我們來實作一下這個算式作為一個Layer
$$
\epsilon \sim \mathcal{N}(0,1)
$$
$$
z = \mu + \sigma \odot \epsilon
$$

In [None]:
from tensorflow.keras import layers

class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.random.normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

### Encoder

In [None]:
class Encoder(keras.Model):
    def __init__(self, latent_dim=2):
        super().__init__()
        self.conv1 = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")
        self.conv2 = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")
        self.flatten = layers.Flatten()
        self.dense1 = layers.Dense(16, activation="relu")
        self.z_mean = layers.Dense(latent_dim, name="z_mean")
        self.z_log_var = layers.Dense(latent_dim, name="z_log_var")
        self.sampling = Sampling()

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.conv2(x)
        x = self.flatten(x)
        x = self.dense1(x)
        z_mean = self.z_mean(x)
        z_log_var = self.z_log_var(x)
        z = self.sampling([z_mean, z_log_var])
        return z_mean, z_log_var, z

latent_dim = 2
encoder = Encoder(latent_dim=latent_dim)

### Decoder (The Generator)

輸入vector (noise)，輸出圖片

In [None]:
class Decoder(keras.Model):
    def __init__(self, latent_dim=2):
        super().__init__()
        self.dense1 = layers.Dense(7 * 7 * 64, activation="relu")
        self.reshape = layers.Reshape((7, 7, 64))
        self.deconv1 = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")
        self.deconv2 = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")
        self.outputs = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")

    def call(self, inputs):
        x = self.dense1(inputs)
        x = self.reshape(x)
        x = self.deconv1(x)
        x = self.deconv2(x)
        y = self.outputs(x)
        return y

decoder = Decoder()

## Training Step

### Model Class with training step

In [None]:
class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
                )
            )
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

    def test_step(self, data):
        z_mean, z_log_var, z = self.encoder(data)
        reconstruction = self.decoder(z)
        reconstruction_loss = tf.reduce_mean(
            tf.reduce_sum(
                keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
            )
        )
        kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
        kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
        total_loss = reconstruction_loss + kl_loss

        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

vae = VAE(encoder, decoder)
vae.compile(optimizer=keras.optimizers.Adam())

看一下我們的兩個部分:

In [None]:
# Encoder, 負責將圖片對應的Latent vector產出，方便訓練latent vector到圖的 generation
encoder.build(input_shape=(None, 28, 28, 1))
encoder.summary()

In [None]:
# Decoder, 我們欲訓練的生成器，給予一個 vector 生一張圖片
decoder.build(input_shape=(None, latent_dim))
decoder.summary()

### Training Start

訓練將壓縮至latent space的資料還原回圖片，我們以這種還原訓練圖片生成的能力

In [None]:

history = vae.fit(mnist_digits, epochs=20, batch_size=128, validation_split=0.2)

### Evaluate for convergence
訓練完成後看看 loss 是否收斂

### Show History

In [None]:
plt.plot(history.history["val_loss"])
plt.plot(history.history["loss"])

## Output Observation

### Encoder Latent Space
在Encoder latent space中我們可以觀察到我們訓練出來的Encoder究竟將影像embed成什麼樣子的分布

In [None]:
import plotly.express as px

def plot_label_clusters(vae, data, labels):
    # Display a 2D plot of the digit classes in the latent space
    z_mean, _, _ = vae.encoder.predict(data, verbose=False)

    fig = px.scatter(x=z_mean[:, 0], y=z_mean[:, 1], color=labels)
    fig.update_layout(
        width=600,
        height=500,
        xaxis_title="z[0]",
        yaxis_title="z[1]",
        coloraxis_colorbar_title="Labels"
    )
    fig.show()



_ , (x_test, y_test) = keras.datasets.mnist.load_data()
x_test = np.expand_dims(x_test, -1).astype("float32") / 255

plot_label_clusters(vae, x_test, y_test)

# Output 出來每個class的分布都不一樣，可看出大致上在latent space (z[0],z[1]) 中不同數字佔不同區塊，分布也有些差異

### Decode Latent Space

Decoder 負責將latent space中的vector還原到image space中，我們來看看space中每個位置還原回來長什麼樣子

In [None]:
import numpy as np
import plotly.graph_objects as go

def plot_latent_space(vae, n=20, figsize=600):
    # Display an n*n 2D manifold of digits
    digit_size = 28
    scale = 1.0
    figure = np.zeros((digit_size * n, digit_size * n))
    # Linearly spaced coordinates corresponding to the 2D plot
    # of digit classes in the latent space
    grid_x = np.linspace(-scale, scale, n)
    grid_y = np.linspace(-scale, scale, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = vae.decoder.predict(z_sample, verbose=False)
            digit = x_decoded[0].reshape(digit_size, digit_size)
            figure[
                i * digit_size : (i + 1) * digit_size,
                j * digit_size : (j + 1) * digit_size,
            ] = digit

    start_range = digit_size // 2
    end_range = n * digit_size + start_range
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)

    fig = go.Figure(data=go.Heatmap(z=figure[::-1], colorscale="gray"))
    fig.update_xaxes(tickvals=pixel_range, ticktext=sample_range_x)
    fig.update_yaxes(tickvals=pixel_range, ticktext=sample_range_y)
    fig.update_layout(
        width=figsize,
        height=figsize,
        xaxis_title="z[0]",
        yaxis_title="z[1]",
    )
    fig.show()
plot_latent_space(vae)