In [1]:
from keras.models import Sequential, Model, load_model
from keras.layers import Input, Reshape, Flatten, Activation
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.advanced_activations import ReLU, LeakyReLU
from keras.optimizers import Adam, RMSprop
from keras.utils import multi_gpu_model

import keras
import keras.backend as K
import keras.backend.tensorflow_backend as KTF
import tensorflow as tf

import shutil, os, sys, io, random, math
import numpy as np
import matplotlib.pyplot as plt
from functools import partial
from PIL import Image
from tqdm import tqdm

os.chdir('/home/k_yonhon/py/Keras-GAN/dragan/')
sys.path.append(os.pardir)

from tensor_board_logger import TensorBoardLogger
from wasserstein_loss import WassersteinLoss, GradientPenaltyLoss

config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
session = tf.Session(config=config)
KTF.set_session(session)

# ---------------------
#  Parameter
# ---------------------
gpu_count = 1
dataset = np.load('../datasets/lfw32.npz')['arr_0']

In [2]:
class WGANGP():
    def __init__(self, dataset, gpu_count=1):
        # ---------------------
        #  Parameter
        # ---------------------
        self.dataset = dataset
        self.gpu_count = gpu_count
                
        self.img_rows = dataset.shape[1]
        self.img_cols = dataset.shape[2]
        self.channels = dataset.shape[3]
        self.img_shape = (self.img_rows, self.img_cols, self.channels)

        self.input_rows = 2
        self.input_cols = 2
        self.latent_dim = 128  # Noise dim
        
        self.n_critic = 5
        self.λ = 10
        optimizer = Adam(lr=0.0001, beta_1=0., beta_2=0.9, epsilon=None, decay=0.0, amsgrad=False)

        # ---------------------
        #  Load models
        # ---------------------
        self.critic = self.build_critic()
        self.generator = self.build_generator()
        
        '''
        G_sample = generator(z)
        D_real = discriminator(X)
        D_fake = discriminator(G_sample)
        D_real_perturbed = discriminator(X_p)

        D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_real, targets=tf.ones_like(D_real)))
        D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake, targets=tf.zeros_like(D_fake)))
        disc_cost = D_loss_real + D_loss_fake 
        gen_cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake, targets=tf.ones_like(D_fake)))

        #Gradient penalty
        alpha = tf.random_uniform(
            shape=[mb_size,1], 
            minval=0.,
            maxval=1.
        )
        differences = X_p - X
        interpolates = X + (alpha*differences)
        gradients = tf.gradients(discriminator(interpolates), [interpolates])[0]
        slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
        gradient_penalty = tf.reduce_mean((slopes-1.)**2)
        disc_cost += lambd*gradient_penalty
        '''      
        
        #-------------------------------
        # Compile Critic
        #-------------------------------    
        generated_samples = Input(shape=self.img_shape) 
        critic_output_from_generated_samples = self.critic(generated_samples)
        
        real_samples = Input(shape=self.img_shape)        
        critic_output_from_real_samples = self.critic(real_samples)

        averaged_samples = Input(shape=self.img_shape)
        critic_output_from_averaged_samples = self.critic(averaged_samples)

        partial_gp_loss = partial(GradientPenaltyLoss,
                                  averaged_samples=averaged_samples,
                                  gradient_penalty_weight=self.λ)
        # Functions need names or Keras will throw an error
        partial_gp_loss.__name__ = 'gradient_penalty'

        self.critic_model = Model(inputs=[generated_samples, 
                                          real_samples,
                                          averaged_samples],
                                  outputs=[critic_output_from_generated_samples, 
                                           critic_output_from_real_samples,
                                           critic_output_from_averaged_samples])
        if self.gpu_count > 1:
            self.critic_model = multi_gpu_model(self.critic_model, gpus=self.gpu_count)
        self.critic_model.compile(optimizer=optimizer, 
                                  loss=[WassersteinLoss, 
                                        WassersteinLoss, 
                                        partial_gp_loss])
        
        # print('Critic Summary:')
        # self.critic.summary()       
        
        #-------------------------------
        # Compile Generator
        #-------------------------------
        # For the generator we freeze the critic's layers
        self.critic.trainable = False
                    
        generator_input = Input(shape=(self.latent_dim,))
        generator_layers = self.generator(generator_input)
        critic_layers_for_generator = self.critic(generator_layers)
        
        self.generator_model = Model(inputs=[generator_input], 
                                     outputs=[critic_layers_for_generator])
        if self.gpu_count > 1:
            self.generator_model = multi_gpu_model(self.generator_model, gpus=self.gpu_count)
        self.generator_model.compile(optimizer=optimizer,
                                     loss=WassersteinLoss)        

        # print('Genarator Summary:')
        # self.generator.summary()   

    def build_generator(self):
        noise = Input(shape=(self.latent_dim,))
        x = Reshape((self.input_rows, self.input_cols, int(self.latent_dim / (self.input_rows * self.input_cols))))(noise)
        for i in range(int(math.log(self.img_rows / self.input_rows, 2)) - 1):
            x = Conv2DTranspose(2 ** (int(math.log(self.img_rows / self.input_rows, 2)) + 5 - i), (5, 5), strides=2, padding='same',
                                kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None))(x)
            x = LeakyReLU(alpha=0.2)(x)
            x = Conv2D(2 ** (int(math.log(self.img_rows / self.input_rows, 2)) + 5 - i), (2, 2), strides=1, padding="same",
                       kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None))(x)
            x = LeakyReLU(alpha=0.2)(x)  
        
        x = Conv2DTranspose(3, (5, 5), strides=2, padding='same',
                            kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None))(x)                           
        img = Activation("tanh")(x)
        
        return Model(noise, img)
    
    def build_critic(self):
        img = Input(shape=self.img_shape)
        x = Conv2D(2 ** 7, (5, 5), strides=2, input_shape=self.img_shape, padding="same",
                   kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None))(img)
        x = LeakyReLU(alpha=0.2)(x)
        
        for i in range(int(math.log(self.img_rows / self.input_rows, 2)) - 2):
            x = Conv2D(2 ** (i + 8), (5, 5), strides=2, padding="same",
                       kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None))(x)
            x = LeakyReLU(alpha=0.2)(x)
            x = Conv2D(2 ** (i + 8), (2, 2), strides=1, padding="same",
                       kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None))(x)
            x = LeakyReLU(alpha=0.2)(x)
                
        x = Conv2D(1, (4, 4), strides=1, padding="valid",
                   kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None))(x)
        validity = Flatten()(x)

        return Model(img, validity)

    def train(self, epochs, batch_size, sample_interval=5000, resume=0):       
        # ---------------------
        #  for Logging
        # ---------------------
        target_dir = "./search/my_log_dir_wgangp"
        seed = 0
        image_num = 5       
        np_samples = []
        
        # Load suspended training weights
        if resume != 0:
            self.critic_model = load_model('./saved_model/wgangp8'+str(self.λ)+'_critic_model_'+str(resume)+'epoch.h5')
            self.generator_model = load_model('./saved_model/wgangp8'+str(self.λ)+'_gen_model_'+str(resume)+'epoch.h5')
            np_samples_npz = np.load('./saved_model/np_samples8'+str(self.λ)+'_'+str(resume)+'epoch.npz')
            for i, np_sample in enumerate(np_samples_npz):
                np_samples.append(np_sample)
        else:            
            shutil.rmtree(target_dir, ignore_errors=True)
            os.mkdir(target_dir)
                
        self.logger = TensorBoardLogger(log_dir=target_dir)            
        
        # ---------------------
        #  Training
        # ---------------------
        # Rescale the dataset -1 to 1 
        X_train = self.dataset / 127.5 - 1.0

        # Adversarial ground truths
        valid = -np.ones((batch_size, 1), dtype=np.float32)
        fake = np.ones((batch_size, 1), dtype=np.float32)
        dummy = np.zeros((batch_size, 1), dtype=np.float32)

        for epoch in tqdm(range(resume, resume + epochs + 1)):
            for _ in range(self.n_critic):
                noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
                gen_imgs = self.generator.predict(noise, batch_size=batch_size)
                
                idx = np.random.randint(0, X_train.shape[0], batch_size)
                real_imgs = X_train[idx]
                               
                ε = np.random.uniform(size=(batch_size, 1,1,1))
                ave_imgs = ε * real_imgs + (1-ε) * gen_imgs
                
                # Train Critic
                d_loss = self.critic_model.train_on_batch([gen_imgs, real_imgs, ave_imgs], 
                                                          [fake, valid, dummy])

            # Train Generator
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            g_loss = self.generator_model.train_on_batch(noise, valid)

            # ---------------------
            #  Logging
            # ---------------------
            # Backup Model
            '''
            if epoch != resume and epoch % sample_interval == 0:
                self.critic_model.save('./saved_model/wgangp'+str(self.λ)+'_critic_model_'+str(epoch)+'epoch.h5')
                self.generator_model.save('./saved_model/wgangp'+str(self.λ)+'_gen_model_'+str(epoch)+'epoch.h5')
                np.savez_compressed('./saved_model/np_samples'+str(self.λ)+'_'+str(epoch)+'epoch.npz', np_samples)
            '''
            # Log Loss & Histgram
            logs = {
                "loss/Critic": d_loss[0],
                "loss/Generator": g_loss,
                "loss_Critic/D_gen": d_loss[1],
                "loss_Critic/D_real": -d_loss[2],
                "loss_Critic/gradient_penalty": d_loss[3],
                "loss_Critic/total_loss": d_loss[1] + d_loss[2] + d_loss[3],                
            }

            histograms = {}
            '''
            for layer in self.critic.layers[1].layers:
                for i in range(len(layer.get_weights())):
                    if "conv" in layer.name or "dense" in layer.name:
                        name = layer.name + "/" + str(i)
                        value = layer.get_weights()[i]
                        histograms[name] = value
            '''
            self.logger.log(logs=logs, histograms=histograms, epoch=epoch)
            
            # Log generated image samples
            if epoch % sample_interval == 0:
                np.random.seed(seed)
                noise = np.random.normal(0, 1, (image_num, self.latent_dim))
                gen_imgs = self.generator.predict(noise)
                gen_imgs = ((0.5 * gen_imgs + 0.5) * 255).astype(np.uint8)
                np_samples.append(gen_imgs)
                '''
                fig, name = self.sample_images(epoch)
                images = {epoch: fig}
                self.logger.log(images=images, epoch=epoch)
                '''
                print("%d [C loss: %f] [G loss: %f]" % (epoch, d_loss[0], g_loss))
                
        return np_samples

