# Data Mapping with Generative Adversarial Network

## Import packages

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds

import os
import time
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import clear_output

tfds.disable_progress_bar()
AUTOTUNE = tf.data.experimental.AUTOTUNE

In [None]:
# load map data set
dataset, metadata = tfds.load(
    'cycle_gan/maps', split=None, data_dir='..\..\dataset', batch_size=None, shuffle_files=False,
    download=True, as_supervised=True, decoders=None, read_config=None,
    with_info=True, builder_kwargs=None, download_and_prepare_kwargs=None,
    as_dataset_kwargs=None, try_gcs=False
)
train_Gps_maps, train_general_maps = dataset['trainA'], dataset['trainB']
test_Gps_maps, test_general_maps = dataset['testA'], dataset['testB']

In [None]:
print (test_Gps_maps)

In [None]:
BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256

In [None]:
# random crop
def random_crop(image):
  cropped_image = tf.image.random_crop(
      image, size=[IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image

In [None]:
# data normalization
def normalize(image):
  image = tf.cast(image, tf.float32)
  image = (image / 127.5) - 1
  return image

In [None]:
# random jitter: resize 288*288*3 + random crop 256*256*3 + random flip
def random_jitter(image):
  image = tf.image.resize(image, [288, 288],method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  image = random_crop(image)
  image = tf.image.random_flip_left_right(image)
  image = tf.image.random_flip_up_down(image)
  return image

In [None]:
# training set data preprocessing: random jitter + data normalization
def preprocess_image_train(image, label):
  image = random_jitter(image)
  image = normalize(image)
  return image

In [None]:
# test set data preprocessing: resize 256*256*3 + data normalization
def preprocess_image_test(image, label):
  image = tf.image.resize(image, [IMG_HEIGHT, IMG_WIDTH],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  image = normalize(image)
  return image

In [None]:
# read training set and test set
train_Gps_maps = train_Gps_maps.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().batch(1)

train_general_maps = train_general_maps.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().batch(1)

test_Gps_maps = test_Gps_maps.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().batch(1)

test_general_maps = test_general_maps.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().batch(1)

In [None]:
# get sample images
sample_Gps_map = next(iter(train_Gps_maps.skip(617).take(1)))
sample_general_map = next(iter(train_general_maps.skip(66).take(1)))
print(sample_general_map.shape)  #check the shape

In [None]:
# show sample satellite map image
plt.rcParams['figure.dpi'] = 500
plt.subplot(121)
plt.title('Satellite map')
plt.imshow(sample_Gps_map[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Satellite map with random jitter')
plt.imshow(random_jitter(sample_Gps_map[0]) * 0.5 + 0.5)

In [None]:
# show sample general layout map image
plt.subplot(121)
plt.title('General layout map')
plt.imshow(sample_general_map[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('General layout map with random jitter')
plt.imshow(random_jitter(sample_general_map[0]) * 0.5 + 0.5)

## Establish GANs model

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Sequential, regularizers
import tensorflow.keras as keras

def regularized_padded_conv(*args, **kwargs):
    return layers.Conv2D(*args, **kwargs, padding='same', use_bias=False,
                         kernel_initializer='he_normal',
                         kernel_regularizer=regularizers.l2(5e-4))

# channel attention mechanism
class ChannelAttention(layers.Layer):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()

        self.avg_out= layers.GlobalAveragePooling2D()
        self.max_out= layers.GlobalMaxPooling2D()

        self.fc1 = layers.Dense(in_planes//ratio, kernel_initializer='he_normal',
                                kernel_regularizer=regularizers.l2(5e-4),
                                activation=tf.nn.relu,
                                use_bias=True, bias_initializer='zeros')
        self.fc2 = layers.Dense(in_planes, kernel_initializer='he_normal',
                                kernel_regularizer=regularizers.l2(5e-4),
                                use_bias=True, bias_initializer='zeros')

    def call(self, inputs):
        avg_out = self.avg_out(inputs)
        max_out = self.max_out(inputs)
        out = tf.stack([avg_out, max_out], axis=1)  
        out = self.fc2(self.fc1(out))
        out = tf.reduce_sum(out, axis=1)     
        out = tf.nn.sigmoid(out)
        out = layers.Reshape((1, 1, out.shape[1]))(out)
        return out

# spatial attention mechanism
class SpatialAttention(layers.Layer):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv1 = regularized_padded_conv(1, kernel_size=kernel_size, strides=1, activation=tf.nn.sigmoid)

    def call(self, inputs):
        avg_out = tf.reduce_mean(inputs, axis=3)
        max_out = tf.reduce_max(inputs, axis=3)
        out = tf.stack([avg_out, max_out], axis=3)          
        out = self.conv1(out)
        return out

In [None]:
# Instance normalization
class InstanceNormalization(tf.keras.layers.Layer):
  def __init__(self, epsilon=1e-5):
    super(InstanceNormalization, self).__init__()
    self.epsilon = epsilon

  def build(self, input_shape):
    self.scale = self.add_weight(
        name='scale',
        shape=input_shape[-1:],
        initializer=tf.random_normal_initializer(1., 0.02),
        trainable=True)

    self.offset = self.add_weight(
        name='offset',
        shape=input_shape[-1:],
        initializer='zeros',
        trainable=True)

  def call(self, x):
    mean, variance = tf.nn.moments(x, axes=[1, 2], keepdims=True)
    inv = tf.math.rsqrt(variance + self.epsilon)
    normalized = (x - mean) * inv
    return self.scale * normalized + self.offset


In [None]:
def downsample(filters, size, norm_type='batchnorm', apply_norm=True): # number of filters, filter size, normalization type

  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

  if apply_norm:
    if norm_type.lower() == 'batchnorm':
      result.add(tf.keras.layers.BatchNormalization())
    elif norm_type.lower() == 'instancenorm':
      result.add(InstanceNormalization())

  result.add(tf.keras.layers.LeakyReLU())

  return result


def upsample(filters, size, norm_type='batchnorm', apply_dropout=False): # number of filters, filter size, normalization type, dropout condition 

  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                      padding='same',
                                      kernel_initializer=initializer,
                                      use_bias=False))

  if norm_type.lower() == 'batchnorm':
    result.add(tf.keras.layers.BatchNormalization())
  elif norm_type.lower() == 'instancenorm':
    result.add(InstanceNormalization())

  if apply_dropout:
    result.add(tf.keras.layers.Dropout(0.5))

  result.add(tf.keras.layers.ReLU())

  return result


In [None]:
def unet_generator(output_channels, norm_type='batchnorm'): # output channels, normalization type

  down_stack = [
      downsample(64, 4, norm_type, apply_norm=False),  # output size: (batch size, 128, 128, 64)
      downsample(128, 4, norm_type),  # output size: (batch size, 64, 64, 128)
      downsample(256, 4, norm_type),  # output size: (batch size, 32, 32, 256)
      downsample(512, 4, norm_type),  # output size: (batch size, 16, 16, 512)
      downsample(512, 4, norm_type),  # output size: (batch size, 8, 8, 512)
      downsample(512, 4, norm_type),  # output size: (batch size, 4, 4, 512)
      downsample(512, 4, norm_type),  # output size: (batch size, 2, 2, 512)
      downsample(512, 4, norm_type),  # output size: (batch size, 1, 1, 512)
  ]

  up_stack = [
      upsample(512, 4, norm_type, apply_dropout=True),  # output size: (batch size, 2, 2, 1024)
      upsample(512, 4, norm_type, apply_dropout=True),  # output size: (batch size, 4, 4, 1024)
      upsample(512, 4, norm_type, apply_dropout=True),  # output size: (batch size, 8, 8, 1024)
      upsample(512, 4, norm_type),  # output size: (batch size, 16, 16, 1024)
      upsample(256, 4, norm_type),  # output size: (batch size, 32, 32, 512)
      upsample(128, 4, norm_type),  # output size: (batch size, 64, 64, 256)
      upsample(64, 4, norm_type),  # output size: (batch size, 128, 128, 128)
  ]

  initializer = tf.random_normal_initializer(0., 0.02)
  last = tf.keras.layers.Conv2DTranspose(
      output_channels, 4, strides=2,
      padding='same', kernel_initializer=initializer,
      activation='tanh')  # output size: (batch size, 256, 256, 3)

  concat = tf.keras.layers.Concatenate()

  inputs = tf.keras.layers.Input(shape=[256, 256, 3])
  x = inputs

  # Downsampling through the model
  skips = []

  i=1
  for down in down_stack:
    x = down(x)
    # Apply CBAM Attention after layer i
    if i==6:
        CA = ChannelAttention(x.shape[-1]) 
        # SA = SpatialAttention()
        x = CA(x) * x  # channel attention 
        # x = SA(x) * x  # spatial attention 
    i+=1
    skips.append(x)

  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    x = concat([x, skip])
        
  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)


In [None]:
def discriminator(norm_type='batchnorm', target=True): # normalization type, whether target image is an input or not

  initializer = tf.random_normal_initializer(0., 0.02)

  inp = tf.keras.layers.Input(shape=[256, 256, 3], name='input_image')
  x = inp

  if target:
    tar = tf.keras.layers.Input(shape=[256, 256, 3], name='target_image')
    x = tf.keras.layers.concatenate([inp, tar])  # output size: (batch size, 256, 256, channels*2)
    
  # apply CBAM attention in downsampling layers
  x = downsample(64, 4, norm_type, False)(x)  # output size: (batch size, 128, 128, 64)
  '''
  CA = ChannelAttention(x.shape[-1])
  SA = SpatialAttention()
  x = CA(x) * x
  x = SA(x) * x
  '''
  x = downsample(128, 4, norm_type)(x)  # output size: (batch size, 64, 64, 128)
  '''
  CA = ChannelAttention(x.shape[-1])
  SA = SpatialAttention()
  x = CA(x) * x
  x = SA(x) * x
  '''

  x = downsample(256, 4, norm_type)(x)  # output size: (batch size, 32, 32, 256)
  '''
  CA = ChannelAttention(x.shape[-1])
  SA = SpatialAttention()
  x = CA(x) * x
  x = SA(x) * x
  '''

  zero_pad1 = tf.keras.layers.ZeroPadding2D()(x)  # output size: (batch size, 34, 34, 256)
  conv = tf.keras.layers.Conv2D(
      512, 4, strides=1, kernel_initializer=initializer,
      use_bias=False)(zero_pad1)  # output size: (batch size, 31, 31, 512)

  if norm_type.lower() == 'batchnorm':
    norm1 = tf.keras.layers.BatchNormalization()(conv)
  elif norm_type.lower() == 'instancenorm':
    norm1 = InstanceNormalization()(conv)

  leaky_relu = tf.keras.layers.LeakyReLU()(norm1)
  
  zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu)  # output size: (batch size, 33, 33, 512)

  last = tf.keras.layers.Conv2D(
      1, 4, strides=1,
      kernel_initializer=initializer)(zero_pad2)  # output size: (batch size, 30, 30, 1)

  if target:
    return tf.keras.Model(inputs=[inp, tar], outputs=last)
  else:
    return tf.keras.Model(inputs=inp, outputs=last)

# Call GANs model

In [None]:
OUTPUT_CHANNELS = 3

generator_g = unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = discriminator(norm_type='instancenorm', target=False)
discriminator_y = discriminator(norm_type='instancenorm', target=False)


In [None]:
# show the structure of generator G
print(generator_g.summary())
tf.keras.utils.plot_model(generator_g, show_shapes=True,dpi=300)

In [None]:
# show the structure of generator F
print(generator_f.summary())
tf.keras.utils.plot_model(generator_f, show_shapes=True,dpi=300)

In [None]:
# show the structure of discriminator X
print(discriminator_x.summary())
tf.keras.utils.plot_model(discriminator_x, show_shapes=True,dpi=300)

In [None]:
# show the structure of discriminator Y
print(discriminator_y.summary())
tf.keras.utils.plot_model(discriminator_y, show_shapes=True,dpi=300)

# Show the performance of the generator and discriminator before training(Epoch=0) 

In [None]:
to_general_map = generator_g(sample_Gps_map)
to_Gps_map = generator_f(sample_general_map)
plt.figure(figsize=(8, 8))
contrast = 8

imgs = [sample_Gps_map, to_general_map, sample_general_map, to_Gps_map]
title = ['Input satellite map', 'Predicted general layout map transferred by generator G', 'Input general layout map', 'Predicted satellite map transferred by generator F']

for i in range(len(imgs)):
  plt.subplot(2, 2, i+1)
  plt.title(title[i],fontsize=9)
  if i % 2 == 0:
    plt.imshow(imgs[i][0] * 0.5 + 0.5)
  else:
    plt.imshow(imgs[i][0] * 0.5 + 0.5)
plt.show()

In [None]:
plt.figure(figsize=(12, 12))
plt.suptitle('Output feature maps from discriminator X and discriminator Y', fontsize=18)

plt.subplot(221)
plt.title('Real satellite map→ Discriminator X')
plt.imshow(discriminator_x(sample_Gps_map)[0, ..., -1], cmap='Spectral')
plt.colorbar()

plt.subplot(222)
plt.title('Real general layout map→ Discriminator Y')
plt.imshow(discriminator_y(sample_general_map)[0, ..., -1], cmap='Spectral')
plt.colorbar()

plt.show()

## Define the loss function and set optimizer and learning rate

In [None]:
LAMBDA = 10

In [None]:
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True) 

In [None]:
# 
def discriminator_loss(real, generated):
  real_loss = loss_obj(tf.ones_like(real), real)

  generated_loss = loss_obj(tf.zeros_like(generated), generated)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss * 0.5

In [None]:
def generator_loss(generated):
  return loss_obj(tf.ones_like(generated), generated)

In [None]:
def calc_cycle_loss(real_image, cycled_image):
  loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
  
  return LAMBDA * loss1

In [None]:
def identity_loss(real_image, same_image):
  loss = tf.reduce_mean(tf.abs(real_image - same_image))
  return LAMBDA * 0.5 * loss

In [None]:
# define optimizer and learning rate
generator_g_optimizer = tf.keras.optimizers.Adam(1e-4,0.9)
generator_f_optimizer = tf.keras.optimizers.Adam(1e-4,0.9)

discriminator_x_optimizer = tf.keras.optimizers.Adam(1e-4,0.9)
discriminator_y_optimizer = tf.keras.optimizers.Adam(1e-4,0.9)

## Checkpoints

In [None]:
checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)


## Define the training function

In [None]:
EPOCHS = 100

In [None]:
def generate_images_g(model, test_input): # model, input data after normalization
  prediction = model(test_input)
    
  plt.figure(figsize=(12, 12))

  display_list = [test_input[0], prediction[0]]
  title = ['Input satellite map', 'Predicted general layout map transferred by generator G']

  for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()

  for i in range(120):
     print('=',end='')
  return tf.reshape(prediction,(-1,256,256,3)) # return the normalization image data

def generate_images_f(model, test_input): # model, input data after normalization
  prediction = model(test_input)
    
  plt.figure(figsize=(12, 12))

  display_list = [test_input[0], prediction[0]]
  title = ['Input general layout map', 'Predicted satellite map transferred by generator F']

  for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()
  
  for i in range(80):
    print('>',end='')
  return tf.reshape(prediction,(-1,256,256,3)) # return the normalization image data

In [None]:
@tf.function
def train_step(real_x, real_y):
  with tf.GradientTape(persistent=True) as tape:
    
    fake_y = generator_g(real_x, training=True)
    cycled_x = generator_f(fake_y, training=True)

    fake_x = generator_f(real_y, training=True)
    cycled_y = generator_g(fake_x, training=True)

    same_x = generator_f(real_x, training=True)
    same_y = generator_g(real_y, training=True)

    disc_real_x = discriminator_x(real_x, training=True)
    disc_real_y = discriminator_y(real_y, training=True)

    disc_fake_x = discriminator_x(fake_x, training=True)
    disc_fake_y = discriminator_y(fake_y, training=True)

    gen_g_loss = generator_loss(disc_fake_y)
    gen_f_loss = generator_loss(disc_fake_x)
    
    total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)
    
    # total generator loss = adversarial loss + cycle loss
    total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
    total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)

    disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
    disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)
  
  # Calculate the gradients for generator and discriminator
  generator_g_gradients = tape.gradient(total_gen_g_loss, generator_g.trainable_variables)
  generator_f_gradients = tape.gradient(total_gen_f_loss, generator_f.trainable_variables)
  
  discriminator_x_gradients = tape.gradient(disc_x_loss, discriminator_x.trainable_variables)
  discriminator_y_gradients = tape.gradient(disc_y_loss, discriminator_y.trainable_variables)
  
  # apply the gradients to the optimizer
  generator_g_optimizer.apply_gradients(zip(generator_g_gradients, generator_g.trainable_variables))
  generator_f_optimizer.apply_gradients(zip(generator_f_gradients, generator_f.trainable_variables))
  discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients, discriminator_x.trainable_variables))
  discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients, discriminator_y.trainable_variables))
  
  return total_gen_g_loss,total_gen_f_loss,disc_x_loss,disc_y_loss

