# <center> <font color=green> PAINTING BEYOND IMAGE BOUNDARIES USING </font> <font color=red>GAN

# References:
 - https://arxiv.org/abs/1808.08483 [1]
 - https://github.com/ShinyCode/image-outpainting [2]
 

# <center> <font color=blue> Necessary Modules

In [None]:
from __future__ import print_function, division

import numpy as np
from numpy import asarray
from numpy import cov
from numpy import trace
from numpy import iscomplexobj
from numpy.random import shuffle
from numpy.random import randint
from numpy.random import randn
from numpy.random import random

from scipy.linalg import sqrtm
from skimage.transform import resize
from matplotlib import pyplot as plt
%matplotlib inline

import time
import sys
import os
import glob
import copy
import keras
import keras.backend as K
import tensorflow as tf
import cv2
from keras.preprocessing import image

# Models -> Layers -> Initializers -> Optimizers -> Utility function
from keras.models import Sequential, load_model
from keras.layers import Dense, Conv2D, Conv2DTranspose, Dropout
from keras.layers import Activation, LeakyReLU, BatchNormalization, Flatten, Reshape
from keras.layers.convolutional import AtrousConvolution2D # needed for dilated convolutional layer in "Generator"

from keras.initializers import RandomNormal
from keras.optimizers import Adam
from keras.utils import plot_model, np_utils
from contextlib import redirect_stdout  # for writing model.summary into a text file

# <center> Drive Folder - Authentication

In [None]:
# For integrating drive with colab
from google.colab import files

#  Google Drive Authentication
from google.colab import drive
drive.mount('/content/drive')

In [None]:
path = '/content/drive/My Drive/OutPaint_DCGAN/'

if not os.path.exists(path):
    os.makedirs(path)
os.chdir(path)

In [None]:
tf.logging.set_verbosity(tf.logging.ERROR)  # suppressing warning messages

#  <center>Dataset: <font color=blue> places365 </font>

 - Dataset URL : http://data.csail.mit.edu/places/places365/val_256.tar

### Preprocessing 

In [None]:
# Ref [2], I took the idea for the below function from their implementation and then write the code by my own

def create_maksed_images(X_temp):
    # Mask creation
    image_count = X_temp.shape[0]
    mask_shape  = (image_count,128,128,1)
    mask = np.zeros(mask_shape)
    mask[:, :, :32, :] = 1.0 # left portion
    mask[:, :,-32:, :] = 1.0 # right portion
    
    # Filled the left and right portion with the pixel_average (mu)
    mu = np.mean(X_temp, axis=(1,2,3))

    X_temp[:, :, :32, :] = mu[:, np.newaxis, np.newaxis, np.newaxis]
    X_temp[:, :,-32:, :] = mu[:, np.newaxis, np.newaxis, np.newaxis]
    X_mask = np.concatenate((X_temp, mask), axis=3)
    
    return X_mask

In [None]:
if not os.path.exists('places_dataset.npz'):
    
    image_size = (128,128,3)
    raw_images_path = '/content/drive/My Drive/OutPaint_DCGAN/raw_dataset/val_2000/*.jpg'  # Use 2000 images as there was getting an error  "A Google Drive timeout has occurred" due to too many file accesses
    
    raw_images = [cv2.imread(file) for file in glob.glob(raw_images_path)]
    raw_images = np.asarray(raw_images) 
    print("Loaded all images from disk")
    resized_images = list()

    for image in raw_images:
        new_image = resize(image, image_size, anti_aliasing=True)
        resized_images.append(new_image)
        
    resized_images = np.asarray(resized_images)
    print("Resized all images")

    # All training images(normalized) and size of every image : 128 by 128 by 3
    X_train = resized_images/255.0 
    X_temp  = copy.deepcopy(X_train)

    # All processed and masked training images and size of each image : 128 by 128 by 4
    X_mask = create_maksed_images(X_temp)
    np.savez('places_dataset.npz', X_train=X_train, X_mask=X_mask) 
    print("Save processed images")

In [None]:
data = np.load('places_dataset.npz')
X_train = data['X_train']
X_mask  = data['X_mask']

# <center> <font color=red>Global</font> <font color=blue>  Discriminator - Model

