In [12]:
# Imports 
import tensorflow as tf
from numpy import load
from numpy import zeros
from numpy import ones
from numpy.random import randint
from keras.initializers import RandomNormal
from keras.models import Model
from keras.models import Input
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import Activation
from keras.layers import Concatenate
from keras.layers import Dropout
from keras.layers import BatchNormalization
from keras.layers import LeakyReLU
from matplotlib import pyplot

In [13]:
# Define the PatchGAN discriminator 

def define_discriminator(image_shape):
    # weight initialization
    init = RandomNormal(stddev=0.02)
    # source image input
    in_src_image = Input(shape=image_shape)
    # target image input
    in_target_image = Input(shape=image_shape)

    # concatenate images channel-wise
    merged = Concatenate()([in_src_image, in_target_image])

    # Architecture
    # C64
    d = Conv2D(64, (4,4), strides=(2,2), padding="same", kernel_initializer=init)(merged)
    d = LeakyReLU(alpha=0.2)(d)
    # C128
    d = Conv2D(128, (4,4), strides=(2,2), padding="same", kernel_initializer=init)(d)
    d = BatchNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    # C256
    d = Conv2D(256, (4,4), strides=(2,2), padding="same", kernel_initializer=init)(d)
    d = BatchNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    # C512
    d = Conv2D(512, (4,4), strides=(2,2), padding="same", kernel_initializer=init)(d)
    d = BatchNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    # second last output layer
    d = Conv2D(512, (4,4), padding="same", kernel_initializer=init)(d)
    d = BatchNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    # Patch output
    d = Conv2D(1, (4,4), padding="same", kernel_initializer=init)(d)
    patch_out = Activation("sigmoid")(d)

    # Define model
    model = Model([in_src_image, in_target_image], patch_out)

    # Compile model
    opt = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
    model.compile(loss="binary_crossentropy", optimizer=opt, loss_weights=[0.5])

    return model

In [14]:
# Define encoder block used in U-net for generator
def define_encoder_block(layer_in, n_filters, batchnorm=True):

    # Weight initialization
    init = RandomNormal(stddev=0.02)

    # Downsample
    g = Conv2D(n_filters, (4,4), strides=(2,2), padding="same",
    kernel_initializer=init)(layer_in)

    # Conditionally add batch normalization
    if batchnorm:
        g = BatchNormalization()(g, training=True)

    # Activation
    g = LeakyReLU(alpha=0.2)(g)
    return g

In [15]:
# Define decoder block used in U-net for generator

def decoder_block(layer_in, skip_in, n_filters, dropout=True):
    # Initialization
    init = RandomNormal(stddev=0.02)

    # Upsampling
    g = Conv2DTranspose(n_filters, (4,4), strides=(2,2), padding="same",
    kernel_initializer=init)(layer_in)

    # Batch normalization
    g = BatchNormalization()(g, training=True)

    # Conditionally add dropout
    if dropout:
        g = Dropout(0.5)(g, training=True)

    # Merge with skip connection
    g = Concatenate()([g, skip_in])

    # Activation
    g = Activation("relu")(g)
    return g

In [16]:
# Define the U-net Generator
def define_generator(image_shape=(256,256,3)):

    # Weight initialization
    init = RandomNormal(stddev=0.02)
    # Input
    in_image = Input(shape=image_shape)

    # Encoder model
    e1 = define_encoder_block(in_image, 64, batchnorm=False)
    e2 = define_encoder_block(e1, 128)
    e3 = define_encoder_block(e2, 256)
    e4 = define_encoder_block(e3, 512)
    e5 = define_encoder_block(e4, 512)
    e6 = define_encoder_block(e5, 512)
    e7 = define_encoder_block(e6, 512)

    # Bottleneck, no batch norm and relu
    b = Conv2D(512, (4,4), strides=(2,2), padding="same", kernel_initializer=init)(e7)
    b = Activation("relu")(b)

    # Decoder model
    d1 = decoder_block(b, e7, 512)
    d2 = decoder_block(d1, e6, 512)
    d3 = decoder_block(d2, e5, 512)
    d4 = decoder_block(d3, e4, 512, dropout=False)
    d5 = decoder_block(d4, e3, 256, dropout=False)
    d6 = decoder_block(d5, e2, 128, dropout=False)
    d7 = decoder_block(d6, e1, 64, dropout=False)

    # Output
    g = Conv2DTranspose(3, (4,4), strides=(2,2), padding="same", kernel_initializer=init)(d7)
    out_image = Activation("tanh")(g)

    # Define model
    model = Model(in_image, out_image)
    return model

