In [1]:
from keras.models import Sequential, Model, load_model
from keras.layers import Input, Reshape, Flatten, Activation, UpSampling2D
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.advanced_activations import ReLU, LeakyReLU
from keras.optimizers import Adam, RMSprop
from keras.utils import multi_gpu_model

import keras
import keras.backend as K
import keras.backend.tensorflow_backend as KTF
import tensorflow as tf

from keras.layers import Dense, Add, Lambda, Concatenate
from keras.layers.convolutional import UpSampling2D
from keras.layers.normalization import BatchNormalization
from keras.engine.network import Network, Layer
from keras.initializers import TruncatedNormal

import shutil, os, sys, io, random, math
import numpy as np
import matplotlib.pyplot as plt
from functools import partial
from PIL import Image
from tqdm import tqdm

os.chdir('/home/k_yonhon/py/Keras-GAN/dragan/')
sys.path.append(os.pardir)

from tensor_board_logger import TensorBoardLogger
from wasserstein_loss import WassersteinLoss, GradientPenaltyLoss, loss_func_dcgan_dis_real, loss_func_dcgan_dis_fake, CommonVanillaLoss

config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
session = tf.Session(config=config)
KTF.set_session(session)

# ---------------------
#  Parameter
# ---------------------
gpu_count = 1
resume = 0
dataset = np.load('../datasets/lfw32.npz')['arr_0']

In [2]:
def residual_block(x, base_name, block_num, initializer, num_channels=128, is_D=False):
    y = Conv2D(num_channels, kernel_size=3, strides=1, padding="same", kernel_initializer=initializer, use_bias=False,
               name=base_name + "_resblock" + str(block_num) + "_conv1")(x)
    if not is_D:
        y = BatchNormalization(momentum=0.9, epsilon=1e-5, name=base_name + "_resblock" + str(block_num) + "_bn1")(y, training=1)
        y = Activation("relu")(y)
    else:
        y = LeakyReLU(0.2)(y)
    y = Conv2D(num_channels, kernel_size=3, strides=1, padding="same", kernel_initializer=initializer, use_bias=False,
               name=base_name + "_resblock" + str(block_num) + "_conv2")(y)
    if not is_D:
        y = BatchNormalization(momentum=0.9, epsilon=1e-5, name=base_name + "_resblock" + str(block_num) + "_bn2")(y, training=1)
    return Add()([x, y])  