In [None]:
    '''
       Model Architecture from Paper : Ref: [1]
      
       Layer-Type  Filter-Size     Stride    No_of_filters
          ----     -----------     ------    -------------
          CONV          5            2           32
          CONV          5            2           64
          CONV          5            2           64
          CONV          5            2           64
          CONV          5            2           64
          FC            -            -           512
    '''

In [None]:
# Use strided convolutions to repeatedly downsample an image for binary classification

def create_D():
    
    D    = Sequential(name='Global_Discriminator')
    init = RandomNormal(stddev=0.02) # zero-centered Gaussian distribution with a standard deviation of 0.02.
  
    
    D.add(Conv2D(32, (5,5), strides=2,padding='same',
                 kernel_initializer=init, input_shape=(128,128,3), activation='relu'))
    
    
    D.add(Conv2D(64, (5,5), strides=2, padding='same', activation='relu'))

    D.add(Conv2D(64, (5,5), strides=2, padding='same', activation='relu'))
    D.add(BatchNormalization(momentum=0.9))

    D.add(Conv2D(64, (5,5), strides=2, padding='same', activation='relu'))
    D.add(BatchNormalization(momentum=0.9))
    
  
    D.add(Conv2D(64, (5,5), strides=2, padding='same', activation='relu'))
    D.add(BatchNormalization(momentum=0.9))

    # Final Layer : Fully Connected Layer (classifier) 
    D.add(Flatten())
    D.add(Dropout(0.4))
    
    D.add(Dense(512, activation='relu'))  
    D.add(Dense(1,   activation='sigmoid'))  
    return D

In [None]:
# plot_model(create_D())

In [None]:
#Saving the architecture of the Discriminator model in a file in google drive

if not os.path.exists('Global_Discriminator.txt'):
    with open('Global_Discriminator.txt', 'w') as f:
        with redirect_stdout(f):
            create_D(0.2,0.02).summary()

# <center><font color=blue> Generator - Model

In [None]:
# Encode-Decoder CNN

def create_G():
    
    G    = Sequential(name='Generator')
    init = RandomNormal(stddev=0.02)
    
      
    G.add(Conv2D(64, (5,5), strides=1, dilation_rate=(1,1), padding='same', 
                              kernel_initializer=init, input_shape = (128,128,4), activation='relu' ))
   

    G.add(Conv2D(128, (3,3), strides=2, dilation_rate=(1,1), padding='same', activation='relu')) 
    G.add(Conv2D(256, (3,3), strides=1, dilation_rate=(1,1), padding='same', activation='relu'))
  
     
    G.add(Conv2D(256, (3,3), strides=1, dilation_rate=(2,2), padding='same', activation='relu'))
    G.add(BatchNormalization(momentum=0.9))

    G.add(Conv2D(256, (3,3), strides=1, dilation_rate=(4,4), padding='same', activation='relu'))
    G.add(BatchNormalization(momentum=0.9))
    
    G.add(Conv2D(256, (3,3), strides=1, dilation_rate=(8,8), padding='same', activation='relu'))
    G.add(BatchNormalization(momentum=0.9))
    
    G.add(Conv2D(256, (3,3), strides=1, dilation_rate=(1,1), padding='same', activation='relu'))
    G.add(BatchNormalization(momentum=0.9))
    
    
    
    G.add(Conv2DTranspose(128,(4,4),strides=2, padding='same', activation='relu')) # DECONV
    G.add(BatchNormalization(momentum=0.9))
    
    G.add(Conv2D(64, (3,3), strides=1, dilation_rate = (1,1), padding='same', activation='relu'))
    G.add(BatchNormalization(momentum=0.9))
    
    G.add(Conv2D(3, (3,3), strides=1, padding='same', activation='sigmoid'))
    
    return G

In [None]:
# plot_model(create_G())

In [None]:
# Saving the architecture of the Generator model in a file in drive

if not os.path.exists('Generator.txt'):
    with open('Generator.txt', 'w') as f:
        with redirect_stdout(f):
            create_G().summary()

# <center><font color=blue> GAN - Model

In [None]:
# custom loss functions 

def G_loss_MSE(y_true, y_pred):
    diff = y_pred - y_true
    loss = K.mean(K.square(diff))
    return loss

def G_loss(y_true, y_pred):
    
    mse_loss = G_loss_MSE(y_true, y_pred)
    
    fake_prediction = tf.maximum(y_pred, K.epsilon())
    fake_loss = tf.log(fake_prediction)
    
    loss = mse_loss - 0.004 * tf.reduce_mean(fake_loss)  # 0.004 is hyperparameter from the paper
    return loss

