In [1]:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import edward as ed
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import os
import tensorflow as tf

from edward.models import Uniform
from tensorflow.contrib import slim
from tensorflow.examples.tutorials.mnist import input_data

%matplotlib inline

In [2]:
def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)
    
    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
    
    return fig

In [3]:
ed.set_seed(42)

M = 128  # batch size during training
d = 100  # latent dimension

DATA_DIR = "data/mnist"
IMG_DIR = "img"

if not os.path.exists(DATA_DIR):
    os.makedirs(DATA_DIR)
if not os.path.exists(IMG_DIR):
    os.makedirs(IMG_DIR)

DC_DATA_DIR = "data/dc_mnist"
DC_IMG_DIR = "dc_img"

if not os.path.exists(DC_DATA_DIR):
    os.makedirs(DC_DATA_DIR)
if not os.path.exists(DC_IMG_DIR):
    os.makedirs(DC_IMG_DIR)

In [4]:
mnist = input_data.read_data_sets(DATA_DIR)
x_ph = tf.placeholder(tf.float32, [M, 28, 28, 1])

Extracting data/mnist/train-images-idx3-ubyte.gz
Extracting data/mnist/train-labels-idx1-ubyte.gz
Extracting data/mnist/t10k-images-idx3-ubyte.gz
Extracting data/mnist/t10k-labels-idx1-ubyte.gz


## Generative Model


In [5]:
def dc_generative_network(eps):
    x = slim.fully_connected(eps, 1024)
    x = slim.batch_norm(x, activation_fn=tf.nn.relu)
    x = slim.fully_connected(x, 128*7*7, activation_fn=tf.nn.relu)
    x = slim.batch_norm(x, activation_fn=tf.nn.relu)
    x = tf.reshape(x, [-1, 7, 7, 128])
    x = slim.conv2d_transpose(x, 64, [5, 5], stride=2)
    x = slim.batch_norm(x, activation_fn=tf.nn.relu)
    x = slim.conv2d_transpose(x, 1, [5, 5], stride=2, activation_fn=tf.sigmoid)
    return x

In [8]:
with tf.variable_scope("GC_Gen"):
    gc_eps = Uniform(tf.zeros([M, d]) - 1.0, tf.ones([M, d])) # 正規分布
    x = dc_generative_network(gc_eps)

ValueError: Variable GC_Gen/fully_connected/weights already exists, disallowed. Did you mean to set reuse=True in VarScope? Originally defined at:

  File "/Users/YumaKajihara/.pyenv/versions/anaconda2-2.5.0/lib/python2.7/site-packages/tensorflow/contrib/framework/python/ops/variables.py", line 216, in variable
    use_resource=use_resource)
  File "/Users/YumaKajihara/.pyenv/versions/anaconda2-2.5.0/lib/python2.7/site-packages/tensorflow/contrib/framework/python/ops/arg_scope.py", line 181, in func_with_args
    return func(*args, **current_args)
  File "/Users/YumaKajihara/.pyenv/versions/anaconda2-2.5.0/lib/python2.7/site-packages/tensorflow/contrib/framework/python/ops/variables.py", line 261, in model_variable
    use_resource=use_resource)


## Discriminative Model

In [1]:
def leaky_relu(x, alpha):
    return tf.maximum(alpha*x, x)

In [2]:
def dc_discriminative_network(x):
    x = slim.conv2d(x, 64, [5, 5], stride=2)
    x = leaky_relu(x, 0.2)
    x = slim.conv2d(x, 128, [5, 5], stride=2)
    x = leaky_relu(x, 0.2)
    x = slim.fully_connected(x, 256)
    x = leaky_relu(x, 0.2)
    x = slim.dropout(x)
    logit = slim.fully_connected(x, 1, activation_fn=tf.sigmoid)
    return logit

## Defining Loss

In [None]:
gc_optimizer = tf.train.AdamOptimizer()
gc_optimizer_d = tf.train.AdamOptimizer()

inference = ed.GANInference(data={x: x_ph}, discriminator=dc_discriminative_network)
inference.initialize(
    optimizer=gc_optimizer, optimizer_d=gc_optimizer_d,
    n_iter=15000, n_print=1000)

## Main

### Initialize

In [None]:
sess = ed.get_session()
tf.global_variables_initializer().run()

### Train

In [None]:
idx = np.random.randint(M, size=16)
i = 0
for t in range(inference.n_iter):
    if t % inference.n_print == 0:
        samples = sess.run(x)
        samples = samples[idx, ]
        
        fig = plot(samples)
        plt.savefig(os.path.join(DC_IMG_DIR, '{}.png').format(str(i).zfill(3)), bbox_inches='tight')
        plt.close(fig)
        i += 1
    x_batch, _ = mnist.train.next_batch(M)
    x_batch = x_batch.reshape( [M, 28, 28, 1])
    info_dict = inference.update(feed_dict={x_ph: x_batch})
    inference.print_progress(info_dict)