<a href="https://colab.research.google.com/github/Ihoold/pixelgen/blob/master/notebooks/dcgan_multi.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import numpy as np
import tensorflow as tf
import json
import matplotlib.pyplot as plt
from tensorflow.python.estimator import estimator
from google.colab import drive, auth
import tensorflow_gan as tfgan




In [0]:
CHANNELS = 4
R, C = 4, 3

ALPHA = 0.2
BATCH_SIZE = 1024
LATENT_DIM = 128
ADD_NOISE_TO_EXAMPLE = True
EPOCHS = 30000
EVAL_EPOCHS = 300
G_LR = 0.0002
D_LR = 0.0001
KERNEL_SIZE = 4

MODEL_NAME = 'DCGAN_MULTI'
RUN_NAME = 'DCGAN_MULTI_10'

data_file = 'gs://tputestingmnist/datasets/dataset_multi.tfrecords'
MODEL_DIR = 'gs://tputestingmnist/{}/{}/'.format(MODEL_NAME, RUN_NAME)
GOOGLE_DRIVE_DIR = '/content/gdrive/My Drive/Programowanie/PixelGen/{}'.format(RUN_NAME)
TF_MASTER = 'grpc://{}'.format(os.environ['COLAB_TPU_ADDR'])

In [0]:
#################################### SETUP #####################################

def setup():
    drive.mount('/content/gdrive')
    auth.authenticate_user()


def upload_credentials():
    # Upload credentials to TPU.
    with tf.Session(TF_MASTER) as sess:    
        with open('/content/adc.json', 'r') as f:
            auth_info = json.load(f)
        tf.contrib.cloud.configure_gcs(sess, credentials=auth_info)

In [0]:
################################# DATA INPUT ###################################

def augment(image):
    image = tf.image.random_flip_left_right(image)

    image_rgb = image[:,:,:3]
    image_rgb = tf.image.random_hue(image_rgb, 0.1)
    image_rgb = tf.image.random_saturation(image_rgb, 0.8, 1.2)
    image_rgb = tf.image.random_brightness(image_rgb, 0.05)
    image_rgb = tf.image.random_contrast(image_rgb, 0.8, 1.2)
    
    return tf.concat([image_rgb, image[:,:,3:4]], axis=2)

def parser(serialized_example):
        """Parses a single example into image and label tensors."""
        features = tf.parse_single_example(
            serialized_example,
            features={
                'image': tf.FixedLenFeature([], tf.string),
                'label': tf.FixedLenFeature([], tf.int64) 
            })

        image = tf.decode_raw(features['image'], tf.uint8)
        image.set_shape([48 * 4 * 48 * 3 * 4])
        image = tf.reshape(image, [48 * 4, 48 * 3, 4])[:,:,:CHANNELS]
        # image = augment(image)
        # Normalize the values of the image from [0, 255] to [-1.0, 1.0]
        image = tf.cast(image, dtype=tf.float32) / 127.5 - 1
        
        if ADD_NOISE_TO_EXAMPLE:
            image += tf.random_normal(shape=tf.shape(image), mean=0.0, stddev=0.1, dtype=tf.float32)
        noise = tf.random_uniform([1, LATENT_DIM], -1, 1, dtype=tf.float32)
        return {'images': image, 'noise':noise}, features['label']

    
def make_input_fn(is_training=True):
    def input_fn(params):
        batch_size = params['batch_size']
        dataset = tf.data.TFRecordDataset(data_file, buffer_size=8*1024*1024)
        dataset = dataset.map(parser).cache().shuffle(batch_size)
        if is_training:
            dataset = dataset.repeat()
        dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(8)
        return dataset
    return input_fn


def noise_input_fn(params):  
    noise_dataset = tf.data.Dataset.from_tensors(tf.constant(np.random.uniform(-1, 1, (params['batch_size'], LATENT_DIM)), dtype=tf.float32))
    return {'noise': noise_dataset.make_one_shot_iterator().get_next()}, None

