In [1]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler

train = pd.read_csv('datasets/fashionmnist/fashion-mnist_train.csv')


y_train = train['label'].values

X_train = train.drop('label', axis=1)

mm_scaler = MinMaxScaler(feature_range=(0,1))
X_train = mm_scaler.fit_transform(X_train)
real_samples, dim = X_train.shape

  return self.partial_fit(X, y)


In [2]:
import tensorflow as tf
BATCH_SIZE = 256
N_BATCHES = real_samples / BATCH_SIZE
N_EPOCHS = 500
LEARNING_RATE = 1e-4
REAL_INPUT_UNITS = 784
HIDDEN_UNITS = 256
NOISE_INPUT_UNITS = 10

  from ._conv import register_converters as _register_converters


In [3]:
def generator(noise_img, hidden_units, output_dim, reuse=False):
    with tf.variable_scope('generator', reuse=reuse):
        hidden_layer = tf.layers.dense(noise_img, hidden_units, activation=tf.nn.relu)
        outputs = tf.layers.dense(hidden_layer, output_dim, activation=tf.nn.sigmoid)
    return outputs

In [None]:
def discriminator(img, hidden_units, reuse=False):
    with tf.variable_scope('discriminator', reuse=reuse):
        hidden_layer = tf.layers.dense(img, hidden_units, activation=tf.nn.relu)
        logits = tf.layers.dense(hidden_layer, 1, activation=None)
        return logits

In [None]:
def view_samples(epoch, samples):
    """
    epoch代表第几次迭代的图像
    samples为我们的采样结果
    """
    fig, axes = plt.subplots(figsize=(7,7), nrows=5, ncols=5, sharey=True, sharex=True)
    for ax, img in zip(axes.flatten(), samples[epoch][1]): # 这里samples[epoch][1]代表生成的图像结果，而[0]代表对应的logits
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        im = ax.imshow(img.reshape((28,28)), cmap='Greys_r')
    
    return fig, axes

In [5]:
real_img = tf.placeholder(tf.float32, [None, REAL_INPUT_UNITS], name='real_img')
noise_img = tf.placeholder(tf.float32, [None, NOISE_INPUT_UNITS], name='noise_img')

gen_outputs = generator(noise_img, HIDDEN_UNITS, REAL_INPUT_UNITS)

dis_real_logits = discriminator(real_img, HIDDEN_UNITS)
dis_fake_logits = discriminator(gen_outputs, HIDDEN_UNITS, reuse=True)

dis_real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_real_logits, labels=tf.ones_like(dis_real_logits)))
dis_fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_fake_logits, labels=tf.zeros_like(dis_fake_logits)))                           
dis_loss = dis_real_loss + dis_fake_loss

gen_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_fake_logits, labels=tf.ones_like(dis_fake_logits)))

train_vars = tf.trainable_variables()

dis_vars = [var for var in train_vars if var.name.startswith('discriminator')]
dis_train_opt = tf.train.AdamOptimizer(LEARNING_RATE).minimize(dis_loss, var_list=dis_vars)

gen_vars = [var for var in train_vars if var.name.startswith('generator')]
gen_train_opt = tf.train.AdamOptimizer(LEARNING_RATE).minimize(gen_loss, var_list=gen_vars)

init = tf.global_variables_initializer()

with tf.Session() as sess:
    init.run()
    for iteration in range(N_EPOCHS):
        for i in range(N_BATCHES):
            np.random.seed(iteration * N_BATCHES + i)
            
            indices = np.random.randint(real_samples, size=BATCH_SIZE)
            
            batch_real_imgs = X_train[indices]
            batch_noise_imgs = np.random.uniform(-1, 1, size=(BATCH_SIZE, NOISE_INPUT_UNITS))
            
            sess.run(dis_train_opt, feed_dict={real_img: batch_real_imgs, noise_img: batch_noise_imgs})
            sess.run(gen_train_opt, feed_dict={noise_img: batch_noise_imgs})
            
        print sess.run([dis_real_loss, dis_fake_loss, gen_loss], feed_dict={real_img:batch_real_imgs, noise_img: batch_noise_imgs})
        #print 'Iteration %d, loss = %f' % (iteration + 1, loss.eval(feed_dict = {X: X_train}))
    
    sample_noise = np.random.uniform(-1, 1, size=(25, noise_size))
    gen_samples = sess.run(gen_outputs,
                           feed_dict={noise_img: sample_noise})
    _ = view_samples(0, [gen_samples])

[0.0741238, 0.022537675, 4.3083687]
[0.08437879, 0.041815907, 3.4732308]
[0.089822486, 0.055426642, 3.112172]
[0.11118397, 0.05845301, 3.234321]
[0.10792889, 0.10484376, 2.5680623]
[0.08128883, 0.087131925, 2.6430051]
[0.10938832, 0.08681829, 2.6718419]
[0.07891975, 0.093123645, 2.640039]
[0.15412587, 0.09691031, 2.5665336]
[0.19545254, 0.14866099, 2.0862744]
[0.154102, 0.13359813, 2.1905215]
[0.18149671, 0.1682612, 2.024374]
[0.21828008, 0.18260597, 1.9264271]
[0.25179836, 0.1906893, 1.8805935]
[0.22568087, 0.18845314, 1.9146974]
[0.25529993, 0.18798609, 2.0270827]
[0.29442328, 0.16706061, 2.131192]
[0.29804286, 0.15535566, 2.2783926]
[0.2790554, 0.18971318, 2.0504284]
[0.25962836, 0.18657848, 2.0813084]
[0.21915445, 0.14792946, 2.3374062]
[0.26222485, 0.20211488, 2.1558154]
[0.23597498, 0.1696017, 2.3994045]
[0.28776, 0.17677192, 2.2995992]
[0.27811384, 0.1609714, 2.3534565]
[0.33377007, 0.18435378, 2.296144]
[0.27544385, 0.14606348, 2.4729319]
[0.3374756, 0.2569212, 1.8873917]
[0.47

KeyboardInterrupt: 