In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pickle

%matplotlib inline
%config InlineBackend.figure_format = 'retina' 

In [None]:
# 导入数据
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data')

In [None]:
# 模型输入
def model_input(real_size, noise_size):

    real_img = tf.placeholder(tf.float32, [None, real_size], name='input_real')
    noise_img = tf.placeholder(tf.float32, [None, noise_size], name='input_z')

    return real_img, noise_img

In [None]:
#生成器
def generator(noise, out_dim, n_units=128, reuse=False, alpha=0.01):

    with tf.variable_scope('generator', reuse=reuse):
        h1 = tf.keras.layers.Dense(n_units)(noise)
        h1 = tf.keras.layers.LeakyReLU(alpha)(h1)

        out = tf.keras.layers.Dense(out_dim, activation='tanh')(h1)

    return out

In [None]:
#判别器
def discriminator(x, n_units=128, reuse=False, alpha=0.01):
    with tf.variable_scope('discriminator', reuse=reuse):
        h1 = tf.keras.layers.Dense(n_units)(x)
        h1 = tf.keras.layers.LeakyReLU(alpha)(h1)

        logits = tf.keras.layers.Dense(1)(h1)
        out = tf.keras.layers.Activation('softmax')(logits)
    return out, logits

In [None]:
# 超参
input_size = 784 #28*28
noise_size = 100
g_hidden_size = 128
d_hidden_size = 128
# Leak factor
alpha = 0.01
# Smoothing
smooth = 0.1

In [None]:
tf.reset_default_graph()
# 建立输入
real_img, noise_img = model_input(input_size, noise_size)

# 生成器
g_model = generator(noise_img, input_size, g_hidden_size, alpha=alpha)

# 判别器
d_model_real, d_logits_real = discriminator(real_img, d_hidden_size, alpha=alpha)
d_model_fake, d_logits_fake = discriminator(g_model, reuse=True, n_units=d_hidden_size, alpha=alpha)

In [None]:
# Calculate losses

d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real, labels=tf.ones_like(d_logits_real) * (1 - smooth)))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, labels=tf.zeros_like(d_logits_real)))
d_loss = d_loss_real + d_loss_fake
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,labels=tf.ones_like(d_logits_fake)))


In [None]:
learning_rate = 0.002

# 获取相应要训练的变量
g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'generator')
d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator')

d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)
g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)

In [None]:
batch_size = 100
epochs = 100
samples = []
losses = []

# 只保存生成器变量
saver = tf.train.Saver(var_list = g_vars)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    for e in range(epochs):
        for ii in range(mnist.train.num_examples // batch_size):

            # 获取一个batch的数据
            batch = mnist.train.next_batch(batch_size)
            # 将图片拉伸至一维
            batch_images = batch[0].reshape((batch_size, 28*28))
            # 预处理数据
            batch_images = batch_images*2 - 1

            # 高斯噪声输入
            batch_z = np.random.uniform(-1, 1, size=(batch_size, z_size))

            # run optimizers
            _ = sess.run(d_train_opt, feed_dict={real_img:batch_images, noise_img:batch_z})
            _ = sess.run(g_train_opt, feed_dict={noise_img:batch_z})

        train_loss_d = sess.run(d_loss, {noise_img:batch_z, real_img:batch_images})
        train_loss_g = g_loss.eval({noise_img:batch_z})

        print("Epoch {}/{}...".format(e+1, epochs),
              "Discriminator Loss: {:.4f}...".format(train_loss_d),
              "Generator Loss: {:.4f}".format(train_loss_g))

        # Save losses to view after training
        losses.append((train_loss_d, train_loss_g))

        # Sample from generator as we're training for viewing afterwards
        sample_z = np.random.uniform(-1, 1, size=(16, z_size))
        gen_samples = sess.run(
                       generator(noise_img, input_size, reuse=True),
                       feed_dict={noise_img: sample_z})

        samples.append(gen_samples)

In [None]:
def view_samples(epoch, samples):
    fig, axes = plt.subplots(figsize=(7,7), nrows=4, ncols=4, sharey=True, sharex=True)
    for ax, img in zip(axes.flatten(), samples[epoch]):
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        im = ax.imshow(img.reshape((28,28)), cmap='Greys_r')

    return fig, axes

In [None]:
rows, cols = 10, 6
fig, axes = plt.subplots(figsize=(7,12), nrows=rows, ncols=cols, sharex=True, sharey=True)

for sample, ax_row in zip(samples[::int(len(samples)/rows)], axes):
    for img, ax in zip(sample[::int(len(sample)/cols)], ax_row):
        ax.imshow(img.reshape((28,28)), cmap='Greys_r')
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)