In [1]:
from keras.models import Sequential, Model, load_model
from keras.layers import Input, Reshape, Flatten, Activation
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.advanced_activations import ReLU, LeakyReLU
from keras.optimizers import Adam, RMSprop
from keras.utils import multi_gpu_model

import keras
import keras.backend as K
import keras.backend.tensorflow_backend as KTF
import tensorflow as tf

import shutil, os, sys, io, random, math
import numpy as np
import matplotlib.pyplot as plt
from functools import partial
from PIL import Image
from tqdm import tqdm

os.chdir('/home/k_yonhon/py/Keras-GAN/dragan/')
sys.path.append(os.pardir)

from tensor_board_logger import TensorBoardLogger
from wasserstein_loss import WassersteinLoss, GradientPenaltyLoss, loss_func_dcgan_dis_real, loss_func_dcgan_dis_fake

config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
session = tf.Session(config=config)
KTF.set_session(session)

# ---------------------
#  Parameter
# ---------------------
gpu_count = 1
resume = 0
dataset = np.load('../datasets/lfw32.npz')['arr_0']

In [2]:
class DRAGAN():
    def __init__(self, dataset, gpu_count=1, resume=0):
        # ---------------------
        #  Parameter
        # ---------------------
        self.dataset = dataset
        self.gpu_count = gpu_count
        self.resume = resume
                
        self.img_rows = dataset.shape[1]
        self.img_cols = dataset.shape[2]
        self.channels = dataset.shape[3]
        self.img_shape = (self.img_rows, self.img_cols, self.channels)

        self.input_rows = 2
        self.input_cols = 2
        self.latent_dim = 128  # Noise dim
        
        self.n_critic = 5
        self.λ = 10
        optimizer = Adam(lr=0.0001, beta_1=0.5, beta_2=0.9, epsilon=None, decay=0.0, amsgrad=False)

        # ---------------------
        #  Load models
        # ---------------------
        self.critic = self.build_critic()
        self.generator = self.build_generator()
        if self.resume != 0:
            self.critic = load_model('./saved_model/dragan'+str(self.λ)+'_critic_'+str(self.resume)+'epoch.h5')
            self.generator = load_model('./saved_model/dragan'+str(self.λ)+'_gen_'+str(self.resume)+'epoch.h5')  
        
        #-------------------------------
        # Compile Critic
        #-------------------------------    
        generated_samples = Input(shape=self.img_shape) 
        critic_output_from_generated_samples = self.critic(generated_samples)
        
        real_samples = Input(shape=self.img_shape)        
        critic_output_from_real_samples = self.critic(real_samples)

        averaged_samples = Input(shape=self.img_shape)
        critic_output_from_averaged_samples = self.critic(averaged_samples)

        partial_gp_loss = partial(GradientPenaltyLoss,
                                  averaged_samples=averaged_samples,
                                  gradient_penalty_weight=self.λ)
        # Functions need names or Keras will throw an error
        partial_gp_loss.__name__ = 'gradient_penalty'

        self.critic_model = Model(inputs=[generated_samples, 
                                          real_samples,
                                          averaged_samples],
                                  outputs=[critic_output_from_generated_samples, 
                                           critic_output_from_real_samples,
                                           critic_output_from_averaged_samples])
        
        # loss_dis = loss_func_dcgan_dis_real(y_dis) + loss_func_dcgan_dis_fake(y_fake) + loss_grad
        # loss_gen = loss_func_dcgan_dis_real(y_fake)
        
        if self.gpu_count > 1:
            self.critic_model = multi_gpu_model(self.critic_model, gpus=self.gpu_count)
        self.critic_model.compile(optimizer=optimizer, 
                                  loss=[loss_func_dcgan_dis_fake, 
                                        loss_func_dcgan_dis_real, 
                                        partial_gp_loss])
        
        # print('Critic Summary:')
        # self.critic.summary()       
        
        #-------------------------------
        # Compile Generator
        #-------------------------------
        # For the generator we freeze the critic's layers
        self.critic.trainable = False
                    
        generator_input = Input(shape=(self.latent_dim,))
        generator_layers = self.generator(generator_input)
        critic_layers_for_generator = self.critic(generator_layers)
        
        self.generator_model = Model(inputs=[generator_input], 
                                     outputs=[critic_layers_for_generator])
        if self.gpu_count > 1:
            self.generator_model = multi_gpu_model(self.generator_model, gpus=self.gpu_count)
        self.generator_model.compile(optimizer=optimizer,
                                     loss=loss_func_dcgan_dis_real)        

        # print('Genarator Summary:')
        # self.generator.summary()     

    def build_generator(self):
        noise = Input(shape=(self.latent_dim,))
        x = Reshape((self.input_rows, self.input_cols, int(self.latent_dim / (self.input_rows * self.input_cols))))(noise)
        for i in range(int(math.log(self.img_rows / self.input_rows, 2)) - 1):
            x = Conv2DTranspose(2 ** (int(math.log(self.img_rows / self.input_rows, 2)) + 5 - i), (5, 5), strides=2, padding='same',
                                kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None))(x)
            x = LeakyReLU(alpha=0.2)(x)
            x = Conv2D(2 ** (int(math.log(self.img_rows / self.input_rows, 2)) + 5 - i), (2, 2), strides=1, padding="same",
                       kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None))(x)
            x = LeakyReLU(alpha=0.2)(x)  
        
        x = Conv2DTranspose(3, (5, 5), strides=2, padding='same',
                            kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None))(x)                           
        img = Activation("tanh")(x)
        
        model = Model(noise, img)
        print('Generator Summary:')
        model.summary()
        return model
    
    def build_critic(self):
        img = Input(shape=self.img_shape)
        x = Conv2D(2 ** 6, (5, 5), strides=2, input_shape=self.img_shape, padding="same",
                   kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None))(img)
        x = LeakyReLU(alpha=0.2)(x)
        
        for i in range(int(math.log(self.img_rows / self.input_rows, 2)) - 2):
            x = Conv2D(2 ** (i + 7), (5, 5), strides=2, padding="same",
                       kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None))(x)
            x = LeakyReLU(alpha=0.2)(x)
            x = Conv2D(2 ** (i + 7), (2, 2), strides=1, padding="same",
                       kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None))(x)
            x = LeakyReLU(alpha=0.2)(x)
                
        x = Conv2D(1, (4, 4), strides=1, padding="valid",
                   kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None))(x)
        validity = Flatten()(x)

        model = Model(img, validity)  
        print('Critic Summary:')
        model.summary()
        return model

    def train(self, epochs, batch_size, sample_interval=5000):       
        # ---------------------
        #  for Logging
        # ---------------------
        target_dir = "./search/my_log_dir_dragan32"
        seed = 0
        image_num = 5       
        np_samples = []
        
        # Load suspended training weights
        if self.resume != 0:
            np_samples_npz = np.load('./saved_model/dragan'+str(self.λ)+'_np_'+str(self.resume)+'epoch.npz')
            for i, np_sample in enumerate(np_samples_npz):
                np_samples.append(np_sample)
        else:            
            shutil.rmtree(target_dir, ignore_errors=True)
            os.mkdir(target_dir)
                
        self.logger = TensorBoardLogger(log_dir=target_dir)            
        
        # ---------------------
        #  Training
        # ---------------------
        # Rescale the dataset -1 to 1 
        X_train = self.dataset / 127.5 - 1.0

        # Adversarial ground truths
        valid = -np.ones((batch_size, 1), dtype=np.float32)
        fake = np.ones((batch_size, 1), dtype=np.float32)
        dummy = np.zeros((batch_size, 1), dtype=np.float32)

        for epoch in tqdm(range(self.resume, self.resume + epochs + 1)):
            for _ in range(self.n_critic):
                noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
                gen_imgs = self.generator.predict(noise, batch_size=batch_size)
                
                idx = np.random.randint(0, X_train.shape[0], batch_size)
                real_imgs = X_train[idx]
                               
                # ε = np.random.uniform(size=(batch_size, 1,1,1))
                # ave_imgs = ε * real_imgs + (1-ε) * gen_imgs
                ave_imgs = real_imgs + 0.5 * np.random.random(real_imgs.shape) * real_imgs.std() 
                
                # Train Critic
                d_loss = self.critic_model.train_on_batch([gen_imgs, real_imgs, ave_imgs], 
                                                          [fake, fake, dummy])

            # Train Generator
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            g_loss = self.generator_model.train_on_batch(noise, fake)

            # ---------------------
            #  Logging
            # ---------------------
            # Backup Model
            '''
            if epoch != resume and epoch % sample_interval == 0:
                self.critic.save('./saved_model/dragan'+str(self.λ)+'_critic_'+str(epoch)+'epoch.h5')
                self.generator.save('./saved_model/dragan'+str(self.λ)+'_gen_'+str(epoch)+'epoch.h5')
                np.savez_compressed('./saved_model/dragan'+str(self.λ)+'_np_'+str(epoch)+'epoch.npz', np_samples)
            '''
            # Log Loss & Histgram
            logs = {
                "loss/Critic": d_loss[0],
                "loss/Generator": g_loss,
                "loss_Critic/D_gen": d_loss[1],
                "loss_Critic/D_real": d_loss[2],
                "loss_Critic/gradient_penalty": d_loss[3],
                "loss_Critic/total_loss": d_loss[1] + d_loss[2] + d_loss[3],                
            }

            histograms = {}
            '''
            for layer in self.critic.layers[1].layers:
                for i in range(len(layer.get_weights())):
                    if "conv" in layer.name or "dense" in layer.name:
                        name = layer.name + "/" + str(i)
                        value = layer.get_weights()[i]
                        histograms[name] = value
            '''
            self.logger.log(logs=logs, histograms=histograms, epoch=epoch)
            
            # Log generated image samples
            if epoch % sample_interval == 0:
                np.random.seed(seed)
                noise = np.random.normal(0, 1, (image_num, self.latent_dim))
                gen_imgs = self.generator.predict(noise)
                gen_imgs = ((0.5 * gen_imgs + 0.5) * 255).astype(np.uint8)
                np_samples.append(gen_imgs)
                '''
                fig, name = self.sample_images(epoch)
                images = {epoch: fig}
                self.logger.log(images=images, epoch=epoch)
                '''
                print("%d [C loss: %f] [G loss: %f]" % (epoch, d_loss[0], g_loss))
                
        return np_samples

