In [2]:
from unittest import mock
import os
import numpy as np
import tensorflow as tf
import json
from tensorflow.python.estimator import estimator
from tensorflow.python.training.basic_session_run_hooks import CheckpointSaverHook, meta_graph
from tensorflow.python.platform import tf_logging as logging

In [None]:
RUN_NAME = 'DCGAN_0'

Setting things up.

In [None]:
from google.colab import drive, auth

drive.mount('/content/gdrive')
# set paths
ROOT = %pwd
MODEL_DIR = 'gs://tputestingmnist/{}/'.format(RUN_NAME)
LOG_DIR = MODEL_DIR
GOOGLE_DRIVE_DIR = '/content/gdrive/My Drive/Programowanie/PixelGen/{}'.format(RUN_NAME)
TF_MASTER = 'grpc://{}'.format(os.environ['COLAB_TPU_ADDR'])

auth.authenticate_user()
  
# Upload credentials to TPU.
with tf.Session(TF_MASTER) as sess:    
    with open('/content/adc.json', 'r') as f:
        auth_info = json.load(f)
    tf.contrib.cloud.configure_gcs(sess, credentials=auth_info)

# Configuration
CHANNELS = 4
R, C = 4, 4
EXAMPLES = R * C
LATENT_DIM = 128
BATCH_SIZE = 1024
EPOCHS = 150000
EVAL_EPOCHS = 500

PREFIX = RUN_NAME

Feeding data to the network

In [3]:
ADD_NOISE_TO_EXAMPLE = True
data_file = 'gs://tputestingmnist/characters_conditional_7.tfrecords'


def parser(serialized_example):
    """Parses a single Example into image and label tensors."""
    features = tf.parse_single_example(
        serialized_example,
        features={
            'image_raw': tf.FixedLenFeature([], tf.string),
            'image_transformed': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([], tf.int64)   # label is unused
        })

    result_image = tf.decode_raw(features['image_raw'], tf.uint8)
    result_image.set_shape([48 * 48 * 4])
    result_image = tf.reshape(result_image, [48, 48, 4])[:,:,:CHANNELS]
    # Normalize the values of the image from [0, 255] to [-1.0, 1.0]
    result_image = tf.cast(result_image, dtype=tf.float32) / 127.5 - 1

    input_image = tf.decode_raw(features['image_transformed'], tf.uint8)
    input_image.set_shape([48 * 48 * 4])
    input_image = tf.reshape(input_image, [48, 48, 4])[:,:,:CHANNELS]
    # Normalize the values of the image from [0, 255] to [-1.0, 1.0]
    input_image = tf.cast(input_image, dtype=tf.float32) / 127.5 - 1

    return input_image, result_image


def make_input_fn(is_training=True):
    def input_fn(params):
        batch_size = params['batch_size']
        dataset = tf.data.TFRecordDataset(data_file).map(parser).cache().shuffle(batch_size)
        if is_training:
            dataset = dataset.repeat()
        input_images, result_images = dataset.prefetch(batch_size).batch(batch_size, drop_remainder=True).make_one_shot_iterator().get_next()

        if ADD_NOISE_TO_EXAMPLE:
            input_images += tf.random_normal(shape=tf.shape(input_images), mean=0.0, stddev=0.1, dtype=tf.float32)

        features = {
            'image_input': input_images,
            'image_result': result_images,
            #         'random_noise': tf.random_uniform([params['batch_size'], LATENT_DIM], -1, 1, dtype=tf.float32)
        }
        return features, None
    return input_fn


def predict_input_fn(params):
    batch_size = params['batch_size']
    dataset = tf.data.TFRecordDataset(data_file).map(parser).cache().shuffle(batch_size)
    input_images, _ = dataset.prefetch(batch_size).batch(batch_size, drop_remainder=True).make_one_shot_iterator().get_next()
    return {'image_input': input_images}, None

def images_to_zero_one(images):
    return np.clip(np.array(images) * 0.5 + 0.5, 0., 1.)

Creating sample images

In [None]:
import matplotlib.pyplot as plt

def save_imgs(epoch, images):
    # Rescale images to 0 - 1
    images = images_to_zero_one(images)
    fig, axs = plt.subplots(R, C)

    for i in range(R):
        for j in range(C):
            axs[i,j].imshow(images[C*i + j])
            axs[i,j].axis('off')
          
    fig.savefig(os.path.join(GOOGLE_DRIVE_DIR, '{}.png'.format(epoch)))
    plt.close()

Architecture

In [4]:
KERNEL_SIZE = 4


def _leaky_relu(x):
    return tf.nn.leaky_relu(x, alpha=0.2)


