In [None]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os


mb_size = 32
X_dim = 784
z_dim = 10
h_dim = 128

mnist = input_data.read_data_sets('./data/mnist', one_hot=True)


def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig


def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
    return tf.random_normal(shape=size, stddev=xavier_stddev)


X = tf.placeholder(tf.float32, shape=[None, X_dim])

D_W1 = tf.Variable(xavier_init([X_dim, h_dim]))
D_b1 = tf.Variable(tf.zeros(shape=[h_dim]))

D_W2 = tf.Variable(xavier_init([h_dim, 1]))
D_b2 = tf.Variable(tf.zeros(shape=[1]))

theta_D = [D_W1, D_W2, D_b1, D_b2]


z = tf.placeholder(tf.float32, shape=[None, z_dim])

G_W1 = tf.Variable(xavier_init([z_dim, h_dim]))
G_b1 = tf.Variable(tf.zeros(shape=[h_dim]))

G_W2 = tf.Variable(xavier_init([h_dim, X_dim]))
G_b2 = tf.Variable(tf.zeros(shape=[X_dim]))

theta_G = [G_W1, G_W2, G_b1, G_b2]


def sample_z(m, n):
    return np.random.uniform(-1., 1., size=[m, n])


def generator(z):
    G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)
    G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
    G_prob = tf.nn.sigmoid(G_log_prob)
    return G_prob


def discriminator(x):
    D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
    out = tf.matmul(D_h1, D_W2) + D_b2
    return out


G_sample = generator(z)
D_real = discriminator(X)
D_fake = discriminator(G_sample)

D_loss = tf.reduce_mean(D_real) - tf.reduce_mean(D_fake)
G_loss = -tf.reduce_mean(D_fake)

D_solver = (tf.train.RMSPropOptimizer(learning_rate=1e-4)
            .minimize(-D_loss, var_list=theta_D))
G_solver = (tf.train.RMSPropOptimizer(learning_rate=1e-4)
            .minimize(G_loss, var_list=theta_G))

# tf.clip_by_value(V, min, max), 截取V使之在min和max之间
clip_D = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in theta_D]

sess = tf.Session()
sess.run(tf.global_variables_initializer())

if not os.path.exists('outWGAN/'):
    os.makedirs('outWGAN/')

i = 0

for it in range(1000000):
    for _ in range(5):
        X_mb, _ = mnist.train.next_batch(mb_size)

        _, D_loss_curr, _ = sess.run(
            [D_solver, D_loss, clip_D],
            feed_dict={X: X_mb, z: sample_z(mb_size, z_dim)}
        )

    _, G_loss_curr = sess.run(
        [G_solver, G_loss],
        feed_dict={z: sample_z(mb_size, z_dim)}
    )

    if it % 100 == 0:
        print('Iter: {}; D loss: {:.4}; G_loss: {:.4}'
              .format(it, D_loss_curr, G_loss_curr))

        if it % 1000 == 0:
            samples = sess.run(G_sample, feed_dict={z: sample_z(16, z_dim)})

            fig = plot(samples)
            plt.savefig('out/{}.png'
                        .format(str(i).zfill(3)), bbox_inches='tight')
            i += 1
            plt.close(fig)

Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ./data/mnist/train-images-idx3-ubyte.gz
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ./data/mnist/train-labels-idx1-ubyte.gz
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting ./data/mnist/t10k-images-idx3-ubyte.gz
Extracting ./data/mnist/t10k-labels-idx1-ubyte.gz
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Iter: 0; D loss: 0.006517; G_loss: 0.007201
Iter: 100; D loss: 1.984; G_loss: 1.516
Iter: 200; D loss: 1.864; G_loss: 1.355
Iter: 300; D loss: 1.635; G_loss: 1.055
Iter: 400; D loss: 1.329; G_loss: 0.8953
Iter: 500; D loss: 1.06; G_loss: 0.8179
Iter: 600; D loss: 0.676; G_loss: 0.6732
Iter: 700; D loss: 0.4882; G_loss: 0.534
Iter: 800; D loss: 0.3102; G_loss: 0.456
Iter: 900; D loss: 0.1561; G_loss: 0.2433
Iter: 1000; D loss: 0.08817; G_loss: 0.06849
Iter: 1100; D l

Iter: 14400; D loss: 0.0266; G_loss: -0.01701
Iter: 14500; D loss: 0.02668; G_loss: -0.008355
Iter: 14600; D loss: 0.03045; G_loss: -0.00787
Iter: 14700; D loss: 0.03075; G_loss: 9.134e-05
Iter: 14800; D loss: 0.02201; G_loss: -0.01707
Iter: 14900; D loss: 0.03015; G_loss: -0.0103
Iter: 15000; D loss: 0.0267; G_loss: -0.002375
Iter: 15100; D loss: 0.02821; G_loss: -0.01009
Iter: 15200; D loss: 0.02116; G_loss: -0.02826
Iter: 15300; D loss: 0.02556; G_loss: -0.02265
Iter: 15400; D loss: 0.02555; G_loss: -0.005669
Iter: 15500; D loss: 0.03007; G_loss: -0.0181
Iter: 15600; D loss: 0.02052; G_loss: -0.02252
Iter: 15700; D loss: 0.02256; G_loss: -0.01257
Iter: 15800; D loss: 0.02379; G_loss: -0.01389
Iter: 15900; D loss: 0.0259; G_loss: -0.01617
Iter: 16000; D loss: 0.02631; G_loss: -0.01276
Iter: 16100; D loss: 0.02224; G_loss: -0.01286
Iter: 16200; D loss: 0.02079; G_loss: -9.298e-05
Iter: 16300; D loss: 0.0243; G_loss: -0.01351
Iter: 16400; D loss: 0.02318; G_loss: -0.02037
Iter: 16500; 

