In [1]:
from keras.models import Sequential, Model, load_model
from keras.layers import Input, Reshape, Flatten, Activation
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.advanced_activations import ReLU, LeakyReLU
from keras.optimizers import Adam, RMSprop
from keras.utils import multi_gpu_model

import keras
import keras.backend as K
import keras.backend.tensorflow_backend as KTF
import tensorflow as tf

import shutil, os, sys, io, random, math
import numpy as np
import matplotlib.pyplot as plt
from functools import partial
from PIL import Image
from tqdm import tqdm

os.chdir('/home/k_yonhon/py/Keras-GAN/wgan_gp/')
sys.path.append(os.pardir)

from tensor_board_logger import TensorBoardLogger
from wasserstein_loss import WassersteinLoss, GradientPenaltyLoss

config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
session = tf.Session(config=config)
KTF.set_session(session)

# ---------------------
#  Parameter
# ---------------------
gpu_count = 1
dataset = np.load('../datasets/lfw32.npz')['arr_0']

In [2]:
class WGANGP():
    def __init__(self, dataset, gpu_count=1):
        # ---------------------
        #  Parameter
        # ---------------------
        self.dataset = dataset
        self.gpu_count = gpu_count
                
        self.img_rows = dataset.shape[1]
        self.img_cols = dataset.shape[2]
        self.channels = dataset.shape[3]
        self.img_shape = (self.img_rows, self.img_cols, self.channels)

        self.input_rows = 2
        self.input_cols = 2
        self.latent_dim = 128  # Noise dim
        
        self.n_critic = 5
        self.λ = 10
        optimizer = Adam(lr=0.0001, beta_1=0., beta_2=0.9, epsilon=None, decay=0.0, amsgrad=False)

        # ---------------------
        #  Load models
        # ---------------------
        self.critic = self.build_critic()
        self.generator = self.build_generator()
        
        #  Load pretrained weights
        '''
        pre_gen = load_model('./saved_model/wgangp64_gen_model_3k.h5')
        for i, layer in enumerate(self.generator.layers[1].layers):
            if i in [i for i in range(1, int(math.log(self.img_rows / self.input_rows, 2)) * 2, 2)]:
                layer.set_weights(pre_gen.layers[1].layers[i].get_weights())
                layer.trainable = False
                
        pre_critic = load_model('./saved_model/wgangp64_critic_model_3k.h5')
        for i, layer in enumerate(self.critic.layers[1].layers):
            j = i - len(self.critic.layers[1].layers)
            if j in [-i for i in range(int(math.log(self.img_rows / self.input_rows, 2)) * 2, 0, -2)]:
                layer.set_weights(pre_critic.layers[1].layers[j].get_weights())
                layer.trainable = False
        '''
        #-------------------------------
        # Compile Critic
        #-------------------------------    
        generated_samples = Input(shape=self.img_shape) 
        critic_output_from_generated_samples = self.critic(generated_samples)
        
        real_samples = Input(shape=self.img_shape)        
        critic_output_from_real_samples = self.critic(real_samples)

        averaged_samples = Input(shape=self.img_shape)
        critic_output_from_averaged_samples = self.critic(averaged_samples)

        partial_gp_loss = partial(GradientPenaltyLoss,
                                  averaged_samples=averaged_samples,
                                  gradient_penalty_weight=self.λ)
        # Functions need names or Keras will throw an error
        partial_gp_loss.__name__ = 'gradient_penalty'

        self.critic_model = Model(inputs=[generated_samples, 
                                          real_samples,
                                          averaged_samples],
                                  outputs=[critic_output_from_generated_samples, 
                                           critic_output_from_real_samples,
                                           critic_output_from_averaged_samples])
        if self.gpu_count > 1:
            self.critic_model = multi_gpu_model(self.critic_model, gpus=self.gpu_count)
        self.critic_model.compile(optimizer=optimizer, 
                                  loss=[WassersteinLoss, 
                                        WassersteinLoss, 
                                        partial_gp_loss])
        
        # print('Critic Summary:')
        # self.critic.summary()       
        
        #-------------------------------
        # Compile Generator
        #-------------------------------
        # For the generator we freeze the critic's layers
        self.critic.trainable = False
                    
        generator_input = Input(shape=(self.latent_dim,))
        generator_layers = self.generator(generator_input)
        critic_layers_for_generator = self.critic(generator_layers)
        
        self.generator_model = Model(inputs=[generator_input], 
                                     outputs=[critic_layers_for_generator])
        if self.gpu_count > 1:
            self.generator_model = multi_gpu_model(self.generator_model, gpus=self.gpu_count)
        self.generator_model.compile(optimizer=optimizer,
                                     loss=WassersteinLoss)        

        # print('Genarator Summary:')
        # self.generator.summary()   

    def build_generator(self):
        model = Sequential()
        model.add(Reshape((self.input_rows, self.input_cols, int(self.latent_dim / (self.input_rows * self.input_cols))), 
                          input_shape=(self.latent_dim,)
                         ))  
        
        for i in range(int(math.log(self.img_rows / self.input_rows, 2)) - 1):
            model.add(Conv2DTranspose(2 ** (int(math.log(self.img_rows / self.input_rows, 2)) + 5 - i), (3, 3), strides=2, padding='same', 
                                     kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                                     ))
            model.add(LeakyReLU(alpha=0.2))  
            model.add(Conv2DTranspose(2 ** (int(math.log(self.img_rows / self.input_rows, 2)) + 5 - i), (3, 3), strides=1, padding="same",
                             kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                            ))
            model.add(LeakyReLU(alpha=0.2))
            model.add(Conv2DTranspose(2 ** (int(math.log(self.img_rows / self.input_rows, 2)) + 5 - i), (3, 3), strides=1, padding="same",
                             kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                            ))
            model.add(LeakyReLU(alpha=0.2))
            
        model.add(Conv2DTranspose(3, (3, 3), strides=2, padding='same', 
                                 kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                                 ))                             
        model.add(Activation("tanh"))
        print('Generator Summary:')
        model.summary()

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)
        return Model(noise, img)
    
    def build_critic(self):
        model = Sequential()
        model.add(Conv2D(2 ** 7, (5, 5), strides=2, input_shape=self.img_shape, padding="same",
                         kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                        ))
        model.add(LeakyReLU(alpha=0.2))
        
        for i in range(int(math.log(self.img_rows / self.input_rows, 2)) - 2):
            model.add(Conv2D(2 ** (i + 8), (3, 3), strides=2, padding="same",
                             kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                            ))
            model.add(LeakyReLU(alpha=0.2))
            model.add(Conv2D(2 ** (i + 8), (3, 3), strides=1, padding="same",
                             kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                            ))
            model.add(LeakyReLU(alpha=0.2))
            model.add(Conv2D(2 ** (i + 8), (3, 3), strides=1, padding="same",
                             kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                            ))
            model.add(LeakyReLU(alpha=0.2))
                
        model.add(Conv2D(1, (4, 4), strides=1, padding="valid",
                         kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                        ))
        model.add(Flatten())
        print('Critic Summary:')
        model.summary()

        img = Input(shape=self.img_shape)
        validity = model(img)
        return Model(img, validity)
    
    def train(self, epochs, batch_size, sample_interval=5000, resume=0):       
        # ---------------------
        #  for Logging
        # ---------------------
        target_dir = "./lambda_search/my_log_dir5_"+str(self.λ)
        seed = 0
        image_num = 5       
        
        # Load suspended training weights
        if resume != 0:
            self.critic = load_model('./saved_model/wgangp'+str(self.λ)+'_critic_model_'+str(resume)+'epoch.h5')
            self.generator = load_model('./saved_model/wgangp'+str(self.λ)+'_gen_model_'+str(resume)+'epoch.h5')
            np_samples = np.load('../saved_model/np_samples'+str(self.λ)+'_'+str(resume)+'epoch.npz')
        else:
            np_samples = []
            shutil.rmtree(target_dir, ignore_errors=True)
            os.mkdir(target_dir)
                
        self.logger = TensorBoardLogger(log_dir=target_dir)            
        
        # ---------------------
        #  Training
        # ---------------------
        # Rescale the dataset -1 to 1 
        X_train = self.dataset / 127.5 - 1.0

        # Adversarial ground truths
        valid = -np.ones((batch_size, 1), dtype=np.float32)
        fake = np.ones((batch_size, 1), dtype=np.float32)
        dummy = np.zeros((batch_size, 1), dtype=np.float32)

        for epoch in tqdm(range(resume, resume + epochs + 1)):
            for _ in range(self.n_critic):
                noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
                gen_imgs = self.generator.predict(noise, batch_size=batch_size)
                
                idx = np.random.randint(0, X_train.shape[0], batch_size)
                real_imgs = X_train[idx]
                               
                ε = np.random.uniform(size=(batch_size, 1,1,1))
                ave_imgs = ε * real_imgs + (1-ε) * gen_imgs
                
                # Train Critic
                d_loss = self.critic_model.train_on_batch([gen_imgs, real_imgs, ave_imgs], 
                                                          [fake, valid, dummy])

            # Train Generator
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            g_loss = self.generator_model.train_on_batch(noise, valid)

            # ---------------------
            #  Logging
            # ---------------------
            # Backup Model
            '''
            if epoch != resume and epoch % sample_interval == 0:
                self.critic.save('./saved_model/wgangp'+str(self.λ)+'_critic_model_'+str(epoch)+'epoch.h5')
                self.generator.save('./saved_model/wgangp'+str(self.λ)+'_gen_model_'+str(epoch)+'epoch.h5')
                np.savez_compressed('./saved_model/np_samples'+str(self.λ)+'_'+str(epoch)+'epoch.npz', np_samples)
            '''
            # Log Loss & Histgram
            logs = {
                "loss/Critic": d_loss[0],
                "loss/Generator": g_loss,
                "loss_Critic/D_gen": d_loss[1],
                "loss_Critic/D_real": -d_loss[2],
                "loss_Critic/gradient_penalty": d_loss[3],
                "loss_Critic/total_loss": d_loss[1] + d_loss[2] + d_loss[3],                
            }

            histograms = {}
            '''
            for layer in self.critic.layers[1].layers:
                for i in range(len(layer.get_weights())):
                    if "conv" in layer.name or "dense" in layer.name:
                        name = layer.name + "/" + str(i)
                        value = layer.get_weights()[i]
                        histograms[name] = value
            '''
            self.logger.log(logs=logs, histograms=histograms, epoch=epoch)
            
            # Log generated image samples
            if epoch % sample_interval == 0:
                np.random.seed(seed)
                noise = np.random.normal(0, 1, (image_num, self.latent_dim))
                gen_imgs = self.generator.predict(noise)
                gen_imgs = ((0.5 * gen_imgs + 0.5) * 255).astype(np.uint8)
                np_samples.append(np.vstack(gen_imgs)) # 縦方向(vertical)に連結
                '''
                fig, name = self.sample_images(epoch)
                images = {epoch: fig}
                self.logger.log(images=images, epoch=epoch)
                '''
                print("%d [C loss: %f] [G loss: %f]" % (epoch, d_loss[0], g_loss))
                
        return np_samples

    def sample_images(self, epoch):
        r, c = 2, 3
        if epoch == 0:
            idx = np.random.randint(0, self.dataset.shape[0], r * c)
            imgs = self.dataset[idx].astype(np.uint8)
            name = "original.png"
        else:
            noise = np.random.normal(0, 1, (r * c, self.latent_dim))
            imgs = self.generator.predict(noise, batch_size=r * c)
            imgs = ((0.5 * imgs + 0.5) * 255).astype(np.uint8) # Rescale images 0 - 255
            name = str(epoch) + ".png"        
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                if self.channels == 1:
                    axs[i, j].imshow(imgs[cnt, :, :, 0], cmap="gray")
                else:
                    axs[i, j].imshow(imgs[cnt, :, :, :self.channels], cmap="gray")
                axs[i, j].axis("off")
                cnt += 1
        return fig, name