In [3]:
gan = DRAGAN(dataset, gpu_count, resume)

Critic Summary:
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 32, 32, 3)         0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 16, 16, 64)        4864      
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 16, 16, 64)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 8, 8, 128)         204928    
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 8, 8, 128)         0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 8, 8, 128)         65664     
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 8, 8, 128)         0    

In [4]:
np_samples = gan.train(epochs=10000, batch_size=64, sample_interval=1000)

  'Discrepancy between trainable weights and collected trainable'
  0%|          | 1/10001 [00:06<18:45:46,  6.75s/it]

0 [C loss: 0.830718] [G loss: 0.703938]


 10%|█         | 1001/10001 [04:31<39:04,  3.84it/s]

1000 [C loss: 0.435446] [G loss: 1.785938]


 20%|██        | 2001/10001 [08:52<34:50,  3.83it/s]

2000 [C loss: 0.587491] [G loss: 1.334499]


 30%|███       | 3001/10001 [13:13<30:35,  3.81it/s]

3000 [C loss: 0.571735] [G loss: 1.658139]


 40%|████      | 4001/10001 [17:35<26:22,  3.79it/s]

4000 [C loss: 0.617196] [G loss: 1.311753]


 50%|█████     | 5001/10001 [21:56<21:46,  3.83it/s]