def _relu(x):
    return tf.nn.relu(x)


def _batch_norm(x, is_training, name):
    return tf.layers.batch_normalization(x, momentum=0.8, epsilon=1e-5, training=is_training, name=name)


def _dense(x, neurons, name, activation=None):
    return tf.layers.dense(x, neurons, kernel_initializer=tf.truncated_normal_initializer(stddev=0.02), name=name,
                           activation=activation)


def _conv2d(x, filters, kernel_size, stride, name):
    return tf.layers.conv2d(
        x, filters, [kernel_size, kernel_size],
        strides=[stride, stride], padding='same',
        kernel_initializer=tf.truncated_normal_initializer(stddev=0.02), name=name)


def _deconv2d(x, filters, kernel_size, stride, name):
    return tf.layers.conv2d_transpose(
        x, filters, [kernel_size, kernel_size],
        strides=[stride, stride], padding='same',
        kernel_initializer=tf.truncated_normal_initializer(stddev=0.02), name=name)


def _dropout(x, prob, name):
    return tf.nn.dropout(x, keep_prob=prob, name=name)


def convolution_block(x, filters, resize_factor, is_training, index, activation=_leaky_relu, dropout=False, batch_norm=False):
    x = _conv2d(x, filters=filters, kernel_size=KERNEL_SIZE, stride=resize_factor, name='conv{}'.format(index))
    if batch_norm:
        x = _batch_norm(x, is_training, name='bnc{}'.format(index))
    if dropout:
        x = _dropout(x, prob=0.5, name='drop{}'.format(index))
    x = activation(x)
    return x


def deconvolution_block(x, filters, resize_factor, is_training, index, activation=_relu, dropout=False, batch_norm=False):
    x = _deconv2d(x, filters=filters, kernel_size=KERNEL_SIZE, stride=resize_factor, name='deconv{}'.format(index))
    if batch_norm:
        x = _batch_norm(x, is_training, name='bnc{}'.format(index))
    if dropout:
        x = _dropout(x, prob=0.5, name='drop{}'.format(index))
    x = activation(x)
    return x


class Dcgan:
    @staticmethod
    def discriminator(x, is_training=True, scope='Discriminator'):
        with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
            x = convolution_block(x, 64, 2, is_training, 1)
            x = convolution_block(x, 128, 2, is_training, 2)
            x = convolution_block(x, 256, 2, is_training, 3)
            x = convolution_block(x, 512, 1, is_training, 4)

            x = tf.layers.Flatten()(x)
            x = _dense(x, neurons=1, name='d_dense')

            return x

    @staticmethod
    def generator(image, is_training=True, scope='Generator'):
        with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
            # Encode
            x = convolution_block(image, 64, 2, is_training, 11, dropout=False, batch_norm=False)
            x = convolution_block(x, 128, 2, is_training, 12, dropout=False)
            x = convolution_block(x, 256, 2, is_training, 13, dropout=False)
            x = convolution_block(x, 512, 2, is_training, 14, dropout=False)
            #             x = convolution_block(x, 512, 2, is_training, 15, dropout=False)
            #             x = convolution_block(x, 512, 2, is_training, 16, dropout=False)


            # Decode
            #             x = deconvolution_block(x, 512, 2, is_training, 21)
            #             x = deconvolution_block(x, 512, 2, is_training, 22)
            x = deconvolution_block(x, 512, 2, is_training, 23)
            x = deconvolution_block(x, 256, 2, is_training, 24, dropout=False)
            x = deconvolution_block(x, 128, 2, is_training, 25, dropout=False)
            x = deconvolution_block(x, 64, 2, is_training, 26, dropout=False)

            x = _conv2d(x, filters=CHANNELS, kernel_size=KERNEL_SIZE, stride=1, name='final_conv')
            x = tf.tanh(x)

            return tf.concat((image, x), axis=1)
        
model = Dcgan()

Model function

