In [None]:
import warnings # We'll use this to suppress warnings caused by TensorFlow
warnings.simplefilter(action='ignore', category=FutureWarning)

import numpy as np # linear algebra
print(np.__version__)
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import matplotlib.pyplot as plt # generating plot

import tensorflow as tf # modeling/training
from tensorflow.keras.preprocessing.image import ImageDataGenerator 
import time # Used for epoch timing

import imageio # GIF generation
import glob # GIF generation
import PIL # GIF generation

import os
import h5py

In [None]:
tf.config.list_physical_devices('GPU')

In [None]:
dataset_dir = "textures_v2_brown500_with_valid.h5"
f = h5py.File(dataset_dir, 'r')
list(f.keys())

In [None]:
full_dataset_yt = np.concatenate((f['yt'][:]/255.0, f['xt'][:]/255.0), axis=3)
full_dataset_yv = np.concatenate(((f['yv'][:] - 127.5) / 127.5, (f['xv'][:] - 127.5) / 127.5), axis=3)
full_dataset = np.concatenate((full_dataset_yt, full_dataset_yv))
full_dataset.shape

In [None]:
# Image Generator
imgen = ImageDataGenerator(horizontal_flip = True, vertical_flip=True, rotation_range=360, fill_mode="reflect")

In [None]:
BATCH_SIZE = 4
img_it = imgen.flow(tf.cast(full_dataset, dtype=tf.float32), None, batch_size=BATCH_SIZE)

In [None]:
examples = img_it.next()
for x in range(len(examples)):
    plt.subplot(4,4,x*2+1)
    plt.imshow((examples[x][:,:,0:3]*127.5+127.5)/255.0)
    plt.axis('off')
    plt.subplot(4,4,x*2+2)
    plt.imshow(examples[x][:,:,3], cmap='gray')
    plt.axis('off')

In [None]:
from tensorflow.keras import layers
def make_hm_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(4*4*512, input_shape=(1000,)))
    model.add(layers.BatchNormalization())
    model.add(layers.Reshape((4, 4, 512)))
    
    assert model.output_shape == (None, 4, 4, 512)

    div = [2,2,4,4,8,8,8]
    div = [512/elem for elem in div]

    for n in div:
        model.add(layers.Conv2D(n, 5, strides=1, padding='same'))
        model.add(layers.BatchNormalization())
        model.add(layers.LeakyReLU(0.2))
        model.add(layers.UpSampling2D())

    model.add(layers.Conv2D(1, 5, strides=1, padding='same', activation='sigmoid'))
    assert model.output_shape == (None, 512, 512, 1)

    return model

