In [3]:
import tensorflow as tf 
from tensorflow import keras 
import os 
import numpy as np 
from PIL import Image 
from matplotlib import pyplot as plt 

In [4]:
tf.random.set_seed(22)
np.random.seed(22)


In [5]:
def save_image(imgs,name):
    new_img = Image.new('L',(280,280))
    index = 0
    for i in range(0,280,28):
        for j in range(0,280,28):
            im = imgs[index]
            im = Image.fromarray(im,mode='L')
            new_img.paste(im,(i,j))
            index += 1
            
    new_img.save(name)

In [25]:
# 超参数定义
h_dim = 20  #最终降维之后的维度
batchsz = 512
learning_rate = 1e-3


In [26]:
def preprocess(x):
    x = tf.cast(x,dtype=tf.float32)/255.
    return x


In [27]:
(train_data,train_label),(test_data,test_label) = keras.datasets.fashion_mnist.load_data()

db = tf.data.Dataset.from_tensor_slices(train_data)
db = db.map(preprocess).shuffle(10000).batch(batchsz)


test_db = tf.data.Dataset.from_tensor_slices(test_data)
test_db = test_db.map(preprocess).batch(batchsz)

In [28]:
print(train_data.shape,test_data.shape)

(60000, 28, 28) (10000, 28, 28)


In [29]:
class AE(keras.Model):
    def __init__(self,h_dim):  # h_dim是中间隐藏层的维度
        super(AE,self).__init__()
        # 首先定义编码器
        self.encoder = keras.Sequential([
            keras.layers.Dense(256,activation = tf.nn.relu,),
            keras.layers.Dropout(0.4),
            keras.layers.Dense(128,activation = tf.nn.relu,),
            keras.layers.Dropout(0.4),
            keras.layers.Dense(h_dim),
        ])
        
        # 定义解码器
        self.decoder = keras.Sequential([
            keras.layers.Dense(128,activation = tf.nn.relu,),
#             keras.layers.Dropout(0.4),
            keras.layers.Dense(256,activation = tf.nn.relu,),
#             keras.layers.Dropout(0.4),
            keras.layers.Dense(784),
        ])
    
    def call(self,inputs,training = None):
        # 首先编码成隐藏层数据
        h = self.encoder(inputs)
        
        # 接下来进行解码
        x_hat = self.decoder(h)
        return x_hat

In [30]:
z_dim = 10 # 降维隐藏层维度(均值/方差层)

class VAE(keras.Model):
    def __init__(self):
        super(VAE,self).__init__()
        # Encoder
        self.fc1 = keras.layers.Dense(128,activation=tf.nn.relu)
        
        self.fc2 = keras.layers.Dense(z_dim)  # get mean-value prediction
        
        self.fc3 = keras.layers.Dense(z_dim)  # get variance prediction
        
        # Decoder
        
        self.fc4 = keras.layers.Dense(128,activation=tf.nn.relu)
        
        self.fc5 = keras.layers.Dense(784)    # 输出大小和输入一样大, 这里没加激活函数
        
    
    def encoder(self,x):
        # get mean 
        mu = self.fc2(self.fc1(x))
        
        # get variance, 返回的variance是一个log值, 区间是负无穷到正无穷
        log_variance = self.fc3(self.fc1(x))
        
        return mu, log_variance
    
    def decoder(self,z):
        out = self.fc5(self.fc4(z))
        return out
    
    def reparameterization(self,mu,log_variance):
        eps = tf.random.normal(log_variance.shape)
        
        std = tf.exp(log_variance)**0.5
        
        z = mu + std*eps
        return z 
        
    def call(self,inputs,training = None):
        # [b,784] ---> [b,z_dim],[b,z_dim]
        mu, log_variance = self.encoder(inputs)
        
        # reparameterization trick
        z = self.reparameterization(mu,log_variance)
        x_hat = self.decoder(z)
        return x_hat,mu, log_variance
        
        

In [36]:
model = VAE()

In [39]:
optimizer = tf.optimizers.Adam(learning_rate = learning_rate)
for epoch in range(50):
    for step,x in enumerate(db):
        x = tf.reshape(x,[-1,784])
        with tf.GradientTape() as tape:
            x_hat_logits,mu,log_variance = model(x)
            rec_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=x,logits=x_hat_logits)
            rec_loss = tf.reduce_sum(rec_loss)/x.shape[0]
            # compute KL divergence (mu,var) ~ N(0,1)
            kl_div = -0.5* (log_variance+1-mu**2 - tf.exp(log_variance))
            kl_div = tf.reduce_sum(kl_div)/x.shape[0]
            loss = rec_loss + 1. * kl_div
        grads = tape.gradient(loss,model.trainable_variables)
        optimizer.apply_gradients(zip(grads,model.trainable_variables))
        if step % 100 == 0:
            print(epoch,step,'kl div:',kl_div.numpy(),',rec_loss:',rec_loss.numpy())
    
    # evaluation 
    z = tf.random.normal((batchsz,z_dim))
    logits = model.decoder(z)
    x_hat = tf.sigmoid(logits)
    x_hat = tf.reshape(x_hat,[-1,28,28]).numpy()*255. 
    x_hat = x_hat.astype(np.uint8)
    save_image(x_hat,'vae_images/sampled_epoch_{}.png'.format(epoch))
    
    x = next(iter(test_db))
    x = tf.reshape(x,[-1,784])
    x_hat_logits,_,_ = model(x)
    x_hat = tf.sigmoid(x_hat_logits)
    x_hat = tf.reshape(x_hat,[-1,28,28]).numpy()*255. 
    x_hat = x_hat.astype(np.uint8)
    save_image(x_hat,'vae_images/recon_epoch_{}.png'.format(epoch))
    

0 0 kl div: 5.8214646e-05 ,rec_loss: 382.6278
0 100 kl div: 11.627984 ,rec_loss: 284.6993
1 0 kl div: 12.361847 ,rec_loss: 275.10934
1 100 kl div: 13.4711 ,rec_loss: 254.60876
2 0 kl div: 14.656879 ,rec_loss: 252.26834
2 100 kl div: 14.117373 ,rec_loss: 251.70723
3 0 kl div: 14.192165 ,rec_loss: 252.00851
3 100 kl div: 14.438064 ,rec_loss: 251.0354
4 0 kl div: 14.40872 ,rec_loss: 245.00447
4 100 kl div: 14.6530485 ,rec_loss: 238.9171
5 0 kl div: 14.259886 ,rec_loss: 241.88979
5 100 kl div: 14.237668 ,rec_loss: 244.97278
6 0 kl div: 14.681679 ,rec_loss: 234.17035
6 100 kl div: 14.552263 ,rec_loss: 238.93242
7 0 kl div: 14.450032 ,rec_loss: 238.82028
7 100 kl div: 15.273353 ,rec_loss: 238.68814
8 0 kl div: 14.691032 ,rec_loss: 234.54947
8 100 kl div: 14.425978 ,rec_loss: 236.42857
9 0 kl div: 14.46881 ,rec_loss: 238.73055
9 100 kl div: 14.85593 ,rec_loss: 230.45891
10 0 kl div: 14.959202 ,rec_loss: 235.99872
10 100 kl div: 15.056475 ,rec_loss: 235.12439
11 0 kl div: 14.742031 ,rec_loss: 