In [None]:
def model_fn(features, labels, mode, params):
    # PREDICT #
    if mode == tf.estimator.ModeKeys.PREDICT:
        random_noise = features['random_noise']
        predictions = {'generated_images': model.generator(random_noise, is_training=False)}

        return tf.contrib.tpu.TPUEstimatorSpec(mode=mode, predictions=predictions)

    batch_size = params['batch_size']   # pylint: disable=unused-variable
    real_images = features['real_images']
    random_noise = features['random_noise']
    generated_images = model.generator(random_noise, is_training=(mode == tf.estimator.ModeKeys.TRAIN))

    # Discriminator loss
    d_on_data_logits = tf.squeeze(model.discriminator(real_images))
    d_on_data_labels = tf.ones_like(d_on_data_logits)

    d_on_g_logits = tf.squeeze(model.discriminator(generated_images))
    d_on_g_labels = tf.zeros_like(d_on_g_logits)

    d_loss = tf.contrib.gan.losses.wargs.modified_discriminator_loss(
        discriminator_real_outputs=d_on_data_logits,
        discriminator_gen_outputs=d_on_g_logits,
        reduction=tf.losses.Reduction.NONE,
        label_smoothing=0.2
    )

    # Generator loss
    g_loss = tf.contrib.gan.losses.wargs.modified_generator_loss(
        discriminator_gen_outputs=d_on_g_logits,
        reduction=tf.losses.Reduction.NONE
    )

    # TRAIN #
    if mode == tf.estimator.ModeKeys.TRAIN:
        d_loss = tf.reduce_mean(d_loss)
        g_loss = tf.reduce_mean(g_loss)
        d_optimizer = tf.train.AdamOptimizer(learning_rate=0.0002, beta1=0.5)
        g_optimizer = tf.train.AdamOptimizer(learning_rate=0.0002, beta1=0.5)

        d_optimizer = tf.contrib.tpu.CrossShardOptimizer(d_optimizer)
        g_optimizer = tf.contrib.tpu.CrossShardOptimizer(g_optimizer)

        with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
            d_step = d_optimizer.minimize(d_loss, var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                                                             scope='Discriminator'))
            g_step = g_optimizer.minimize(g_loss, var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                                                             scope='Generator'))

            increment_step = tf.assign_add(tf.train.get_or_create_global_step(), 1)
            joint_op = tf.group([d_step, g_step, increment_step])

            return tf.contrib.tpu.TPUEstimatorSpec(mode=mode, loss=g_loss, train_op=joint_op)

    # EVAL #
    elif mode == tf.estimator.ModeKeys.EVAL:
        def _eval_metric_fn(d_loss, g_loss, d_real_labels, d_gen_lanels, d_real_logits, d_gen_logits):
            return {
                'discriminator_loss': tf.metrics.mean(d_loss),
                'generator_loss': tf.metrics.mean(g_loss),
                'discriminator_real_accuracy': tf.metrics.accuracy(labels=d_real_labels, predictions=tf.math.round(tf.sigmoid(d_real_logits))),
                'discriminator_gen_accuracy': tf.metrics.accuracy(labels=d_gen_lanels, predictions=tf.math.round(tf.sigmoid(d_gen_logits)))
            }

        return tf.contrib.tpu.TPUEstimatorSpec(mode=mode, loss=tf.reduce_mean(g_loss),
                                               eval_metrics=(_eval_metric_fn, [d_loss, g_loss, d_on_data_labels,
                                                                               d_on_g_labels, d_on_data_logits, d_on_g_logits]))

Training!

In [None]:
# tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(TF_MASTER)

config = tf.contrib.tpu.RunConfig(
    master=TF_MASTER,
    save_checkpoints_steps=EVAL_EPOCHS,
    save_checkpoints_secs=None,
    save_summary_steps=EVAL_EPOCHS,
    model_dir=MODEL_DIR,
    keep_checkpoint_max=3,
    tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1000))

# TPU-based estimator used for TRAIN and EVAL
est = tf.contrib.tpu.TPUEstimator(
    model_fn=model_fn,
    use_tpu=True,
    config=config,
    train_batch_size=BATCH_SIZE,
    eval_batch_size=BATCH_SIZE)

# CPU-based estimator used for PREDICT (generating images)
cpu_est = tf.contrib.tpu.TPUEstimator(
    model_fn=model_fn,
    use_tpu=False,
    config=config,
    predict_batch_size=EXAMPLES)

current_step = estimator._load_global_step_from_checkpoint_dir(MODEL_DIR)
tf.logging.info('Starting training')

while current_step < EPOCHS:
    next_checkpoint = int(min(current_step + EVAL_EPOCHS, EPOCHS))
    est.train(input_fn=make_input_fn(), max_steps=next_checkpoint)
    current_step = next_checkpoint
    tf.logging.info('Finished training step %d' % current_step)

    # Evaluation
    metrics = est.evaluate(input_fn=make_input_fn(False), steps=1)
    tf.logging.info('Finished evaluating')
    tf.logging.info(metrics)

    # Render some generated images
    generated_iter = cpu_est.predict(input_fn=noise_input_fn)
    images = [p['generated_images'] for p in generated_iter]
    save_imgs(current_step, images)
    tf.logging.info('Finished generating images')