In [3]:
class DRAGAN():
    def __init__(self, dataset, gpu_count=1, resume=0):
        # ---------------------
        #  Parameter
        # ---------------------
        self.dataset = dataset
        self.gpu_count = gpu_count
        self.resume = resume
                
        self.img_rows = dataset.shape[1]
        self.img_cols = dataset.shape[2]
        self.channels = dataset.shape[3]
        self.img_shape = (self.img_rows, self.img_cols, self.channels)

        self.input_rows = 2
        self.input_cols = 2
        self.latent_dim = 128  # Noise dim
        
        self.n_critic = 5
        self.λ = 10
        optimizer = Adam(lr=0.0002, beta_1=0.5, beta_2=0.9, epsilon=None, decay=0.0, amsgrad=False)

        # ---------------------
        #  Load models
        # ---------------------
        self.critic = self.build_critic()
        self.generator = self.build_generator()
        if self.resume != 0:
            self.critic = load_model('./saved_model/dragan'+str(self.λ)+'_critic_'+str(self.resume)+'epoch.h5')
            self.generator = load_model('./saved_model/dragan'+str(self.λ)+'_gen_'+str(self.resume)+'epoch.h5')  
        
        #-------------------------------
        # Compile Critic
        #-------------------------------    
        generated_samples = Input(shape=self.img_shape) 
        critic_output_from_generated_samples = self.critic(generated_samples)
        
        real_samples = Input(shape=self.img_shape)        
        critic_output_from_real_samples = self.critic(real_samples)

        averaged_samples = Input(shape=self.img_shape)
        critic_output_from_averaged_samples = self.critic(averaged_samples)

        partial_gp_loss = partial(GradientPenaltyLoss,
                                  averaged_samples=averaged_samples,
                                  gradient_penalty_weight=self.λ)
        # Functions need names or Keras will throw an error
        partial_gp_loss.__name__ = 'gradient_penalty'

        self.critic_model = Model(inputs=[generated_samples, 
                                          real_samples,
                                          averaged_samples],
                                  outputs=[critic_output_from_generated_samples, 
                                           critic_output_from_real_samples,
                                           critic_output_from_averaged_samples])
        
        # loss_dis = loss_func_dcgan_dis_real(y_dis) + loss_func_dcgan_dis_fake(y_fake) + loss_grad
        # loss_gen = loss_func_dcgan_dis_real(y_fake)
        
        if self.gpu_count > 1:
            self.critic_model = multi_gpu_model(self.critic_model, gpus=self.gpu_count)
        self.critic_model.compile(optimizer=optimizer, 
                                  loss=[CommonVanillaLoss, 
                                        CommonVanillaLoss, 
                                        partial_gp_loss])
        
        # print('Critic Summary:')
        # self.critic.summary()       
        
        #-------------------------------
        # Compile Generator
        #-------------------------------
        # For the generator we freeze the critic's layers
        self.critic.trainable = False
                    
        generator_input = Input(shape=(self.latent_dim,))
        generator_layers = self.generator(generator_input)
        critic_layers_for_generator = self.critic(generator_layers)
        
        self.generator_model = Model(inputs=[generator_input], 
                                     outputs=[critic_layers_for_generator])
        if self.gpu_count > 1:
            self.generator_model = multi_gpu_model(self.generator_model, gpus=self.gpu_count)
        self.generator_model.compile(optimizer=optimizer,
                                     loss=CommonVanillaLoss)        

        # print('Genarator Summary:')
        # self.generator.summary()     

    def build_generator(self):

        latent_dim = self.latent_dim
        image_shape = self.img_shape
        num_res_blocks = 3
        base_name = "generator"
        initializer = TruncatedNormal(mean=0, stddev=0.2, seed=42)
        
        in_x = Input(shape=(latent_dim,))

        h, w, c = image_shape

        x = Dense(64*8*h//8*w//8, activation="relu", name=base_name+"_dense")(in_x)
        x = Reshape((h//8, w//8, -1))(x)

        x = UpSampling2D((2, 2))(x)
        x = Conv2D(64*4, kernel_size=3, strides=1, padding="same", kernel_initializer=initializer,
                   use_bias=False,name=base_name + "_conv1")(x)
        x = BatchNormalization(momentum=0.9, epsilon=1e-5, name=base_name + "_bn1")(x, training=1)

        for i in range(num_res_blocks):
            x = residual_block(x, base_name=base_name+"res1", block_num=i, initializer=initializer, num_channels=64*4)
        x = Activation("relu")(x)

        # size//8→size//4→size//2→size
        x = UpSampling2D((2, 2))(x)
        x = Conv2D(64*2, kernel_size=3, strides=1, padding="same", kernel_initializer=initializer,
                   use_bias=False,name=base_name + "_conv2")(x)
        x = BatchNormalization(momentum=0.9, epsilon=1e-5, name=base_name + "_bn2")(x, training=1)

        for i in range(num_res_blocks):
            x = residual_block(x, base_name=base_name+"res2", block_num=i, initializer=initializer, num_channels=64*2)
        x = Activation("relu")(x)

        x = UpSampling2D((2, 2))(x)
        x = Conv2D(64*1, kernel_size=3, strides=1, padding="same", kernel_initializer=initializer,
                   use_bias=False,name=base_name + "_conv3")(x)
        x = BatchNormalization(momentum=0.9, epsilon=1e-5, name=base_name + "_bn3")(x,training=1)

        for i in range(num_res_blocks):
            x = residual_block(x, base_name=base_name+"res3", block_num=i, initializer=initializer, num_channels=64*1)
        x = Activation("relu")(x)

        x = Conv2D(3, kernel_size=3, strides=1, padding="same", kernel_initializer=initializer, activation="tanh",
                   use_bias=False,name=base_name + "_conv4")(x)
        out = Activation("tanh")(x)
        
        model = Model(in_x, out, name=base_name)
        print('Generator Summary:')
        model.summary()
        return model
    
    def build_critic(self):
        input_shape = self.img_shape
        base_name = "discriminator"
        num_res_blocks = 0
        is_D = True
        use_res = True
        
        initializer_d = TruncatedNormal(mean=0, stddev=0.1, seed=42)

        D = in_D = Input(shape=input_shape)
        D = Conv2D(64, kernel_size=4, strides=2, padding="same", kernel_initializer=initializer_d,
                   use_bias=False,
                   name=base_name + "_conv1")(D)
        """
        if use_res:
            for i in range(3):
                D = residual_block(D, base_name=base_name+"res1", block_num=i,
                                   initializer=initializer_d, num_channels=64, is_D=is_D)
        """
        D = LeakyReLU(0.2)(D)

        D = Conv2D(128, kernel_size=4, strides=2, padding="same", kernel_initializer=initializer_d,
                   use_bias=False,
                   name=base_name + "_conv2")(D)
        """
        if use_res:
            for i in range(3):
                D = residual_block(D, base_name=base_name+"res2", block_num=i,
                                   initializer=initializer_d, num_channels=128, is_D=is_D)
        """
        #D = BatchNormalization(momentum=0.9, epsilon=1e-5, name=base_name + "_bn1")(D, training=1)
        D = LeakyReLU(0.2)(D)

        D = Conv2D(256, kernel_size=4, strides=2, padding="same", kernel_initializer=initializer_d,
                   use_bias=False,
                   name=base_name + "_conv3")(D)
        #D = BatchNormalization(momentum=0.9, epsilon=1e-5, name=base_name + "_bn2")(D, training=1)

        if use_res:
            for i in range(5):
                D = residual_block(D, base_name=base_name+"res3", block_num=i,
                                   initializer=initializer_d, num_channels=256, is_D=is_D)
        D = LeakyReLU(0.2)(D)

        D = Conv2D(512, kernel_size=4, strides=2, padding="same", kernel_initializer=initializer_d,
                   use_bias=False,
                   name=base_name + "_conv4")(D)
        """
        if use_res:
            for i in range(5):
                D = residual_block(D, base_name=base_name+"res4", block_num=i,
                                   initializer=initializer_d, num_channels=512, is_D=is_D)
        """
        #D = BatchNormalization(momentum=0.9, epsilon=1e-5, name=base_name + "_bn3")(D, training=1)
        D = LeakyReLU(0.2)(D)
        D = Conv2D(1, kernel_size=1, strides=1, padding="same", kernel_initializer=initializer_d,
                   use_bias=False,
                   name=base_name + "_conv5")(D)

        out = Flatten()(D)
        # out = Dense(units=1, activation=None, name=base_name + "_out")(D)
        
        model = Model(in_D, out, name=base_name)
        print('Critic Summary:')
        model.summary()
        return model 

    def train(self, epochs, batch_size, sample_interval=5000):       
        # ---------------------
        #  for Logging
        # ---------------------
        target_dir = "./search/my_log_dir_dragan10"
        seed = 0
        image_num = 5       
        np_samples = []
        
        # Load suspended training weights
        if self.resume != 0:
            np_samples_npz = np.load('./saved_model/dragan'+str(self.λ)+'_np_'+str(self.resume)+'epoch.npz')
            for i, np_sample in enumerate(np_samples_npz):
                np_samples.append(np_sample)
        else:            
            shutil.rmtree(target_dir, ignore_errors=True)
            os.mkdir(target_dir)
                
        self.logger = TensorBoardLogger(log_dir=target_dir)            
        
        # ---------------------
        #  Training
        # ---------------------
        # Rescale the dataset -1 to 1 
        X_train = self.dataset / 127.5 - 1.0

        # Adversarial ground truths
        minus = -np.ones((batch_size, 1), dtype=np.float32)
        plus = np.ones((batch_size, 1), dtype=np.float32)
        dummy = np.zeros((batch_size, 1), dtype=np.float32)

        for epoch in tqdm(range(self.resume, self.resume + epochs + 1)):
            for _ in range(self.n_critic):
                noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
                gen_imgs = self.generator.predict(noise, batch_size=batch_size)
                
                idx = np.random.randint(0, X_train.shape[0], batch_size)
                real_imgs = X_train[idx]
                               
                # ε = np.random.uniform(size=(batch_size, 1,1,1))
                # ave_imgs = ε * real_imgs + (1-ε) * gen_imgs
                ave_imgs = real_imgs + 0.5 * np.random.random(real_imgs.shape) * real_imgs.std() 
                
                # Train Critic
                d_loss = self.critic_model.train_on_batch([gen_imgs, real_imgs, ave_imgs], 
                                                          [minus, plus, dummy])

            # Train Generator
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            g_loss = self.generator_model.train_on_batch(noise, plus)

            # ---------------------
            #  Logging
            # ---------------------
            # Backup Model
            '''
            if epoch != resume and epoch % sample_interval == 0:
                self.critic.save('./saved_model/dragan'+str(self.λ)+'_critic_'+str(epoch)+'epoch.h5')
                self.generator.save('./saved_model/dragan'+str(self.λ)+'_gen_'+str(epoch)+'epoch.h5')
                np.savez_compressed('./saved_model/dragan'+str(self.λ)+'_np_'+str(epoch)+'epoch.npz', np_samples)
            '''
            # Log Loss & Histgram
            logs = {
                "loss/Critic": d_loss[0],
                "loss/Generator": g_loss,
                "loss_Critic/D_gen": d_loss[1],
                "loss_Critic/D_real": d_loss[2],
                "loss_Critic/gradient_penalty": d_loss[3],
                "loss_Critic/total_loss": d_loss[1] + d_loss[2] + d_loss[3],                
            }

            histograms = {}
            '''
            for layer in self.critic.layers[1].layers:
                for i in range(len(layer.get_weights())):
                    if "conv" in layer.name or "dense" in layer.name:
                        name = layer.name + "/" + str(i)
                        value = layer.get_weights()[i]
                        histograms[name] = value
            '''
            self.logger.log(logs=logs, histograms=histograms, epoch=epoch)
            
            # Log generated image samples
            if epoch % sample_interval == 0:
                np.random.seed(seed)
                noise = np.random.normal(0, 1, (image_num, self.latent_dim))
                gen_imgs = self.generator.predict(noise)
                gen_imgs = ((0.5 * gen_imgs + 0.5) * 255).astype(np.uint8)
                np_samples.append(gen_imgs)
                '''
                fig, name = self.sample_images(epoch)
                images = {epoch: fig}
                self.logger.log(images=images, epoch=epoch)
                '''
                print("%d [C loss: %f] [G loss: %f]" % (epoch, d_loss[0], g_loss))
                
        return np_samples

In [4]:
gan = DRAGAN(dataset, gpu_count, resume)

Critic Summary:
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 32, 32, 3)    0                                            
__________________________________________________________________________________________________
discriminator_conv1 (Conv2D)    (None, 16, 16, 64)   3072        input_1[0][0]                    
__________________________________________________________________________________________________
leaky_re_lu_1 (LeakyReLU)       (None, 16, 16, 64)   0           discriminator_conv1[0][0]        
__________________________________________________________________________________________________
discriminator_conv2 (Conv2D)    (None, 8, 8, 128)    131072      leaky_re_lu_1[0][0]              
_____________________________________________________________________________________________

In [5]:
np_samples = gan.train(epochs=10000, batch_size=64, sample_interval=1000)

  'Discrepancy between trainable weights and collected trainable'
  0%|          | 1/10001 [00:13<37:07:30, 13.37s/it]

0 [C loss: nan] [G loss: nan]


  0%|          | 39/10001 [00:32<1:25:20,  1.95it/s]

KeyboardInterrupt: 

In [None]:
import holoviews as hv
hv.notebook_extension()
for j in range(1, 11):
    y = np_samples[j]
    for i in range(5):
        if j == 1 and i == 0:
            hv_points = hv.RGB(y[i]).relabel(str(j*10000)+' epoch')
        else:
            hv_points += hv.RGB(y[i]).relabel(str(j*10000)+' epoch')
hv_points.cols(5)