In [1]:
import tensorflow as tf
import tensorflow.contrib.gan as tfgan
import tensorflow.layers as layers
from tensorflow.contrib.gan.python.namedtuples import GANTrainSteps
# tf.enable_eager_execution()

In [51]:
def get_generator_fn(batch_size):
    def generator_fn(input_image, mode):
        with tf.name_scope('generator'):
            #4*4
            dense_1 = layers.dense(inputs=input_image, units=batch_size*16)
            print(dense_1)
            batch_norm_1 = layers.batch_normalization(inputs=dense_1)
            print(batch_norm_1)
            reshape_1 = tf.reshape(batch_norm_1, shape=(batch_size, 4, 4, batch_size))
            print(reshape_1)
            relu_1 = tf.nn.relu(reshape_1)
            # 8*8
            conv_T_1 = layers.conv2d_transpose(inputs=relu_1, filters=64, kernel_size=(2, 2), strides=(2, 2), padding='same')
            batch_norm_2 = layers.batch_normalization(inputs=conv_T_1)
            relu_2 = tf.nn.relu(batch_norm_2)
            # 16*16
            conv_T_2 = layers.conv2d_transpose(inputs=relu_2, filters=32, kernel_size=(2, 2), strides=(2, 2), padding='same')
            batch_norm_3 = layers.batch_normalization(inputs=conv_T_2)
            relu_3 = tf.nn.relu(batch_norm_3)
            # 32*32
            conv_T_3 = layers.conv2d_transpose(inputs=relu_3, filters=16, kernel_size=(2, 2), strides=(2, 2), padding='same')
            batch_norm_4 = layers.batch_normalization(inputs=conv_T_3)
            relu_4 = tf.nn.relu(batch_norm_4)
            # 64*64
            conv_T_4 = layers.conv2d_transpose(
                inputs=relu_4, filters=3, kernel_size=(2, 2), strides=(2, 2), padding='same')
            tanh_1 = tf.nn.tanh(conv_T_4)
            print(tanh_1)
            return tanh_1
    return generator_fn


def discriminator_fn(image, noise):
    with tf.name_scope('Discriminator'):
        # 64 -> 32
        conv_1 = layers.conv2d(image, 64 , (2, 2), padding='same')
        lrelu_1 = tf.nn.leaky_relu(conv_1, alpha=0.2)
        # 32 -> 16
        conv_2 = layers.conv2d(lrelu_1, 64, (2, 2), padding='same')
        batch_norm_1 = layers.batch_normalization(inputs=conv_2)
        lrelu_2 = tf.nn.leaky_relu(batch_norm_1, alpha=0.2)
        # 16 -> 8
        conv_3 = layers.conv2d(lrelu_2, 64, (2, 2), padding='same')
        batch_norm_2 = layers.batch_normalization(inputs=conv_3)
        lrelu_3 = tf.nn.leaky_relu(batch_norm_2, alpha=0.2)
        # 8 -> 4
        conv_4 = layers.conv2d(lrelu_3, 64, (2,2), padding='same')
        batch_norm_3 = layers.batch_normalization(inputs=conv_4)
        lrelu_4 = tf.nn.leaky_relu(batch_norm_3, alpha=0.2)
        fc1 = layers.flatten(lrelu_4)
        fc2 = layers.dense(fc1, 1)
        return fc2

In [52]:
def get_run_config(check_point_dir, summary_steps, checkpoints_steps):
    run_config = tf.estimator.RunConfig(
        model_dir=check_point_dir,
        save_summary_steps=summary_steps,
        save_checkpoints_steps=checkpoints_steps)
    return run_config


In [53]:
check_point_dir = 'gs://gan-pipeline/checkpoints/fAnogan1'
batch_size = 32
image_path = 'gs://gan-pipeline/dataset/train/*'
eval_image_path = 'gs://gan-pipeline/dataset/test/*'
summary_steps = 100
checkpoints_steps = 100
max_training_steps = 10000
eval_steps = 1000

In [54]:
gan_estimator = tfgan.estimator.GANEstimator(
        model_dir=check_point_dir,
        generator_fn=get_generator_fn(batch_size),
        discriminator_fn=discriminator_fn,
        generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
        discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
        generator_optimizer=tf.train.AdamOptimizer(0.00003, 0.5),
        discriminator_optimizer=tf.train.AdamOptimizer(0.00001, 0.5),
        warm_start_from= check_point_dir,
        get_hooks_fn= tf.contrib.gan.get_sequential_train_hooks(train_steps=GANTrainSteps(2, 1)),
        config=get_run_config(check_point_dir, summary_steps, checkpoints_steps))

INFO:tensorflow:Using config: {'_model_dir': 'gs://gan-pipeline/checkpoints/fAnogan1', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 100, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f75801f3128>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}


In [55]:
def _get_predict_input_fn(batch_size = 64, noise_dims = 100):
    def predict_input_fn():
        noise = tf.random_normal([batch_size, noise_dims])
        return noise
    return predict_input_fn

In [58]:
images = gan_estimator.predict(input_fn =_get_predict_input_fn(64,100))

In [59]:
print(list(images)[0])

INFO:tensorflow:Calling model_fn.
Tensor("Generator/generator/dense/BiasAdd:0", shape=(16, 512), dtype=float32)
Tensor("Generator/generator/batch_normalization/batchnorm/add_1:0", shape=(16, 512), dtype=float32)


ValueError: Cannot reshape a tensor with 8192 elements to shape [32,4,4,32] (16384 elements) for 'Generator/generator/Reshape' (op: 'Reshape') with input shapes: [16,512], [4] and with input tensors computed as partial shapes: input[1] = [32,4,4,32].