In [3]:
import shutil, os, sys, io, random, math
gpu_count = 1
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

from keras.models import Sequential, Model, load_model
from keras.layers import Input, Reshape, Flatten, Activation, UpSampling2D, Dense, AveragePooling2D, add, Lambda, GaussianNoise
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.advanced_activations import ReLU, LeakyReLU, PReLU
from keras.layers.normalization import BatchNormalization
from keras.optimizers import Adam, RMSprop
from keras.initializers import glorot_uniform, RandomNormal
from keras.utils import multi_gpu_model

import keras
import keras.backend as K
import keras.backend.tensorflow_backend as KTF
import tensorflow as tf

import numpy as np
import matplotlib.pyplot as plt
from functools import partial
from PIL import Image
from tqdm import tqdm

os.chdir('/home/k_yonhon/py/Keras-GAN/srresgan/')
# os.chdir('\\Users\\pro18\\Documents\\py\\Keras-GAN\\wgan_gp')
sys.path.append(os.pardir)

from tensor_board_logger import TensorBoardLogger
from wasserstein_loss import WassersteinLoss, GradientPenaltyLoss, loss_func_dcgan_dis_real, loss_func_dcgan_dis_fake

config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
session = tf.Session(config=config)
KTF.set_session(session)

# ---------------------
#  Parameter
# ---------------------
size = 128
dataset = np.load('../datasets/lfw'+str(size)+'.npz')['arr_0']
# dataset = np.ones((1000, size, size, 3), dtype=np.float32)

modelname = 'srres_dragan93_base16_lmd05'

input_size = size // 8 # size // 8
base_critic = 16 # 16
n_Gen_ResBlock = 16 # 16

unrolling_steps = 5 # 5
n_critic = 1 # 1
λ = 10 # 10

lr = 0.0002 # 0.0002
beta_1 = 0.5 # 0.5
BN = True

epochs = 10000
batch_size = 64 
sample_interval = 1000
resume = 0

In [2]:
def Gen_ResBlock(input_layer, filters, kernal_size=3, strides=1):
    # initializer = glorot_uniform(seed=None)
    initializer = RandomNormal(mean=0.0, stddev=0.02, seed=None)

    x = Conv2D(filters = filters, 
               kernel_size = kernal_size, strides = strides, padding = "same", 
               kernel_initializer=initializer)(input_layer)
    if BN:
        x = BatchNormalization(momentum = 0.9)(x)
    x = Activation("relu")(x)
    x = Conv2D(filters = filters, 
               kernel_size = kernal_size, strides = strides, padding = "same", 
               kernel_initializer=initializer)(x)
    if BN:
        x = BatchNormalization(momentum = 0.9)(x)
    output_layer = add([input_layer, x])
    
    return output_layer


def pixel_shuffler(input_layer):
    input_shape = K.int_shape(input_layer)
    batch_size, h, w, c = input_shape
    # batch_size is None maybe..
    rh, rw = 2, 2
    oh, ow = h * rh, w * rw
    oc = c // (rh * rw)

    # x = Reshape((h, w, rh, rw, oc))(input_layer) -> Mod
    x = Reshape((h, w, oc, rh, rw))(input_layer)
    # calculated = K.sqrt(model_input + 1.0) 
    # calculated = Lambda(lambda x: K.sqrt(x + 1.0), output_shape=(10,))(model_input)
    # x = K.permute_dimensions(x, (0, 1, 3, 2, 4, 5))
    x = Lambda(lambda x: K.permute_dimensions(x, (0, 1, 4, 2, 5, 3)), 
               output_shape=(h, rh, w, rw, oc))(x)
    output_layer = Reshape((oh, ow, oc))(x)
    
    return output_layer


def CBR(input_layer, filters, kernel_size=3, strides=1):
    # initializer = glorot_uniform(seed=None)
    initializer = RandomNormal(mean=0.0, stddev=0.02, seed=None)
    
    x = Conv2D(filters = filters, 
               kernel_size = kernel_size, strides = strides, padding = "same", 
               kernel_initializer=initializer)(input_layer)
    x = pixel_shuffler(x)
    if BN:
        x = BatchNormalization(momentum = 0.9)(x)
    output_layer = Activation("relu")(x)
    
    return output_layer


