## VAE ---- Variational Auto Encoders  变分自动编码器

在隐藏层，有两类节点：一类是所有特征的均值节点（Mean）；一类是所有特征的方差节点（Variance）。
VAE是一个生成模型。

In [1]:
import os
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import Sequential, layers
from PIL import Image

In [2]:
tf.random.set_seed(22)
np.random.seed(22)
os.environ['TF_CPP_MIN_LOG_LEVEV'] = '2'
assert tf.__version__.startswith('2.')

In [3]:
def save_image(imgs, name):
    new_im = Image.new('L', (280, 280))
    index = 0
    for i in range(0, 280, 28):
        for j in range(0, 280, 28):
            im = imgs[index]
            im = Image.fromarray(im, mode='L')
            new_im.paste(im, (i, j))
            index += 1
            
    new_im.save(name)


In [4]:
# 通过VAE 将图片降维后的维度
h_dim = 20
batchsz = 10000
lr = 1e-3

(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(np.float32) / 255.

train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(batchsz * 5).batch(batchsz, drop_remainder=True)
test_db = tf.data.Dataset.from_tensor_slices(x_test)
test_db = test_db.batch(batchsz, drop_remainder=True)

z_dim = 10

In [5]:
# 
class VAE(keras.Model):
    def __init__(self):
        super(VAE, self).__init__()
        
        # Encoder
        # 不添加relu激活函数
        self.fc1 = layers.Dense(256)    # [b, 28*28] => [b, 256]
        self.fc2 = layers.Dense(128)    # [b, 256] => [b, 128]
        self.fc_mean = layers.Dense(z_dim)    # fc2:[b, 128] => fc_mean:[b, z_dim]
        self.fc_var = layers.Dense(z_dim)     # fc2:[b, 128] => fc_var:[b, z_dim]
        
        # Decoder
#         self.fcd1 = layers.Dence(z_dim)    # [b, z_dim] => [b, z_dim]
#         self.fcd2 = layers.Dence(z_dim)    # [b, z_dim] => [b, z_dim]
        self.fc3 = layers.Dense(128)    # z_reparameterization:[b, z_dim] => fc3: [b, 128]
        self.fc4 = layers.Dense(256)    # fc3:[b, 128] => fc4:[b, 256]
        self.fc5 = layers.Dense(28*28)    # fc4:[b, 256] => fc5:[b, 28*28]
        
    # 创建 encoder 的传播过程
    def encoder(self, x):
        h1 = tf.nn.relu(self.fc1(x))    # [b, 28*28] => [b, 256]
        h2 = tf.nn.relu(self.fc2(h1))   # [b, 256] => [b, 128]
        # get mean
        h_mean = self.fc_mean(h2)    # [b, 128] => [b, z_dim=10]
        # get variance （一般情况下需要取对数，使其在整个实数R上取值）
        h_var = self.fc_var(h2)      # [b, 128] => [b, z_dim=10]
        
        return h_mean, h_var
    
    # 创建 decoder 的传播过程
    def decoder(self, z):
        out1 = tf.nn.relu(self.fc3(z))
        out2 = tf.nn.relu(self.fc4(out1))
        out3 = self.fc5(out2)
        
        return out3
    
    # reparameterization trick 再参数化 通过再参数化之后再前向传播为decoder部分
    def reparameterization(self, h_mean, h_var):
        # h_mean: [b, z_dim]
        # h_var:  [b, z_dim]
        
        # 从标准正态分布中采样的到 eps，此处由 tf.random.normal()，安装 h_var 的形状生成
        eps = tf.random.normal(h_var.shape)
        # 因为在使用reparameterization时，使用的是标准差，此处由方差计算标准差
        # 【关键点】
        # 此处的均值是实际意义的均值；但是方差不是实际情况中的方差
        # 此处迫使EncodeNetwork 能够学习到方差的对数（以e为底的对数）
        # 所以实际意义的方差 var = tf.exp(h_var)
#         std = tf.exp(h_var) ** 0.5    # 不应该是这样的
        std = tf.exp(h_var) ** 0.5
        # 再参数化技巧，实现可导，便于梯度反向传播更新参数
        z_reparameterization = h_mean + std * eps
        
        return z_reparameterization
    
    def call(self, inputs, training=None):
        # [b, 28*28] => [b, z_dim], [b, z_dim]
        h_mean, h_var = self.encoder(inputs)
        # reparameterization trick
        z_raparameterization = self.reparameterization(h_mean, h_var)
        # 重建 x_hat
        x_hat = self.decoder(z_raparameterization)
        
        return x_hat, h_mean, h_var

In [6]:
VAE_model = VAE()
VAE_model.build(input_shape=(batchsz, 28*28))    # 为什么不能是（None， 28*28）
optimizer = tf.optimizers.Adam(lr=lr)
VAE_model.summary()

Model: "vae"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                multiple                  200960    
_________________________________________________________________
dense_1 (Dense)              multiple                  32896     
_________________________________________________________________
dense_2 (Dense)              multiple                  1290      
_________________________________________________________________
dense_3 (Dense)              multiple                  1290      
_________________________________________________________________
dense_4 (Dense)              multiple                  1408      
_________________________________________________________________
dense_5 (Dense)              multiple                  33024     
_________________________________________________________________
dense_6 (Dense)              multiple                  201488  

In [7]:
for epoch in range(1000):
    for step, x in enumerate(train_db):
        
        # x: [b, 28, 28] => [-1, 28*28]
        x = tf.reshape(x, [-1, 28*28])
        
        with tf.GradientTape() as tape:
            
            x_rec_logits, h_mean, h_var = VAE_model(x)
#             print(x_rec_logits)
            # 经过VAE 重构的 x_rec_logits 与 x 之间的损失 rec_loss 
            rec_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=x, logits=x_rec_logits)
            # [b, 10] => scalar标量   tf.reduce_sum()  求出所有行和列的和
            rec_loss = tf.reduce_sum(rec_loss) / x.shape[0]
