In [1]:
import os
import shutil
import tensorflow as tf
import numpy as np
from skimage.io import imsave
from tensorflow.examples.tutorials.mnist import input_data 

### Step 1 数据预处理

#### 1.1 将下载好的MNIST_data数据解码

In [3]:
#定义DataLoad()函数将文件数据转为numpy可以读取的格式
def DataLoad(data_path):
    file_data = open(os.path.join(data_path,'train-images.idx3-ubyte'))
    loaded_data = np.fromfile(file = file_data,dtype = np.uint8)
    #前16个字符为说明符，需要跳过
    train_data = loaded_data[16:].reshape((-1,784)).astype(np.float)
    
    file_label = open(os.path.join(data_path,'train-labels.idx1-ubyte'))
    loaded_label = np.fromfile(file = file_label,dtype = np.uint8)
    #前8个字符为说明符，需要跳过
    train_label = loaded_label[8:].reshape((-1)).astype(np.float)
    
    return train_data,train_label

#### 1.2 设置超参数

In [14]:
#图像的size为（28,28,1）
image_width = 28
image_height = 28
image_size = image_width * image_height

#是否训练和存储设置
train = True
restore = False
output_path = "./output_image/"

# set hyperparameters
max_epoch = 300
batch_size = 256
z_size = 220            #生成器的传入参数
h1_size = 300     #第一隐藏层的size，即特征数
h2_size = 300     #第二隐藏层的size，即特征数

### Step 2 搭建模型

#### 2.1 构建生成器

In [5]:
import tensorflow as tf

def Generator(z_input):
    #第一个链接层
    w1 = tf.Variable(tf.truncated_normal([z_size,h1_size],stddev = 0.1),name = "g_w1",dtype = tf.float32)
    b1 = tf.Variable(tf.zeros([h1_size]),name = "g_b1",dtype = tf.float32)
    h1 = tf.nn.relu(tf.matmul(z_input, w1) + b1)
    
    #第二个链接层
    w2 = tf.Variable(tf.truncated_normal([h1_size,h2_size],stddev = 0.1),name = "g_w2",dtype = tf.float32)
    b2 = tf.Variable(tf.zeros([h2_size]),name = "g_b2",dtype = tf.float32)
    h2 = tf.nn.relu(tf.matmul(h1, w2) + b2)
    
    #第三个链接层
    w3 = tf.Variable(tf.truncated_normal([h2_size,image_size],stddev = 0.1),name = "g_w3",dtype = tf.float32)
    b3 = tf.Variable(tf.zeros([image_size]),name = "g_b3",dtype = tf.float32)
    h3 = tf.nn.relu(tf.matmul(h2, w3) + b3)
    
    #输出：生成图像
    output_generate = tf.nn.tanh(h3)   #利用tanh激活函数，将h3传入输出层
    
    #输出：生成图像的所有参数
    g_parameters = [w1, w2, w3, b1, b2, b3]               #合并所有参数
    
    return output_generate,g_parameters

#### 2.2 构建GAN的判别器

In [6]:

def Discriminator(true_data, generated_data, dropout_rate):
 
    # 合并输入数据，包括真实数据true_data和生成器生成的假数据generated_data
    sum_data = tf.concat([true_data, generated_data], 0) 
 
    # 第一个链接层
    w1 = tf.Variable(tf.truncated_normal([image_size, h2_size], stddev=0.1), name="d_w1", dtype=tf.float32)
    b1 = tf.Variable(tf.zeros([h2_size]), name="d_b1", dtype=tf.float32)
    h1 = tf.nn.dropout(tf.nn.relu(tf.matmul(sum_data, w1) + b1), dropout_rate)
 
    # 第二个链接层
    w2 = tf.Variable(tf.truncated_normal([h2_size, h1_size], stddev=0.1), name="d_w2", dtype=tf.float32)
    b2 = tf.Variable(tf.zeros([h1_size]), name="d_b2", dtype=tf.float32)
    h2 = tf.nn.dropout(tf.nn.relu(tf.matmul(h1, w2) + b2), dropout_rate)
 
    # 第三个链接层
    w3 = tf.Variable(tf.truncated_normal([h1_size, 1], stddev=0.1), name="d_w3", dtype=tf.float32)
    b3 = tf.Variable(tf.zeros([1]), name="d_b3", dtype=tf.float32)
    h3 = tf.matmul(h2, w3) + b3
 
    #从h3中切出batch_size张图像
    slice_image = tf.nn.sigmoid(tf.slice(h3, [0, 0], [batch_size, -1], name=None))
    #从h3中切除余下的图像
    slice_left_image = tf.nn.sigmoid(tf.slice(h3, [batch_size, 0], [-1, -1], name=None))
 
    #合并参数
    d_parameters = [w1, w2, w3, b1, b2, b3]
 
    return slice_image, slice_left_image, d_parameters


