In [1]:
from datetime import datetime
import numpy as np
import tensorflow as tf
from tensorflow import keras
from PIL import Image


In [2]:
def save_images(imgs, name):
    """
    多张小图合成一张大图
    """
    new_im = Image.new('L', (280, 280))

    index = 0
    for i in range(0, 280, 28):
        for j in range(0, 280, 28):
            im = imgs[index]
            im = Image.fromarray(im, mode='L')
            new_im.paste(im, (i, j))
            index += 1

    new_im.save(name)

In [4]:
# tf.datasets.fashion_mnist.load_data()
(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()

batch_size = 512
# h_dim = 20
z_dim = 10



x_train = x_train.astype(np.float32) / 255
x_test = x_test.astype(np.float32) / 255

# 不需要标签
db_train = tf.data.Dataset.from_tensor_slices(x_train)
db_test = tf.data.Dataset.from_tensor_slices(x_test)

db_train = db_train.shuffle(10000).batch(batch_size)
db_test = db_test.batch(batch_size)

print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)


(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)


In [5]:
class VAE(keras.Model):
    def __init__(self):
        super(VAE, self).__init__()
        
        # encode 编码层
        self.fc1 = keras.layers.Dense(128)
        # 得到mean值层
        self.fc2 = keras.layers.Dense(z_dim)
        # 得到方差层
        self.fc3 = keras.layers.Dense(z_dim)
        
        # decode 解码层
        self.fc4 = keras.layers.Dense(128)
        self.fc5 = keras.layers.Dense(784)
       
    def encode(self, x):
        # 编码 拿到预测均值和log方差
        h = tf.nn.relu(self.fc1(x))
        
        mu = self.fc2(h)
        log_var = self.fc3(h)
        
        return mu, log_var
    
    def decode(self, x):
        out = tf.nn.relu(self.fc4(x))
        out = self.fc5(out)
        return out
    
    def reparameterize(self, mu, log_var):
        eps = tf.random.normal(log_var.shape)
        std = tf.exp(log_var * 0.5)
        z = mu + std * eps
        return z
        
        
    def call(self, inputs, training=None):
        # [b, 784] => [b, z_dim], [b, z_dim]
        mu, log_var = self.encode(inputs)
        z = self.reparameterize(mu, log_var)
        x_hat = self.decode(z)
        
        return x_hat, mu, log_var

        

In [7]:
model = VAE()

model.build(input_shape=(4, 784))
model.summary()

Model: "vae_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_5 (Dense)              multiple                  100480    
_________________________________________________________________
dense_6 (Dense)              multiple                  1290      
_________________________________________________________________
dense_7 (Dense)              multiple                  1290      
_________________________________________________________________
dense_8 (Dense)              multiple                  1408      
_________________________________________________________________
dense_9 (Dense)              multiple                  101136    
Total params: 205,604
Trainable params: 205,604
Non-trainable params: 0
_________________________________________________________________


In [10]:
# 开始训练
optimizer = tf.optimizers.Adam(1e-3)

for epoch in range(100):
    for step, x in enumerate(db_train):
        # [b, 28, 28] => [b, 784]
        x = tf.reshape(x, [-1, 28 * 28])
        with tf.GradientTape() as tap:
            x_rec_logits, mu, log_var = model(x)
            rec_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=x, logits=x_rec_logits)
            rec_loss = tf.reduce_sum(rec_loss) / x.shape[0]
            
            # 计算散度
            # compute kl divergence (mu, var) ~ N (0, 1)
            # https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussians
            kl_div = -0.5 * (log_var + 1 - mu**2 - tf.exp(log_var))
            kl_div = tf.reduce_sum(kl_div) / x.shape[0]
            
            loss = rec_loss + 1. * kl_div
            

        # 计算梯度 更新梯度
        grads = tap.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

        if step % 100 ==0:
            print(epoch, step, float(rec_loss))

        # 测试
        # 随机生成一个正太分布
        z = tf.random.normal((batch_size, z_dim))
        logits = model.decode(z)

        x_hat = tf.sigmoid(logits)
        # [b, 784] => [b, 28, 28]
#         x_hat = tf.reshape(x_hat, [-1, 28, 28])
        x_hat = tf.reshape(x_hat, [-1, 28, 28]).numpy() *255.

        x_hat = x_hat.astype(np.uint8)
        save_images(x_hat, 'vae_images/sampled_epoch%d.png'%epoch)

        
        x = next(iter(db_test))
        x = tf.reshape(x, [-1, 28 *28])
        logits, _, _ = model(x)
        x_hat = tf.sigmoid(logits)
        # [b, 784] => [b, 28, 28]
        x_hat = tf.reshape(x_hat, [-1, 28, 28]).numpy() *255.

        x_hat = x_hat.astype(np.uint8)
        save_images(x_hat, 'vae_images/rec_epoch%d.png'%epoch)

    
    
        
        

0 0 537.3427734375
0 100 286.2420654296875
1 0 278.1512451171875
1 100 258.4302978515625
2 0 254.70762634277344
2 100 251.5272979736328
3 0 248.7001495361328
3 100 247.20147705078125
4 0 244.83370971679688
4 100 245.36611938476562
5 0 242.1671142578125
5 100 242.462646484375


KeyboardInterrupt: 