def Dis_ResBlock(input_layer, filters, kernel_size=3, strides=1):
    # initializer = glorot_uniform(seed=None)
    initializer = RandomNormal(mean=0.0, stddev=0.02, seed=None)
    
    x = Conv2D(filters = filters, 
               kernel_size = kernel_size, strides = strides, padding = "same", 
               kernel_initializer=initializer)(input_layer)
    x = LeakyReLU(alpha = 0.2)(x)
    x = Conv2D(filters = filters, 
               kernel_size = kernel_size, strides = strides, padding = "same", 
               kernel_initializer=initializer)(x)
    x = add([input_layer, x])
    output_layer = LeakyReLU(alpha = 0.2)(x)
    
    return output_layer


def SE(input_layer, r=16):
    initializer = glorot_uniform(seed=None)
    input_shape = K.int_shape(input_layer)
    batch_size, h, w, c = input_shape
    # batch_size is None maybe..
    
    #x = AveragePooling2D(pool_size=(h, w), strides=None, padding='valid')(input_layer)
    x = AveragePooling2D(pool_size=(h, w), strides=1, padding='valid')(input_layer)
    x = Reshape((c,))(x)
    x = Dense(int(c/r), kernel_initializer=initializer)(x)
    x = Activation("relu")(x)
    x = Dense(c, kernel_initializer=initializer)(x)
    x = Activation("sigmoid")(x)
    # x = K.expand_dims(x, axis=2)
    # x = K.repeat_elements(x, rep=h, axis=2)
    x = Lambda(lambda x: K.expand_dims(x, axis=2),
               output_shape=(c, 1))(x)
    x = Lambda(lambda x: K.repeat_elements(x, rep=h, axis=2),
               output_shape=(c, h))(x)
    x = Lambda(lambda x: K.expand_dims(x, axis=3),
               output_shape=(c, h, 1))(x)
    x = Lambda(lambda x: K.repeat_elements(x, rep=w, axis=3),
               output_shape=(c, h, w))(x)
    output_layer = Lambda(lambda x: K.permute_dimensions(x, (0, 2, 3, 1)), 
                          output_shape=(h, w, c))(x)
    
    return output_layer
    

def Dis_SEResBlock(input_layer, filters, kernel_size=3, strides=1):
    initializer = glorot_uniform(seed=None)
    
    x = Conv2D(filters = filters, 
               kernel_size = kernel_size, strides = strides, padding = "same", 
               kernel_initializer=initializer)(input_layer)
    x = LeakyReLU(alpha = 0.2)(x)
    x = Conv2D(filters = filters, 
               kernel_size = kernel_size, strides = strides, padding = "same", 
               kernel_initializer=initializer)(x)
    x = SE(x)
    x = add([input_layer, x])
    output_layer = LeakyReLU(alpha = 0.2)(x)    
    
    return output_layer

