In [1]:
from keras.layers import Input, Reshape, Flatten, Activation, Lambda
from keras.layers.merge import _Merge
from keras.layers.advanced_activations import LeakyReLU, ReLU
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.models import Sequential, Model
from keras.optimizers import Adam, RMSprop
from keras.models import load_model
from keras.utils import multi_gpu_model

import keras
import tensorflow as tf
import keras.backend as K
import keras.backend.tensorflow_backend as KTF

import shutil, os, sys, io, random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import math
from functools import partial
from tqdm import tqdm

os.chdir('/home/k_yonhon/py/Keras-GAN/pggan/')
sys.path.append(os.pardir)

from tensor_board_logger import TensorBoardLogger
from layer_visualizer import LayerVisualizer

config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
session = tf.Session(config=config)
KTF.set_session(session)

gpu_count = 2

In [8]:
class WGANGP():
    def __init__(self):
        # ---------------------
        #  for log on TensorBoard
        # ---------------------
        target_dir = "./my_log_dir"
        shutil.rmtree(target_dir, ignore_errors=True)
        os.mkdir(target_dir)
        self.logger = TensorBoardLogger(log_dir=target_dir)

        # ---------------------
        #  Parameter
        # ---------------------
        self.resume = 0
        
        self.n_critic = 5
        self.λ = 10
        
        self.img_rows = 32
        self.img_cols = 32
        self.channels = 3
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.input_rows = 4
        self.input_cols = 4
        self.latent_dim = 128  # Noiseの次元
        optimizer = Adam(lr=0.0005, beta_1=0., beta_2=0.9, epsilon=None, decay=0.0, amsgrad=False)

        # ---------------------
        #  Build model
        # ---------------------
        with tf.device("/cpu:0"):
            if self.resume == 0:
                self.critic = self.build_critic()
                self.generator = self.build_generator()
            else:
                self.critic = load_model('./saved_model/wgangp32_disc_model_'+str(self.resume)+'epoch.h5')
                self.generator = load_model('./saved_model/wgangp32_gen_model_'+str(self.resume)+'epoch.h5')
                
            #  Load pretrained weights
            pre_gen = load_model('./saved_model/wgangp16_gen_model.h5')
            for i, layer in enumerate(self.generator.layers[1].layers):
                if i in [i for i in range(1, int(math.log(self.img_rows / self.input_rows, 2)) * 2, 2)]:
                    layer.set_weights(pre_gen.layers[1].layers[i].get_weights())

            pre_critic = load_model('./saved_model/wgangp16_disc_model.h5')
            for i, layer in enumerate(self.critic.layers[1].layers):
                j = i - len(self.critic.layers[1].layers)
                if j in [-i for i in range(int(math.log(self.img_rows / self.input_rows, 2)) * 2, 0, -2)]:
                    layer.set_weights(pre_critic.layers[1].layers[j].get_weights())
                    layer.trainable = False

        #-------------------------------
        # Construct Computational Graph
        #       for the Critic
        #-------------------------------
        # Freeze generator's layers while training critic
        self.generator.trainable = False

        # Image input (real sample)
        real_img = Input(shape=self.img_shape)

        # Noise input
        z_disc = Input(shape=(self.latent_dim,))
        # Generate image based of noise (fake sample)
        fake_img = self.generator(z_disc)
        
        # Construct weighted average between real and fake images
        ϵ_input = Input(shape=(1,))
        # interpolated_img = K.dot(ϵ_input, real_img) + K.dot((1-ϵ_input), fake_img)
        # interpolated_img = K.dot(ϵ_input, real_img)
        self.interpolate = self.build_interpolate()
        mix_img = self.interpolate(ϵ_input, real_img, fake_img)
        ϵ_input = Input(shape=(1,))
        # mix_img = Lambda(lambda ϵ_input: ϵ_input * real_img + (1-ϵ_input) * fake_img, output_shape=self.img_shape)(ϵ_input)
        interpolated_img = Model(inputs=[ϵ_input, real_img, fake_img], outputs=[mix_img])
        
        # Discriminator determines validity of the real and fake and interpolated images
        valid = self.critic(real_img)
        fake = self.critic(fake_img)
        validity_interpolated = self.critic(interpolated_img)

        # Use Python partial to provide loss function with additional
        # 'averaged_samples' argument
        partial_gp_loss = partial(self.gradient_penalty_loss, 
                                  averaged_samples=interpolated_img)
        partial_gp_loss.__name__ = 'gradient_penalty' # Keras requires function names
        
        self.critic_model = Model(inputs=[real_img, z_disc, ϵ_input], 
                                  outputs=[valid, fake, validity_interpolated])
        if gpu_count > 1:
            self.critic_model = multi_gpu_model(self.critic_model, gpus=gpu_count)
        self.critic_model.compile(loss=[self.wasserstein_loss,
                                        self.wasserstein_loss,
                                        partial_gp_loss],
                                        optimizer=optimizer,
                                        loss_weights=[1, 1, self.λ])
        print('Critic Summary:')
        self.critic.summary()
        
        #-------------------------------
        # Construct Computational Graph
        #         for Generator
        #-------------------------------
        # For the generator we freeze the critic's layers
        self.critic.trainable = False
        
        self.generator.trainable = True
        for i, layer in enumerate(self.generator.layers[1].layers):
            if i in [i for i in range(1, int(math.log(self.img_rows / self.input_rows, 2)) * 2, 2)]:
                layer.trainable = False
                
        # Sampled noise for input to generator
        z_gen = Input(shape=(self.latent_dim,))
        # Generate images based of noise
        img = self.generator(z_gen)
        # Discriminator determines validity
        valid = self.critic(img)
        # Defines generator model
        self.generator_model = Model(z_gen, valid)
        if gpu_count > 1:
            self.generator_model = multi_gpu_model(self.generator_model, gpus=gpu_count)
        self.generator_model.compile(loss=self.wasserstein_loss, optimizer=optimizer)
        print('Genarator Summary:')
        self.generator.summary()
        
        # self.critic_model = tf.keras.estimator.model_to_estimator(keras_model=self.critic_model)
        # self.generator_model = tf.keras.estimator.model_to_estimator(keras_model=self.generator_model)
            
    def compute_gradients(self, tensor, var_list):
        grads = tf.gradients(tensor, var_list)
        return [grad if grad is not None else tf.zeros_like(var) for var, grad in zip(var_list, grads)]
    
    def gradient_penalty_loss(self, y_true, y_pred, averaged_samples):
        gradients_sqr = K.square(self.compute_gradients(y_pred, [averaged_samples])[0])
        gradients_sqr_sum = K.sum(gradients_sqr, axis=np.arange(1, len(gradients_sqr.shape)))
        gradient_l2_norm = K.sqrt(gradients_sqr_sum)
        gradient_penalty = K.square(1 - gradient_l2_norm)
        return K.mean(gradient_penalty)
   
    def wasserstein_loss(self, y_true, y_pred):
        return K.mean(y_true * y_pred)

    def build_interpolate(self):
        with tf.device("/cpu:0"):
            ϵ_input = Input(shape=(1,))
            real_img = Input(shape=self.img_shape)
            fake_img = Input(shape=self.img_shape)
            mix_img = Lambda(lambda ϵ_input, real_img, fake_img: ϵ_input * real_img + (1-ϵ_input) * fake_img)
            return Model(inputs=[ϵ_input, real_img, fake_img], outputs=[mix_img])

    def build_generator(self):
        with tf.device("/cpu:0"):
            model = Sequential()
            model.add(Reshape((self.input_rows, self.input_cols, int(self.latent_dim / (self.input_rows * self.input_cols))), 
                              input_shape=(self.latent_dim,)
                             ))

            model.add(Conv2DTranspose(512, (3, 3), strides=1, padding='same',
                                     kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                                     ))
            model.add(LeakyReLU(alpha=0.2))      

            for _ in range(int(math.log(self.img_rows / self.input_rows, 2))):
                model.add(Conv2DTranspose(512, (3, 3), strides=2, padding='same', 
                                         kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                                         ))
                model.add(LeakyReLU(alpha=0.2))

            model.add(Conv2DTranspose(256, (3, 3), strides=1, padding='same', 
                                      kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                                      ))
            model.add(LeakyReLU(alpha=0.2))

            model.add(Conv2DTranspose(128, (3, 3), strides=1, padding='same', 
                                     kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                                     ))
            model.add(LeakyReLU(alpha=0.2))

            model.add(Conv2DTranspose(64, (3, 3), strides=1, padding='same', 
                                     kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                                     ))
            model.add(LeakyReLU(alpha=0.2))

            model.add(Conv2DTranspose(32, (3, 3), strides=1, padding='same', 
                                     kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                                     ))
            model.add(LeakyReLU(alpha=0.2))

            model.add(Conv2DTranspose(16, (3, 3), strides=1, padding='same', 
                                     kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                                     ))
            model.add(LeakyReLU(alpha=0.2))

            model.add(Conv2DTranspose(3, (3, 3), strides=1, padding='same', 
                                     kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                                     ))                
            model.add(Activation("tanh"))

            noise = Input(shape=(self.latent_dim,))
            img = model(noise)
            return Model(noise, img)
    
    def build_critic(self):
        with tf.device("/cpu:0"):
            model = Sequential()
            model.add(Conv2D(16, (1, 1), strides=1, input_shape=self.img_shape, padding="valid",
                             kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                            ))
            model.add(LeakyReLU(alpha=0.2))

            model.add(Conv2D(32, (3, 3), strides=1, padding="same",
                             kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                            ))
            model.add(LeakyReLU(alpha=0.2))

            model.add(Conv2D(64, (3, 3), strides=1, padding="same",
                             kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                            ))
            model.add(LeakyReLU(alpha=0.2))

            model.add(Conv2D(128, (3, 3), strides=1, padding="same",
                             kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                            ))
            model.add(LeakyReLU(alpha=0.2))

            model.add(Conv2D(256, (3, 3), strides=1, padding="same",
                             kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                            ))
            model.add(LeakyReLU(alpha=0.2))

            model.add(Conv2D(512, (3, 3), strides=1, padding="same",
                             kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                            ))
            model.add(LeakyReLU(alpha=0.2))

            for _ in range(int(math.log(self.img_rows / self.input_rows, 2))):
                model.add(Conv2D(512, (3, 3), strides=2, padding="same",
                                 kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                                ))
                model.add(LeakyReLU(alpha=0.2))

            model.add(Conv2D(1, (4, 4), strides=1, padding="valid",
                             kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                            ))
            model.add(Flatten())

            img = Input(shape=self.img_shape)
            validity = model(img)
            return Model(img, validity)
    
    def train(self, epochs, batch_size, sample_interval=50):
        # ---------------------
        #  Load the dataset
        # ---------------------      
        # Original dataset
        X_train = np.load('../datasets/lfw32.npz')['arr_0']
        X_train = X_train / 127.5 - 1.0   # Rescale -1 to 1

        # Adversarial ground truths
        valid = -np.ones((batch_size, 1))
        fake =  np.ones((batch_size, 1))
        dummy = np.zeros((batch_size, 1))
                
        for epoch in tqdm(range(self.resume, self.resume + epochs + 1)):
            for _ in range(self.n_critic):
                # ---------------------
                #  Train Discriminator
                # ---------------------
                # Select a random batch of images
                idx = np.random.randint(0, X_train.shape[0], batch_size)
                imgs = X_train[idx]
                # Sample generator input
                noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
                epsilon = np.random.uniform(0, 1, (batch_size, 1))
                
                # Train the critic https://keras.io/ja/models/model/#train_on_batch
                d_loss = self.critic_model.train_on_batch([imgs, noise, epsilon], 
                                                          [valid, fake, dummy])
                d_loss = d_loss[0]

            # ---------------------
            #  Train Generator
            # ---------------------
            g_loss = self.generator_model.train_on_batch(noise, valid)

            # ---------------------
            #  Log on TensorBoard
            # ---------------------
            # Backup Model
            # if epoch != 0 and epoch % 1000 == 0:
            #     self.critic.save('./saved_model/wgangp32_disc_model_'+str(epoch+self.resume)+'epoch.h5')
            #     self.generator.save('./saved_model/wgangp32_gen_model_'+str(epoch+self.resume)+'epoch.h5')
            
            # Save Loss & Histgram
            logs = {
                "Discriminator/loss": d_loss,
                "Generator/loss": g_loss,
            }

            histograms = {}
            for layer in self.critic.layers[1].layers:
                for i in range(len(layer.get_weights())):
                    if "conv" in layer.name or "dense" in layer.name:
                        name = layer.name + "/" + str(i)
                        value = layer.get_weights()[i]
                        histograms[name] = value
            self.logger.log(logs=logs, histograms=histograms, epoch=epoch+self.resume)
            
            # Save generated image samples
            if epoch+self.resume == 1000 or epoch+self.resume == 2000 or (epoch+self.resume) % sample_interval == 0:
                fig, name = self.sample_images(epoch+self.resume)
                images = {name: fig}
                self.logger.log(images=images, epoch=epoch+self.resume)
                print("%d [D loss: %f] [G loss: %f]" % (epoch, d_loss, g_loss))

    def sample_images(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                if self.channels == 1:
                    axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap="gray")
                else:
                    axs[i, j].imshow(gen_imgs[cnt, :, :, :self.channels], cmap="gray")
                axs[i, j].axis("off")
                cnt += 1
        name = str(epoch) + ".png"
        return fig, name