In [3]:
gan = WGANGP(dataset, gpu_count)

Critic Summary:
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_1 (Conv2D)            (None, 16, 16, 128)       9728      
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 16, 16, 128)       0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 8, 8, 256)         295168    
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 8, 8, 256)         0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 8, 8, 256)         590080    
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 8, 8, 256)         0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 8, 8, 256)         59008

In [4]:
np_samples = gan.train(epochs=10000, batch_size=64, sample_interval=1000, resume=0)

  'Discrepancy between trainable weights and collected trainable'
  0%|          | 1/10001 [00:10<28:00:16, 10.08s/it]

0 [C loss: -29.239109] [G loss: -3.524148]


 10%|█         | 1001/10001 [08:45<1:17:13,  1.94it/s]

1000 [C loss: -6.536359] [G loss: 1.576369]


 20%|██        | 2001/10001 [17:21<1:09:06,  1.93it/s]

2000 [C loss: -5.912831] [G loss: -2.081486]


 30%|███       | 3001/10001 [25:58<1:00:14,  1.94it/s]

3000 [C loss: -5.271179] [G loss: -3.171620]


 40%|████      | 4001/10001 [34:34<51:44,  1.93it/s]

4000 [C loss: -5.513949] [G loss: -5.624416]


 50%|█████     | 5001/10001 [43:09<43:05,  1.93it/s]