# Define the function used to calculate FID score

In [None]:
import numpy as np
from numpy import cov
from numpy import trace
from numpy import iscomplexobj
from numpy import asarray
from numpy.random import randint
from scipy.linalg import sqrtm
from tensorflow.keras.applications.inception_v3 import InceptionV3
from tensorflow.keras.applications.inception_v3 import preprocess_input
from tensorflow.keras.datasets.mnist import load_data
from skimage.transform import resize
 
# scale an array of images to a new size
def scale_images(images, new_shape):
    images_list = list()
    for image in images:
        # resize with nearest neighbor interpolation
        new_image = resize(image, new_shape, 0)
        # store
        images_list.append(new_image)
    return asarray(images_list)
 
# calculate frechet inception distance
def calculate_fid(model, images1, images2):

    images1 = scale_images(images1, (299,299,3))
    images2 = scale_images(images2, (299,299,3))

    images1 = preprocess_input(images1)
    images2 = preprocess_input(images2)    
    
    act1 = model.predict(images1)
    act2 = model.predict(images2)
    mu1, sigma1 = act1.mean(axis=0), cov(act1, rowvar=False)
    mu2, sigma2 = act2.mean(axis=0), cov(act2, rowvar=False)
    ssdiff = np.sum((mu1 - mu2)**2.0)
    covmean = sqrtm(sigma1.dot(sigma2))
    if iscomplexobj(covmean):
        covmean = covmean.real
    fid = ssdiff + trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid
 