def make_p2p_generator_model():
    i = layers.Input(shape=[512, 512, 1])
    
    conv1 = layers.Conv2D(64, 3, strides=2, padding='same')(i)
    conv1 = layers.BatchNormalization()(conv1)
    x = layers.LeakyReLU(0.01)(conv1)
    
    conv2 = layers.Conv2D(64*2, 3, strides=2, padding='same')(x)
    conv2 = layers.BatchNormalization()(conv2)
    x = layers.LeakyReLU(0.01)(conv2)
    
    conv3 = layers.Conv2D(64*4, 3, strides=2, padding='same')(x)
    conv3 = layers.BatchNormalization()(conv3)
    x = layers.LeakyReLU(0.01)(conv3)
    
    conv4 = layers.Conv2D(64*8, 3, strides=2, padding='same')(x)
    conv4 = layers.BatchNormalization()(conv4)
    x = layers.LeakyReLU(0.01)(conv4)
    
    conv5 = layers.Conv2D(64*8, 3, strides=2, padding='same')(x)
    conv5 = layers.BatchNormalization()(conv5)
    x = layers.LeakyReLU(0.01)(conv5)
    
    conv6 = layers.Conv2D(64*8, 3, strides=2, padding='same')(x)
    conv6 = layers.BatchNormalization()(conv6)
    x = layers.LeakyReLU(0.01)(conv6)
    
    conv7 = layers.Conv2D(64*8, 3, strides=2, padding='same')(x)
    conv7 = layers.BatchNormalization()(conv7)
    x = layers.LeakyReLU(0.01)(conv7)
    
    conv8 = layers.Conv2D(64*8, 3, strides=2, padding='same')(x)
    conv8 = layers.BatchNormalization()(conv8)
    x = layers.LeakyReLU(0.01)(conv8)
    
    conv9 = layers.Conv2D(64*8, 2, strides=1, padding='valid')(x)
    conv9 = layers.BatchNormalization()(conv9)
    x = layers.LeakyReLU(0.01)(conv9)
    
    deconv1 = layers.Conv2DTranspose(64*8, 2, strides=1)(x)
    deconv1 = layers.BatchNormalization()(deconv1)
    
    x = layers.concatenate([deconv1, conv8])
    x = layers.LeakyReLU(0.01)(x)
    
    dconv2 = layers.UpSampling2D(interpolation='bilinear')(x)
    dconv2 = layers.Conv2D(64*8, 3, strides=1, padding='same')(dconv2)
    dconv2 = layers.BatchNormalization()(dconv2)
    
    x = layers.concatenate([dconv2, conv7])
    x = layers.LeakyReLU(0.01)(x)
    
    dconv3 = layers.UpSampling2D(interpolation='bilinear')(x)
    dconv3 = layers.Conv2D(64*8, 3, strides=1, padding='same')(dconv3)
    dconv3 = layers.BatchNormalization()(dconv3)

    x = layers.concatenate([dconv3, conv6])
    x = layers.LeakyReLU(0.01)(x)
    
    dconv4 = layers.UpSampling2D(interpolation='bilinear')(x)
    dconv4 = layers.Conv2D(64*8, 3, strides=1, padding='same')(dconv4)
    dconv4 = layers.BatchNormalization()(dconv4)

    x = layers.concatenate([dconv4, conv5])
    x = layers.LeakyReLU(0.01)(x)

    dconv5 = layers.UpSampling2D(interpolation='bilinear')(x)
    dconv5 = layers.Conv2D(64*8, 3, strides=1, padding='same')(dconv5)
    dconv5 = layers.BatchNormalization()(dconv5)

    x = layers.concatenate([dconv5, conv4])
    x = layers.LeakyReLU(0.01)(x)
    
    dconv6 = layers.UpSampling2D(interpolation='bilinear')(x)
    dconv6 = layers.Conv2D(64*4, 3, strides=1, padding='same')(dconv6)
    dconv6 = layers.BatchNormalization()(dconv6)

    x = layers.concatenate([dconv6, conv3])
    x = layers.LeakyReLU(0.01)(x)
    
    dconv7 = layers.UpSampling2D(interpolation='bilinear')(x)
    dconv7 = layers.Conv2D(64*2, 3, strides=1, padding='same')(dconv7)
    dconv7 = layers.BatchNormalization()(dconv7)

    x = layers.concatenate([dconv7, conv2])
    x = layers.LeakyReLU(0.01)(x)
    
    dconv8 = layers.UpSampling2D(interpolation='bilinear')(x)
    dconv8 = layers.Conv2D(64, 3, strides=1, padding='same')(dconv8)
    dconv8 = layers.BatchNormalization()(dconv8)

    x = layers.concatenate([dconv8, conv1])
    x = layers.LeakyReLU(0.01)(x)
    
    dconv9 = layers.Conv2DTranspose(3, 2, strides=2)(x)
    last_layer = layers.Activation(activation='tanh')(dconv9)
    
    return tf.keras.Model(i, last_layer)