5000 [C loss: -5.144420] [G loss: -3.110868]


 60%|██████    | 6001/10001 [51:46<34:28,  1.93it/s]

6000 [C loss: -4.813450] [G loss: -4.110725]


 70%|███████   | 7001/10001 [1:00:23<25:49,  1.94it/s]

7000 [C loss: -4.866711] [G loss: -4.385365]


 80%|████████  | 8001/10001 [1:08:59<17:13,  1.94it/s]

8000 [C loss: -4.798648] [G loss: -4.895304]


 90%|█████████ | 9001/10001 [1:17:35<08:36,  1.94it/s]

9000 [C loss: -4.942816] [G loss: -4.719893]


100%|██████████| 10001/10001 [1:26:11<00:00,  1.94it/s]

10000 [C loss: -4.815519] [G loss: -4.536453]





In [5]:
import holoviews as hv
hv.notebook_extension()
for j in range(1, 11):
    y = np_samples[j].reshape(-1,32,32,3)
    for i in range(5):
        if j == 1 and i == 0:
            hv_points = hv.RGB(y[i]).relabel(str(j*1000)+' epoch')
        else:
            hv_points += hv.RGB(y[i]).relabel(str(j*1000)+' epoch')
hv_points.cols(10)

In [6]:
noise = np.random.normal(0, 1, (10, gan.latent_dim))
gen_imgs = gan.generator.predict(noise)
y = ((0.5 * gen_imgs + 0.5) * 255).astype(np.uint8)
for j in range(5):
    if j == 0:
        hv_points = hv.RGB(y[j])
    else:
        hv_points += hv.RGB(y[j])
hv_points.cols(5)