In [9]:
import glob
import os
import numpy as np
import matplotlib.pyplot as plt
from imageio import imread

import tensorflow as tf
from keras import Input
from keras.applications import VGG19
from keras.callbacks import TensorBoard
from keras.layers import BatchNormalization, Activation, LeakyReLU, Add, Dense, PReLU, Flatten
from keras.layers.convolutional import Conv2D, UpSampling2D
from keras.models import Model
from keras.optimizers import Adam
from keras.preprocessing.image import img_to_array, load_img


In [2]:
def residual_block(x):
    
    res = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = "same")(x)
    res = Activation(activation = "relu")(res)
    res = BatchNormalization(momentum = 0.8)(res)
  
    res = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = "same")(x)
    res = BatchNormalization(momentum = 0.8)(res)
  
    res = Add()([res, x])
  
    return res

In [3]:
def build_gen():
  
    res_blocks = 16
    input_shape = (64, 64, 3)
  
    input_layer = Input(shape = input_shape)
  
    gen1 = Conv2D(filters = 64, kernel_size = 9, strides = 1, padding = 'same', activation = 'relu')(input_layer)
  
    res = residual_block(gen1)
    for i in range(res_block - 1):
        res = residual_block(res)
    
    gen2 = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = 'same')(res)
    gen2 = BatchNormalization(momentum = 0.8)(gen2)
  
    gen3 = Add()([gen2, gen1])
  
    gen4 = UpSampling2D(size = (2, 2))(gen3)
    gen4 = Conv2D(filters = 256, kernel_size = 3, strides = 1, padding = "same")(gen4)
    gen4 = Activation('relu')(gen4)
  
    gen5 = UpSampling2D(size = (2, 2))(gen4)
    gen5 = Conv2D(filters = 256, kernel_size = 3, strides = 1, padding = "same")(gen5)
    gen5 = Activation('relu')(gen5)
  
    gen6 = Conv2D(filters = 3, kernel_size = 9, strides = 1, padding = "same")(gen5)
    output = Activation('tanh')(gen6)
  
    model = Model(inputs = [input_layer], outputs = [output], name = 'generator')
  
    return model

In [4]:
def build_disc():
    
    input_shape = (256, 256, 3)
    
    input_layer = Input(shape = input_shape)
    
    disc1 = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = 'same')(input_layer)
    disc1 = LeakyReLU(alpha = 0.2)(disc1)
    
    disc2 = Conv2D(filters = 64, kernel_size = 3, strides = 2, padding = 'same')(disc1)
    disc2 = LeakyReLU(alpha = 0.2)(disc2)
    disc2 = BatchNormalization(momentum = 0.8)(disc2)
    
    disc3 = Conv2D(filters = 128, kernel_size = 3, strides = 1, padding = 'same')(disc2)
    disc3 = LeakyReLU(alpha = 0.2)(disc3)
    disc3 = BatchNormalization(momentum = 0.8)(disc3)
    
    disc4 = Conv2D(filters = 128, kernel_size = 3, strides = 2, padding = 'same')(disc3)
    disc4 = LeakyReLU(alpha = 0.2)(disc4)
    disc4 = BatchNormalization(momentum = 0.8)(disc4)
    
    disc5 = Conv2D(filters = 256, kernel_size = 3, strides = 1, padding = 'same')(disc4)
    disc5 = LeakyReLU(alpha = 0.2)(disc5)
    disc5 = BatchNormalization(momentum = 0.8)(disc5)
    
    disc6 = Conv2D(filters = 256, kernel_size = 3, strides = 2, padding = 'same')(disc5)
    disc6 = LeakyReLU(alpha = 0.2)(disc6)
    disc6 = BatchNormalization(momentum = 0.8)(disc6)
    
    disc7 = Conv2D(filters = 512, kernel_size = 3, strides = 1, padding = 'same')(disc6)
    disc7 = LeakyReLU(alpha = 0.2)(disc7)
    disc7 = BatchNormalization(momentum = 0.8)(disc7)
    
    disc8 = Conv2D(filters = 512, kernel_size = 3, strides = 2, padding = 'same')(disc7)
    disc8 = LeakyReLU(alpha = 0.2)(disc8)
    disc8 = BatchNormalization(momentum = 0.8)(disc8)
    
    disc9 = Dense(units = 1024)(disc8)
    disc9 = LeakyReLU(alpha = 0.2)(disc9)
    
    output = Dense(units = 1, activation = sigmoid)(disc9)
    
    model = Model(inputs = [input_layer], outputs = [output], name = 'discriminator')
    
    return model