In [None]:
def make_p2p_generator_model_with_dropout():
    i = layers.Input(shape=[512, 512, 1])
    
    conv1 = layers.Conv2D(64, 3, strides=2, padding='same')(i)
    conv1 = layers.BatchNormalization()(conv1)
    x = layers.LeakyReLU(0.01)(conv1)
    
    conv2 = layers.Conv2D(64*2, 3, strides=2, padding='same')(x)
    conv2 = layers.BatchNormalization()(conv2)
    x = layers.LeakyReLU(0.01)(conv2)
    
    conv3 = layers.Conv2D(64*4, 3, strides=2, padding='same')(x)
    conv3 = layers.BatchNormalization()(conv3)
    x = layers.LeakyReLU(0.01)(conv3)
    
    conv4 = layers.Conv2D(64*8, 3, strides=2, padding='same')(x)
    conv4 = layers.BatchNormalization()(conv4)
    x = layers.LeakyReLU(0.01)(conv4)
    
    conv5 = layers.Conv2D(64*8, 3, strides=2, padding='same')(x)
    conv5 = layers.BatchNormalization()(conv5)
    x = layers.LeakyReLU(0.01)(conv5)
    
    conv6 = layers.Conv2D(64*8, 3, strides=2, padding='same')(x)
    conv6 = layers.BatchNormalization()(conv6)
    x = layers.LeakyReLU(0.01)(conv6)
    
    conv7 = layers.Conv2D(64*8, 3, strides=2, padding='same')(x)
    conv7 = layers.BatchNormalization()(conv7)
    x = layers.LeakyReLU(0.01)(conv7)
    
    conv8 = layers.Conv2D(64*8, 3, strides=2, padding='same')(x)
    conv8 = layers.BatchNormalization()(conv8)
    x = layers.LeakyReLU(0.01)(conv8)
    
    conv9 = layers.Conv2D(64*8, 2, strides=1, padding='valid')(x)
    conv9 = layers.BatchNormalization()(conv9)
    x = layers.LeakyReLU(0.01)(conv9)
    
    deconv1 = layers.Conv2DTranspose(64*8, 2, strides=1)(x)
    deconv1 = layers.BatchNormalization()(deconv1)
    deconv1 = layers.Dropout(0.5)(deconv1)
    
    x = layers.concatenate([deconv1, conv8])
    x = layers.LeakyReLU(0.01)(x)
    
    dconv2 = layers.Conv2DTranspose(64*8, 2, strides=2)(x)
    dconv2 = layers.BatchNormalization()(dconv2)
    dconv2 = layers.Dropout(0.5)(dconv2)
    
    x = layers.concatenate([dconv2, conv7])
    x = layers.LeakyReLU(0.01)(x)
    
    dconv3 = layers.Conv2DTranspose(64*8, 2, strides=2)(x)
    dconv3 = layers.BatchNormalization()(dconv3)
    dconv3 = layers.Dropout(0.5)(dconv3)

    x = layers.concatenate([dconv3, conv6])
    x = layers.LeakyReLU(0.01)(x)
    
    dconv4 = layers.Conv2DTranspose(64*8, 2, strides=2)(x)
    dconv4 = layers.BatchNormalization()(dconv4)

    x = layers.concatenate([dconv4, conv5])
    x = layers.LeakyReLU(0.01)(x)

    dconv5 = layers.Conv2DTranspose(64*8, 2, strides=2)(x)
    dconv5 = layers.BatchNormalization()(dconv5)

    x = layers.concatenate([dconv5, conv4])
    x = layers.LeakyReLU(0.01)(x)
    
    dconv6 = layers.Conv2DTranspose(64*4, 2, strides=2)(x)
    dconv6 = layers.BatchNormalization()(dconv6)

    x = layers.concatenate([dconv6, conv3])
    x = layers.LeakyReLU(0.01)(x)
    
    dconv7 = layers.Conv2DTranspose(64*2, 2, strides=2)(x)
    dconv7 = layers.BatchNormalization()(dconv7)

    x = layers.concatenate([dconv7, conv2])
    x = layers.LeakyReLU(0.01)(x)
    
    dconv8 = layers.Conv2DTranspose(64, 2, strides=2)(x)
    dconv8 = layers.BatchNormalization()(dconv8)

    x = layers.concatenate([dconv8, conv1])
    x = layers.LeakyReLU(0.01)(x)
    
    dconv9 = layers.Conv2DTranspose(3, 2, strides=2)(x)
    last_layer = layers.Activation(activation='tanh')(dconv9)
    
    return tf.keras.Model(i, last_layer)

