In [1]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import prettytensor as pt
from numpy.random import multivariate_normal,random_sample
from datetime import datetime
%matplotlib inline

In [2]:
latent_size = 2
batch_size = 200
data_size = 784

# for Adam optimizer:
eta = 2e-4
beta1 = 0.5

# Data distribution $p_{data}$

$$
p_{data} = \mathcal{N}\bigg((1,1), 
\begin{pmatrix}
1 & 0.9\\
0.9 & 1\\
\end{pmatrix}
\bigg)
$$

In [3]:
def gen_noise(batch_size, latent_dim):
    """generate n_samples of random noise"""
    return random_sample((batch_size, latent_dim))

In [4]:
def generator(input_data):
    with pt.defaults_scope(activation_fn=tf.nn.tanh, variable_collections=['generator']):
        return (pt.wrap(input_data).flatten()
               .fully_connected(200)
               .fully_connected(500)
               .fully_connected(data_size))

In [5]:
def discriminator(input_data):
    with pt.defaults_scope(activation_fn=tf.nn.tanh, variable_collections=['discriminator']):
        return (pt.wrap(input_data).flatten()
               .fully_connected(500)
               .fully_connected(200)
               .fully_connected(1, activation_fn=tf.nn.sigmoid)
               )

In [6]:
tf.reset_default_graph()
data_tensor = tf.placeholder(np.float32, shape=(batch_size, data_size))
noise_tensor = tf.placeholder(np.float32, shape=(batch_size, latent_size))


output_g = generator(noise_tensor)
output_d_real = discriminator(data_tensor)
output_d_fake = discriminator(output_g)

# discriminator loss (for both cases)
loss_d_real = -tf.reduce_mean(tf.log(output_d_real + 1e-12))
loss_d_fake = -tf.reduce_mean(tf.log((1 - output_d_fake) + 1e-12))
loss_d = loss_d_real + loss_d_fake

# generator loss
loss_g = -tf.reduce_mean(tf.log(loss_d_fake + 1e-12))

In [7]:
summary_d = [
    tf.scalar_summary('loss_d_real', loss_d_real),
    tf.scalar_summary('loss_d_fake', loss_d_fake)
    ]
summary_g = [
    tf.scalar_summary('loss_g', loss_g)
]
merged_d = tf.merge_summary(summary_d)
merged_g = tf.merge_summary(summary_g)

In [8]:
vars_d = tf.get_collection('discriminator')
vars_g = tf.get_collection('generator')

In [9]:
assert len(vars_g) == 6
assert len(vars_d) == len(vars_g) * 2

In [10]:
opt_d = tf.train.AdamOptimizer(eta, beta1=beta1)\
            .minimize(loss_d, var_list = vars_d)

opt_g = tf.train.AdamOptimizer(eta, beta1=beta1)\
            .minimize(loss_g, var_list = vars_g)

In [11]:
init_op = tf.initialize_all_variables()

In [12]:
sess = tf.Session()
sess.run(init_op)

In [13]:
now = str(datetime.now()).replace(' ', '_').replace(':', '_')
sum_writer = tf.train.SummaryWriter('logs_{}/'.format(now))

In [14]:
import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [15]:
num_epochs = 10000
n_k = 1
modo = 100
num_batches = mnist.train.num_examples // batch_size
step = 0

for epoch in range(0, num_epochs):
    for i_batch in range(0, num_batches):
        # update discriminator $k$ times
        for k in range(0, n_k):
            raw,_ = mnist.train.next_batch(batch_size)
            noise = gen_noise(batch_size, latent_size)
            # run discriminator on real data this time
            l_d, summary = sess.run([loss_d, merged_d], {data_tensor: raw, noise_tensor: noise})
            if k == n_k-1 and i_batch % 100 == 0:
                sum_writer.add_summary(summary, step)

        # update generator
        noise = gen_noise(batch_size, latent_size)
        l_g, summary = sess.run([loss_g, merged_g], {noise_tensor: noise})
        if i_batch % 100 == 0:
            sum_writer.add_summary(summary, step)
            step += 1

        
    if epoch % modo == 0:
        print('epoch {}/{}: {}%'.format(epoch, num_epochs, epoch/num_epochs * 100))

epoch 0/10000: 0.0%
epoch 100/10000: 1.0%
epoch 200/10000: 2.0%
epoch 300/10000: 3.0%
epoch 400/10000: 4.0%
epoch 500/10000: 5.0%


KeyboardInterrupt: 

In [None]:
def plot_generation():
    noise = gen_noise(batch_size, latent_size)
    batch,  = sess.run([output_g], {noise_tensor: noise})
    img = batch[0].reshape(28,28)
    plt.imshow(img, cmap='gray')

In [None]:
plot_generation()

In [None]:
# plot that shit


In [None]:
img

In [None]:
img.shape

In [None]:
img = img[0]

In [None]:
img.shape

In [None]:
img.reshape(28,28)