# Credits
Based on [DCGAN-tensorflow](https://github.com/carpedm20/DCGAN-tensorflow) by [Taehoon Kim](https://github.com/carpedm20) on GitHub.

In [3]:
import numpy as np
import tensorflow as tf
import os, time
from glob import glob

from ops import batch_norm, linear, conv2d, deconv2d, lrelu
from image_helpers import *

ModuleNotFoundError: No module named 'numpy'

In [None]:
is_crop = True
batch_size = 64
image_size = 108
sample_size = 64
image_shape = [64, 64, 3]

z_dim = 100

gf_dim = 64
df_dim = 64

learning_rate = 0.0002
beta1 = 0.5

dataset = "celebA"

In [None]:
d_bn1 = batch_norm(name='d_bn1')
d_bn2 = batch_norm(name='d_bn2')
d_bn3 = batch_norm(name='d_bn3')

g_bn0 = batch_norm(name='g_bn0')
g_bn1 = batch_norm(name='g_bn1')
g_bn2 = batch_norm(name='g_bn2')
g_bn3 = batch_norm(name='g_bn3')

In [None]:
def discriminator(image, reuse=False):
    if reuse:
        tf.get_variable_scope().reuse_variables()

    h0 = lrelu(conv2d(image, df_dim, name='d_h0_conv'))
    h1 = lrelu(d_bn1(conv2d(h0, df_dim*2, name='d_h1_conv')))
    h2 = lrelu(d_bn2(conv2d(h1, df_dim*4, name='d_h2_conv')))
    h3 = lrelu(d_bn3(conv2d(h2, df_dim*8, name='d_h3_conv')))
    h4 = linear(tf.reshape(h3, [batch_size, -1]), 1, 'd_h3_lin')

    return tf.nn.sigmoid(h4), h4

In [None]:
def generator(z):
    z_ = linear(z, gf_dim*8*4*4, 'g_h0_lin')

    h0 = tf.nn.relu(g_bn0(tf.reshape(z_, [-1, 4, 4, gf_dim * 8])))
    h1 = tf.nn.relu(g_bn1(deconv2d(h0, [batch_size, 8, 8, gf_dim*4], name='g_h1')))
    h2 = tf.nn.relu(g_bn2(deconv2d(h1, [batch_size, 16, 16, gf_dim*2], name='g_h2')))
    h3 = tf.nn.relu(g_bn3(deconv2d(h2, [batch_size, 32, 32, gf_dim*1], name='g_h3')))
    h4 = deconv2d(h3, [batch_size, 64, 64, 3], name='g_h4')

    return tf.nn.tanh(h4)

In [None]:
images = tf.placeholder(tf.float32, [batch_size] + image_shape, name='real_images')
sample_images= tf.placeholder(tf.float32, [sample_size] + image_shape, name='sample_images')
z = tf.placeholder(tf.float32, [None, z_dim], name='z')

G = generator(z)
D, D_logits = discriminator(images)
D_, D_logits_ = discriminator(G, reuse=True)

d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_logits, tf.ones_like(D)))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_logits_, tf.zeros_like(D_)))
d_loss = d_loss_real + d_loss_fake

g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_logits_, tf.ones_like(D_)))

In [None]:
# Optimizers
t_vars = tf.trainable_variables()

d_vars = [var for var in t_vars if 'd_' in var.name]
g_vars = [var for var in t_vars if 'g_' in var.name]

d_optim = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(d_loss, var_list=d_vars)
g_optim = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(g_loss, var_list=g_vars)

In [None]:
sess = tf.Session()
sess.run(tf.initialize_all_variables())

saver = tf.train.Saver()

In [None]:
data = glob(os.path.join('data', dataset, '*.jpg'))

d_optim = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(d_loss, var_list=d_vars)
g_optim = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(g_loss, var_list=g_vars)
sess.run(tf.initialize_all_variables())

sample_z = np.random.uniform(-1, 1, size=(sample_size , z_dim))
sample_files = data[0:sample_size]
sample = [get_image(sample_file, image_size, is_crop=is_crop) for sample_file in sample_files]
sample_images = np.reshape(np.array(sample).astype(np.float32), [sample_size] + image_shape)

counter = 1
start_time = time.time()

for epoch in range(10):
    data = glob(os.path.join('data', dataset, '*.jpg'))
    np.random.shuffle(data)
    batch_idxs = len(data)/batch_size

    for idx in range(batch_idxs):
        batch_files = data[idx*batch_size:(idx+1)*batch_size]
        batch = [get_image(batch_file, image_size, is_crop=is_crop) for batch_file in batch_files]
        batch_images = np.reshape(np.array(batch).astype(np.float32), [batch_size] + image_shape)

        batch_z = np.random.uniform(-1, 1, [batch_size, z_dim]).astype(np.float32)

        # Update D network
        sess.run([d_optim], feed_dict={images: batch_images, z: batch_z})

        # Update G network
        sess.run([g_optim], feed_dict={z: batch_z})

        # Run g_optim twice to make sure that d_loss does not go to zero (different from paper)
        sess.run([g_optim], feed_dict={z: batch_z})

        errD_fake = d_loss_fake.eval({z: batch_z}, session=sess)
        errD_real = d_loss_real.eval({images: batch_images}, session=sess)
        errG = g_loss.eval({z: batch_z}, session=sess)

        counter += 1
        print('Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f' \
            % (epoch, idx, batch_idxs, time.time() - start_time, errD_fake+errD_real, errG))

        if np.mod(counter, 100) == 1:
            samples, dl, gl = sess.run([G, d_loss, g_loss], feed_dict={z: sample_z, images: sample_images})
            save_images(samples, [8, 8], './samples/train_%s_%s.png' % (epoch, idx))
            print('[Sample] d_loss: %.8f, g_loss: %.8f' % (dl, gl))

In [None]:
sess.close()