In [None]:
import tensorflow as tf
import keras 
from keras import layers
import numpy as np
from numba import cuda
import matplotlib.pyplot as plt
import cv2
import os
from tqdm import tqdm
import re
from keras.utils import img_to_array, plot_model, array_to_img
import gc

gpus = tf.config.list_physical_devices('GPU')
if gpus:
  try:
    # Currently, memory growth needs to be the same across GPUs
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
    logical_gpus = tf.config.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
  except RuntimeError as e:
    # Memory growth must be set before GPUs have been initialized
    print(e)

In [None]:
def sorted_alphanumeric(data):  
    convert = lambda text: int(text) if text.isdigit() else text.lower()
    alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)',key)]
    return sorted(data,key = alphanum_key)
# defining the size of the image
SIZE = 256
color_img = []
gray_img = []
path = 'E:/TER/LandscapeDataResize/color'
files = os.listdir(path)
files = sorted_alphanumeric(files)
for i in tqdm(files):    
        if i == '7200.jpg':
            break
        else:
            img = cv2.imread(path + '/'+i,1)
            img = cv2.resize(img, (SIZE, SIZE))
            # open cv reads images in BGR format so we have to convert it to LAB
            img = cv2.cvtColor(img, cv2.COLOR_BGR2Lab)
            #resizing image
            #分离图像通道
            l, a, b = cv2.split(img)
            xx=b
            #归一化
            l = (l/ 255.0)
            a = (a/ 255.0)
            b = (b/ 255.0)
            #合并通道
            img = cv2.merge([l, a, b])
            color_img.append(img_to_array(img))
            gray_img.append(img_to_array(l))
color_dataset=tf.data.Dataset.from_tensor_slices(np.array(color_img[0:2000])).batch(20)
gray_dataset=tf.data.Dataset.from_tensor_slices(np.array(gray_img[0:2000])).batch(20)
gc.collect()

In [None]:
example_color = next(iter(color_dataset))
example_gray = next(iter(gray_dataset))

In [None]:
def plot_images(a = 4):
    
    for i in range(a):
        plt.figure(figsize = (10,10))
        plt.subplot(121)
        plt.title('color')
        l, a, b = cv2.split(np.array(example_color[i]))
        #归一化
        l = (l.astype('float32')* 100)
        a = (a.astype('float32')* 255.0)-128
        b = (b.astype('float32')* 255.0)-128
        #合并通道
        color = cv2.merge([l, a, b])
        plt.imshow(cv2.cvtColor(color, cv2.COLOR_Lab2RGB))
        plt.subplot(122)
        plt.title('gray')
        plt.imshow(example_gray[i], cmap = 'gray')
        plt.show()

In [None]:
plot_images(3)

In [None]:
def downsample(filters, size, apply_batchnorm=True):
  
  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer='he_normal', use_bias=False))

  if apply_batchnorm:
    result.add(tf.keras.layers.BatchNormalization())

  result.add(tf.keras.layers.LeakyReLU())

  return result

def upsample(filters, size, apply_dropout=False):
  

  result = tf.keras.Sequential()
  result.add(
    tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                    padding='same',
                                    kernel_initializer='he_normal',
                                    use_bias=False))

  result.add(tf.keras.layers.BatchNormalization())

  if apply_dropout:
      result.add(tf.keras.layers.Dropout(0.5))

  result.add(tf.keras.layers.ReLU())

  return result  

