# 一、导入需要用到的包

In [None]:
import numpy as np
from tqdm import tqdm #progress bar 进度条
import tensorflow as tf
if tf.__version__ != '1.0.0':
    print("please check the version of tensorflow!")
from matplotlib import pyplot as plt
from IPython import display #用来显示图片
from tensorflow.examples.tutorials.mnist import input_data

## 1、tqdm用法示例

In [None]:
from time import sleep
for i in tqdm(range(10)):  
     sleep(1)
     pass  

# 二、准备相关工具

## 1、全连接层

In [None]:
def _dense(x, scope, out_dim, activation = True):
    '''
    _dense用于构建生成器和判别器
    x：是一个2-D的array或者tensor
    scope：用来区分不同层的参数
    out_dim：int类型，输出的每个sample的维度数
    activation：是否需要添加激活函数relu
    '''
    in_dim = int(x.shape[-1])
    #输入x可能是tenor，也可能是array，能同时取得tensor和array的shape用x.shape
    #如果是tensor输出的类型是tensor_shape.Dimension，如果是array的话输出的类型是int
    #所以tensor_shape.Dimension要转换成int，这对本身就是int类型的不造成影响
    
    with tf.variable_scope(scope): 
        w = tf.get_variable('w', [in_dim, out_dim], initializer=tf.random_normal_initializer(stddev=0.01))
        #w的shape为[in_dim, out_dim]，对应的矩阵乘法应该是tf.matmul(x, w)
        b = tf.get_variable('b', [out_dim], initializer=tf.constant_initializer(0))
    
    if activation:
        output = tf.nn.relu(tf.matmul(x, w) + b)
    else:
        output = tf.matmul(x, w) + b
    return output

获取tensor的shape有三种方式:tensor.shape,tensor.get_shape().as_list(),tf.shape()

获取array的shape:array.shape

注意适用对象和返回值的类型

## 2、数据预处理

In [None]:
def preprocess(x):
    """rescale x
    """
    return x*2-1

## 3、将一个batch的100张图片汇总在一张大图片中

In [None]:
def grid_vis(imgs):
    """imgs是一个list
    list的每一个元素是一个28*28的图片
    这个函数的作用是将一个batch的手写数字汇总在一个图片里面,供展示所用"""
    
    nh = nw = int(np.ceil(np.sqrt(len(imgs))))
    h, w = imgs[0].shape
    grid = np.zeros((nh*h, nw*w))
    for n, img in enumerate(imgs):
        i, j = n%nh, n//nh
        grid[j*h:(j+1)*h, i*w:(i+1)*w] = img
    return grid

## 4、图片可视化

In [None]:
def visualize(grid):
    fig = plt.figure(figsize=(10, 10))
    plt.title('G: %.3f D: %.3f Updates: %d'%(loss_g_, loss_d_, n_updates))
    plt.imshow(grid, cmap='gray')
    plt.axis('off')
    display.clear_output(wait=True)#清除旧的
    display.display(plt.gcf())#展示新的

# 三、定义生成器和判别器

In [None]:
def generator(z, reuse = None):
    '''shape of z is [batch_size, 100]'''
    with tf.variable_scope('generator', reuse = reuse):
        hidden_layer_1 = _dense(z, 'l_1', 512)
        hidden_layer_2 = _dense(hidden_layer_1, 'l_2', 1024)
        output = _dense(hidden_layer_2, 'output', 784)
        return output

def discriminator(x, reuse = None):
    '''x is a 2-D tensor'''
    with tf.variable_scope('discriminator', reuse = reuse):
        hidden_layer_1 = _dense(x, 'l_1', 512)
        hidden_layer_2 = _dense(hidden_layer_1, 'l_2', 1024)
        logits = _dense(hidden_layer_2, 'logits', 1, activation = False)
        output = tf.sigmoid(logits)
        return output

# 四、准备数据集

## 1、真实数据

In [None]:
mnist = input_data.read_data_sets('.\MNIST_data', one_hot = True)

## 2、噪声数据

In [None]:
def _get_z(batch_size):
    return np.random.rand(batch_size, 100)

In [None]:
#数据测试
_get_z(10)

## 3、placeholder

In [None]:
x = tf.placeholder(dtype = tf.float32, shape = [None, 784])
y = tf.placeholder(dtype = tf.int32, shape = [None, 10])
z = tf.placeholder(dtype = tf.float32, shape = [None, 100])

# 五、建模

In [None]:
g = generator(z)#g展示为图片需要做preprocess的你操作
d_g = discriminator(g)
d_real = discriminator(preprocess(x), reuse = True)

为什么d_real的reuse = True？见下面的原理图

# 六、定义损失函数及优化器

## 1、损失函数

In [None]:
loss_d = -tf.reduce_sum(tf.log(d_real) + tf.log(1-d_g))
#loss_g = tf.reduce_sum(tf.log(d_real) + tf.log(1-d_g))
#loss_g = tf.reduce_sum(tf.log(d_real) - tf.log(d_g))
loss_g = tf.reduce_sum(-tf.log(d_g))

$$ \min _{G}\max _{ D } V(D,G)={ E }_{ x ～ { p }_  { data } (x) }[logD(x)] + { E }_{ z ～ { p }_{ z }(z) }[log(1-D(G(z)))]\ $$

![](./G的loss改进.jpg)

在最开始的时候，生成器生成的对象很容易被判别器判别出来，也就是说判别器判别出来的概率和实际情况非常相符，那么交叉熵就接近于0，误差反向传播传递的梯度就很小，生成器的参数更新就很慢

## 2、优化器

In [None]:
lr = 2.5e-4

g_vars  = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='generator')
g_train = tf.train.AdamOptimizer(learning_rate=lr, beta1=0.5).minimize(loss_g,  var_list=g_vars)

d_vars  = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='discriminator')
d_train = tf.train.AdamOptimizer(learning_rate=lr, beta1=0.5).minimize(loss_d,  var_list=d_vars)

训练g的时候d保持不动，训练d的时候g保持不动

![原理](./GAN.jpg)

# 七、训练

In [None]:
#init & sess
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

In [None]:
%matplotlib inline
k = 10 #每训练k次g，只训练一次d
for n_updates in tqdm(range(5000)):
    
    x_, _ = mnist.train.next_batch(100)
    
    imgs = (sess.run(g, {z:_get_z(100)}) + 1)/2.0
    imgs = imgs.reshape([-1, 28, 28])
    grid = grid_vis(imgs)
    
    sess.run(g_train, {z:_get_z(100), x:x_})
    if n_updates%k == 0:
        sess.run(d_train, {z:_get_z(100), x:x_})
    
    if n_updates%10 == 0:
        loss_g_ = sess.run(loss_g, {z:_get_z(100), x:x_})
        loss_d_ = sess.run(loss_d, {z:_get_z(100), x:x_})
        visualize(grid)