#             print(x.shape[0])

            # https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussians
            # compute kl divergence散度  p ~  N(h_mean, h_var)  与 q ~  N(0, 1)  的散度
            # 潜在分布Z ~ N( h_mean, h_val)  与 标准正态分布 N(0, 1)  的KL散度
            # h_mean: [b, 10], h_var: [b, 10]
#             print(h_mean)
#             print(h_var)
#             kl_div = -0.5 * (1 - h_val - h_mean**2 + tf.math.log(h_var))  
            kl_div = -0.5 * (h_var + 1 - h_mean**2 - tf.exp(h_var))
#             print(kl_div.shape)    [512, 10]
#             print(tf.reduce_sum(kl_div))
    
            kl_div = tf.reduce_sum(kl_div) / x.shape[0]
#             print(kl_div)
            # 其中 1. 属于超参数，用于权衡rec_loss 与 kl_div
            loss = rec_loss + 1. * kl_div
            
        grads = tape.gradient(loss, VAE_model.trainable_variables)
        optimizer.apply_gradients(zip(grads, VAE_model.trainable_variables))
        
    if epoch % 1 == 0:
            print(epoch, 'Rec loss: ', float(rec_loss), 'KL_div: ', float(kl_div))
            
    # evalution 测试评估
    # 对于随机生成的数据
    if epoch % 10 == 0:
        z = tf.random.normal((batchsz, z_dim))
        logits = VAE_model.decoder(z)
        x_hat = tf.sigmoid(logits)
        x_hat = tf.reshape(x_hat, [-1, 28, 28]).numpy() * 255.
        x_hat = x_hat.astype(np.uint8)
        save_image(x_hat, '/home/kukafee/workspace/picture/pic1/sampled_epoch_%d.png'%epoch)
    
    # 对于测试集中数据
    if epoch % 10 == 0:
        x_test = next(iter(test_db))
        x_test = tf.reshape(x_test, [-1, 28*28])
        x_hat_logits, _, _ = VAE_model(x_test)
        x_hat = tf.sigmoid(x_hat_logits)
        x_hat = tf.reshape(x_hat, [-1, 28, 28]).numpy() * 255.
        x_hat = x_hat.astype(np.uint8)
        save_image(x_hat, '/home/kukafee/workspace/picture/pic2/rec_epoch_%d.png'%epoch)
            

0 Rec loss:  510.4544372558594 KL_div:  2.8124639987945557
1 Rec loss:  432.5567932128906 KL_div:  10.553927421569824
2 Rec loss:  388.7998352050781 KL_div:  7.959600925445557
3 Rec loss:  366.1690368652344 KL_div:  6.751097679138184
4 Rec loss:  344.0827941894531 KL_div:  7.7710466384887695
5 Rec loss:  327.2944641113281 KL_div:  8.801284790039062
6 Rec loss:  314.537353515625 KL_div:  9.207307815551758
7 Rec loss:  306.978515625 KL_div:  9.217245101928711
8 Rec loss:  299.57366943359375 KL_div:  9.099848747253418
9 Rec loss:  295.4237976074219 KL_div:  9.484129905700684
10 Rec loss:  291.18084716796875 KL_div:  10.0331449508667
11 Rec loss:  287.4630126953125 KL_div:  10.387138366699219
12 Rec loss:  284.28131103515625 KL_div:  10.56123161315918
13 Rec loss:  280.41412353515625 KL_div:  10.880800247192383
14 Rec loss:  277.9190368652344 KL_div:  11.023004531860352
15 Rec loss:  274.6007995605469 KL_div:  11.360089302062988
16 Rec loss:  271.7486877441406 KL_div:  11.670470237731934
1