#### 2.3 显示结果的函数，结果图片输出到output_image文件夹中

In [7]:
def ShowResult(batch_res, filepath, grid_size=(8, 8), grid_pad=5):
    
    #将batch_res进行值[0, 1]归一化，同时将其reshape成（batch_size, image_height, image_width）
    batch_res = 0.5 * batch_res.reshape((batch_res.shape[0], image_height, image_width)) + 0.5
   
    #重构显示图像格网的参数
    re_image_height, re_image_width = batch_res.shape[1], batch_res.shape[2]
    grid_height = re_image_height * grid_size[0] + grid_pad * (grid_size[0] - 1)
    grid_width = re_image_width  * grid_size[1] + grid_pad * (grid_size[1] - 1)
    img_grid = np.zeros((grid_height, grid_width), dtype=np.uint8)
    for i, res in enumerate(batch_res):
        if i >= grid_size[0] * grid_size[1]:
            break
        img = (res) * 255.
        img = img.astype(np.uint8)
        row = (i // grid_size[0]) * (re_image_height + grid_pad)
        col = (i % grid_size[1]) * (re_image_width + grid_pad)
        img_grid[row:row + re_image_height, col:col + re_image_width] = img
    #保存图像
    imsave(filepath, img_grid)


#### 2.4 定义训练过程

In [8]:
def StartTrain():

    # 加载数据
    train_data, train_label = DataLoad("./data/MNIST_data")
    size = train_data.shape[0]
 
    # 构建模型---------------------------------------------------------------------
    # 定义GAN网络的输入，其中x_data为[batch_size, image_size], z_input为[batch_size, z_size]
    x_data = tf.placeholder(tf.float32, [batch_size, image_size], name="x_data") # (batch_size, image_size)
    z_input = tf.placeholder(tf.float32, [batch_size, z_size], name="z_input") # (batch_size, z_size)
    # 定义dropout率
    dropout_rate = tf.placeholder(tf.float32, name="dropout_rate") 
    global_step = tf.Variable(0, name="global_step", trainable=False)
 
    # 利用生成器生成数据x_generated和参数g_params
    x_generated, g_params = Generator(z_input)
    # 利用判别器判别生成器的结果
    y_data, y_generated, d_params = Discriminator(x_data, x_generated, dropout_rate)
 
    # 定义判别器和生成器的loss函数
    d_loss = - (tf.log(y_data) + tf.log(1 - y_generated))
    g_loss = - tf.log(y_generated)
 
    # 设置学习率为0.0001，用AdamOptimizer进行优化
    optimizer = tf.train.AdamOptimizer(0.0001)
 
    # 判别器discriminator 和生成器 generator 对损失函数进行最小化处理
    d_trainer = optimizer.minimize(d_loss, var_list=d_params)
    g_trainer = optimizer.minimize(g_loss, var_list=g_params)
    # 模型构建完毕--------------------------------------------------------------------
 
    # 全局变量初始化
    init = tf.global_variables_initializer()
 
    # 启动会话sess
    saver = tf.train.Saver()
    sess = tf.Session()
    sess.run(init)
 
    # 判断是否需要存储
    if restore:
        #若是，将最近一次的checkpoint点存到outpath下
        chkpt_fname = tf.train.latest_checkpoint(output_path)
        saver.restore(sess, chkpt_fname)
    else:
        #若否，判断目录是存在，如果目录存在，则递归的删除目录下的所有内容，并重新建立目录
        if os.path.exists(output_path):
            shutil.rmtree(output_path)
        os.mkdir(output_path)
 
    # 利用随机正态分布产生噪声影像，尺寸为(batch_size, z_size)
    z_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)
 
    # 逐个epoch内训练
    for i in range(sess.run(global_step), max_epoch):
        # 图像每个epoch内可以放(size // batch_size)个size
        for j in range(size // batch_size):
            if j%20 == 0:
                print("epoch:%s, iter:%s" % (i, j))
            
            # 训练一个batch的数据
            batch_end = j * batch_size + batch_size
            if batch_end >= size:
                batch_end = size - 1
            x_value = train_data[ j * batch_size : batch_end ]
            # 将数据归一化到[-1, 1]
            x_value = x_value / 255.
            x_value = 2 * x_value - 1
            
            # 以正太分布的形式产生随机噪声
            z_value = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)
            # 每个batch下，输入数据运行GAN，训练判别器
            sess.run(d_trainer,
                     feed_dict={x_data: x_value, z_input: z_value, dropout_rate: np.sum(0.7).astype(np.float32)})
            # 每个batch下，输入数据运行GAN，训练生成器
            if j % 1 == 0:
                sess.run(g_trainer,
                         feed_dict={x_data: x_value,z_input: z_value, dropout_rate: np.sum(0.7).astype(np.float32)})
        # 每一个epoch中的所有batch训练完后，利用z_sample测试训练后的生成器
        x_gen_val = sess.run(x_generated, feed_dict={z_input: z_sample_val})
        # 每一个epoch中的所有batch训练完后，显示生成器的结果，并打印生成结果的值
        ShowResult(x_gen_val, os.path.join(output_path, "sample%s.jpg" % i))
        print(x_gen_val)
        # 每一个epoch中，生成随机分布以重置z_random_sample_val
        z_random_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)
        # 每一个epoch中，利用z_random_sample_val生成手写数字图像，并显示结果
        x_gen_val = sess.run(x_generated, feed_dict={z_input: z_random_sample_val})
        ShowResult(x_gen_val, os.path.join(output_path, "random_sample%s.jpg" % i))
        # 保存会话
        sess.run(tf.assign(global_step, i + 1))
        saver.save(sess, os.path.join(output_path, "model"), global_step=global_step)


