### 載入MNIST手寫數字數據集

In [1]:
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST_data/",one_hot = True)

  from ._conv import register_converters as _register_converters


Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [2]:
import tensorflow as tf

### 定義模型變數生成函式

In [3]:
def weight_variable(shape,name):
    return tf.Variable(tf.truncated_normal(shape,stddev = 0.1),name = name)

def bias_variable(shape,name):
    return tf.Variable(tf.constant(0.1,shape = shape),name = name)

### 定義輸入

In [4]:
X = tf.placeholder(tf.float32,[None,784])

### 定義Normal Autoencoder

In [5]:
# Encoder
w1 = weight_variable([784,300],'encoder_w1')
b1 = bias_variable([300],'encoder_b1')
o1 = tf.nn.sigmoid(tf.matmul(X,w1)+b1)
    
w2 = weight_variable([300,30],'encoder_w2')
b2 = bias_variable([30],'encoder_b2')
o2 = tf.nn.sigmoid(tf.matmul(o1,w2)+b2)
    
 # Decoder
w3 = weight_variable([30,300],'decoder_w1')
b3 = bias_variable([300],'decoder_b1')
o3 = tf.nn.sigmoid(tf.matmul(o2,w3)+b3)
    
w4 = weight_variable([300,784],'decoder_w2')
b4 = bias_variable([784],'decoder_b2')
o4 = tf.nn.sigmoid(tf.matmul(o3,w4)+b4)

### 定義損失函數與優化器

在這邊是與普通的Autoencoder不同的地方：
1. 加上Kullback-Leibler divergence來讓平均神經元輸出值越接近某數越好，如此可降低各神經元的輸出值，使各神經元不會對每個圖片都起反應而而無用化。
2. 我們還會利用L2正則化來讓權重變小，使整個模型變得較簡單。

In [6]:
def kl_div(rho,rho_hat):
    invrho = tf.subtract(tf.constant(1.), rho)
    invrhohat = tf.subtract(tf.constant(1.),rho_hat)
    logrho = logfunc(rho,rho_hat) + logfunc(invrho, invrhohat)
    return logrho

def logfunc(x1, x2):
    return tf.multiply(x1,tf.log(tf.div(x1,x2)))

In [7]:
import functools

In [8]:
kl_div_loss = functools.reduce(lambda x, y: x + y, map(lambda x: tf.reduce_sum(kl_div(0.02, tf.reduce_mean(x,0))), [o1,o2]))

In [9]:
l2_loss = functools.reduce(lambda x, y: x + y, map(lambda x: tf.nn.l2_loss(x), [w1,w2,w3,w4]))

In [10]:
alpha = 5e-6
beta = 7.5e-5

In [11]:
loss = tf.reduce_mean(tf.pow(o4-X,2) + alpha * l2_loss + beta * kl_div_loss)
opt = tf.train.AdamOptimizer(0.01).minimize(loss)

### 實例化執行

In [12]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    
    for step in range(20000):
        batchx = mnist.train.next_batch(60)[0]
        
        if step%1000 == 0:
            print("Step:{},Loss:{}".format(step,loss.eval(feed_dict = {X:batchx})))
        sess.run(opt,feed_dict = {X:batchx})
    
    print("Test Loss:{}".format(loss.eval(feed_dict = {X:mnist.test.images})))
    
    import matplotlib.pyplot as plt
    import numpy as np
    
    predict_output = o4.eval(feed_dict = {X:mnist.train.images})
    
    for i in range(5):
        curr_img = np.reshape(mnist.train.images[i,:],(28,28)) 
        pre_img = np.reshape(predict_output[i,:],(28,28))
        
        plt.matshow(curr_img, cmap = plt.get_cmap('gray'))
        plt.matshow(pre_img, cmap = plt.get_cmap('gray'))

        
        plt.show()

Step:0,Loss:0.2933363616466522
Step:1000,Loss:0.04224457964301109
Step:2000,Loss:0.03947719186544418
Step:3000,Loss:0.039851706475019455
Step:4000,Loss:0.03764984384179115
Step:5000,Loss:0.03527306765317917
Step:6000,Loss:0.03829438239336014
Step:7000,Loss:0.03651284798979759
Step:8000,Loss:0.03670467808842659
Step:9000,Loss:0.03610767424106598
Step:10000,Loss:0.0382281094789505
Step:11000,Loss:0.03714229539036751
Step:12000,Loss:0.035715408623218536
Step:13000,Loss:0.03587040305137634
Step:14000,Loss:0.03603891283273697
Step:15000,Loss:0.03853045031428337
Step:16000,Loss:0.03750507906079292
Step:17000,Loss:0.03566473722457886
Step:18000,Loss:0.03493707254528999
Step:19000,Loss:0.036548107862472534
Test Loss:0.035388704389333725


<matplotlib.figure.Figure at 0x112ce2ef0>

<matplotlib.figure.Figure at 0x1c2576af60>

<matplotlib.figure.Figure at 0x1c257ba630>

<matplotlib.figure.Figure at 0x1c2509ceb8>

<matplotlib.figure.Figure at 0x1c258079e8>

<matplotlib.figure.Figure at 0x1c2576a9b0>

<matplotlib.figure.Figure at 0x1c25839908>

<matplotlib.figure.Figure at 0x1c2509c3c8>

<matplotlib.figure.Figure at 0x1c2586ddd8>

<matplotlib.figure.Figure at 0x1c25793ac8>