In [9]:
wgan = WGANGP()

ValueError: Output tensors to a Model must be the output of a Keras `Layer` (thus holding past layer metadata). Found: <keras.layers.core.Lambda object at 0x7fcc92b00b38>

In [44]:
wgan.train(epochs=10000, batch_size=32, sample_interval=1000)

  0%|          | 0/10001 [00:00<?, ?it/s]


ValueError: Layer model_57 was called with an input that isn't a symbolic tensor. Received type: <class 'numpy.ndarray'>. Full input: [array([[ 0.65793702, -0.75805358, -0.11626432, ..., -0.98515984,
        -0.58734277, -0.36988764],
       [ 0.69364593,  0.59771056, -0.32158228, ...,  0.00577185,
         0.95836626, -0.22131529],
       [-0.63888897,  1.06703683, -0.28133707, ...,  0.67263259,
         1.00769612,  0.3193566 ],
       ...,
       [ 1.88963822,  1.03870942, -0.94576235, ...,  2.17268146,
         0.54735567, -1.18917328],
       [ 0.63212721, -1.02561557, -0.55356255, ...,  0.32797021,
         0.2961781 , -0.56840746],
       [ 0.86339462,  2.1592337 , -1.0359202 , ...,  0.19962849,
         0.84791822, -0.07980201]])]. All inputs to the layer should be tensors.