In [3]:
class DRAGAN():
    def __init__(self, dataset, gpu_count=1, resume=0):
        # ---------------------
        #  Parameter
        # ---------------------
        self.dataset = dataset
        self.gpu_count = gpu_count
        self.resume = resume
                
        self.img_rows = dataset.shape[1]
        self.img_cols = dataset.shape[2]
        self.channels = dataset.shape[3]
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        
        self.input_rows = input_size
        self.input_cols = input_size
        self.latent_dim = 128  # Noise dim
        
        self.unrolling_steps = unrolling_steps
        self.n_critic = n_critic
        self.λ = λ
        optimizer = Adam(lr=lr, beta_1=beta_1, beta_2=0.99, epsilon=None, decay=0.0, amsgrad=False)

        # ---------------------
        #  Load models
        # ---------------------
        self.critic = self.build_critic()
        self.generator = self.build_generator()
        if self.resume != 0:
            self.critic = load_model('./saved_model/'+modelname+'_critic_'+str(self.resume)+'epoch.h5')
            self.generator = load_model('./saved_model/'+modelname+'_gen_'+str(self.resume)+'epoch.h5')  
        
        #-------------------------------
        # Compile Critic
        #-------------------------------    
        generated_samples = Input(shape=self.img_shape) 
        critic_output_from_generated_samples = self.critic(generated_samples)
        
        real_samples = Input(shape=self.img_shape)        
        critic_output_from_real_samples = self.critic(real_samples)

        averaged_samples = Input(shape=self.img_shape)
        critic_output_from_averaged_samples = self.critic(averaged_samples)

        partial_gp_loss = partial(GradientPenaltyLoss,
                                  averaged_samples=averaged_samples,
                                  gradient_penalty_weight=self.λ)
        # Functions need names or Keras will throw an error
        partial_gp_loss.__name__ = 'gradient_penalty'

        self.critic_model = Model(inputs=[generated_samples, 
                                          real_samples,
                                          averaged_samples],
                                  outputs=[critic_output_from_generated_samples, 
                                           critic_output_from_real_samples,
                                           critic_output_from_averaged_samples])
        
        # loss_dis = loss_func_dcgan_dis_real(y_dis) + loss_func_dcgan_dis_fake(y_fake) + loss_grad
        # loss_gen = loss_func_dcgan_dis_real(y_fake)
        
        if self.gpu_count > 1:
            self.critic_model = multi_gpu_model(self.critic_model, gpus=self.gpu_count)
        self.critic_model.compile(optimizer=optimizer, 
                                  loss=[loss_func_dcgan_dis_fake,
                                        loss_func_dcgan_dis_real, 
                                        partial_gp_loss])

        # print('Critic Summary:')
        # self.critic.summary()       
        
        #-------------------------------
        # Compile Generator
        #-------------------------------
        # For the generator we freeze the critic's layers
        self.critic.trainable = False
                    
        generator_input = Input(shape=(self.latent_dim,))
        generator_layers = self.generator(generator_input)
        critic_layers_for_generator = self.critic(generator_layers)
        
        self.generator_model = Model(inputs=[generator_input], 
                                     outputs=[critic_layers_for_generator])
        if self.gpu_count > 1:
            self.generator_model = multi_gpu_model(self.generator_model, gpus=self.gpu_count)
        self.generator_model.compile(optimizer=optimizer,
                                     loss=loss_func_dcgan_dis_real)        

        # print('Genarator Summary:')
        # self.generator.summary()     

    def build_generator(self):
        base = 64
        # initializer = glorot_uniform(seed=None)
        initializer = RandomNormal(mean=0.0, stddev=0.02, seed=None)

        input_layer = Input(shape=(self.latent_dim,))
        x = Dense(base*self.input_rows*self.input_cols, kernel_initializer=initializer)(input_layer)
        if BN:
            x = BatchNormalization(momentum = 0.9)(x)
        x = Activation("relu")(x)
        x1 = Reshape((self.input_rows, self.input_cols, base))(x) 
        x = Gen_ResBlock(x1, base)
        for _ in range(n_Gen_ResBlock-1):
            x = Gen_ResBlock(x, base)
        x = Conv2D(filters = base, 
                   kernel_size = 3, strides = 1, padding = "same", 
                   kernel_initializer=initializer)(x)        
        if BN:
            x = BatchNormalization(momentum = 0.9)(x)
        x = Activation("relu")(x)
        x = add([x1, x])
        for _ in range(int(math.log(self.img_rows / self.input_rows, 2))):
            # x = CBR(x, base)
            x = CBR(x, base*4)
        x = Conv2D(filters = 3, 
                   kernel_size = math.ceil(9 * self.img_rows / 128), 
                   strides = 1, padding = "same", 
                   kernel_initializer=initializer)(x)
        output_layer = Activation('tanh')(x)
        
        model = Model(inputs = input_layer, outputs = output_layer)
        print('Generator Summary:')
        model.summary()
        return model
    
    def build_critic(self):
        base = base_critic
        # initializer = glorot_uniform(seed=None)
        initializer = RandomNormal(mean=0.0, stddev=0.02, seed=None)
        input_layer = Input(shape = self.img_shape)
        input_shape = K.int_shape(input_layer)
        batch_size, h, w, c = input_shape
        
        x = GaussianNoise(stddev=0.1)(input_layer)
        x = Conv2D(filters = base, 
                   kernel_size = 4, strides = 2, padding = "same", 
                   kernel_initializer=initializer)(x)        
        x = LeakyReLU(alpha = 0.2)(x)
        for i in range(int(math.log(self.img_rows, 2) - 2)): # 128=2^7
            if self.img_rows / (2 ** (i + 2)) >=  16: #2^7/2^(1+2)=2^4
                kernel_size = 4
            else:
                kernel_size = 3
            x = Dis_ResBlock(x, base * 2**i)
            x = Dis_ResBlock(x, base * 2**i)        
            x = Conv2D(filters = base * 2**(i+1), 
                       kernel_size = kernel_size, 
                       strides = 2, padding = "same", 
                       kernel_initializer=initializer)(x)        
            x = LeakyReLU(alpha = 0.2)(x)  
        x = Flatten()(x)
        output_layer = Dense(1, kernel_initializer=initializer)(x)
        # output_layer = Activation('sigmoid')(x)
        
        model = Model(inputs = input_layer, outputs = output_layer)
        print('Critic Summary:')
        model.summary()
        return model

    def train(self, epochs, batch_size, sample_interval=5000):       
        # ---------------------
        #  for Logging
        # ---------------------
        target_dir = "./search_128wgan9/log_"+modelname
        seed = 0
        image_num = 5       
        np_samples = []
        
        # "REVISED" Load suspended training weights 
        if self.resume != 0:
            np_samples_npz = np.load('./saved_model/'+modelname+'_sample_'+str(self.resume)+'epoch.npz', allow_pickle=True)['arr_0']
            for i in range(np_samples_npz.shape[0]-1): #last Image will be appended at Log Section
                np_samples.append(np_samples_npz[i,:,:,:,:])
        else:            
            shutil.rmtree(target_dir, ignore_errors=True)
            os.mkdir(target_dir)
                
        self.logger = TensorBoardLogger(log_dir=target_dir)            
        
        # ---------------------
        #  Training
        # ---------------------
        # Rescale the dataset -1 to 1 
        X_train = self.dataset / 127.5 - 1.0

        # Adversarial ground truths
        minus = -np.ones((batch_size, 1), dtype=np.float32)
        plus = np.ones((batch_size, 1), dtype=np.float32)
        dummy = np.zeros((batch_size, 1), dtype=np.float32)

        for epoch in tqdm(range(self.resume, self.resume + epochs + 1)):
            
            # Train Critic
            for _ in range(self.n_critic):
                noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
                gen_imgs = self.generator.predict(noise, batch_size=batch_size)
                
                idx = np.random.randint(0, X_train.shape[0], batch_size)
                real_imgs = X_train[idx]
                               
                ε = np.random.uniform(size=(batch_size, 1,1,1))
                # ave_imgs = ε * real_imgs + (1-ε) * gen_imgs
                # ave_imgs = real_imgs + 0.5 * np.random.random(real_imgs.shape) * real_imgs.std() 
                ave_imgs = real_imgs + 0.5 * ε * np.random.random(real_imgs.shape) * real_imgs.std()
                
                d_loss = self.critic_model.train_on_batch([gen_imgs, real_imgs, ave_imgs], 
                                                          [plus, plus, dummy])
            
            #  Unrolling Step1
            if self.unrolling_steps != 0:
                # Backup weights
                backup_weights = self.critic.get_weights()
                
                # Train critic
                for _ in range(self.unrolling_steps):
                    
                    noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
                    gen_imgs = self.generator.predict(noise, batch_size=batch_size)

                    idx = np.random.randint(0, X_train.shape[0], batch_size)
                    real_imgs = X_train[idx]

                    ε = np.random.uniform(size=(batch_size, 1,1,1))
                    ave_imgs = ε * real_imgs + (1-ε) * gen_imgs
                    # ave_imgs = real_imgs + 0.5 * np.random.random(real_imgs.shape) * real_imgs.std() 
                    # ave_imgs = real_imgs + 0.5 * ε * np.random.random(real_imgs.shape) * real_imgs.std()

                    # Train Critic
                    d_loss = self.critic_model.train_on_batch([gen_imgs, real_imgs, ave_imgs], 
                                                              [plus, plus, dummy])
                
            # Train Generator
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            g_loss = self.generator_model.train_on_batch(noise, plus)

            #  Unrolling Step2
            if self.unrolling_steps != 0:
                # Undo waights
                self.critic.set_weights(backup_weights)
                
            # ---------------------
            #  Logging
            # ---------------------
            # Backup Model
            '''
            if epoch != resume and epoch % sample_interval == 0:
                self.critic.save('./saved_model/'+modelname+'_critic_'+str(self.resume)+'epoch.h5')
                self.generator.save('./saved_model/'+modelname+'_gen_'+str(self.resume)+'epoch.h5')
                np.savez_compressed('./saved_model/'+modelname+'_sample_'+str(self.resume)+'epoch.npz', np_samples)
            '''
            # Log Loss & Histgram
            logs = {
                "loss/Critic": d_loss[0],
                "loss/Generator": g_loss,
                "loss_Critic/D_gen": d_loss[1],
                "loss_Critic/D_real": d_loss[2],
                "loss_Critic/gradient_penalty": d_loss[3],
                "loss_Critic/total_loss": d_loss[1] + d_loss[2] + d_loss[3],                
            }

            histograms = {}
            '''
            for layer in self.critic.layers[1].layers:
                for i in range(len(layer.get_weights())):
                    if "conv" in layer.name or "dense" in layer.name:
                        name = layer.name + "/" + str(i)
                        value = layer.get_weights()[i]
                        histograms[name] = value
            '''
            self.logger.log(logs=logs, histograms=histograms, epoch=epoch)
            
            # Log generated image samples
            if epoch % sample_interval == 0:
                np.random.seed(seed)
                noise = np.random.normal(0, 1, (image_num, self.latent_dim))
                gen_imgs = self.generator.predict(noise)
                gen_imgs = ((0.5 * gen_imgs + 0.5) * 255).astype(np.uint8)
                np_samples.append(gen_imgs)
                '''
                fig, name = self.sample_images(epoch)
                images = {epoch: fig}
                self.logger.log(images=images, epoch=epoch)
                '''
                print("%d [C loss: %f] [G loss: %f]" % (epoch, d_loss[0], g_loss))
                
        return np_samples

In [4]:
gan = DRAGAN(dataset, gpu_count, resume)

Critic Summary:
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 128, 128, 3)  0                                            
__________________________________________________________________________________________________
gaussian_noise_1 (GaussianNoise (None, 128, 128, 3)  0           input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 64, 64, 16)   784         gaussian_noise_1[0][0]           
__________________________________________________________________________________________________
leaky_re_lu_1 (LeakyReLU)       (None, 64, 64, 16)   0           conv2d_1[0][0]                   
_____________________________________________________________________________________________

In [None]:
np_samples = gan.train(epochs=epochs, batch_size=batch_size, sample_interval=sample_interval)

In [5]:
import holoviews as hv
hv.notebook_extension()
for j in range(1, int((epochs+resume)/sample_interval)+1):
    y = np_samples[j]
    for i in range(5):
        if j == 1 and i == 0:
            hv_points = hv.RGB(y[i]).relabel(str(j*sample_interval)+' epoch')
        else:
            hv_points += hv.RGB(y[i]).relabel(str(j*sample_interval)+' epoch')
hv_points.cols(5)

##### 