In [1]:
from google.colab import drive
drive.mount('/content/gdrive')
root_path = 'gdrive/My Drive/MIE1516 Project/' 

Mounted at /content/gdrive


In [0]:
import os
import cv2
import h5py
import numpy
import math
import skimage
from math import log10, sqrt 
from skimage.metrics import structural_similarity, peak_signal_noise_ratio

In [0]:
DATA_PATH = root_path + "/DIV2K_train_HR"
LR_DATA_PATH = root_path + "/DIV2K_train_LR_bicubic_X2/"
TEST_PATH = root_path + "/DIV2K_valid_HR/"

In [0]:
import scipy
from skimage import transform, io
import imageio
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
#Link: https://github.com/eriklindernoren/Keras-GAN
class DataLoader():
    def __init__(self, dataset_name, img_res=(128, 128)):
        self.dataset_name = dataset_name
        self.img_res = img_res

    def load_data(self, batch_size=1, is_testing=False):
        data_type = "train" if not is_testing else "test"
        
        path = glob('%s/*' % (self.dataset_name))

        batch_images = np.random.choice(path, size=batch_size)

        imgs_hr = []
        imgs_lr = []
        for img_path in batch_images:
            img = self.imread(img_path)

            h, w = self.img_res
            low_h, low_w = int(h / 4), int(w / 4)

            img_hr = transform.resize(img, self.img_res)
            img_lr = transform.resize(img, (low_h, low_w))

            # If training => do random flip
            if not is_testing and np.random.random() < 0.5:
                img_hr = np.fliplr(img_hr)
                img_lr = np.fliplr(img_lr)

            imgs_hr.append(img_hr)
            imgs_lr.append(img_lr)

        imgs_hr = np.array(imgs_hr) / 127.5 - 1.
        imgs_lr = np.array(imgs_lr) / 127.5 - 1.
        print(imgs_hr.shape)
        print(imgs_lr.shape)
        return imgs_hr, imgs_lr


    def imread(self, path):
        return imageio.imread(path, pilmode='RGB').astype(np.float)


In [0]:
class ImageMetrics():
  def __init__(self):
    self.name = "Metric for Image"
  
  def print_metrics(self, reference, generated):
    self.reference = (reference)*255.0
    self.generated = (generated)*255.0
    assert self.reference.shape == self.generated.shape
    psnr_score, ssim_score = self.psnr(), self.ssim()
    print("PSNR: %.4f SSIM: %.4f" % (psnr_score, ssim_score))
    return psnr_score, ssim_score

  #Link: https://dsp.stackexchange.com/questions/38065/peak-signal-to-noise-ratio-psnr-in-python-for-an-image
  def psnr(self):
    mse = np.mean((self.reference-self.generated)**2)
    if(mse == 0):
      mse = 100
    MAX_PIXEL = 255.0
    psnr_score = 20*log10(MAX_PIXEL/sqrt(mse))
    return psnr_score

  #Link: https://scikit-image.org/docs/dev/api/skimage.metrics.html#skimage.metrics.structural_similarity
  def ssim(self):
    (ssim_score, difference) = structural_similarity(self.reference, self.generated, multichannel= True, gaussian_weights= True, full=True)
    return ssim_score


In [9]:
"""
Super-resolution of CelebA using Generative Adversarial Networks.

The dataset can be downloaded from: https://www.dropbox.com/sh/8oqt9vytwxb3s4r/AADIKlz8PR9zr6Y20qbkunrba/Img/img_align_celeba.zip?dl=0

Instrustion on running the script:
1. Download the dataset from the provided link
2. Save the folder 'img_align_celeba' to 'datasets/'
4. Run the sript using command 'python srgan.py'
"""
#Link: https://github.com/eriklindernoren/Keras-GAN
from __future__ import print_function, division
import scipy
!pip install git+https://www.github.com/keras-team/keras-contrib.git

%tensorflow_version 1.x


from keras.datasets import mnist
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from keras.layers import BatchNormalization, Activation, ZeroPadding2D, Add
from keras.layers.advanced_activations import PReLU, LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.applications import VGG19
from keras.models import Sequential, Model
from keras.optimizers import Adam

import tensorflow as tf
import datetime
import matplotlib.pyplot as plt
import sys
#from data_loader import DataLoader
import numpy as np
import os
from collections import defaultdict
try:
    import cPickle as pickle
except ImportError:
    import pickle

import keras.backend as K

class SRGAN():
    def __init__(self):
        # Input shape
        self.channels = 3
        self.lr_height = 64                 # Low resolution height
        self.lr_width = 64                  # Low resolution width
        self.lr_shape = (self.lr_height, self.lr_width, self.channels)
        self.hr_height = self.lr_height*4   # High resolution height
        self.hr_width = self.lr_width*4     # High resolution width
        self.hr_shape = (self.hr_height, self.hr_width, self.channels)

        # Number of residual blocks in the generator
        self.n_residual_blocks = 16

        optimizer = Adam(0.0002, 0.5)

        # We use a pre-trained VGG19 model to extract image features from the high resolution
        # and the generated high resolution images and minimize the mse between them
        self.vgg = self.build_vgg()
        self.vgg.trainable = False
        self.vgg.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])

        # Configure data loader
        self.save_path_name = root_path
        self.dataset_name = DATA_PATH
        self.test_dataset_name = TEST_PATH
        self.data_loader = DataLoader(dataset_name=self.dataset_name,
                                      img_res=(self.hr_height, self.hr_width))
        self.test_data_loader = DataLoader(dataset_name=self.test_dataset_name,
                                      img_res=(self.hr_height, self.hr_width))
        self.metrics = ImageMetrics()
        
        self.model_gen_path = self.save_path_name + "SRGAN_PL/checkpoints/model_gen_april11.h5"
        self.model_disc_path = self.save_path_name + "SRGAN_PL/checkpoints/model_disc_april11.h5"
        self.model_combined_path = self.save_path_name + "SRGAN_PL/checkpoints/model_combined_april11.h5"
        self.sample_images_path = self.save_path_name + "SRGAN_PL/sample_image_outputs"
        self.history_loss_accuracy_path = self.save_path_name + "SRGAN_PL/srgan-pl-history.pkl"

        # Calculate output shape of D (PatchGAN)
        patch = int(self.hr_height / 2**4)
        self.disc_patch = (patch, patch, 1)

        # Number of filters in the first layer of G and D
        self.gf = 64
        self.df = 64

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()

        # High res. and low res. images
        img_hr = Input(shape=self.hr_shape)
        img_lr = Input(shape=self.lr_shape)

        # Generate high res. version from low res.
        fake_hr = self.generator(img_lr)

        # Extract image features of the generated img
        fake_features = self.vgg(fake_hr)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # Discriminator determines validity of generated high res. images
        validity = self.discriminator(fake_hr)

        self.combined = Model([img_lr, img_hr], [validity, fake_features])
        self.combined.compile(loss=['binary_crossentropy', 'mse'],
                              loss_weights=[1e-3, 1],
                              optimizer=optimizer)


    def build_vgg(self):
        """
        Builds a pre-trained VGG19 model that outputs image features extracted at the
        third block of the model
        """
        vgg = VGG19(weights="imagenet")
        # Set outputs to outputs of last conv. layer in block 3
        # See architecture at: https://github.com/keras-team/keras/blob/master/keras/applications/vgg19.py
        vgg.outputs = [vgg.layers[9].output]

        img = Input(shape=self.hr_shape)

        # Extract image features
        img_features = vgg(img)

        return Model(img, img_features)

    def build_generator(self):

        def residual_block(layer_input, filters):
            """Residual block described in paper"""
            d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(layer_input)
            d = Activation('relu')(d)
            d = BatchNormalization(momentum=0.8)(d)
            d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(d)
            d = BatchNormalization(momentum=0.8)(d)
            d = Add()([d, layer_input])
            return d

        def deconv2d(layer_input):
            """Layers used during upsampling"""
            u = UpSampling2D(size=2)(layer_input)
            u = Conv2D(256, kernel_size=3, strides=1, padding='same')(u)
            u = Activation('relu')(u)
            return u

        # Low resolution image input
        img_lr = Input(shape=self.lr_shape)

        # # Pre-residual block
        # c1 = Conv2D(64, kernel_size=9, strides=1, padding='same')(img_lr)
        # c1 = Activation('relu')(c1)

        # # Propogate through residual blocks
        # r = residual_block(c1, self.gf)
        # for _ in range(self.n_residual_blocks - 1):
        #     r = residual_block(r, self.gf)

        # # Post-residual block
        # c2 = Conv2D(64, kernel_size=3, strides=1, padding='same')(r)
        # c2 = BatchNormalization(momentum=0.8)(c2)
        # c2 = Add()([c2, c1])

        # # Upsampling
        # u1 = deconv2d(c2)
        # u2 = deconv2d(u1)

        # # Generate high resolution output
        # gen_hr = Conv2D(self.channels, kernel_size=9, strides=1, padding='same', activation='tanh')(u2)
        c1 = Conv2D(128, kernel_size=9, strides=1, padding='same')(img_lr)
        c1 = Activation('relu')(c1)
        c2 = Conv2D(64, kernel_size=3, strides=1, padding='same')(c1)
        c2 = Activation('relu')(c2)
        c3 = Conv2D(3, kernel_size=5, strides=1, padding='same')(c2)
        c3 = Activation('linear')(c3)
        u1 = UpSampling2D(size=4)(c3)
        gen_hr = u1
        
        #gen_hr = Reshape(self.hr_shape)

        model_gen = Model(img_lr, gen_hr)
        print(model_gen.summary())
        return model_gen

    def build_discriminator(self):

        def d_block(layer_input, filters, strides=1, bn=True):
            """Discriminator layer"""
            d = Conv2D(filters, kernel_size=3, strides=strides, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if bn:
                d = BatchNormalization(momentum=0.8)(d)
            return d

        # Input img
        d0 = Input(shape=self.hr_shape)

        d1 = d_block(d0, self.df, bn=False)
        d2 = d_block(d1, self.df, strides=2)
        d3 = d_block(d2, self.df*2)
        d4 = d_block(d3, self.df*2, strides=2)
        d5 = d_block(d4, self.df*4)
        d6 = d_block(d5, self.df*4, strides=2)
        d7 = d_block(d6, self.df*8)
        d8 = d_block(d7, self.df*8, strides=2)

        d9 = Dense(self.df*16)(d8)
        d10 = LeakyReLU(alpha=0.2)(d9)
        validity = Dense(1, activation='sigmoid')(d10)

        model_disc = Model(d0, validity)
        print(model_disc.summary())
        return model_disc

    def train(self, epochs, batch_size=1, sample_interval=50):

        start_time = datetime.datetime.now()

        if(os.path.exists(self.model_gen_path)):
          print("I am here")
          self.generator.load_weights(self.model_gen_path)

        if(os.path.exists(self.model_disc_path)):
          print("I am here")
          self.discriminator.load_weights(self.model_disc_path)
        
        if(os.path.exists(self.model_combined_path)):
          print("I am here")
          self.combined.load_weights(self.model_combined_path)

        train_dict = defaultdict(list)
        test_dict = defaultdict(list)
        sample_imgs_hr, sample_imgs_lr = self.test_data_loader.load_data(batch_size=2, is_testing=True)
        for epoch in range(epochs):

            # ----------------------
            #  Train Discriminator
            # ----------------------

            # Sample images and their conditioning counterparts
            imgs_hr, imgs_lr = self.data_loader.load_data(batch_size)
            # # For the combined model we will only train the generator
            # self.discriminator.trainable = True
            # From low res. image generate high res. version
            fake_hr = self.generator.predict(imgs_lr)

            valid = np.ones((batch_size,) + self.disc_patch)
            fake = np.zeros((batch_size,) + self.disc_patch)

            # Train the discriminators (original images = real / generated = Fake)
            d_loss_real = self.discriminator.train_on_batch(imgs_hr, valid)
            d_loss_fake = self.discriminator.train_on_batch(fake_hr, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ------------------
            #  Train Generator
            # ------------------

            # Sample images and their conditioning counterparts
            imgs_hr, imgs_lr = self.data_loader.load_data(batch_size)
            # For the combined model we will only train the generator
            self.discriminator.trainable = False
            # The generators want the discriminators to label the generated images as real
            valid = np.ones((batch_size,) + self.disc_patch)

            # Extract ground truth image features using pre-trained VGG19 model
            image_features = self.vgg.predict(imgs_hr)

            # Train the generators
            g_loss = self.combined.train_on_batch([imgs_lr, imgs_hr], [valid, image_features])

            print("Training Losses")
            print("Discriminator Real + Fake => Training Loss: %.4f Training Accuracy: %.4f" %(d_loss[0], d_loss[1]))
            train_dict["discriminator"].append(d_loss)

            print("Generator (Combined) => Training Content Loss: %.4f Adverserial Loss: %.4f Perceptual Loss: %.4f" %(g_loss[2], g_loss[1], g_loss[0]))
            train_dict["generator"].append(g_loss)

            # ----------------------
            #  Test Discriminator
            # ----------------------

            # Sample images and their conditioning counterparts
            imgs_hr, imgs_lr = self.test_data_loader.load_data(batch_size,is_testing=True)
            # # For the combined model we will only train the generator
            # self.discriminator.trainable = True
            # From low res. image generate high res. version
            fake_hr = self.generator.predict(imgs_lr)

            valid = np.ones((batch_size,) + self.disc_patch)
            fake = np.zeros((batch_size,) + self.disc_patch)

            # Train the discriminators (original images = real / generated = Fake)
            test_d_loss_real = self.discriminator.test_on_batch(imgs_hr, valid)
            test_d_loss_fake = self.discriminator.test_on_batch(fake_hr, fake)
            test_d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ------------------
            #  Test Generator
            # ------------------

            # Sample images and their conditioning counterparts
            imgs_hr, imgs_lr = self.test_data_loader.load_data(batch_size, is_testing=True)

            # The generators want the discriminators to label the generated images as real
            valid = np.ones((batch_size,) + self.disc_patch)

            # Extract ground truth image features using pre-trained VGG19 model
            image_features = self.vgg.predict(imgs_hr)

            # Train the generators
            test_g_loss = self.combined.test_on_batch([imgs_lr, imgs_hr], [valid, image_features])

            print("Testing Losses")
            print("Discriminator Real + Fake => Testing Loss: %.4f Testing Accuracy: %.4f" %(test_d_loss[0], test_d_loss[1]))
            test_dict["discriminator"].append(test_d_loss)

            print("Generator (Combined) => Testing Content Loss: %.4f Adverserial Loss: %.4f Perceptual Loss: %.4f" %(test_g_loss[2], test_g_loss[1], test_g_loss[0]))
            test_dict["generator"].append(test_g_loss)

            elapsed_time = datetime.datetime.now() - start_time
            # Plot the progress
            print ("%d time: %s" % (epoch, elapsed_time))

            self.generator.save_weights(self.model_gen_path)
            self.discriminator.save_weights(self.model_disc_path)
            self.combined.save_weights(self.model_combined_path)
            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                self.sample_images(epoch,sample_imgs_hr, sample_imgs_lr )
        
        with open(self.history_loss_accuracy_path, 'wb') as f:
          pickle.dump({'train': train_dict, 'test': test_dict}, f)

    def sample_images(self, epoch, sample_imgs_hr=None, sample_imgs_lr=None):
        #os.makedirs('images/%s' % self.dataset_name, exist_ok=True)
        r, c = 2, 3

        if(sample_imgs_hr is None or sample_imgs_lr is None):
          imgs_hr, imgs_lr = self.test_data_loader.load_data(batch_size=2, is_testing=True)
        else:
          imgs_hr, imgs_lr = sample_imgs_hr,sample_imgs_lr

        fake_hr = self.generator.predict(imgs_lr)

        # Rescale images 0 - 1
        imgs_lr = 0.5 * imgs_lr + 0.5
        fake_hr = 0.5 * fake_hr + 0.5
        imgs_hr = 0.5 * imgs_hr + 0.5

        # Save generated images and the high resolution originals
        titles = ['LR Original','Generated', 'Original']
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for row in range(r):
            psnr_score, ssim_score = self.metrics.print_metrics(reference=imgs_hr[row][2], generated=fake_hr[row][1])
            for col, image in enumerate([imgs_lr, fake_hr, imgs_hr]):
                if(row == 0):
                  if(titles[col] == 'Generated'):
                    col_title = "Generated \n PSNR: %.4f SSIM: %.4f" % (psnr_score, ssim_score)
                  else:
                    col_title = titles[col]+ "\n"
                else:
                  if(titles[col] == 'Generated'):
                    col_title = "PSNR: %.4f SSIM: %.4f" % (psnr_score, ssim_score)
                  else:
                    col_title = "\n"
                axs[row, col].imshow(image[row])
                axs[row, col].set_title(col_title)
                axs[row, col].axis('off')
            cnt += 1
        fig.savefig("%s/%d.png" % (self.sample_images_path, epoch))
        plt.close()

        # Save low resolution images for comparison
        if((epoch == 0 and (sample_imgs_hr is not None or sample_imgs_lr is not None)) or (sample_imgs_hr is None or sample_imgs_lr is None)):
          for i in range(r):
              fig = plt.figure()
              plt.imshow(imgs_lr[i])
              fig.savefig('%s/%d_lowres%d.png' % (self.sample_images_path, epoch, i))
              plt.close()

if __name__ == '__main__':
    tf.reset_default_graph()
    try:
        del gan
    except:
        pass
    K.clear_session()

    
    graph = tf.get_default_graph()

    with graph.as_default():
      gan = SRGAN()
      gan.train(epochs=201, batch_size=10, sample_interval=50)


Collecting git+https://www.github.com/keras-team/keras-contrib.git
  Cloning https://www.github.com/keras-team/keras-contrib.git to /tmp/pip-req-build-eiou9kqz
  Running command git clone -q https://www.github.com/keras-team/keras-contrib.git /tmp/pip-req-build-eiou9kqz
Building wheels for collected packages: keras-contrib
  Building wheel for keras-contrib (setup.py) ... [?25l[?25hdone
  Created wheel for keras-contrib: filename=keras_contrib-2.0.8-cp36-none-any.whl size=101064 sha256=8fd371f10752b40c39c24e6c46e3dd4479d2118423357a12f12c0b535e14f4bf
  Stored in directory: /tmp/pip-ephem-wheel-cache-o8bka907/wheels/11/27/c8/4ed56de7b55f4f61244e2dc6ef3cdbaff2692527a2ce6502ba
Successfully built keras-contrib
TensorFlow 1.x selected.


Using TensorFlow backend.


Instructions for updating:
If using Keras pass *_constraint arguments to layers.

Model: "model_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_3 (InputLayer)         (None, 256, 256, 3)       0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 256, 256, 64)      1792      
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 256, 256, 64)      0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 128, 128, 64)      36928     
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 128, 128, 64)      0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 128, 128, 64)      256       
___________________________________________

  'Discrepancy between trainable weights and collected trainable'
  'Discrepancy between trainable weights and collected trainable'


(10, 256, 256, 3)
(10, 64, 64, 3)
Training Losses
Discriminator Real + Fake => Training Loss: 0.3469 Training Accuracy: 0.4828
Generator (Combined) => Training Content Loss: 70.3232 Adverserial Loss: 0.7717 Perceptual Loss: 70.3240
(10, 256, 256, 3)
(10, 64, 64, 3)
(10, 256, 256, 3)
(10, 64, 64, 3)
Testing Losses
Discriminator Real + Fake => Testing Loss: 0.3469 Testing Accuracy: 0.4828
Generator (Combined) => Testing Content Loss: 65.8685 Adverserial Loss: 0.6892 Perceptual Loss: 65.8691
0 time: 0:01:24.308012


  im2[..., ch], **args)


PSNR: 9.7081 SSIM: -0.0094
PSNR: 14.7075 SSIM: -0.0037
(10, 256, 256, 3)
(10, 64, 64, 3)


  'Discrepancy between trainable weights and collected trainable'


(10, 256, 256, 3)
(10, 64, 64, 3)
Training Losses
Discriminator Real + Fake => Training Loss: 0.2560 Training Accuracy: 0.6379
Generator (Combined) => Training Content Loss: 64.1250 Adverserial Loss: 1.1700 Perceptual Loss: 64.1262
(10, 256, 256, 3)
(10, 64, 64, 3)
(10, 256, 256, 3)
(10, 64, 64, 3)
Testing Losses
Discriminator Real + Fake => Testing Loss: 0.2560 Testing Accuracy: 0.6379
Generator (Combined) => Testing Content Loss: 52.4936 Adverserial Loss: 0.6869 Perceptual Loss: 52.4943
1 time: 0:02:43.211186
(10, 256, 256, 3)
(10, 64, 64, 3)
(10, 256, 256, 3)
(10, 64, 64, 3)
Training Losses
Discriminator Real + Fake => Training Loss: 0.3412 Training Accuracy: 0.4750
Generator (Combined) => Training Content Loss: 57.9817 Adverserial Loss: 1.3376 Perceptual Loss: 57.9831
(10, 256, 256, 3)
(10, 64, 64, 3)
(10, 256, 256, 3)
(10, 64, 64, 3)
Testing Losses
Discriminator Real + Fake => Testing Loss: 0.3412 Testing Accuracy: 0.4750
Generator (Combined) => Testing Content Loss: 62.2995 Adver

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


PSNR: 15.9656 SSIM: 0.1362
PSNR: 19.8301 SSIM: 0.1178
(10, 256, 256, 3)
(10, 64, 64, 3)
(10, 256, 256, 3)
(10, 64, 64, 3)
Training Losses
Discriminator Real + Fake => Training Loss: 0.2667 Training Accuracy: 0.5412
Generator (Combined) => Training Content Loss: 22.3892 Adverserial Loss: 1.8794 Perceptual Loss: 22.3910
(10, 256, 256, 3)
(10, 64, 64, 3)
(10, 256, 256, 3)
(10, 64, 64, 3)
Testing Losses
Discriminator Real + Fake => Testing Loss: 0.2667 Testing Accuracy: 0.5412
Generator (Combined) => Testing Content Loss: 22.7520 Adverserial Loss: 0.8569 Perceptual Loss: 22.7529
101 time: 2:04:31.256807
(10, 256, 256, 3)
(10, 64, 64, 3)
(10, 256, 256, 3)
(10, 64, 64, 3)
Training Losses
Discriminator Real + Fake => Training Loss: 0.1146 Training Accuracy: 0.8988
Generator (Combined) => Training Content Loss: 21.8271 Adverserial Loss: 2.0245 Perceptual Loss: 21.8291
(10, 256, 256, 3)
(10, 64, 64, 3)
(10, 256, 256, 3)
(10, 64, 64, 3)
Testing Losses
Discriminator Real + Fake => Testing Loss: 0