In [None]:
generator = make_hm_generator_model()
p2p_generator = make_p2p_generator_model_with_dropout()
noise_image = tf.random.normal([1,1000,])
generated_image = generator(noise_image, training=False)
plt.imshow(generated_image[0]*255.0, cmap='gray')
generator.summary()
p2p_generator.summary()

In [None]:
def make_hm_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.InputLayer(input_shape=[512, 512, 1]))
    
    div = [8,4,4,4,2,2,2]
    div = [512/elem for elem in div]
    
    for n in div:
        model.add(layers.Conv2D(n, 5, strides=1, padding='same'))
        model.add(layers.LeakyReLU(0.2))
        model.add(layers.MaxPool2D())
    
    model.add(layers.Conv2D(1, 5, padding='same'))
    
    reduction_factor = 512 // (2**len(div))
    model.add(layers.AveragePooling2D(pool_size=(reduction_factor, reduction_factor)))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

def make_p2p_discriminator_model():  
    input_a = layers.Input(shape=[512, 512, 1])
    input_b = layers.Input(shape=[512, 512, 3])
    
    inputs = layers.concatenate([input_a, input_b])
    x = inputs
    
    mul_factor = [1,2,4,8]
    
    for m in mul_factor:
        x = layers.Conv2D(64*m, 3, strides=2, padding='same')(x)
        x = layers.LeakyReLU(0.01)(x)
    
    out = layers.Conv2D(1, 3, strides=2, padding='same')(x)

    return tf.keras.Model([input_a, input_b], out)

In [None]:
discriminator = make_hm_discriminator_model()
discriminator.summary()
p2p_discriminator = make_p2p_discriminator_model()
p2p_discriminator.summary()
print(discriminator(generated_image))

In [None]:
loss_obj = tf.keras.losses.MeanSquaredError()
abs_loss = tf.keras.losses.MeanAbsoluteError()

def discriminator_loss(real_output, fake_output):
    real_loss = loss_obj(tf.ones_like(real_output), real_output)
    fake_loss = loss_obj(tf.zeros_like(fake_output), fake_output)
    return real_loss + fake_loss

def generator_loss(fake_output):
    return loss_obj(tf.ones_like(fake_output), fake_output)

def generator_p2p_loss(real_output, gen_output, fake_output):
    x = loss_obj(tf.ones_like(fake_output), fake_output)
    y = abs_loss(real_output, gen_output)
    return x + 100*y

In [None]:
gen_optimizer = tf.keras.optimizers.RMSprop(1e-4)
disc_optimizer = tf.keras.optimizers.RMSprop(1e-4)
p2p_gen_optimizer = tf.keras.optimizers.RMSprop(1e-4)
p2p_disc_optimizer = tf.keras.optimizers.RMSprop(1e-4)

checkpoint_dir = 'gan_heightmaps'
hm_checkpoint_prefix = os.path.join(checkpoint_dir, "hm_ckpt")
hm_checkpoint = tf.train.Checkpoint(generator_optimizer=gen_optimizer,
                                 discriminator_optimizer=disc_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)
p2p_checkpoint_prefix = os.path.join(checkpoint_dir, "p2p_ckpt")
p2p_checkpoint = tf.train.Checkpoint(generator_optimizer=p2p_gen_optimizer,
                                 discriminator_optimizer=p2p_disc_optimizer,
                                 generator=p2p_generator,
                                 discriminator=p2p_discriminator)

In [None]:
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, 1000])
    
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)
        
        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)
        
        disc_loss = discriminator_loss(real_output, fake_output)
        gen_loss = generator_loss(fake_output)
        
    disc_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    gen_gradients = gen_tape.gradient(gen_loss, generator.trainable_variables)
    
    disc_optimizer.apply_gradients(zip(disc_gradients, discriminator.trainable_variables))
    gen_optimizer.apply_gradients(zip(gen_gradients, generator.trainable_variables))
    
    return gen_loss, disc_loss