In [0]:
############################### DATA SAVEING ###################################
 
def images_to_zero_one(images):
        return np.clip(np.array(images) * 0.5 + 0.5, 0., 1.)


def save_imgs(epoch, images):
    if not os.path.exists(GOOGLE_DRIVE_DIR):
        os.mkdir(GOOGLE_DRIVE_DIR)

    # images = [images[:,:,i*4:(i+1)*4] for i in range(12)]

    # Rescale images to 0 - 1
    images = images_to_zero_one(images)
    fig, axs = plt.subplots(R, C, figsize=(20,20))

    for i in range(R):
        for j in range(C):
            axs[i,j].imshow(images[0])
            axs[i,j].axis('off')

    fig.savefig(os.path.join(GOOGLE_DRIVE_DIR, '{}.png'.format(epoch)))
    plt.close()

In [0]:
################################## MODEL #######################################

def _relu(x):
    return tf.nn.relu(x)


def _leaky_relu(x):
    return tf.nn.leaky_relu(x, alpha=ALPHA)


def _get_batch_norm(is_training, name):
    def batch_norm(x):
        return tf.layers.batch_normalization(x, momentum=0.8, epsilon=1e-5, 
                                             training=is_training, name=name)
    return batch_norm


def _dense(x, neurons, name, activation=None):
    return tf.layers.dense(x, neurons, name=name, activation=activation,
                           kernel_initializer=tf.truncated_normal_initializer(stddev=0.02))


def _conv2d(x, filters, kernel_size, stride, name, activation=None):
    return tf.layers.conv2d(x, filters, [kernel_size, kernel_size], 
                            strides=[stride, stride], activation=activation,
                            padding='same', name=name,
                            kernel_initializer=tf.truncated_normal_initializer(stddev=0.02))

def _deconv2d(x, filters, stride, name, activation=None):
    return tf.layers.conv2d_transpose(x, filters, [KERNEL_SIZE, KERNEL_SIZE],
                                      strides=[stride, stride], activation=activation,
                                      padding='same', name=name,
                                      kernel_initializer=tf.truncated_normal_initializer(stddev=0.02))


def convolution_block(x, filters, kernel_size, resize_factor, index, activation=_leaky_relu, activation_first=False, normalization=None):
    if activation and activation_first:
        x = activation(x)
    x = _conv2d(x, kernel_size=kernel_size, filters=filters, stride=resize_factor, activation=None, name='conv_{}'.format(index))
    if normalization:
        x = normalization(x)
    if activation and not activation_first:
        x = activation(x)
    return x


def deconvolution_block(x, filters, resize_factor, index, activation=_relu, normalization=None):
    x = _deconv2d(x, filters=filters, stride=resize_factor, activation=None, name='deconv_{}'.format(index))
    if normalization:
        x = normalization(x)
    if activation:
        x = activation(x)
    return x


class DCGAN:

    @staticmethod
    def discriminator(x, is_training=False, scope='Discriminator'):
        with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):            
            x = convolution_block(x, 64, kernel_size=KERNEL_SIZE, resize_factor=2, normalization=_get_batch_norm(is_training, 'bn10'), index='disc_10')
            x = convolution_block(x, 128, kernel_size=KERNEL_SIZE, resize_factor=2, normalization=_get_batch_norm(is_training, 'bn11'), index='disc_12')
            x = convolution_block(x, 256, kernel_size=KERNEL_SIZE, resize_factor=2, normalization=_get_batch_norm(is_training, 'bn12'), index='disc_13')
            x = convolution_block(x, 512, kernel_size=KERNEL_SIZE, resize_factor=2, normalization=_get_batch_norm(is_training, 'bn13'), index='disc_14')
            
            output = convolution_block(x, 1, kernel_size=3, resize_factor=3, activation=None, index='disc_out')
            return output, x
          
    @staticmethod
    def generator(latent_code, is_training=False, scope='Generator'):
        with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
            x = _dense(latent_code, 1024 * 3 * 3, activation=tf.nn.relu, name='g_dense')
            x = tf.reshape(x, [-1, 3, 3, 1024])
            
            x = deconvolution_block(x, 512, 2, index='gen_11', normalization=_get_batch_norm(is_training, 'bn21'))
            x = deconvolution_block(x, 256, 2, index='gen_12', normalization=_get_batch_norm(is_training, 'bn22'))
            x = deconvolution_block(x, 128, 2, index='gen_13', normalization=_get_batch_norm(is_training, 'bn23'))
            
            x = deconvolution_block(x, CHANNELS, 2, index='gen_14', activation=tf.tanh)

            return x