In [2]:
class RandomWeightedAverage(_Merge):
    """Provides a (random) weighted average between real and generated image samples"""
    def _merge_function(self, inputs):
        alpha = K.random_uniform((32, 1, 1, 1))
        return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])

In [20]:
for layer in wgan.critic.layers[1].layers:
    print(layer.name)

conv2d_34
leaky_re_lu_61
conv2d_35
leaky_re_lu_62
conv2d_36
leaky_re_lu_63
conv2d_37
leaky_re_lu_64
conv2d_38
leaky_re_lu_65
conv2d_39
leaky_re_lu_66
conv2d_40
leaky_re_lu_67
conv2d_41
leaky_re_lu_68
conv2d_42
leaky_re_lu_69
conv2d_43
leaky_re_lu_70
conv2d_44
flatten_4


In [None]:
range(int(math.log(self.img_rows / self.input_rows, 2)))

In [33]:
list = [i for i in range(1, int(math.log(64 / 4, 2)) * 2, 2)]
print(list)

[1, 3, 5, 7]


In [34]:
list = [-i for i in range(int(math.log(64 / 4, 2)) * 2, 0, -2)]
print(list)

[-8, -6, -4, -2]


In [None]:
[i for i in range(1, int(math.log(self.img_rows / self.input_rows, 2)) * 2, 2)]
[-i for i in range(int(math.log(self.img_rows / self.input_rows, 2)) * 2, 0, -2)]