In [None]:
def Generator():
  inputs = tf.keras.layers.Input(shape=[256,256,1])

  down_stack = [
    downsample(64, 4, apply_batchnorm=False), # (bs, 128, 128, 64)
    downsample(128, 4), # (bs, 64, 64, 128)
    downsample(256, 4), # (bs, 32, 32, 256)
    downsample(512, 4), # (bs, 16, 16, 512)
    downsample(512, 4), # (bs, 8, 8, 512)
    downsample(512, 4), # (bs, 4, 4, 512)
    downsample(512, 4), # (bs, 2, 2, 512)
    downsample(512, 4), # (bs, 1, 1, 512)
  ]

  up_stack = [
    upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
    upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
    upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
    upsample(512, 4), # (bs, 16, 16, 1024)
    upsample(256, 4), # (bs, 32, 32, 512)
    upsample(128, 4), # (bs, 64, 64, 256)
    upsample(64, 4), # (bs, 128, 128, 128)
  ]

  initializer = tf.random_normal_initializer(0., 0.02)
  last = tf.keras.layers.Conv2DTranspose(3, 4,
                                         strides=2,
                                         padding='same',
                                         kernel_initializer=initializer,
                                         activation='tanh') # (bs, 256, 256, 3)

  x = inputs

  # Downsampling through the model
  skips = []
  for down in down_stack:
    x = down(x)
    skips.append(x)

  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    x = tf.keras.layers.Concatenate()([x, skip])

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)


In [None]:
def Discriminator():
  initializer = tf.random_normal_initializer(0., 0.02)

  inp = tf.keras.layers.Input(shape=[256, 256, 1], name='input_image')
  tar = tf.keras.layers.Input(shape=[256, 256, 3], name='target_image')

  x = tf.keras.layers.concatenate([inp, tar]) # (bs, 256, 256, channels*2)

  down1 = downsample(64, 4, False)(x) # (bs, 128, 128, 64)
  down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)
  down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)

  zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
  conv = tf.keras.layers.Conv2D(512, 4, strides=1,
                                kernel_initializer=initializer,
                                use_bias=False)(zero_pad1) # (bs, 31, 31, 512)

  batchnorm1 = tf.keras.layers.BatchNormalization()(conv)

  leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)

  zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)

  last = tf.keras.layers.Conv2D(1, 4, strides=1,
                                kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)

  return tf.keras.Model(inputs=[inp, tar], outputs=last)


In [None]:
generator = Generator()
generator.summary()
plot_model(generator, to_file='generator_lab.png', show_shapes=True, show_layer_names=True, dpi=320)

In [None]:
discriminator = Discriminator()
discriminator.summary()
plot_model(discriminator, to_file='discriminator_lab.png', show_shapes=True, show_layer_names=True, dpi=320)

In [None]:
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

LAMBDA = 100

def generator_loss(disc_generated_output, gen_output, target):
  gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)

  # mean absolute error
  l1_loss = tf.reduce_mean(tf.abs(target - gen_output))

  total_gen_loss = gan_loss + (LAMBDA * l1_loss)

  return total_gen_loss, gan_loss, l1_loss

def discriminator_loss(disc_real_output, disc_generated_output):
  real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)

  generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss

In [None]:
def generate_images(model, test_input, tar, n=0):
  prediction = model(test_input, training=False)
  plt.figure(figsize=(15,5))
  l, a, b = cv2.split(np.array(tar[0]))
  #归一化
  l = (l.astype('float32')* 100)
  a = (a.astype('float32')* 255.0)-128
  b = (b.astype('float32')* 255.0)-128
  #合并通道
  target_image = cv2.cvtColor(cv2.merge([l, a, b]), cv2.COLOR_LAB2RGB)  
  l, a, b = cv2.split(np.array(prediction[0]))
  #归一化
  l = (l.astype('float32')* 100)
  a = (a.astype('float32')* 255.0)-128
  b = (b.astype('float32')* 255.0)-128
  #合并通道
  prediction_image = cv2.cvtColor(cv2.merge([l, a, b]), cv2.COLOR_LAB2RGB)
  display_list = [test_input[0], target_image, prediction_image]
  title = ['Input Image', 'Ground Truth', 'Predicted Image']
  M = tf.keras.losses.MeanSquaredError()
  MSE = M(tar[0], prediction[0])
  for i in range(3):
    plt.subplot(1, 3, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i],cmap='gray')
    plt.axis('off')
  plt.suptitle("MSE: "+str(MSE.numpy()), fontsize=20)
  plt.imsave('E:/TER/Result_pix2pix_LandscapeDataResize_LAB/output/'+str(n)+".png", prediction_image)
  plt.savefig('E:/TER/Result_pix2pix_LandscapeDataResize_LAB/plot/')
  plt.show()
  plt.close()