In [0]:
# # It's not exactly the norm, but taking mean instead of sum makes the losses more comparable
# def l2_norm(x):
#     return tf.reduce_mean(tf.square(x))

In [0]:
################################ MODEL FUN #####################################
def make_model_fn(model):

    def model_fn(features, labels, mode, params):
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        with tf.variable_scope('inputs'):
            noise = features['noise']
            
        generated_images = model.generator(noise, is_training)
        
        # PREDICT #
        if mode == tf.estimator.ModeKeys.PREDICT:
            predictions = {'generated_images': generated_images}
            return tf.contrib.tpu.TPUEstimatorSpec(mode=mode, predictions=predictions)
        
        with tf.variable_scope('inputs'):
            images = features['images']
            images = images[:, 0:48, 48:96, :] #tf.concat([images[:, r*48:(r+1)*48, c*48:(c+1)*48, :] for r in range(R) for c in range(C)], axis=3)
        
        # Labels
        d_on_data_labels = tf.ones([images.shape[0]])
        d_on_g_labels = tf.zeros([images.shape[0]])
        
        # Discriminator loss
        d_on_data_results, d_on_data_features = model.discriminator(images, is_training)
        d_on_data_logits = tf.squeeze(d_on_data_results)
       
        d_on_g_results, d_on_g_features = model.discriminator(generated_images, is_training)
        d_on_g_logits = tf.squeeze(d_on_g_results)
        
        with tf.variable_scope('losses'):
            d_loss = tfgan.losses.wargs.modified_discriminator_loss(
                discriminator_real_outputs=d_on_data_logits,
                discriminator_gen_outputs=d_on_g_logits,
                label_smoothing=0.2,
                reduction=tf.losses.Reduction.NONE,
            )
            d_loss_reduced = tf.reduce_mean(d_loss)

            # Generator loss
            g_loss = tfgan.losses.wargs.modified_generator_loss(
                discriminator_gen_outputs=d_on_g_logits,
                reduction=tf.losses.Reduction.NONE
            )
            # g_loss_feature_matching = l2_norm(tf.reduce_mean(d_on_g_features, axis=0) - tf.reduce_mean(d_on_data_features, axis=0))
            g_loss_reduced = tf.reduce_mean(g_loss) #+ g_loss_feature_matching)
            
        # TRAIN #
        if mode == tf.estimator.ModeKeys.TRAIN:
            
            with tf.variable_scope('optimizer'):
                d_optimizer = tf.train.AdamOptimizer(learning_rate=D_LR, beta1=0.5)
                d_optimizer = tf.contrib.tpu.CrossShardOptimizer(d_optimizer)
            
                g_optimizer = tf.train.AdamOptimizer(learning_rate=G_LR, beta1=0.5)
                g_optimizer = tf.contrib.tpu.CrossShardOptimizer(g_optimizer)
         
                with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                    d_step = d_optimizer.minimize(d_loss_reduced, var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                                                                     scope='Discriminator'))
                    g_step = g_optimizer.minimize(g_loss_reduced, var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                                                                     scope='Generator'))
                    increment_step = tf.assign_add(tf.train.get_or_create_global_step(), 1)
                    joint_op = tf.group([d_step, g_step, increment_step])

                    a = tf.contrib.tpu.TPUEstimatorSpec(mode=mode, loss=d_loss_reduced+g_loss_reduced, train_op=joint_op)
                    return a

        # EVAL #
        elif mode == tf.estimator.ModeKeys.EVAL:
            def _eval_metric_fn(d_loss, g_loss_1, d_real_labels, d_gen_labels, d_real_logits, d_gen_logits):
                real_pred = tf.math.round(tf.sigmoid(d_real_logits))
                real_acc = tf.metrics.accuracy(labels=d_real_labels, predictions=real_pred)

                gen_pred = tf.math.round(tf.sigmoid(d_gen_logits))
                gen_acc = tf.metrics.accuracy(labels=d_gen_labels, predictions=gen_pred)

                joint_acc = tf.metrics.accuracy(labels=tf.concat([d_real_labels, d_gen_labels], axis=0), 
                                                predictions=tf.concat([real_pred, gen_pred], axis=0))

                return {
                    'discriminator_loss': tf.metrics.mean(d_loss),
                    'generator_loss': tf.metrics.mean(g_loss_1),
                    'discriminator_real_accuracy': real_acc,
                    'discriminator_gen_accuracy': gen_acc,
                    'discriminator_joint_accuracy': joint_acc
                }

            return tf.contrib.tpu.TPUEstimatorSpec(mode=mode, loss=d_loss_reduced + g_loss_reduced,
                                                   eval_metrics=(_eval_metric_fn, [d_loss, g_loss, 
                                                                                   d_on_data_labels, d_on_g_labels,
                                                                                   d_on_data_logits, d_on_g_logits]))
    return model_fn