# get the two image data sets used to calculate the FID score
def get_scored_datasets(input_dataset,model,real_dataset,test_number): 

    input_image_set=np.array([])
    fake_image_set=np.array([])
    real_image_set=np.array([])
    
    for inp in input_dataset.take(test_number):
        input_image=tf.reshape(inp,(-1,256,256,3))    
        fake_image=model(input_image).numpy()    
        input_image=(input_image.numpy()).astype(np.float32)
        input_image_set=np.append(input_image_set,input_image)        
        fake_image_set=(np.append(fake_image_set,fake_image)).astype(np.float32)        
    input_image_set=np.reshape(input_image_set,[-1,256,256,3])    
    input_image_set=tf.convert_to_tensor(input_image_set)
    fake_image_set=np.reshape(fake_image_set,[-1,256,256,3])
    fake_image_set=tf.convert_to_tensor(fake_image_set)

    for inp in real_dataset.take(test_number):
        real_image=tf.reshape(inp,(-1,256,256,3))    
        real_image=(real_image.numpy()).astype(np.float32)
        real_image_set=(np.append(real_image_set,real_image)).astype(np.float32)         
    real_image_set=np.reshape(real_image_set,[-1,256,256,3])    
    real_image_set=tf.convert_to_tensor(real_image_set)
    return input_image_set,fake_image_set,real_image_set