@tf.function
def p2p_train_step(real_heightmaps, real_textures):

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_textures = p2p_generator(real_heightmaps, training=True)
        
        real_output = p2p_discriminator([real_heightmaps, real_textures], training=True)
        fake_output = p2p_discriminator([real_heightmaps, generated_textures], training=True)
        
        disc_loss = discriminator_loss(real_output, fake_output)
        gen_loss = generator_p2p_loss(real_textures, generated_textures, fake_output)
        
    disc_gradients = disc_tape.gradient(disc_loss, p2p_discriminator.trainable_variables)
    gen_gradients = gen_tape.gradient(gen_loss, p2p_generator.trainable_variables)
    
    p2p_disc_optimizer.apply_gradients(zip(disc_gradients, p2p_discriminator.trainable_variables))
    p2p_gen_optimizer.apply_gradients(zip(gen_gradients, p2p_generator.trainable_variables))
    
    return gen_loss, disc_loss

In [None]:
def show_and_generate_images(model, p2p_model, epoch, test_output):
    predictions = model(test_output, training=False)
    p2p_predictions = p2p_model(predictions, training=False)  
        
    plt.figure(figsize=(10,10))
    for i in range(len(test_output)):
        plt.subplot(4,4,i*2+1)
        plt.imshow(predictions[i], cmap='gray')
        plt.axis('off')
        plt.subplot(4,4,i*2+2)
        plt.imshow((p2p_predictions[i]*127.5+127.5)/255.0)
        plt.axis('off')
        
    plt.savefig('gan_heightmaps_hm_image_at_epoch_{:04d}.png'.format(epoch))
    _=plt.show()

In [None]:
def train(dataset, epochs):
    print('Beginning to train...')
    
    history = pd.DataFrame(['gen_loss', 'disc_loss', 'p2p_gen_loss', 'p2p_disc_loss'])
    for epoch in range(epochs):
        start = time.time()
        epoch_gen_loss = tf.keras.metrics.Mean()
        epoch_disc_loss = tf.keras.metrics.Mean()
        epoch_p2p_gen_loss = tf.keras.metrics.Mean()
        epoch_p2p_disc_loss = tf.keras.metrics.Mean()
        for i in range(len(f['xt']) // BATCH_SIZE):
            images = dataset.next()
            real_heightmaps = images[:,:,:,3]
            real_textures = images[:,:,:,0:3]
            
            gen_loss, disc_loss = train_step(real_heightmaps)
            p2p_gen_loss, p2p_disc_lodd = p2p_train_step(real_heightmaps, real_textures)
            epoch_gen_loss.update_state(gen_loss)
            epoch_disc_loss.update_state(disc_loss)
            epoch_p2p_gen_loss.update_state(p2p_gen_loss)
            epoch_p2p_disc_loss.update_state(p2p_disc_lodd)

        show_and_generate_images(generator, p2p_generator, epoch + 1, seed)
        stats = 'Epoch {0} took {1} seconds. Gen_loss: {2:0.3f}, Disc_loss: {3:0.3f}, p2p_gen_loss: {4:0.3f}, p2p_disc_loss: {5:0.3f}'
        print(stats.format(epoch + 1, int(time.time() - start), 
                           epoch_gen_loss.result().numpy(), 
                           epoch_disc_loss.result().numpy(),
                           epoch_p2p_gen_loss.result().numpy(),
                           epoch_p2p_disc_loss.result().numpy()))
        history = history.append({'gen_loss': epoch_gen_loss.result().numpy(), 
                                  'disc_loss': epoch_disc_loss.result().numpy(),
                                  'p2p_gen_loss': epoch_p2p_gen_loss.result().numpy(),
                                  'p2p_disc_loss': epoch_p2p_disc_loss.result().numpy()}, 
                                  ignore_index=True)
        if (epoch+1) % 50 == 0:
            hm_checkpoint.save(file_prefix = hm_checkpoint_prefix)
            p2p_checkpoint.save(file_prefix = p2p_checkpoint_prefix)
        
    return history

In [None]:
EPOCHS = 1000

seed = tf.random.normal([8, 1000])
history = train(img_it, EPOCHS)
history.index = history.index + 1