### Step 3 测试模型

In [15]:
if __name__ == '__main__':
    if train:
        StartTrain()


epoch:0, iter:0
epoch:0, iter:20
epoch:0, iter:40
epoch:0, iter:60
epoch:0, iter:80
epoch:0, iter:100
epoch:0, iter:120
epoch:0, iter:140
epoch:0, iter:160
epoch:0, iter:180
epoch:0, iter:200
epoch:0, iter:220
[[0.         0.9999915  0.         ... 0.         0.         0.99978554]
 [0.         0.9999782  0.         ... 0.         0.         0.9873895 ]
 [0.         0.99989575 0.         ... 0.         0.         0.99940664]
 ...
 [0.77444166 0.99998367 0.         ... 0.         0.         0.9997942 ]
 [0.         0.9999719  0.         ... 0.         0.         0.99995947]
 [0.         0.99999285 0.         ... 0.         0.         0.99930865]]
epoch:1, iter:0
epoch:1, iter:20
epoch:1, iter:40
epoch:1, iter:60
epoch:1, iter:80
epoch:1, iter:100
epoch:1, iter:120
epoch:1, iter:140
epoch:1, iter:160
epoch:1, iter:180
epoch:1, iter:200
epoch:1, iter:220
[[0.        0.        0.        ... 0.        0.        0.       ]
 [0.        0.        0.        ... 0.        0.        0.       ]
 [

KeyboardInterrupt: 