model_FID = InceptionV3(include_top=False, pooling='avg', input_shape=(299,299,3))

## Start the training

In [None]:
gen_g_loss_plot,gen_f_loss_plot,disc_x_loss_plot,disc_y_loss_plot=np.array([]),np.array([]),np.array([]),np.array([])
gen_g_trainingset_fidplot,gen_f_trainingset_fidplot,gen_g_testset_fidplot,gen_f_testset_fidplot=np.array([]),np.array([]),np.array([]),np.array([])


for epoch in range(EPOCHS):
  start = time.time()

  n = 0
  gen_g_loss_list,gen_f_loss_list,disc_x_loss_list,disc_y_loss_list=np.array([]),np.array([]),np.array([]),np.array([])
  for image_x, image_y in tf.data.Dataset.zip((train_Gps_maps, train_general_maps)):
    gen_g_loss_step, gen_f_loss_step, disc_x_loss_step, disc_y_loss_step = train_step(image_x, image_y)
    gen_g_loss_list=np.append(gen_g_loss_list,gen_g_loss_step.numpy())
    gen_f_loss_list=np.append(gen_f_loss_list,gen_f_loss_step.numpy())
    disc_x_loss_list=np.append(disc_x_loss_list,disc_x_loss_step.numpy())
    disc_y_loss_list=np.append(disc_y_loss_list,disc_y_loss_step.numpy())                                                                                                                        
    if n % 10 == 0:
      print ('.', end='')
    n+=1
  gen_g_loss_plot=np.append(gen_g_loss_plot,gen_g_loss_list.mean())
  gen_f_loss_plot=np.append(gen_f_loss_plot,gen_f_loss_list.mean())
  disc_x_loss_plot=np.append(disc_x_loss_plot,disc_x_loss_list.mean())
  disc_y_loss_plot=np.append(disc_y_loss_plot,disc_y_loss_list.mean())
  
  input_image_set,fake_image_set,real_image_set=get_scored_datasets(train_Gps_maps,generator_g,train_general_maps,100)
  fid = calculate_fid(model_FID, real_image_set, fake_image_set)
  gen_g_trainingset_fidplot=np.append(gen_g_trainingset_fidplot,fid)                     

  input_image_set,fake_image_set,real_image_set=get_scored_datasets(train_general_maps,generator_f,train_Gps_maps,100)
  fid = calculate_fid(model_FID, real_image_set, fake_image_set)
  gen_f_trainingset_fidplot=np.append(gen_f_trainingset_fidplot,fid) 

  input_image_set,fake_image_set,real_image_set=get_scored_datasets(test_Gps_maps,generator_g,test_general_maps,100)
  fid = calculate_fid(model_FID, real_image_set, fake_image_set)
  gen_g_testset_fidplot=np.append(gen_g_testset_fidplot,fid) 

  input_image_set,fake_image_set,real_image_set=get_scored_datasets(test_general_maps,generator_f,test_Gps_maps,100)
  fid = calculate_fid(model_FID, real_image_set, fake_image_set)
  gen_f_testset_fidplot=np.append(gen_f_testset_fidplot,fid) 
                           
  if epoch == 0 or (epoch + 1) % 10 == 0:
    print('\nTest on input satellite map from training set:')
    fake_image_g1=generate_images_g(generator_g,tf.reshape(sample_Gps_map,(-1,256,256,3)))
    fake_image_f1=generate_images_f(generator_f,tf.reshape(fake_image_g1,(-1,256,256,3)))
  
    print('\nTest on input general layout map from training set:')  
    fake_image_f2=generate_images_f(generator_f,tf.reshape(sample_general_map,(-1,256,256,3)))
    fake_image_g2=generate_images_g(generator_g,tf.reshape(fake_image_f2,(-1,256,256,3)))
   
  if (epoch + 1) % EPOCHS == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,ckpt_save_path))


  print ('\n Epoch-{}:Generator G Loss={}, Generator F Loss={}, Discriminator X Loss={}, Discriminator Y Loss={}.'.format(epoch + 1,gen_g_loss_plot[-1],gen_f_loss_plot[-1],disc_x_loss_plot[-1],disc_y_loss_plot[-1]))
  print ('\n Epoch-{}:Generator G(training set)-FID Score={}, Generator F(training set)-FID Score={}, Generator G(test set)-FID Score={},Generator F(test set)-FID Score={}.'.format(epoch + 1,gen_g_trainingset_fidplot[-1],gen_f_trainingset_fidplot[-1],gen_g_testset_fidplot[-1],gen_f_testset_fidplot[-1]))
  print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,time.time()-start))