5000 [C loss: 0.641876] [G loss: 1.265587]


 60%|██████    | 6001/10001 [26:17<17:24,  3.83it/s]

6000 [C loss: 0.638308] [G loss: 1.460099]


 70%|███████   | 7001/10001 [30:38<12:57,  3.86it/s]

7000 [C loss: 0.617736] [G loss: 1.179487]


 80%|████████  | 8001/10001 [34:59<08:42,  3.83it/s]

8000 [C loss: 0.590304] [G loss: 1.056519]


 90%|█████████ | 9001/10001 [39:20<04:23,  3.80it/s]

9000 [C loss: 0.555407] [G loss: 1.123766]


100%|██████████| 10001/10001 [43:41<00:00,  3.84it/s]

10000 [C loss: 0.551391] [G loss: 0.983229]





In [5]:
import holoviews as hv
hv.notebook_extension()
for j in range(1, 11):
    y = np_samples[j]
    for i in range(5):
        if j == 1 and i == 0:
            hv_points = hv.RGB(y[i]).relabel(str(j*10000)+' epoch')
        else:
            hv_points += hv.RGB(y[i]).relabel(str(j*10000)+' epoch')
hv_points.cols(5)

In [8]:
noise = np.random.normal(0, 1, (50, gan.latent_dim))
gen_imgs = gan.generator.predict(noise)
y = ((0.5 * gen_imgs + 0.5) * 255).astype(np.uint8)
for j in range(50):
    if j == 0:
        hv_points = hv.RGB(y[j])
    else:
        hv_points += hv.RGB(y[j])
hv_points.cols(5)

In [9]:
# Original
gen_imgs = dataset / 127.5 - 1.0
y = ((0.5 * gen_imgs + 0.5) * 255).astype(np.uint8)
for j in range(50):
    if j == 0:
        hv_points = hv.RGB(y[j])
    else:
        hv_points += hv.RGB(y[j])
hv_points.cols(5)