In [1]:
import tensorflow as tf 
from tensorflow import keras 
import os 
import numpy as np 
from PIL import Image 
from matplotlib import pyplot as plt 

In [2]:
tf.random.set_seed(22)
np.random.seed(22)


In [31]:
def save_image(imgs,name):
    new_img = Image.new('L',(280,280))
    index = 0
    for i in range(0,280,28):
        for j in range(0,280,28):
            im = imgs[index]
            im = Image.fromarray(im,mode='L')
            new_img.paste(im,(i,j))
            index += 1
            
    new_img.save(name)

In [32]:
# 超参数定义
h_dim = 20  #最终降维之后的维度
batchsz = 512
learning_rate = 1e-3 


In [33]:
def preprocess(x):
    x = tf.cast(x,dtype=tf.float32)/255.
    return x




In [34]:
(train_data,train_label),(test_data,test_label) = keras.datasets.fashion_mnist.load_data()

db = tf.data.Dataset.from_tensor_slices(train_data)
db = db.map(preprocess).shuffle(10000).batch(batchsz)


test_db = tf.data.Dataset.from_tensor_slices(test_data)
test_db = test_db.map(preprocess).batch(batchsz)

In [35]:
print(train_data.shape,test_data.shape)

(60000, 28, 28) (10000, 28, 28)


In [36]:
# 自编码器不需要label(无监督学习)

In [40]:
class AE(keras.Model):
    def __init__(self,h_dim):  # h_dim是中间隐藏层的维度
        super(AE,self).__init__()
        # 首先定义编码器
        self.encoder = keras.Sequential([
            keras.layers.Dense(256,activation = tf.nn.relu,),
            keras.layers.Dropout(0.4),
            keras.layers.Dense(128,activation = tf.nn.relu,),
            keras.layers.Dropout(0.4),
            keras.layers.Dense(h_dim),
        ])
        
        # 定义解码器
        self.decoder = keras.Sequential([
            keras.layers.Dense(128,activation = tf.nn.relu,),
#             keras.layers.Dropout(0.4),
            keras.layers.Dense(256,activation = tf.nn.relu,),
#             keras.layers.Dropout(0.4),
            keras.layers.Dense(784),
        ])
    
    def call(self,inputs,training = None):
        # 首先编码成隐藏层数据
        h = self.encoder(inputs)
        
        # 接下来进行解码
        x_hat = self.decoder(h)
        return x_hat

In [49]:
model = AE(h_dim)
model.build(input_shape = (None,784))
model.summary()

Model: "ae_9"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
sequential_18 (Sequential)   multiple                  236436    
_________________________________________________________________
sequential_19 (Sequential)   multiple                  237200    
Total params: 473,636
Trainable params: 473,636
Non-trainable params: 0
_________________________________________________________________


In [50]:
optimizer = tf.optimizers.Adam(learning_rate = learning_rate)

for epoch in range(100):
    for step,x in enumerate(db):
        x = tf.reshape(x,[-1,784])
        with tf.GradientTape() as tape:
            x_hat_logits = model(x)
            rec_loss = tf.losses.binary_crossentropy(x,x_hat_logits,from_logits=True)
            rec_loss = tf.reduce_mean(rec_loss)
        grads = tape.gradient(rec_loss,model.trainable_variables)
        optimizer.apply_gradients(zip(grads,model.trainable_variables))
    
        if step%100 == 0:
            print(epoch,step,rec_loss.numpy())
    
    # 检验模型
    x = next(iter(test_db))
    x = tf.reshape(x,[-1,784])
    logits = model(x)
    x_hat = tf.sigmoid(logits)
    
    x_hat = tf.reshape(x_hat,[-1,28,28])
    
    # 显式输出结果,保存文件
    x = tf.reshape(x,[-1,28,28])
    x_orig = x.numpy()*255. 
    x_orig = x_orig.astype(np.uint8)
    
    x_got = x_hat.numpy()*255. 
    x_got = x_got.astype(np.uint8)
    save_image(x_orig,r'./image/{}_orig.png'.format(epoch))
    save_image(x_got,r'./image/{}_got.png'.format(epoch))
    

0 0 0.69376504
0 100 0.33426696
1 0 0.31869346
1 100 0.30915943
2 0 0.30374634
2 100 0.30402866
3 0 0.2960075
3 100 0.2879151
4 0 0.29201528
4 100 0.29691324
5 0 0.28097254
5 100 0.29715574
6 0 0.2888335
6 100 0.29048067
7 0 0.2889262
7 100 0.28994003
8 0 0.28304884
8 100 0.2860364
9 0 0.28172478
9 100 0.28562355
10 0 0.28953028
10 100 0.2882984
11 0 0.2820677
11 100 0.27415124
12 0 0.28383937
12 100 0.27891284
13 0 0.28188974
13 100 0.28092182
14 0 0.27767718
14 100 0.27877706
15 0 0.2807819
15 100 0.27653626
16 0 0.2805609
16 100 0.2820426
17 0 0.2788326
17 100 0.27664763
18 0 0.27883774
18 100 0.2836342
19 0 0.27100262
19 100 0.27202377
20 0 0.26786596
20 100 0.28971845
21 0 0.2826934
21 100 0.28230736
22 0 0.27376607
22 100 0.27023262
23 0 0.2802124
23 100 0.27545622
24 0 0.270912
24 100 0.27445
25 0 0.27591082
25 100 0.27804992
26 0 0.27442777
26 100 0.27175453
27 0 0.26143324
27 100 0.27107704
28 0 0.27394733
28 100 0.27833396
29 0 0.27253962
29 100 0.28026158
30 0 0.2689855
30 1