# Visualize the training process

In [None]:
import matplotlib.pyplot as plt

plt.rcParams['figure.dpi'] = 800
def training_visualisation (gen_g_loss_plot, gen_f_loss_plot, disc_x_loss_plot, disc_y_loss_plot, gen_g_trainingset_fidplot, gen_f_trainingset_fidplot, gen_g_testset_fidplot, gen_f_testset_fidplot, num_epochs=EPOCHS):
    plt.figure(figsize=(12,7))
    plt.subplot(2,2,1)
    plt.title("Generator's Loss with Epochs")
    plt.xlabel("Training Epochs")
    plt.ylabel("Loss")
    plt.plot(range(1,num_epochs+1), gen_g_loss_plot , ls="-", color="#0072BD", lw=2, label="Generator G")
    plt.plot(range(1,num_epochs+1), gen_f_loss_plot , ls="-.", color="#D95319", lw=2, label="Generator F")
    plt.legend(frameon=True)

    plt.subplot(2,2,2)
    plt.title("Discriminator's Loss with Epochs")
    plt.xlabel("Training Epochs")
    plt.ylabel("Loss")
    plt.plot(range(1,num_epochs+1), disc_x_loss_plot, ls="-", color="#EDB120", lw=2, label="Discriminator X")
    plt.plot(range(1,num_epochs+1), disc_y_loss_plot, ls="-.", color="#7E2F8E", lw=2, label="Discriminator Y")
    plt.legend(frameon=True)


    
    plt.subplot(2,2,3)
    plt.title("General Layout Map Generation Mission's FID Score with Epochs")
    plt.xlabel("Training Epochs")
    plt.ylabel("FID Score")
    plt.plot(range(1,num_epochs+1), gen_g_trainingset_fidplot, ls="-", color="#77AC30", lw=2, label="Generator G on training set")
    plt.plot(range(1,num_epochs+1), gen_g_testset_fidplot, ls="-.", color="#4DBEEE", lw=2, label="Generator G on test set")
    plt.legend(frameon=True)

    
    plt.subplot(2,2,4)
    plt.title("Satellite Map Generation Mission's FID Score with Epochs")
    plt.xlabel("Training Epochs")
    plt.ylabel("FID Score")
    plt.plot(range(1,num_epochs+1), gen_f_trainingset_fidplot, ls="-", color="#A2142F", lw=2, label="Generator F on training set")
    plt.plot(range(1,num_epochs+1), gen_f_testset_fidplot, ls="-.", color="#143CA2", lw=2, label="Generator F on test set")
    plt.legend(frameon=True)
    plt.tight_layout()   
    
    plt.show()    
    
    