In [5]:
def build_vgg():
    
    input_shape = (256, 256, 3)
    
    vgg = VGG19(weights = "imagenet")
    vgg.outputs = [vgg.layers[9].output]
    
    input_layer = Input(shape = input_shape)
    features = vgg(input_layer)
    
    model = Model(inputs = [input_layer], outputs = [features])
    
    return model

In [6]:
def build_adversarial(generator, discriminator, vgg):
    
    input_low_resolution = (64, 64, 3)
    
    fake_hr_images = generator(input_low_resolution)
    fake_features = vgg(fake_hr_images)
    
    discriminator.trainable = False
    
    output = discriminator(fake_hr_images)
    
    model = Model(inputs = [input_low_resolution], outputs = [output, fake_features])
    
    for layer in model.layers:
        print(layer.name, layer.trainable)
        
    print(model.summary())
    
    return model


In [10]:
def sample_images(data_dir, batch_size, high_resolution_shape, low_resolution_shape):
    
    print("Loading Data")
    
    all_images = glob.glob(data_dir)
    
    images_batch = np.random.choice(all_images, size = batch_size)
    
    low_resolution_images = []
    high_resolution_images = []
    
    for img in images_batch:
        
        img1 = imread(img, mode = 'RGB')
        img1 = img1.astype(np.float32)
        
        img1_high_resolution = imresize(img1, high_resolution_shape)
        img1_low_resolution = imresize(img1, low_resolution_shape)
        
        if np.random.random() < 0.5:
            
            img1_high_resolution = np.flip(img1_high_resolution)
            img1_low_resolution = np.flip(img1_low_resolution)
            
        high_resolution_images.append(img1_high_resolution)
        low_resolution_images.append(img1_low_resolution)
        
        print("Data Loaded")
        
        return np.asarray(high_resolution_images), np.asarray(low_resolution_images)

In [11]:
def write_log(callback, name, value, batch_no):
    
    summary = tf.Summary()
    summary_value = summary.value.add()
    summary_value.simple_value = value
    summary_value.tag = name
    callback.writer.add_summary(summary, batch_no)
    callback.writer.flush()

In [None]:
def train():
    
    data_dir = "img_align_celeba"
    epochs = 20000
    batch_size = 1
    
    low_resolution_shape = (64, 64, 3)
    high_resolution_shape = (256, 256, 3)
    
    optimizer = Adam(0.0002, 0.5)
    
    vgg = build_vgg()
    vgg.trainable = False
    vgg.compile(loss = 'mse', optimizer = optimizer, metrics = ['accuracy'])
    
    discriminator = build_disc()
    discriminator.compile(loss = 'mse', optimizer = optimizer, metrics = ['accuracy'])
    
    generator = build_generator()
    
    input_high_resolution = Input(shape = high_resolution_shape)
    input_low_resolution = Input(shape = low_resolution_shape)
    
    generated_high_resolution_images = generator(input_low_resolution)
    features = vgg(generated_high_resolution_images)
    
    discriminator.trainable = False
    
    probs = discriminator(generated_high_resolution_images)
    
    adversarial_model = Model(inputs = [input_low_resolution, input_high_resolution], outputs = [probs, features])
    adversarial_model.compile(loss = ['binary_crossentropy', 'mse'], loss_weights = [1e-3, 1], optimizer = optimizer)
    
    tensorboard = TensorBoard(log_dir = "logs/".format(time.time()))
    tensorboard.set_model(generator)
    tensorboard.set_model(discriminator)
    
    for epoch in range(epochs):
        
        print("Epoch: {}".format(epoch))
        
        high_resolution_images, low_resolution_images = sample_images(data_dir = data_dir, batch_size = batch_size,
                                                                      low_resolution_shape = low_resolution_shape, 
                                                                      high_resolution_shape = high_resolution_shape)
        
        high_resolution_images = high_resolution_images/127.5 - 1
        low_resolution_images = low_resolution_images/127.5 - 1
        
        generated_high_resolution_images = generator.predict(low_resolution_images)
        
        real_labels = np.ones((batch_size, 16, 16, 1))
        fake_labels = np.zeros((batch_size, 16, 16, 1))
        
        