In [11]:
import os
import numpy as np
import pandas as pd
from PIL import Image
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
from tensorflow.keras import layers, Sequential, optimizers, losses

In [21]:
h_dim = 20
batchsz = 512
lr = 1e-3

def save_images(imgs, name):
    new_im = Image.new('L', (280, 280))
    index = 0
    for i in range(0, 280, 28):
        for j in range(0, 280, 28):
            im = imgs[index]
            im = Image.fromarray(im, mode = 'L')
            new_im.paste(im, (i, j))
            index += 1
    new_im.save(name)


(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
x_train ,x_test = x_train.astype(np.float32) / 255., x_test.astype(np.float32) / 255.
train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(batchsz * 5).batch(batchsz)
test_db = tf.data.Dataset.from_tensor_slices(x_test)
test_db = test_db.batch(batchsz)

print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)

(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)


In [18]:
class AE(keras.Model):
    def __init__(self):
        super(AE, self).__init__()
        self.encoder = Sequential([
            layers.Dense(256, activation = tf.nn.relu),
            layers.Dense(128, activation = tf.nn.relu),
            layers.Dense(h_dim)
        ])
        self.decoder = Sequential([
            layers.Dense(128, activation = tf.nn.relu),
            layers.Dense(256, activation = tf.nn.relu),
            layers.Dense(784)
        ])
        
    def call(self, inputs, training = None):
        h = self.encoder(inputs)
        x_hat = self.decoder(h)
        return x_hat

In [23]:
model = AE()
model.build(input_shape=(4, 784))
model.summary()
optimizer = optimizers.Adam(lr = lr)

for epoch in range(100):
    for step, x in enumerate(train_db):
        x = tf.reshape(x, [-1, 784])
        with tf.GradientTape() as tape:
            x_rec_logits = model(x)
            rec_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=x, logits=x_rec_logits)
            rec_loss = tf.reduce_mean(rec_loss)
        grads = tape.gradient(rec_loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        
        if step % 100 == 0:
            print(epoch, step, float(rec_loss))
            
        x = next(iter(test_db))
        logits = model(tf.reshape(x, [-1, 784]))
        x_hat = tf.sigmoid(logits)
        x_hat = tf.reshape(x_hat, [-1, 28, 28])
        x_concat = tf.concat([x[:50],x_hat[:50]], axis=0)
        x_concat = x_concat.numpy() * 255.
        x_concat = x_concat.astype(np.uint8)
        save_images(x_concat, 'ae_images/rec_epoch_%d.png'%epoch)

Model: "ae_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
sequential_6 (Sequential)    (4, 20)                   236436    
_________________________________________________________________
sequential_7 (Sequential)    (4, 784)                  237200    
Total params: 473,636
Trainable params: 473,636
Non-trainable params: 0
_________________________________________________________________
0 0 0.6941691637039185
0 100 0.3297346830368042
1 0 0.3226393461227417
1 100 0.29958394169807434
2 0 0.30743956565856934
2 100 0.2934790253639221
3 0 0.2952147126197815
3 100 0.2991750240325928
4 0 0.2829587459564209
4 100 0.29336726665496826
5 0 0.29115134477615356
5 100 0.29069268703460693
6 0 0.28854429721832275
6 100 0.29182589054107666
7 0 0.2849160134792328
7 100 0.2867415249347687
8 0 0.2843138873577118
8 100 0.2800070643424988
9 0 0.28185465931892395
9 100 0.2828569710254669
10 0 0.2752522826194763
10 100

In [33]:
class VAE(keras.Model):
    def __ini__(self):
        super(VAE, self).__init__()
        #  Encoder 网络
        self.fc1 = layers.Dense(128)
        self.fc2 = layers.Dense(z_dim)
        self.fc3 = layers.Dense(z_dim)
        # Decoder 网络
        self.fc4 = layers.Dense(128)
        self.fc5 = layers.Dense(784)
        
    def encoder(self, x):
        h = tf.nn.relu(self.fc1(x))
        mu = self.fc2(h)
        log_var = self.fc3(h)
        
        return mu, log_var
    
    def decoder(self, z):
        out = tf.nn.relu(self.fc4(z))
        out = self.fc5(out)
        return out
    
    def reparameterize(self, mu, log_var):
        eps = tf.random.normal(log_var.shape)
        std = tf.exp(log_var) ** 0.5
        z = mu + std + eps
        return z
    
    def call(self, inputs, training = None):
        mu, log_var = self.encoder(inputs)
        z = self.reparameterize(mu, log_var)
        x_hat = self.decoder(z)
        return x_hat, mu, log_var

In [None]:
z_dim = 10
model = VAE()
model.build(input_shape = (4, 784))
optimizer = optimizers.Adam(lr)

for epoch in range(100):
    for step, x in enumerate(train_db):
        x = tf.reshape(x, [-1,784])
        with tf.GradientTape() as tape:
            x_rec_logits, mu, log_var = model(x)
            rec_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=x,logits=x_rec_logits)
            rec_loss = tf.reduce_sum(rec_loss) / x.shape[0]
            kl_div = -0.5 * (log_var + 1 - mu ** 2 - tf.exp(log_var))
            kl_div = tf.reduce_sum(kl_div) / x.shape[0]
            loss = rec_loss + 1. * kl_div
            grads = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))
            
            if step % 100 == 0:
                print(epoch, step, 'kl div:', float(kl_div), 'rec_loss:', float(rec_loss))
                
    # evaluation
    z = tf.random.normal((batchsz, z_dim))
    logits = model.decoder(z)
    x_hat = tf.sigmoid(logits)
    x_hat = tf.reshape(x_hat, [-1, 28, 28]).numpy() *255.
    x_hat = x_hat.astype(np.uint8)
    save_images(x_hat, 'vae_images/sampled_epoch%d.png'%epoch)

    x = next(iter(test_db))
    x = tf.reshape(x, [-1, 784])
    x_hat_logits, _, _ = model(x)
    x_hat = tf.sigmoid(x_hat_logits)
    x_hat = tf.reshape(x_hat, [-1, 28, 28]).numpy() *255.
    x_hat = x_hat.astype(np.uint8)
    save_images(x_hat, 'vae_images/rec_epoch%d.png'%epoch)