In [0]:
################################ ESTIMATORS ####################################

def make_estimators(model, only_cpu=False):
    model_fn = make_model_fn(model)
    
    config = tf.contrib.tpu.RunConfig(
        master=TF_MASTER,
        save_checkpoints_steps=EVAL_EPOCHS,
        save_checkpoints_secs=None,
        save_summary_steps=EVAL_EPOCHS,
        model_dir=MODEL_DIR,
        keep_checkpoint_max=3,
        tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=EVAL_EPOCHS))

    if not only_cpu:
        # TPU-based estimator used for TRAIN and EVAL
        est = tf.contrib.tpu.TPUEstimator(
            model_fn=model_fn,
            use_tpu=True,
            config=config,
            train_batch_size=BATCH_SIZE,
            eval_batch_size=BATCH_SIZE)
    else:
        est = None

    # CPU-based estimator used for PREDICT (generating images)
    cpu_est = tf.contrib.tpu.TPUEstimator(
        model_fn=model_fn,
        use_tpu=False,
        config=config,
        predict_batch_size=1)
    
    return est, cpu_est

In [0]:
################################# TRAINING #####################################

def train(est, cpu_est):
    current_step = estimator._load_global_step_from_checkpoint_dir(MODEL_DIR)
    print('Starting training')

    while current_step < EPOCHS:
        next_checkpoint = int(min(current_step + EVAL_EPOCHS, EPOCHS))
        est.train(input_fn=make_input_fn(), max_steps=next_checkpoint)
        current_step = next_checkpoint
        print('Finished training step %d' % current_step)

        # Evaluation
        metrics = est.evaluate(input_fn=make_input_fn(False), steps=1)
        print('Finished evaluating')
        print(metrics)

        # Render some generated images
        generated_iter = cpu_est.predict(input_fn=noise_input_fn)
        images = [p['generated_images'] for p in generated_iter]
        save_imgs(str(current_step), images)
        print('Finished generating images')

In [0]:
def do_experiment():
    setup()
    upload_credentials()
    model = DCGAN()
    est, cpu_est = make_estimators(model)
    train(est, cpu_est)

In [0]:
tf.logging.set_verbosity(tf.logging.INFO)

try:
    do_experiment()
except Exception as e:
    print (e)
    pass

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

INFO:tensorflow:Using config: {'_model_dir': 'gs://tputestingmnist/DCGAN_MULTI/DCGAN_MULTI_10/', '_tf_random_seed': None, '_save_summary_steps': 300, '_save_checkpoints_steps': 300, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 3, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '