In [0]:
from __future__ import print_function, division
import scipy

from keras.datasets import mnist
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import datetime
import matplotlib.pyplot as plt
import sys
from data_loader import DataLoader
import numpy as np
import os

class SuperResolveBlurryFaceGAN():
    def __init__(self, img_rows = 256, img_cols = 256, channels = 3):
        # Input shape
        self.img_rows = img_rows
        self.img_cols = img_cols
        self.channels = channels
        self.img_shape = (self.img_rows, self.img_cols, self.channels)

        # Configure data loader
        self.dataset_name = 'facades'
        self.data_loader = DataLoader(dataset_name=self.dataset_name,
                                      img_res=(self.img_rows, self.img_cols))


        # Number of filters in the first layer of G and D
        self.gf = 64
        self.df = 64

        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])

        #-------------------------
        # Construct Computational
        #   Graph of Generator
        #-------------------------

        # Build the generator
        self.generator = self.build_generator()

        # Input images and their conditioning images
        img_A = Input(shape=self.img_shape)
        img_B = Input(shape=self.img_shape)

        # By conditioning on B generate a fake version of A
        fake_A = self.generator(img_B)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # Discriminators determines validity of translated images / condition pairs
        valid = self.discriminator([fake_A, img_B])

        self.combined = Model(inputs=[img_A, img_B], outputs=[valid, fake_A])
        self.combined.compile(loss=['mse', 'mae'],
                              loss_weights=[1, 100],
                              optimizer=optimizer)

    def build_generator(self):

        def conv2d(layer_input, filters=1, kernel_size=1, strides=1):
            """Layers used during downsampling"""
            d = Conv2D(filters, kernel_size, strides, padding='same', activation='relu')(layer_input)
            d = BatchNormalization()(d)
            return d

        def deconv2d(layer_input, filters=1, kernel_size=1, strides=1):
            """Layers used during upsampling"""
            u = Deconvolution2D(filters, kernel_size, strides, padding='same', activation='relu')(layer_input)
            u = BatchNormalization()(u)
            return u

        # Image input
        img_A = Input(shape=self.img_shape)
        d1 = deconv2d(iimg_A, filters=self.gf, kernel_size=6, strides=2)
        c1 = conv2d(d1, filters=self.gf, kernel_size=5, strides=1)
        d2 = deconv2d(c1, filters=self.gf, kernel_size=6, strides=2)
        c2 = conv2d(d2, filters=self.gf, kernel_size=5, strides=1)
        c3 = conv2d(c2, filters=self.gf, kernel_size=5, strides=1)
        c4 = conv2d(c3, filters=self.gf, kernel_size=5, strides=1)
        c5 = conv2d(c4, filters=self.gf, kernel_size=5, strides=1)
        c6 = conv2d(c5, filters=self.gf, kernel_size=5, strides=1)
        c7 = conv2d(c6, filters=self.gf, kernel_size=5, strides=1)
        c8 = conv2d(c7, filters=self.gf, kernel_size=5, strides=1)
        c9 = conv2d(c8, filters=self.gf, kernel_size=5, strides=1)          
        c10 = Conv2D(3, 3, 1, padding='same', activation='tanh')(c9)
        c10= BatchNormalization()(c10)

        return c10

    def build_discriminator(self):

        def d_layer(layer_input, filters=1, kernel_size=1, strides=1):
            """Discriminator layer"""
            d = Conv2D(filters, kernel_size, strides, padding='same')(layer_input)
            d = LeakyReLU()(d)
            d = BatchNormalization()(d)
            
            return d


        img_B = Input(shape=self.img_shape)
        c1 = d_layer(img_B, filters = self.df, kernel_size = 4, strides = 2)
        c2 = d_layer(c1, filters = self.df, kernel_size = 4, strides = 2)
        c3 = d_layer(c2, filters = self.df, kernel_size = 4, strides = 2)
        c4 = d_layer(c3, filters = self.df, kernel_size = 4, strides = 2)
        f1 = Dense(1, activation = 'sigmoid')

        return f1

    def train(self, epochs, batch_size=1, sample_interval=50):

        start_time = datetime.datetime.now()

        # Adversarial loss ground truths
        valid = np.ones((batch_size,) + self.disc_patch)
        fake = np.zeros((batch_size,) + self.disc_patch)

        for epoch in range(epochs):
            for batch_i, (imgs_A, imgs_B) in enumerate(self.data_loader.load_batch(batch_size)):

                # ---------------------
                #  Train Discriminator
                # ---------------------

                # Condition on B and generate a translated version
                fake_A = self.generator.predict(imgs_B)

                # Train the discriminators (original images = real / generated = Fake)
                d_loss_real = self.discriminator.train_on_batch([imgs_A, imgs_B], valid)
                d_loss_fake = self.discriminator.train_on_batch([fake_A, imgs_B], fake)
                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

                # -----------------
                #  Train Generator
                # -----------------

                # Train the generators
                g_loss = self.combined.train_on_batch([imgs_A, imgs_B], [valid, imgs_A])

                elapsed_time = datetime.datetime.now() - start_time
                # Plot the progress
                print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %f] time: %s" % (epoch, epochs,
                                                                        batch_i, self.data_loader.n_batches,
                                                                        d_loss[0], 100*d_loss[1],
                                                                        g_loss[0],
                                                                        elapsed_time))

                # If at save interval => save generated image samples
                if batch_i % sample_interval == 0:
                    self.sample_images(epoch, batch_i)

    def sample_images(self, epoch, batch_i):
        os.makedirs('images/%s' % self.dataset_name, exist_ok=True)
        r, c = 3, 3

        imgs_A, imgs_B = self.data_loader.load_data(batch_size=3, is_testing=True)
        fake_A = self.generator.predict(imgs_B)

        gen_imgs = np.concatenate([imgs_B, fake_A, imgs_A])

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        titles = ['Condition', 'Generated', 'Original']
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt])
                axs[i, j].set_title(titles[i])
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%s/%d_%d.png" % (self.dataset_name, epoch, batch_i))
        plt.close()


if __name__ == '__main__':
    gan = SuperResolveBlurryFaceGAN()
    gan.train(epochs=200, batch_size=1, sample_interval=200)