0 0 kl div: 2.400975227355957 rec_loss: 545.963134765625
0 100 kl div: 15.316713333129883 rec_loss: 291.8880615234375
1 0 kl div: 15.32813835144043 rec_loss: 283.68939208984375
1 100 kl div: 15.231584548950195 rec_loss: 262.5343933105469
2 0 kl div: 15.452970504760742 rec_loss: 255.45103454589844
2 100 kl div: 14.357110977172852 rec_loss: 247.87405395507812
3 0 kl div: 14.484508514404297 rec_loss: 247.49111938476562
3 100 kl div: 14.951042175292969 rec_loss: 251.74859619140625
4 0 kl div: 14.298286437988281 rec_loss: 245.4153594970703
4 100 kl div: 14.243423461914062 rec_loss: 241.90908813476562
5 0 kl div: 14.655537605285645 rec_loss: 242.5178680419922
5 100 kl div: 14.697585105895996 rec_loss: 237.09083557128906
6 0 kl div: 14.514230728149414 rec_loss: 235.9995574951172
6 100 kl div: 14.402326583862305 rec_loss: 236.97702026367188
7 0 kl div: 14.578542709350586 rec_loss: 238.581787109375
7 100 kl div: 14.738691329956055 rec_loss: 241.0812530517578
8 0 kl div: 14.994415283203125 rec_l

67 100 kl div: 15.518753051757812 rec_loss: 227.69898986816406
68 0 kl div: 14.833165168762207 rec_loss: 224.1849365234375
68 100 kl div: 14.959325790405273 rec_loss: 230.2875213623047
69 0 kl div: 15.196447372436523 rec_loss: 228.95762634277344
69 100 kl div: 15.017204284667969 rec_loss: 227.50967407226562
70 0 kl div: 14.849920272827148 rec_loss: 220.37762451171875
70 100 kl div: 15.533491134643555 rec_loss: 225.40155029296875
71 0 kl div: 14.781791687011719 rec_loss: 226.70831298828125
71 100 kl div: 14.774552345275879 rec_loss: 228.6650848388672
72 0 kl div: 14.753569602966309 rec_loss: 226.87274169921875
72 100 kl div: 14.628839492797852 rec_loss: 228.13720703125
73 0 kl div: 15.018328666687012 rec_loss: 223.79971313476562
73 100 kl div: 15.17948055267334 rec_loss: 229.6072998046875
74 0 kl div: 15.067667007446289 rec_loss: 223.3785858154297
74 100 kl div: 15.305872917175293 rec_loss: 223.47813415527344
75 0 kl div: 14.593851089477539 rec_loss: 223.14056396484375
75 100 kl div: 14