In [None]:
@tf.function
def train_step(input_image, target, epoch):
  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    gen_output = generator(input_image, training=True)

    disc_real_output = discriminator([input_image, target], training=True)
    disc_generated_output = discriminator([input_image, gen_output], training=True)

    gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
    disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

  generator_gradients = gen_tape.gradient(gen_total_loss,
                                          generator.trainable_variables)
  discriminator_gradients = disc_tape.gradient(disc_loss,
                                               discriminator.trainable_variables)

  generator_optimizer.apply_gradients(zip(generator_gradients,
                                          generator.trainable_variables))
  discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                              discriminator.trainable_variables))
  return gen_total_loss, disc_loss

In [None]:
import time
def fit(epochs):
  gen_loss=[]
  dis_loss=[]
  for epoch in range(epochs):
    result_gen_loss=[]
    result_disc_loss=[]
    start = time.time()
    print("Epoch: "+ str(epoch+1), end="\r")
    # Train
    color_dataset=tf.data.Dataset.from_tensor_slices(np.array(color_img[0:2000])).batch(16)
    gray_dataset=tf.data.Dataset.from_tensor_slices(np.array(gray_img[0:2000])).batch(16)
    for n, (input_image, target) in tf.data.Dataset.zip((gray_dataset, color_dataset)).enumerate():
      result_g,result_d=train_step(input_image, target, epoch)
      result_gen_loss.append(result_g)
      result_disc_loss.append(result_d)
    del color_dataset
    del gray_dataset
    gc.collect()
    color_dataset=tf.data.Dataset.from_tensor_slices(np.array(color_img[2000:4000])).batch(16)
    gray_dataset=tf.data.Dataset.from_tensor_slices(np.array(gray_img[2000:4000])).batch(16)
    for n, (input_image, target) in tf.data.Dataset.zip((gray_dataset, color_dataset)).enumerate():
      result_g,result_d=train_step(input_image, target, epoch)
      result_gen_loss.append(result_g)
      result_disc_loss.append(result_d)
    generator.compile()
    discriminator.compile()
    generator.save('E:/TER/model/generator_lab.h5')
    discriminator.save('E:/TER/model/discriminator_lab.h5')
    del color_dataset
    del gray_dataset
    gc.collect()
    gen_loss.append(np.mean(result_gen_loss))
    dis_loss.append(np.mean(result_disc_loss))
    print ('Time taken for epoch {} is {} sec. Generator loss: {} discriminator loss: {}'.format(epoch + 1, time.time()-start,np.mean(result_gen_loss),np.mean(result_disc_loss)))
  return gen_loss,dis_loss

In [None]:
if not os.path.exists('E:/TER/Result_pix2pix_LandscapeDataResize_LAB/output/'):
    os.makedirs('E:/TER/Result_pix2pix_LandscapeDataResize_LAB/output/')
if not os.path.exists('E:/TER/Result_pix2pix_LandscapeDataResize_LAB/plot/'):
    os.makedirs('E:/TER/Result_pix2pix_LandscapeDataResize_LAB/plot/')
for example_input, example_target in tf.data.Dataset.zip((gray_dataset,color_dataset)).take(2):
  generate_images(generator, example_input, example_target)

In [None]:
gen_loss,dis_loss=fit(epochs = 100)

In [None]:
plt.subplot(1, 2, 1)
plt.plot(gen_loss)
plt.title('Generator loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.subplot(1, 2, 2)
plt.plot(dis_loss)
plt.title('Discriminator loss')
plt.xlabel('epoch')
plt.savefig('./GAN_lab_loss.png')
plt.show()

In [None]:
color_dataset_t=tf.data.Dataset.from_tensor_slices(np.array(color_img[4000:])).batch(1)
gray_dataset_t=tf.data.Dataset.from_tensor_slices(np.array(gray_img[4000:])).batch(1)
n = 0
for example_input, example_target in tf.data.Dataset.zip((gray_dataset_t,color_dataset_t)).take(319):
  generate_images(generator, example_input, example_target,n)
  n+=1