In [1]:
import os
import tensorflow as tf
import numpy as np
from tensorflow import keras
from PIL import Image
from matplotlib import pyplot as plt
from tensorflow.keras import datasets, Sequential, layers, metrics, optimizers, losses

tf.random.set_seed(22)
np.random.seed(22)

In [2]:
def save_image(imgs, name):
  new_im = Image.new('L', (280, 280))
  index = 0
  for i in range(0, 280, 28):
    for j in range(0, 280, 28):
      im = imgs[index]
      im = Image.fromarray(im, mode='L')
      new_im.paste(im, (i,j))
      index += 1
  new_im.save(name)

In [3]:
h_dim = 20
z_dim = 10
batches = 512

In [4]:
(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(np.float32) / 255.

train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(batches*5).batch(batches)

test_db = tf.data.Dataset.from_tensor_slices(x_test)
test_db = test_db.batch(batches)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)


In [5]:
class VAE(keras.Model):
  def __init__(self):
    super(VAE, self).__init__()
    # Encoders
    self.fc1 = layers.Dense(128, activation=tf.nn.relu)
    self.fc2 = layers.Dense(z_dim)
    self.fc3 = layers.Dense(z_dim)

    #Decoders
    self.fc4 = layers.Dense(128, activation=tf.nn.relu)
    self.fc5 = layers.Dense(784)

  def encoder(self, x):
    h = self.fc1(x)
    mu = self.fc2(h)
    log_var = self.fc3(h)
    return mu, log_var

  def decoder(self, z):
    out = self.fc4(z)
    out = self.fc5(out)
    return out

  def reparameterize(self, mu, log_var):
    esp = tf.random.normal(log_var.shape)
    std = tf.exp(log_var*0.5)
    z = mu + std * esp
    return z

  def call(self, inputs, training):
    # [b, 784] -> [b, z_dim], [b, z_dim]
    mu, log_var = self.encoder(inputs)

    z = self.reparameterize(mu, log_var)
    x_hat = self.decoder(z)
    return x_hat, mu, log_var

In [6]:
model = VAE()
model.build(input_shape = (4,784))
model.summary()

optimizer = optimizers.Adam(lr=1e-3)

Model: "vae"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense (Dense)               multiple                  100480    
                                                                 
 dense_1 (Dense)             multiple                  1290      
                                                                 
 dense_2 (Dense)             multiple                  1290      
                                                                 
 dense_3 (Dense)             multiple                  1408      
                                                                 
 dense_4 (Dense)             multiple                  101136    
                                                                 
Total params: 205604 (803.14 KB)
Trainable params: 205604 (803.14 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________




In [None]:
for epoch in range(50):
  for step, x in enumerate(train_db):
    x = tf.reshape(x, [-1, 784])
    with tf.GradientTape() as tape:
      x_rec_logits, mu, log_var = model(x)
      rec_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=x, logits=x_rec_logits)
      rec_loss = tf.reduce_sum(rec_loss) / x.shape[0]

      kl_div = -0.5*(log_var+1-mu**2-tf.exp(log_var))
      kl_div = tf.reduce_sum(kl_div) / x.shape[0]

      loss = rec_loss + 1.*kl_div

    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))

    if step % 100 == 0:
      print('epoch:', epoch, 'step:', step, 'kl_div:', float(kl_div), 'rec loss:', float(rec_loss))

    z = tf.random.normal((batches, z_dim))
    logits = model.decoder(z)
    x_hat = tf.sigmoid(logits)
    x_hat = tf.reshape(x_hat, [-1,28,28]).numpy() * 255.
    x_hat = x_hat.astype(np.uint8)
    save_image(x_hat, '/content/drive/MyDrive/NCHU/碩二/深度生成模型/vae_images/sampled_epoch%d.png' % epoch)

    x = next(iter(test_db))
    x = tf.reshape(x, [-1,784])
    x_hat_logits, _, _ = model(x)
    x_hat = tf.sigmoid(x_hat_logits)
    x_hat = tf.reshape(x_hat, [-1, 28, 28]).numpy() * 255.
    x_hat = x_hat.astype(np.uint8)
    save_image(x_hat, '/content/drive/MyDrive/NCHU/碩二/深度生成模型/vae_images/rec_epoch%d.png' % epoch)



epoch: 0 step: 0 kl_div: 1.9156365394592285 rec loss: 546.4342651367188
epoch: 0 step: 100 kl_div: 15.764577865600586 rec loss: 285.8543395996094
epoch: 1 step: 0 kl_div: 15.772905349731445 rec loss: 271.75079345703125
epoch: 1 step: 100 kl_div: 15.878440856933594 rec loss: 258.23291015625
epoch: 2 step: 0 kl_div: 15.695530891418457 rec loss: 249.97988891601562
epoch: 2 step: 100 kl_div: 14.598577499389648 rec loss: 254.60684204101562
epoch: 3 step: 0 kl_div: 14.87392807006836 rec loss: 250.02096557617188
epoch: 3 step: 100 kl_div: 15.078659057617188 rec loss: 249.062255859375
epoch: 4 step: 0 kl_div: 14.45456314086914 rec loss: 245.77304077148438
epoch: 4 step: 100 kl_div: 15.277034759521484 rec loss: 246.81369018554688
epoch: 5 step: 0 kl_div: 14.87739086151123 rec loss: 237.3536834716797
epoch: 5 step: 100 kl_div: 15.063006401062012 rec loss: 240.16671752929688
epoch: 6 step: 0 kl_div: 15.251806259155273 rec loss: 236.96929931640625
epoch: 6 step: 100 kl_div: 14.844560623168945 rec 