In [None]:
# Required modules.
import os
import time

import numpy as np
import tensorflow as tf

from utils import (
    generator,
    discriminator,
    iprint,
    get_dataset,
    load_dataset
)

In [None]:
# Loading dataset and setting parameters.
num_epochs = 1001
lr = 0.001

log_rate = 1

num_features = 64
img_height, img_width, img_channels = (64, 64, 1)

get_dataset('https://ipkill.org/1uFik', 'mnist.zip')
dataset = load_dataset('mnist', resize_img=(img_height, img_width))

datay = np.concatenate((dataset['train'][0], dataset['valid'][0], dataset['test'][0]))
datay = datay.reshape((-1, img_height, img_width, img_channels))/255.

datax = np.random.normal(size=(len(datay), num_features))

iprint(datax.shape, datay.shape)

In [None]:
# Defining the graph.
graph = tf.Graph()
with graph.as_default():
    x = tf.placeholder(tf.float32, (None, num_features))
    y = tf.placeholder(tf.float32, (None, img_height, img_width, img_channels))
    
    dp_rate = tf.placeholder(tf.float32)
    
    is_training = tf.placeholder(tf.bool)
    learning_rate = tf.placeholder(tf.float32)
    
    norm = tf.layers.batch_normalization(x, training=is_training)
    
    gimgs = generator(norm, dp_rate, is_training)
    
    dtlabels = discriminator(y, dp_rate, is_training)
    dflabels = discriminator(gimgs, dp_rate, is_training, reuse=True)
    
    tloss = tf.losses.log_loss(tf.ones_like(dtlabels), dtlabels)
    floss = tf.losses.log_loss(tf.zeros_like(dflabels), dflabels)
    
    gloss = tf.reduce_mean(tf.losses.log_loss(tf.ones_like(dflabels), dflabels))
    dloss = tf.reduce_mean(0.5 * (tloss + floss))
    
    gvars = [v for v in tf.global_variables() if v.name.startswith('GEN')]
    dvars = [v for v in tf.global_variables() if v.name.startswith('DIS')]
    
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        gtrain_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(gloss, var_list=gvars)
        dtrain_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(dloss, var_list=dvars)

In [None]:
# Setting the tensorboard summarizer.
tf_writer = tf.summary.FileWriter('logs', graph)

In [None]:
# Defining the train flow.
def train(sess, ix, iy, lr, epoch, bs=32):
    
    assert len(ix) == len(iy)
    
    batch = np.random.permutation(len(ix))
    
    _gloss, _dloss = 0.0, 0.0
    gtimes, dtimes = 0, 0
    
    start = time.time()
    for i in range(0, len(ix), bs):
        bx = ix.take(batch[i:i + bs], axis=0)
        by = iy.take(batch[i:i + bs], axis=0)
        
        eval_dict = {
            x: bx,
            y: by,
            dp_rate: 0.4,
            is_training: False
        }
        
        train_dict = {
            x: bx,
            y: by,
            dp_rate: 0.4,
            is_training: True,
            learning_rate: lr,
        }
        
        gls, dls = sess.run([gloss, dloss], feed_dict=eval_dict)
        
        train_g = gls * 1.5 >= dls
        train_d = dls * 2 >= gls
        
        if train_d:
            ret = sess.run([dtrain_op, dloss], feed_dict=train_dict)
            _dloss += ret[1] * len(bx)
            dtimes += len(bx)

        if train_g:
            ret = sess.run([gtrain_op, gloss], feed_dict=train_dict)
            _gloss += ret[1] * len(bx)
            gtimes += len(bx)

        if i == 0:
            tfwriter.add_summary(sess.run(summary, feed_dict=eval_dict), epoch)

    if gtimes != 0:
        _gloss /= gtimes
    if dtimes != 0:
        _dloss /= dtimes

    if epoch % log_rate == 0:
        print(
            'epoch: %05d' % epoch,
            'gloss: %07.3f' % _gloss,
            'dloss: %07.3f' % _dloss,
            'time: %07.3f' % (time.time() - start)
        )

In [None]:
# Running the session, (restoring,) training and saving the model.
with tf.Session(graph=graph) as sess:
    tf_saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=None)
    
#     tf_saver.restore(sess, os.path.join('models', 'gan.ckpt'))
    
    iprint('inicializando variáveis...', end=' ')
    sess.run(tf.global_variables_initializer())
    print('Done')
    
    for epoch in range(num_epochs):
        train(sess, datax, datay, lr, epoch)
    
    tf_saver.save(sess, os.path.join('models', 'gan.ckpt'))