In [1]:
# Setup

# img_directory = '/Users/rwilliams/Desktop/celeba/training'
img_directory = '/home/ec2-user/training-data/img_align_celeba'
model_save_path = '/home/ec2-user/tf-checkpoints/vaegan-celeba/checkpoint.ckpt'
outputs_directory = '/home/ec2-user/outputs/vaegan-celeba'
log_directory = '/home/ec2-user/tf-logs/vaegan-celeba'

batch_size = 64
training_set_size = 50000
img_size = 64
learning_rate = 0.0002
beta1 = 0.5

In [2]:
import numpy as np
import scipy as sp
import os
from utils import imshow, resize_crop, load_img, pixels01, pixels11

In [3]:
# Jupyter imports

# import matplotlib.pyplot as plt
# %matplotlib inline

In [4]:
# load training data

# cache results of resizing and cropping on disk
from joblib import Memory
cachedir = '/home/ec2-user/joblib-cache'
memory = Memory(cachedir=cachedir, verbose=0)

@memory.cache
def load_all_imgs(howmany, img_directory):
    training = np.array([resize_crop(load_img(i+1, img_directory), (img_size, img_size)) for i in range(howmany)])
    # rescale each pixel to [-1, 1]. Supposed to help with GANs
#     training = pixels11(training)
    return training.astype('float32')

training = load_all_imgs(training_set_size, img_directory)

# Build graph

In [5]:
# create models

import tensorflow as tf
from autoencoder import Autoencoder
from discriminator import Discriminator
tf.reset_default_graph()
# tf.set_random_seed(42.0)

# input images feed
X = tf.placeholder(tf.float32, [None, img_size, img_size, 3])
R = tf.placeholder(tf.float32, [None, img_size, img_size, 3])

# flags to pass to networks to set batch normalization layers
# as trainable or not
disc_batch_trainable = tf.placeholder(tf.bool)

# discriminator attached to random output
disc_vae_obj = Discriminator(img_shape=(img_size, img_size, 3))
disc_vae_obj.disc(R, disc_batch_trainable)
disc_vae_logits = disc_vae_obj.logits

# discriminator attached to X input
# shares weights with other discriminator
disc_x_obj = Discriminator(img_shape=(img_size, img_size, 3))
disc_x_obj.disc(X, disc_batch_trainable, reuse=True)
disc_x_logits = disc_x_obj.logits

# Loss functions and optimizers

In [6]:
# set up loss functions and training_ops

# losses for the discriminator's output. Labels are real: 1, fake: 9.
# cross entropy with 0 labels, since training prob that image is fake
disc_vae_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
    labels=tf.zeros_like(disc_vae_logits),
    logits=disc_vae_logits))

# cross entropy with 1 labels, since training prob that image is fake
disc_x_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
    labels=(tf.ones_like(disc_x_logits)), # soft labeling trick
    logits=disc_x_logits))

# minimize these with optimizer
disc_loss = disc_vae_loss + disc_x_loss

disc_vars = [i for i in tf.trainable_variables() if 'discriminator' in i.name]
disc_update_ops = [i for i in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if 'discriminator' in i.name]
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=beta1)
    
with tf.control_dependencies(disc_update_ops):
    train_disc = optimizer.minimize(disc_loss, var_list=disc_vars)

saver = tf.train.Saver()

# Init session

In [7]:
# create or restore session

sess = tf.InteractiveSession()
try:
    print('trying to restore session')
    saver.restore(sess, model_save_path)
    print('restored session')
except:
    print('failed to restore session, creating a new one')
    tf.global_variables_initializer().run()

# write logs for tensorboard
writer = tf.summary.FileWriter(log_directory, sess.graph)

trying to restore session
INFO:tensorflow:Restoring parameters from /home/ec2-user/tf-checkpoints/vaegan-celeba/checkpoint.ckpt
failed to restore session, creating a new one


In [8]:
# collect data for tensorboard

disc_vae_out = tf.reduce_mean(tf.sigmoid(disc_vae_logits))
disc_x_out = tf.reduce_mean(tf.sigmoid(disc_x_logits))

tf.summary.scalar('discriminator_loss', disc_loss)
tf.summary.scalar('disc_vae_out', disc_vae_out)
tf.summary.scalar('disc_x_out', disc_x_out)
tf.summary.scalar('disc_x_loss', disc_x_loss)
tf.summary.scalar('disc_vae_loss', disc_vae_loss)

merged_summary = tf.summary.merge_all()

In [9]:
img_idx = 0

# Train

In [None]:
# write data to tensorboard log
def report():
    xfeed = training[:batch_size]
    rfeed = np.random.normal(size=(batch_size, img_size, img_size, 3), loc=0, scale=2.0)
    summary = merged_summary.eval(feed_dict={
        X: xfeed, 
        R: rfeed,
        disc_batch_trainable: False
    })
    writer.add_summary(summary, epoch)

In [None]:
import math
epochs = 10000
batches = int(float(training_set_size) / batch_size)
print('training over %s batches' % batches)
# number of iterations to train per epoch, so I can easily
# train, e.g. 2x disc for each encoder, decoder; or skip trianing

for epoch in range(epochs):
    print ('epoch %s ' % epoch, end='')
    rdraws = np.random.normal(size=(training_set_size, img_size, img_size, 3), loc=0.5, scale=1.0)
    rdraws = rdraws.astype('float32')
    
    # train discriminator
    for batch in range(batches):
        xfeed = training[batch*batch_size:(batch+1)*batch_size]
        rfeed = rdraws[batch*batch_size:(batch+1)*batch_size]
        sess.run(train_disc, feed_dict={
            X: xfeed,
            R: rfeed,
            disc_batch_trainable: True
        })
        print('.', end='')
    report()
      
    if (epoch % 1 == 0):
        print('saving session', flush=True)
        saver.save(sess, model_save_path)

training over 781 batches
epoch 0 .............................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................saving session
epoch 1 ..................................................................................................................................................................

epoch 10 .............................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................saving session
epoch 11 ..........................................................................................................................................................................................

epoch 20 .............................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................saving session
epoch 21 ..........................................................................................................................................................................................

epoch 30 .............................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................saving session
epoch 31 ..........................................................................................................................................................................................

epoch 40 .............................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................saving session
epoch 41 ..........................................................................................................................................................................................

epoch 50 .............................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................saving session
epoch 51 ..........................................................................................................................................................................................

epoch 60 .............................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................saving session
epoch 61 ..........................................................................................................................................................................................