def D_loss(y_true, y_pred):
    
    real_prediction = tf.maximum(y_true, K.epsilon())
    fake_prediction = tf.maximum(1.0 - y_pred, K.epsilon())
    
    real_loss = tf.log(real_prediction)
    fake_loss = tf.log(fake_prediction)
    
    loss = -tf.reduce_mean(real_loss+fake_loss)
    return loss

def wasserstein_loss(y_true, y_pred):
    return K.mean(y_true * y_pred)


def create_DCGAN(g_learning_rate, g_beta_1, d_learning_rate, d_beta_1, phase=1):
    
    # discriminator
    D = create_D()
    D_optimizer = Adam(lr=d_learning_rate, beta_1=d_beta_1)
    D.compile(optimizer=D_optimizer,loss=D_loss, metrics=['binary_accuracy'])
    
  
    # generator
    G = create_G()

    # GAN
    gan = Sequential([G,D])
    gan_optimizer = Adam(lr=g_learning_rate, beta_1=g_beta_1)
    
    if phase==1:
        gan.compile(optimizer=gan_optimizer, loss=G_loss_MSE, metrics=['binary_accuracy'])
    else:
        gan.compile(optimizer=gan_optimizer, loss=G_loss, metrics=['binary_accuracy'])
    
    return gan, G, D

# <center> <font color=blue>    Utility functions for Training 

In [None]:
def trainable(gan_model):
    for layer in gan_model.layers:
        layer.trainable = True

def non_trainable(gan_model):
    for layer in gan_model.layers:
        layer.trainable = False       
        
def make_labels(batch_size, isReal):
    if isReal==True:
        return np.ones([batch_size, 1])
    else:
        return np.zeros([batch_size, 1])
    
''' Label Smoothing : From Soumith Chintala’s GAN Hacks 
    Ref : https://www.youtube.com/watch?v=X1mUN6dD8uE '''

# Here, I use only one-sided smoothing
# Smoothing Real images to [0.9, 1.0]
def smooth_positive_labels(y):
    return y - 0.1 + (random(y.shape)*0.1)


# Necessary text files 

In [None]:
if not os.path.exists('iteration_count.txt'):
    with open("iteration_count.txt", "a") as f:
        f.write("Iteration_count : \n")
        
if not os.path.exists('losses.txt'):
    with open("losses.txt", "a") as f:
        f.write("D_Real_Loss D_Fake_Loss D_Total_loss G_Loss\n")

# <center> <font color=red> TRAINING

