In [1]:
import tensorflow as tf
from tensorflow import keras

# 加载数据集
(train_images, _), (test_images, _) = keras.datasets.mnist.load_data()

# 将图像值缩放到 [0, 1] 的范围内
train_images = train_images.astype('float32') / 255.
test_images = test_images.astype('float32') / 255.

# 将图像变形为1维向量
train_images = train_images.reshape(train_images.shape[0], 784)
test_images = test_images.reshape(test_images.shape[0], 784)


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [6]:
train_images

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)

构建去噪自动编码器模型
1. 去噪自动编码器模型与标准自动编码器类似，但输入图像被添加了噪声，输出图像是去除噪声后的图像。创建一个去噪自动编码器模型

In [2]:
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model

# 定义输入层
input_img = Input(shape=(784,))

# 添加噪声
noisy_img = input_img + 0.1 * tf.random.normal(tf.shape(input_img))

# 定义编码器
encoded = Dense(64, activation='relu')(noisy_img)
encoded = Dense(32, activation='relu')(encoded)

# 定义解码器
decoded = Dense(64, activation='relu')(encoded)
decoded = Dense(784, activation='sigmoid')(decoded)

# 定义自编码器模型
autoencoder = Model(input_img, decoded)


In [4]:
# 编译模型
autoencoder.compile(optimizer='adam', loss='binary_crossentropy')

# 训练模型
autoencoder.fit(train_images, train_images,
                epochs=10,
                batch_size=256,
                shuffle=True,
                validation_data=(test_images, test_images))


Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x7f2651973670>

In [5]:
# 用训练好的自编码器重建测试集图像
reconstructed_imgs = autoencoder.predict(test_images)

# 计算重建图像的平均损失
loss = tf.keras.metrics.binary_crossentropy(test_images, reconstructed_imgs)
print('Average Test Loss: ', tf.reduce_mean(loss))


Average Test Loss:  tf.Tensor(0.099536024, shape=(), dtype=float32)