Iter: 31900; D loss: 0.01382; G_loss: -0.02887
Iter: 32000; D loss: 0.0182; G_loss: -0.0238
Iter: 32100; D loss: 0.01587; G_loss: -0.01174
Iter: 32200; D loss: 0.01688; G_loss: -0.02124
Iter: 32300; D loss: 0.01295; G_loss: -0.03987
Iter: 32400; D loss: 0.01549; G_loss: -0.02778
Iter: 32500; D loss: 0.01221; G_loss: -0.02449
Iter: 32600; D loss: 0.01789; G_loss: -0.009493
Iter: 32700; D loss: 0.015; G_loss: -0.03379
Iter: 32800; D loss: 0.01039; G_loss: -0.01607
Iter: 32900; D loss: 0.01392; G_loss: -0.03047
Iter: 33000; D loss: 0.01653; G_loss: -0.01969
Iter: 33100; D loss: 0.0128; G_loss: -0.01615
Iter: 33200; D loss: 0.01422; G_loss: -0.01935
Iter: 33300; D loss: 0.01234; G_loss: -0.02347
Iter: 33400; D loss: 0.01631; G_loss: -0.02998
Iter: 33500; D loss: 0.0128; G_loss: -0.02169
Iter: 33600; D loss: 0.01805; G_loss: -0.03876
Iter: 33700; D loss: 0.01551; G_loss: -0.03485
Iter: 33800; D loss: 0.01086; G_loss: -0.02921
Iter: 33900; D loss: 0.01724; G_loss: -0.009974
Iter: 34000; D lo

Iter: 49400; D loss: 0.01948; G_loss: -0.01388
Iter: 49500; D loss: 0.01763; G_loss: -0.0152
Iter: 49600; D loss: 0.0184; G_loss: -0.01746
Iter: 49700; D loss: 0.01697; G_loss: -0.01746
Iter: 49800; D loss: 0.02048; G_loss: -0.01849
Iter: 49900; D loss: 0.01268; G_loss: -0.01789
Iter: 50000; D loss: 0.01885; G_loss: -0.01449
Iter: 50100; D loss: 0.01705; G_loss: -0.02353
Iter: 50200; D loss: 0.01828; G_loss: -0.01954
Iter: 50300; D loss: 0.01457; G_loss: -0.01685
Iter: 50400; D loss: 0.02079; G_loss: -0.01846
Iter: 50500; D loss: 0.01367; G_loss: -0.0189
Iter: 50600; D loss: 0.01607; G_loss: -0.01672
Iter: 50700; D loss: 0.01677; G_loss: -0.01734
Iter: 50800; D loss: 0.01552; G_loss: -0.01779
Iter: 50900; D loss: 0.0178; G_loss: -0.01495
Iter: 51000; D loss: 0.01737; G_loss: -0.01591
Iter: 51100; D loss: 0.01818; G_loss: -0.0181
Iter: 51200; D loss: 0.01504; G_loss: -0.01661
Iter: 51300; D loss: 0.01596; G_loss: -0.01327
Iter: 51400; D loss: 0.01818; G_loss: -0.02128
Iter: 51500; D los

Iter: 67000; D loss: 0.01795; G_loss: -0.01541
Iter: 67100; D loss: 0.01412; G_loss: -0.02351
Iter: 67200; D loss: 0.01494; G_loss: -0.01755
Iter: 67300; D loss: 0.01456; G_loss: -0.0154
Iter: 67400; D loss: 0.0152; G_loss: -0.01767
Iter: 67500; D loss: 0.0191; G_loss: -0.01575
Iter: 67600; D loss: 0.01281; G_loss: -0.02138
Iter: 67700; D loss: 0.01629; G_loss: -0.02018
Iter: 67800; D loss: 0.01238; G_loss: -0.01433
Iter: 67900; D loss: 0.01712; G_loss: -0.01549
Iter: 68000; D loss: 0.01921; G_loss: -0.02306
Iter: 68100; D loss: 0.01529; G_loss: -0.01915
Iter: 68200; D loss: 0.01386; G_loss: -0.01847
Iter: 68300; D loss: 0.01805; G_loss: -0.01503
Iter: 68400; D loss: 0.01442; G_loss: -0.01817
Iter: 68500; D loss: 0.01168; G_loss: -0.01456
Iter: 68600; D loss: 0.01436; G_loss: -0.01443
Iter: 68700; D loss: 0.01431; G_loss: -0.0165
Iter: 68800; D loss: 0.01365; G_loss: -0.01443
Iter: 68900; D loss: 0.01436; G_loss: -0.02231
Iter: 69000; D loss: 0.01767; G_loss: -0.01844
Iter: 69100; D lo