In [None]:
def train_DCGAN(g_learning_rate,    # learning rate for the generator
                g_beta_1,           # the exponential decay rate for the 1st moment estimates in Adam optimizer
                d_learning_rate,    # learning rate for the discriminator
                d_beta_1            # the exponential decay rate for the 1st moment estimates in Adam optimizer
               ):

    # Training specific Hyperparameter
    max_iters = 160000
    smooth= 0.05        # label smooting to avoid overfitting problem in discriminator model
    batch_size = 64     # batch size
    
    # Phase iteration count 
    # Phase split ratio :-  T1:T2:T3 --> 18:2:80
    phase_1_iters = 36000
    phase_2_iters = 4000

    # Get the previous iteration counts if the models was trained before
    prev_iter = 0
    
    if os.path.exists('iteration_count.txt'):
        prev_iter = sum(1 for line in open('iteration_count.txt'))-1
    print('Previous iteration count :',prev_iter)
    

    # labels for real and fake images
    y_train_real = make_labels(batch_size,True)
    y_train_fake = make_labels(batch_size,False)
    
    
    # create a GAN, a generator and a discriminator or load the previously trained models
    if os.path.exists('best_generator.h5'):
        generator     = load_model('best_generator.h5')
        discriminator = load_model('best_discriminator.h5')
        gan           = load_model('best_gan.h5')
        print("Loaded old Gan, Generator and Discriminator")
      
    else:
        gan, generator, discriminator = create_DCGAN(g_learning_rate, g_beta_1, d_learning_rate, d_beta_1)
        print("Created new Gan, Generator and Discriminator")
      
    start_time = time.time()
    
    
    for iters in range(prev_iter, max_iters):
        print("Iteration :",iters)
        with open('iteration_count.txt', 'w') as f:
            f.write(str(iters))
        
        # Training phase 01 - Train only the generator according to MSE loss
        
        if iters < phase_1_iters: 
            print("phase 01")
            discriminator.trainable = False 
            X_batch_masked = X_mask[np.random.choice(X_mask.shape[0], batch_size , replace = True), :] 
            gan.train_on_batch(X_batch_masked, y_train_real)
    
    
        # Training phase 02 - Train only the discriminator 
        
        elif iters < phase_1_iters + phase_2_iters:
            print("phase 02")
            discriminator.trainable = True 
            X_batch_real   = X_train[np.random.choice(X_train.shape[0], batch_size , replace = True), :] 
            
            X_batch_masked = X_mask[np.random.choice(X_mask.shape[0], batch_size , replace = True), :] 
            X_batch_fake   = generator.predict_on_batch(X_batch_masked)

            discriminator.train_on_batch(X_batch_real, smooth_positive_labels(y_train_real))
            discriminator.train_on_batch(X_batch_fake, y_train_fake)
        
        
        # Training phase 03 - Train both D and G adversarially
        else:
            print("phase 03")
            if iters==phase_1_iters + phase_2_iters:
                GAN,GENERATOR,_ = create_DCGAN(g_learning_rate, g_beta_1, d_learning_rate, d_beta_1, 2)
                GAN.set_weights(gan.get_weights())
                GENERATOR.set_weights(generator.get_weights())
                
            # TRAIN - DISCRIMINATOR
            discriminator.trainable = True # trainable(discriminator)

            # Real samples
            X_batch_real   = X_train[np.random.choice(X_train.shape[0], batch_size , replace = True), :] 

            # Fake Samples
            X_batch_masked = X_mask[np.random.choice(X_mask.shape[0], batch_size , replace = True), :] 
            X_batch_fake   = GENERATOR.predict_on_batch(X_batch_masked)

            # Train the discriminator to detect real and fake images
            discriminator.train_on_batch(X_batch_real, smooth_positive_labels(y_train_real))
            discriminator.train_on_batch(X_batch_fake, y_train_fake)

            # TRAIN - GENERATOR
            discriminator.trainable = False # non_trainable(discriminator)
            GAN.train_on_batch(X_batch_masked, y_train_real)
        
        
        # Training Losses
        
        if (iters + 1) % 10000 == 0:
            
            X_batch_masked = X_mask[np.random.choice(X_mask.shape[0], batch_size , replace = True), :]
            
            if iters < phase_1_iters + phase_2_iters:
                gan_images     = generator.predict_on_batch(X_batch_masked)
                g_loss_batch   = gan.test_on_batch(X_batch_masked, y_train_real)
            else:
                gan_images     = GENERATOR.predict_on_batch(X_batch_masked)
                g_loss_batch   = GAN.test_on_batch(X_batch_masked, y_train_real)

            X_batch_real    = X_train[np.random.choice(X_train.shape[0], batch_size , replace = True), :] 

            d_loss_real  = discriminator.test_on_batch(X_batch_real, y_train_real)
            d_loss_fake  = discriminator.test_on_batch(gan_images, y_train_fake)

            with open("losses.txt", "a") as f:
                total = 0.5 * ( round(d_loss_real[0],12) + round(d_loss_fake[0],12) )
                f.write( str(round(d_loss_real[0],12)) + ' ' + str(round(d_loss_fake[0],12)) +' ' + str(total) + ' ' + str( round(g_loss_batch[0],12) ) )
                f.write('\n')
        

        # CHECKPOINT FOR SAVING MODEL  
        
        if (iters + 1) % 10000 == 0 and iters > phase_1_iters + phase_2_iters:
            if os.path.exists('best_generator.h5'):
                os.remove('best_generator.h5')
            GENERATOR.save('best_generator.h5')

            if os.path.exists("best_discriminator.h5"):
                os.remove("best_discriminator.h5")
            discriminator.save('best_discriminator.h5')

            if os.path.exists("best_gan.h5"):
                os.remove("best_gan.h5")
            GAN.save('best_gan.h5')
            
    print("Training Time:- %s seconds" % (time.time() - start_time)) 
    
    return GAN, GENERATOR, discriminator

In [None]:
GAN, GENERATOR, D = train_DCGAN(g_learning_rate=0.0001,g_beta_1=0.5,d_learning_rate=0.0001,d_beta_1=0.5)