In [10]:
import tensorflow as tf
from keras.models import Sequential, Model
from keras.layers import Input, Activation, Add
from keras.layers import BatchNormalization, LeakyReLU, PReLU, Conv2D, Dense
from keras.layers import UpSampling2D, Lambda
from keras.optimizers import Adam
from keras.applications import VGG19
from keras.applications.vgg19 import preprocess_input
from keras.utils.data_utils import OrderedEnqueuer
from keras import backend as K
from keras.callbacks import TensorBoard, ModelCheckpoint, LambdaCallback

import os
import sys
import pickle
import datetime
import numpy as np

In [11]:
class SRGAN():
    def __init__(self,
        height_lr=56, width_lr=56, channels=3,
        upscaling_factor=4,
        gen_lr=1e-4, dis_lr=1e-4, loss_weights=[1e-3, 0.006],
        training_mode=True):

        self.height_lr = height_lr
        self.width_lr = width_lr

        self.upscaling_factor = upscaling_factor
        self.height_hr = int(self.height_lr * self.upscaling_factor)
        self.width_hr = int(self.width_lr * self.upscaling_factor)

        self.channels = channels
        self.shape_lr = (self.height_lr, self.width_lr, self.channels)
        self.shape_hr = (self.height_hr, self.width_hr, self.channels)

        self.gen_lr = gen_lr
        self.dis_lr = dis_lr
        self.loss_weights = loss_weights
        self.gan_loss = 'mse'
        self.dis_loss = 'binary_crossentropy'

        self.generator = self.build_generator()
        self.compile_generator(self.generator)
        #self.generator.summary()
        if training_mode:
            self.vgg = self.build_vgg()
            self.compile_vgg(self.vgg)
            self.discriminator = self.build_discriminator()
            self.compile_discriminator(self.discriminator)
            self.srgan = self.build_srgan()
            self.compile_srgan(self.srgan)

    def build_vgg(self):
        img = Input(shape=self.shape_hr)
        print(img.shape)
        vgg = VGG19(weights="imagenet")
        vgg.outputs = [vgg.layers[20].output]
        model = Model(inputs=img, outputs=vgg(img))
        model.trainable = False
        return model

    def preprocess_vgg(self, x):
        if isinstance(x, np.ndarray):
            return preprocess_input((x+1)*127.5)
        else:            
            return Lambda(lambda x: preprocess_input(tf.add(x, 1) * 127.5))(x)

    def SubpixelConv2D(self, name, scale=2):

        def subpixel_shape(input_shape):
            dims = [input_shape[0],
                    None if input_shape[1] is None else input_shape[1] * scale,
                    None if input_shape[2] is None else input_shape[2] * scale,
                    int(input_shape[3] / (scale ** 2))]
            output_shape = tuple(dims)
            return output_shape

        def subpixel(x):
            return tf.nn.depth_to_space(x, scale)

        return Lambda(subpixel, output_shape=subpixel_shape, name=name)

    def build_generator(self, residual_blocks=16):

        def residual_block(input):
            x = Conv2D(64, kernel_size=3, strides=1, padding='same')(input)
            x = BatchNormalization(momentum=0.8)(x)
            x = PReLU(shared_axes=[1,2])(x)
            x = Conv2D(64, kernel_size=3, strides=1, padding='same')(x)
            x = Add()([x, input])
            return x

        def upsample(x, number):
            x = Conv2D(256, kernel_size=3, strides=1, padding='same', name='upSampleConv2D_'+str(number))(x)
            x = self.SubpixelConv2D('upSampleSubPixel_'+str(number), 2)(x)
            x = PReLU(shared_axes=[1,2], name='upSamplePReLU_'+str(number))(x)
            return x

        lr_input = Input(shape=(None, None, 3))
        x_start = Conv2D(64, kernel_size=9, strides=1, padding='same')(lr_input)
        x_start = PReLU(shared_axes=[1,2])(x_start)

        r = residual_block(x_start)
        for _ in range(residual_blocks - 1):
            r = residual_block(r)
        
        x = Conv2D(64, kernel_size=3, strides=1, padding='same')(r)
        x = BatchNormalization(momentum=0.8)(x)
        x = Add()([x, x_start])

        x = upsample(x, 1)
        if self.upscaling_factor > 2:
            x = upsample(x, 2)
        if self.upscaling_factor > 4:
            x = upsample(x, 3)
        
        hr_output = Conv2D(
            self.channels, 
            kernel_size=9, 
            strides=1, 
            padding='same', 
            activation='tanh'
        )(x)

        model = Model(inputs=lr_input, outputs=hr_output)        
        return model

    def build_discriminator(self, filters=64):

        def conv2d_block(input, filters, strides=1, bn=True):
            d = Conv2D(filters, kernel_size=3, strides=strides, padding='same')(input)
            d = LeakyReLU(alpha=0.2)(d)
            if bn:
                d = BatchNormalization(momentum=0.8)(d)
            return d

        img = Input(shape=self.shape_hr)
        x = conv2d_block(img, filters, bn=False)
        x = conv2d_block(x, filters, strides=2)
        x = conv2d_block(x, filters*2)
        x = conv2d_block(x, filters*2, strides=2)
        x = conv2d_block(x, filters*4)
        x = conv2d_block(x, filters*4, strides=2)
        x = conv2d_block(x, filters*8)
        x = conv2d_block(x, filters*8, strides=2)
        x = Dense(filters*16)(x)
        x = LeakyReLU(alpha=0.2)(x)
        x = Dense(1, activation='sigmoid')(x)

        model = Model(inputs=img, outputs=x)
        return model

    def build_srgan(self):

        img_lr = Input(self.shape_lr)
        generated_hr = self.generator(img_lr)
        generated_features = self.vgg(
            self.preprocess_vgg(generated_hr)
        )
        self.discriminator.trainable = False
        generated_check = self.discriminator(generated_hr)
        generated_features = Lambda(lambda x: x, name='Content')(generated_features)
        generated_check = Lambda(lambda x: x, name='Adversarial')(generated_check)

        model = Model(inputs=img_lr, outputs=[generated_check, generated_features])        
        return model

    def PSNR(self, y_true, y_pred):
        return -10.0 * K.log(K.mean(K.square(y_pred - y_true))) / K.log(10.0) 

    def compile_vgg(self, model):
        model.compile(
            loss='mse',
            optimizer=Adam(0.0001, 0.9),
            metrics=['accuracy']
        )

    def compile_generator(self, model):
        model.compile(
            loss=self.gan_loss,
            optimizer=Adam(self.gen_lr, 0.9),
            metrics=['mse', self.PSNR]
        )

    def compile_discriminator(self, model):
        model.compile(
            loss=self.dis_loss,
            optimizer=Adam(self.dis_lr, 0.9),
            metrics=['accuracy']
        )

    def compile_srgan(self, model):
        model.compile(
            loss=[self.dis_loss, self.gan_loss],
            loss_weights=self.loss_weights,
            optimizer=Adam(self.gen_lr, 0.9)
        )

In [12]:
 gan = SRGAN(gen_lr=1e-5)

(None, 224, 224, 3)
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg19/vgg19_weights_tf_dim_ordering_tf_kernels.h5


IndexError: pop from empty list