136 Rec loss:  232.16722106933594 KL_div:  13.673529624938965
137 Rec loss:  231.90660095214844 KL_div:  13.491768836975098
138 Rec loss:  232.44842529296875 KL_div:  13.099771499633789
139 Rec loss:  231.6725311279297 KL_div:  13.64323616027832
140 Rec loss:  231.62158203125 KL_div:  13.650008201599121
141 Rec loss:  231.8282928466797 KL_div:  13.29979419708252
142 Rec loss:  231.4832000732422 KL_div:  13.643506050109863
143 Rec loss:  231.77005004882812 KL_div:  13.247669219970703
144 Rec loss:  231.3741455078125 KL_div:  13.627256393432617
145 Rec loss:  231.39112854003906 KL_div:  13.458590507507324
146 Rec loss:  231.4821014404297 KL_div:  13.341267585754395
147 Rec loss:  231.44305419921875 KL_div:  13.287522315979004
148 Rec loss:  231.26068115234375 KL_div:  13.665559768676758
149 Rec loss:  231.10902404785156 KL_div:  13.63379955291748
150 Rec loss:  231.0372772216797 KL_div:  13.680044174194336
151 Rec loss:  231.03579711914062 KL_div:  13.91466236114502
152 Rec loss:  231.07

270 Rec loss:  227.1885986328125 KL_div:  13.926846504211426
271 Rec loss:  227.46812438964844 KL_div:  13.702373504638672
272 Rec loss:  227.0191192626953 KL_div:  13.98148250579834
273 Rec loss:  227.05838012695312 KL_div:  13.970406532287598
274 Rec loss:  227.02835083007812 KL_div:  13.961189270019531
275 Rec loss:  226.9850311279297 KL_div:  14.222722053527832
276 Rec loss:  226.7987823486328 KL_div:  14.319843292236328
277 Rec loss:  227.58932495117188 KL_div:  13.724888801574707
278 Rec loss:  227.0367431640625 KL_div:  13.910018920898438
279 Rec loss:  226.8970489501953 KL_div:  14.175945281982422
280 Rec loss:  227.09805297851562 KL_div:  13.862153053283691
281 Rec loss:  226.8332061767578 KL_div:  14.002012252807617
282 Rec loss:  226.81202697753906 KL_div:  14.359034538269043
283 Rec loss:  227.41409301757812 KL_div:  13.795271873474121
284 Rec loss:  226.8406219482422 KL_div:  13.930062294006348
285 Rec loss:  226.73841857910156 KL_div:  13.999682426452637
286 Rec loss:  22

404 Rec loss:  225.04977416992188 KL_div:  14.284957885742188
405 Rec loss:  225.16685485839844 KL_div:  14.087176322937012
406 Rec loss:  224.98497009277344 KL_div:  14.618640899658203
407 Rec loss:  225.24795532226562 KL_div:  14.033724784851074
408 Rec loss:  225.606201171875 KL_div:  13.877848625183105
409 Rec loss:  224.99154663085938 KL_div:  14.250640869140625
410 Rec loss:  225.06031799316406 KL_div:  14.250078201293945
411 Rec loss:  225.06785583496094 KL_div:  14.222740173339844
412 Rec loss:  225.0799560546875 KL_div:  14.247734069824219
413 Rec loss:  225.07864379882812 KL_div:  14.129072189331055
414 Rec loss:  224.83457946777344 KL_div:  14.336470603942871
415 Rec loss:  224.91769409179688 KL_div:  14.22293758392334
416 Rec loss:  224.92626953125 KL_div:  14.242818832397461
417 Rec loss:  225.32005310058594 KL_div:  13.9295654296875
418 Rec loss:  225.09542846679688 KL_div:  14.07717227935791
419 Rec loss:  224.80494689941406 KL_div:  14.32125473022461
420 Rec loss:  224.

538 Rec loss:  223.9896697998047 KL_div:  14.195223808288574
539 Rec loss:  224.2412567138672 KL_div:  14.114896774291992
540 Rec loss:  223.7060546875 KL_div:  14.635849952697754
541 Rec loss:  223.92889404296875 KL_div:  14.327012062072754
542 Rec loss:  224.17849731445312 KL_div:  14.176887512207031
543 Rec loss:  224.36849975585938 KL_div:  13.96109676361084
544 Rec loss:  223.71629333496094 KL_div:  14.533878326416016
545 Rec loss:  223.95594787597656 KL_div:  14.164576530456543
546 Rec loss:  223.8144989013672 KL_div:  14.366606712341309
547 Rec loss:  223.6996307373047 KL_div:  14.5260591506958


KeyboardInterrupt: 