training_visualisation(gen_g_loss_plot, gen_f_loss_plot, disc_x_loss_plot, disc_y_loss_plot, gen_g_trainingset_fidplot, gen_f_trainingset_fidplot, gen_g_testset_fidplot, gen_f_testset_fidplot, num_epochs=EPOCHS)

# After training, use the test set to test the generation performance

In [None]:
# Restore the latest checkpoint
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')

test_number=15
i=1    

for inp in test_Gps_maps.take(test_number): 
    print('\nTest set-Input satellite map-'+str(i))
    fake_image_g1=generate_images_g(generator_g,tf.reshape(inp,(-1,256,256,3)))
    fake_image_f1=generate_images_f(generator_f,tf.reshape(fake_image_g1,(-1,256,256,3)))
    i+=1
    
i=1    
for inp in test_general_maps.take(test_number): 
    print('\nTest set-Input general layout map-'+str(i))  
    fake_image_f2=generate_images_f(generator_f,tf.reshape(inp,(-1,256,256,3)))
    fake_image_g2=generate_images_g(generator_g,tf.reshape(fake_image_f2,(-1,256,256,3)))
    i+=1
    

# After training, test the performance of the generator and discriminator on the sample image

In [None]:
to_general_map = generator_g(sample_Gps_map)
to_Gps_map = generator_f(sample_general_map)
plt.figure(figsize=(8, 8))
contrast = 8

