In [1]:
import tensorflow as tf
import numpy as np




In [5]:
# residual block for iperator
def residual_block(ip):
    residual_model = tf.keras.layers.Conv2D(64, (3,3), padding='same')(ip)
    residual_model = tf.keras.layers.BatchNormalization(momentum= 0.5)(residual_model)
    residual_model = tf.keras.layers.PReLU(shared_axes=[1,2])(residual_model)
    residual_model = tf.keras.layers.Conv2D(64, (3,3), padding='same')(residual_model)
    residual_model = tf.keras.layers.BatchNormalization(momentum= 0.5)(residual_model)
    
    return tf.keras.layers.Add([ip, residual_model])
    

In [3]:
def pixel_shuffler_block(ip):
    ps_model = tf.keras.layers.Conv2d(256, (3,3), padding='same')(ip)
    ps_model = tf.keras.layers.Upsampling2D(size = 2)(ps_model)
    ps_model = tf.keras.layers.PReLU(shared_axes=[1,2])(ps_model)
    
    return ps_model

In [7]:
def create_generator(generator, no_res_blocks = 16, no_pixel_shuffler_blocks = 2):
    model = tf.keras.layers.Conv2D(64, (9,9), padding='same')(generator)
    model = tf.keras.layers.PReLU(shared_axes=[1,2])(model)
    
    temp = model
    
    for i in range(no_res_blocks):
        model = residual_block(model)
    
    model = tf.keras.layers.Conv2D(64, (3,3), padding='same')(model)
    model = tf.keras.layers.BatchNormalization(momentum= 0.5)(model)
    model = tf.keras.layers.Add([model, temp])
    
    for i in range(no_pixel_shuffler_blocks):
        model = pixel_shuffler_block(model)
    
    output = tf.keras.layers.Conv2D(3, (9,9), padding='same')(model)

    return tf.keras.Model(inputs= generator, outputs= output)
    

In [11]:
def create_disriminator(discriminator):
    filters = [128, 256, 512]
    strides = [1, 2]
    model = tf.keras.layers.Conv2D(64, (3,3), strides= 1, padding='same')(discriminator)
    model = tf.keras.layers.LeakyReLU(alpha=0.2)(model)
    
    model = tf.keras.layers.Conv2D(64, (3,3), strides= 2, padding='same')(model)
    model = tf.keras.layers.BatchNormalization(momentum = 0.8)(model)
    model = tf.keras.layers.LeakyReLU(alpha=0.2)(model)
    
    for filter in filters:
        for stride in strides:       
            model = tf.keras.layers.Conv2D(filter, (3,3), strides= stride, padding='same')(model)
            model = tf.keras.layers.BatchNormalization(momentum = 0.8)(model)
            model = tf.keras.layers.LeakyReLU(alpha=0.2)(model)
    
    model = tf.kera.layers.Flatten()(model)
    model = tf.keras.layers.Dense(1024)(model)
    model = tf.keras.layers.LeakyReLU(alpha= 0.2)(model)
    model = tf.keras.layers.Dense(1, activation='sigmoid')(model)
    
    return tf.keras.Model(inputs= discriminator, outputs=model)

In [9]:
from keras.applications import VGG19

In [10]:
def create_vgg(hr_shape):
    vgg = VGG19(weights='imagenet', include_top=False, input_shape=hr_shape)
    
    return tf.keras.Model(inputs= vgg.inputs, output= vgg.layers[10].output)

In [12]:
def final_model(generator, discriminator, vgg, lr_ip, hr_ip):
    gen_img = generator(lr_ip)
    
    gen_features = vgg(gen_img)
    
    discriminator.trainable = False
    disc_result = discriminator(gen_img)
    
    return tf.keras.Model(inputs=[lr_ip, hr_ip], outputs= [disc_result, gen_features])