In [1]:
from datetime import datetime
import numpy as np
import tensorflow as tf
from tensorflow import keras
from PIL import Image


In [12]:
def save_images(imgs, name):
    """
    多张小图合成一张大图
    """
    new_im = Image.new('L', (280, 280))

    index = 0
    for i in range(0, 280, 28):
        for j in range(0, 280, 28):
            im = imgs[index]
            im = Image.fromarray(im, mode='L')
            new_im.paste(im, (i, j))
            index += 1

    new_im.save(name)

In [3]:
# tf.datasets.fashion_mnist.load_data()
(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()

In [4]:
batch_size = 512
h_dim = 20


x_train = x_train.astype(np.float32) / 255
x_test = x_test.astype(np.float32) / 255

# 不需要标签
db_train = tf.data.Dataset.from_tensor_slices(x_train)
db_test = tf.data.Dataset.from_tensor_slices(x_test)

db_train = db_train.shuffle(10000).batch(batch_size)
db_test = db_test.batch(batch_size)

print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)


(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)


In [5]:
class AutoEncode(keras.Model):
    
    def __init__(self):
        super(AutoEncode, self).__init__()
        
        # 编码 解码
        self.encode = keras.Sequential([
            keras.layers.Dense(256, activation=tf.nn.relu),
            keras.layers.Dense(128, activation=tf.nn.relu),
            keras.layers.Dense(h_dim)
        ])
        
        # 解码
        self.decode = keras.Sequential([
            keras.layers.Dense(128, activation=tf.nn.relu),
            keras.layers.Dense(256, activation=tf.nn.relu),
            keras.layers.Dense(784)
        ])
        
    def call(self, inputs, training=None):
        # 编码 [b, 784] => [b, 10]
        x = self.encode(inputs)
        
        # 解码 [b, 10] => [b, 784]
        x = self.decode(x)
        return x
        
        

In [6]:
model = AutoEncode()

model.build(input_shape=(None, 784))
model.summary()

Model: "auto_encode"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
sequential (Sequential)      multiple                  236436    
_________________________________________________________________
sequential_1 (Sequential)    multiple                  237200    
Total params: 473,636
Trainable params: 473,636
Non-trainable params: 0
_________________________________________________________________


In [11]:
# 开始训练
optimizer = tf.optimizers.Adam(1e-3)

for epoch in range(10):
    for step, x in enumerate(db_train):
        # [b, 28, 28] => [b, 784]
        x = tf.reshape(x, [-1, 28 * 28])
        with tf.GradientTape() as tap:
            x_rec_logits = model(x)
            rec_loss = tf.losses.binary_crossentropy(x, x_rec_logits, from_logits=True)
            rec_loss = tf.reduce_mean(rec_loss)

        # 计算梯度 更新梯度
        grads = tap.gradient(rec_loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

        if step % 100 ==0:
            print(epoch, step, float(rec_loss))

        # 测试
        x = next(iter(db_test))
        logits = model(tf.reshape(x, [-1, 28 * 28]))

        x_hat = tf.sigmoid(logits)
        # [b, 784] => [b, 28, 28]
        x_hat = tf.reshape(x_hat, [-1, 28, 28])

        x_concat = tf.concat([x, x_hat], axis=0)
#         x_concat = x_hat
        x_concat = x_concat.numpy() * 255.
        x_concat = x_concat.astype(np.uint8)
        
        save_images(x_concat, 'ae_images/rec_epoch_%d.png'%epoch)

    
    
        
        

W0831 16:57:45.354861 4495443392 deprecation.py:323] From /anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/nn_impl.py:182: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


0 0 0.6935093402862549
0 100 0.3309305012226105
1 0 0.31759634613990784
1 100 0.3094624876976013
2 0 0.3014313578605652
2 100 0.3021594285964966
3 0 0.29397737979888916
3 100 0.2974031865596771


KeyboardInterrupt: 

In [10]:
x = next(iter(db_test))
logits = model(tf.reshape(x, [-1, 28 * 28]))

x_hat = tf.sigmoid(logits)
# [b, 784] => [b, 28, 28]
x_hat = tf.reshape(x_hat, [-1, 28, 28])

x_concat = tf.concat([x, x_hat], axis=0)
#         x_concat = x_hat
x_concat = x_concat.numpy() * 255.
x_concat = x_concat.astype(np.uint8)
save_images(x_concat, 'ae_images/1.png')