imgs = [sample_Gps_map, to_general_map, sample_general_map, to_Gps_map]
title = ['Input real satellite map', 'Predicted general layout map transferred by generator G', 'Input real general layout map', 'Predicted satellite map transferred by generator F']

for i in range(len(imgs)):
  plt.subplot(2, 2, i+1)
  plt.title(title[i],fontsize=9)
  if i % 2 == 0:
    plt.imshow(imgs[i][0] * 0.5 + 0.5)
  else:
    plt.imshow(imgs[i][0] * 0.5 + 0.5)
    #plt.imshow(imgs[i][0] * 0.5 * contrast + 0.5)
plt.show()

In [None]:
plt.figure(figsize=(12, 12))
plt.suptitle('Output feature maps from discriminator X and discriminator Y', fontsize=18)

plt.subplot(221)
plt.title('Real satellite map→ Discriminator X')
plt.imshow(discriminator_x(sample_Gps_map)[0, ..., -1], cmap='Spectral')
plt.colorbar()

plt.subplot(222)
plt.title('G:Fake general layout map→ Discriminator Y')
plt.imshow(discriminator_y(to_general_map)[0, ..., -1], cmap='Spectral')
plt.colorbar()

plt.subplot(223)
plt.title('Real general layout map→ Discriminator Y')
plt.imshow(discriminator_y(sample_general_map)[0, ..., -1], cmap='Spectral')
plt.colorbar()