In [17]:
# define the combined generator and discriminator model, for updating the generator
def define_gan(g_model, d_model, image_shape):

    # Make weights in the discriminator not trainable
    d_model.trainable = False

    # Define the source image
    in_src = Input(shape=image_shape)

    # Connect the source image to the generator input
    gen_out = g_model(in_src)

    # connect the source input and generator output to the discriminator input
    dis_out = d_model([in_src, gen_out])

    # Src image as input, generated image and classification output
    model = Model(in_src, [dis_out, gen_out])

    # Compile model
    opt = tf.keras.optimizers.Adam(learning_rate =0.0002, beta_1=0.5)
    model.compile(loss=["binary_crossentropy", "mae"], optimizer=opt, loss_weights=[1,100])

    return model

In [18]:
# Load and scale training images
def load_real_samples(filename):

    data = load(filename)
    X1, X2 = data["arr_0"], data["arr_1"]

    # Scale to [-1,1]
    X1 = (X1 - 127.5) / 127.5
    X2 = (X2 - 127.5) / 127.5
    return [X1, X2]

In [19]:
# Generate real samples in batches
def generate_real_samples(dataset, n_samples, patch_shape): 

    trainA, trainB = dataset
    ix = randint(0, trainA.shape[0], n_samples)
    X1, X2 = trainA[ix], trainB[ix]

    # Real class label = 1
    y = ones((n_samples, patch_shape, patch_shape, 1))
    
    return [X1, X2], y

In [20]:
# Generate a batch of images
def generate_fake_samples(g_model, samples, patch_shape):

    # Generate fake samles
    X = g_model.predict(samples)
    
    # Fake class label = 0
    y = zeros((len(X), patch_shape, patch_shape, 1))

    return X, y

In [21]:
# Generate samples, save plots and the model
def summarize_performance(step, g_model, dataset, n_samples=3):

    # Generate real samples
    [X_realA, X_realB], _ = generate_real_samples(dataset, n_samples, 1)

    # Generate fake samples
    X_fakeB, _ = generate_fake_samples(g_model, X_realA, 1)

    # Scale all pixels from [-1,1] to [0,1]
    X_realA = (X_realA + 1) / 2.0
    X_realB = (X_realB + 1) / 2.0
    X_fakeB = (X_fakeB + 1) / 2.0

    # Plot real source images
    for i in range(n_samples):
        pyplot.subplot(3, n_samples, 1 + i)
        pyplot.axis("off")
        pyplot.imshow(X_realA[i])

    # Plot generated target image
    for i in range(n_samples):
        pyplot.subplot(3, n_samples, 1 + n_samples + i)
        pyplot.axis("off")
        pyplot.imshow(X_fakeB[i])

    # Plot real target image
    for i in range(n_samples):
        pyplot.subplot(3, n_samples, 1 + n_samples*2 + i)
        pyplot.axis("off")
        pyplot.imshow(X_realB[i])

    # Save plot to file
    filename1 = "plot_%06d.png" % (step+1)
    pyplot.savefig(filename1)
    pyplot.close()
    # save the generator model
    filename2 = "model_%06d.h5" % (step+1)
    g_model.save(filename2)
    print(">Saved: %s and %s" % (filename1, filename2))