In [3]:
gan = WGANGP(dataset, gpu_count)

In [4]:
np_samples = gan.train(epochs=5000, batch_size=64, sample_interval=1000, resume=0)

  'Discrepancy between trainable weights and collected trainable'
  0%|          | 1/5001 [00:08<11:12:40,  8.07s/it]

0 [C loss: -21.260990] [G loss: -3.095846]


 20%|██        | 1001/5001 [10:43<42:10,  1.58it/s]

1000 [C loss: -6.432305] [G loss: 5.674395]


 40%|████      | 2001/5001 [26:11<1:03:23,  1.27s/it]

2000 [C loss: -5.154934] [G loss: 2.717120]


 60%|██████    | 3001/5001 [40:31<21:05,  1.58it/s]

3000 [C loss: -4.598778] [G loss: 2.601803]


 80%|████████  | 4001/5001 [57:17<10:33,  1.58it/s]

4000 [C loss: -4.245855] [G loss: 2.073720]


100%|██████████| 5001/5001 [1:07:51<00:00,  1.57it/s]

5000 [C loss: -4.203047] [G loss: 2.120786]





In [6]:
import holoviews as hv
hv.notebook_extension()
for j in range(1, 6):
    y = np_samples[j]
    for i in range(5):
        if j == 1 and i == 0:
            hv_points = hv.RGB(y[i]).relabel(str(j*1000)+' epoch')
        else:
            hv_points += hv.RGB(y[i]).relabel(str(j*1000)+' epoch')
hv_points.cols(10)