plt.subplot(224)
plt.title('F:Fake satellite map→ Discriminator X')
plt.imshow(discriminator_x(to_Gps_map)[0, ..., -1], cmap='Spectral')
plt.colorbar()
plt.show()

# Finally test the generation performance on the map of Leeds

In [None]:
import tensorflow as tf
import matplotlib.pylab as plt

def test_leeds_image(input_image_dir,generate_images_g_or_f,generator_g_or_f,real_image_dir):
    input_image = tf.io.read_file(input_image_dir, 'r')
    input_image = tf.image.decode_jpeg(input_image)
    input_image = tf.image.resize(input_image, [256, 256],method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    input_image = normalize(input_image)
    input_image = tf.reshape(input_image,(-1,256,256,3))
    fake_image=generate_images_g_or_f(generator_g_or_f,input_image)
    real_image = tf.io.read_file(real_image_dir, 'r')
    real_image = tf.image.decode_jpeg(real_image)
    plt.figure(figsize=(2, 2))
    plt.imshow(real_image)
    plt.title('Google map-version: 10.49.2)')
    real_image = tf.image.resize(real_image, [256, 256],method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    real_image = normalize(real_image)
    real_image = tf.reshape(real_image,(-1,256,256,3))   
    return input_image,fake_image,real_image


In [None]:
input_image,fake_image,real_image = test_leeds_image('..\..\dataset\leeds\\Gps_roundhay.jpg',generate_images_g,generator_g,'..\..\dataset\leeds\Layout_roundhay.jpg')

input_image,fake_image,real_image = test_leeds_image('..\..\dataset\leeds\Gps_woodhouse.jpg',generate_images_g,generator_g,'..\..\dataset\leeds\Layout_woodhouse.jpg')

input_image,fake_image,real_image = test_leeds_image('..\..\dataset\leeds\Gps_airport.jpg',generate_images_g,generator_g,'..\..\dataset\leeds\Layout_airport.jpg')