In [22]:
# Training the Pix2Pix models
def train(d_model, g_model, gan_model, dataset, n_epochs=100, n_batch=1):

    # Determine the output square shape of the discriminator
    n_patch = d_model.output_shape[1]

    # Dataset
    trainA, trainB = dataset

    # Batches and iterations
    bat_per_epo = int(len(trainA) / n_batch)
    n_steps = bat_per_epo * n_epochs

    # manually enumerate epochs
    for i in range(n_steps):
        
        # Generate real samples
        [X_realA, X_realB], y_real = generate_real_samples(dataset, n_batch, n_patch)

        # Generate fake samples
        X_fakeB, y_fake = generate_fake_samples(g_model, X_realA, n_patch)

        # Update discriminator for real samples
        d_loss1 = d_model.train_on_batch([X_realA, X_realB], y_real)

        # Update discriminator for generated samples
        d_loss2 = d_model.train_on_batch([X_realA, X_fakeB], y_fake)

        # Update the generator
        g_loss, _, _ = gan_model.train_on_batch(X_realA, [y_real, X_realB])

        # Summarize performance
        print(">%d, d1[%.3f] d2[%.3f] g[%.3f]" % (i+1, d_loss1, d_loss2, g_loss))
        
        # Summarize model performance
        if (i+1) % (bat_per_epo * 10) == 0:
            summarize_performance(i, g_model, dataset)

In [23]:
# Load data
dataset = load_real_samples("maps_256.npz")
print("Loaded: ", dataset[0].shape, dataset[1].shape)

# Define input shape based on the loaded dataset
image_shape = dataset[0].shape[1:]

# Create Pix2Pix GAN
d_model = define_discriminator(image_shape)
g_model = define_generator(image_shape)
gan_model = define_gan(g_model, d_model, image_shape)

train(d_model, g_model, gan_model, dataset)

Loaded:  (1096, 256, 256, 3) (1096, 256, 256, 3)




>1, d1[0.443] d2[0.807] g[83.381]
>2, d1[0.360] d2[1.079] g[75.517]
>3, d1[0.363] d2[0.636] g[81.980]
>4, d1[0.345] d2[0.547] g[69.047]
>5, d1[0.346] d2[0.474] g[79.846]
>6, d1[0.357] d2[0.425] g[75.366]
>7, d1[0.348] d2[0.365] g[72.718]
>8, d1[0.310] d2[0.349] g[69.591]
>9, d1[0.127] d2[0.286] g[53.488]
>10, d1[0.361] d2[0.299] g[51.735]
>11, d1[0.504] d2[0.294] g[59.647]
>12, d1[0.167] d2[0.313] g[63.992]
>13, d1[0.199] d2[0.153] g[51.636]
>14, d1[0.029] d2[0.102] g[42.196]
>15, d1[0.139] d2[0.152] g[53.949]
>16, d1[0.062] d2[0.124] g[47.290]
>17, d1[0.045] d2[0.060] g[43.919]
>18, d1[0.085] d2[0.070] g[33.061]
>19, d1[0.067] d2[0.164] g[46.458]
>20, d1[0.026] d2[0.042] g[41.516]
>21, d1[0.048] d2[0.028] g[41.395]
>22, d1[0.049] d2[0.046] g[36.363]
>23, d1[0.142] d2[1.110] g[19.427]
>24, d1[0.218] d2[0.041] g[19.363]
>25, d1[0.090] d2[0.032] g[29.839]
>26, d1[0.148] d2[0.059] g[26.171]
>27, d1[0.039] d2[0.070] g[31.980]
>28, d1[0.067] d2[0.039] g[23.174]
>29, d1[0.017] d2[0.083] g[28

KeyboardInterrupt: 

In [None]:
from numpy import vstack
def plot_images(src_img, gen_img, tar_img):
    images = vstack((src_img, gen_img, tar_img))
    # scale from [-1,1] to [0,1]
    images = (images + 1) / 2.0
    titles = ["Source", "Generated", "Expected"]
    # plot images row by row
    for i in range(len(images)):
        pyplot.subplot(1, 3, 1 + i)
        # turn off axis
        pyplot.axis("off")
        # plot raw pixel data
        pyplot.imshow(images[i])
        # show title
        pyplot.title(titles[i])
    pyplot.show()

In [None]:
# Dataset
from keras.models import load_model
[X1, X2] = load_real_samples("maps_256.npz")
print("Loaded: ", X1.shape, X2.shape)
# load model
model = load_model("model_109600.h5")

# Select random sample
ix = randint(0, len(X1), 1)
src_image, tar_image = X1[ix], X2[ix]
gen_image = model.predict(src_image)

# Plot all three images
plot_images(src_image, gen_image, tar_image)