In [1]:
from keras.layers import Input, Reshape, Flatten, Activation
from keras.layers.merge import _Merge
from keras.layers.advanced_activations import LeakyReLU, ReLU
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.models import Sequential, Model
from keras.optimizers import Adam, RMSprop
from keras.models import load_model
from keras.utils import multi_gpu_model

import keras
import tensorflow as tf
import keras.backend as K
import keras.backend.tensorflow_backend as KTF

import shutil, os, sys, io, random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import math
from functools import partial
from itertools import zip_longest
from tqdm import tqdm

os.chdir('/home/k_yonhon/py/Keras-GAN/pggan/')
sys.path.append(os.pardir)

from tensor_board_logger import TensorBoardLogger
from layer_visualizer import LayerVisualizer

config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
session = tf.Session(config=config)
KTF.set_session(session)

gpu_count = 2
batch_size = 64

In [2]:
class RandomWeightedAverage(_Merge):
    """Takes a randomly-weighted average of two tensors. In geometric terms, this
    outputs a random point on the line between each pair of input points.
    Inheriting from _Merge is a little messy but it was the quickest solution I could
    think of. Improvements appreciated."""

    def _merge_function(self, inputs):
        weights = K.random_uniform((batch_size, 1, 1, 1))
        return (weights * inputs[0]) + ((1 - weights) * inputs[1])

def compute_gradients(tensor, var_list):
    grads = tf.gradients(tensor, var_list) # return A list of sum(dy/dx) for each x in xs.
    return [grad if grad is not None else tf.zeros_like(var) for grad, var in zip_longest(grads, var_list)]

def wasserstein_loss(y_true, y_pred):
    return K.mean(y_true * y_pred)

def gradient_penalty_loss(y_true, y_pred, averaged_samples,
                          gradient_penalty_weight):
    # gradients = K.gradients(y_pred, averaged_samples)[0]
    gradients = compute_gradients(y_pred, [averaged_samples])[0]
    # compute the euclidean norm by squaring ...
    gradients_sqr = K.square(gradients)
    #   ... summing over the rows ...
    gradients_sqr_sum = K.sum(gradients_sqr, axis=np.arange(1, len(gradients_sqr.shape)))
    #   ... and sqrt
    gradient_l2_norm = K.sqrt(gradients_sqr_sum)
    # compute lambda * (1 - ||grad||)^2 still for each single sample
    gradient_penalty = gradient_penalty_weight * K.square(1 - gradient_l2_norm)
    # return the mean as loss over all the batch samples
    return K.mean(gradient_penalty)

In [3]:
class WGANGP():
    def __init__(self):
        # ---------------------
        #  for log on TensorBoard
        # ---------------------
        target_dir = "./my_log_dir"
        shutil.rmtree(target_dir, ignore_errors=True)
        os.mkdir(target_dir)
        self.logger = TensorBoardLogger(log_dir=target_dir)

        # ---------------------
        #  Parameter
        # ---------------------
        self.resume = 0
        
        self.n_critic = 5
        self.λ = 10
        
        self.img_rows = 32
        self.img_cols = 32
        self.channels = 3
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.input_rows = 4
        self.input_cols = 4
        self.latent_dim = 128  # Noiseの次元
        optimizer = Adam(lr=0.00001, beta_1=0., beta_2=0.9, epsilon=None, decay=0.0, amsgrad=False)
        # optimizer = Adam(lr=0.00005, beta_1=0., beta_2=0.9, epsilon=None, decay=0.0, amsgrad=False)
        # optimizer = Adam(0.0001, beta_1=0., beta_2=0.9)

        # ---------------------
        #  Build model
        # ---------------------
        with tf.device("/cpu:0"):
            if self.resume == 0:
                self.critic = self.build_critic()
                self.generator = self.build_generator()
            else:
                self.critic = load_model('./saved_model/wgangp32_disc_model_'+str(self.resume)+'epoch.h5')
                self.generator = load_model('./saved_model/wgangp32_gen_model_'+str(self.resume)+'epoch.h5')
                
            #  Load pretrained weights
            pre_gen = load_model('./saved_model/wgangp16_gen_model.h5')
            for i, layer in enumerate(self.generator.layers[1].layers):
                if i in [i for i in range(1, int(math.log(self.img_rows / self.input_rows, 2)) * 2, 2)]:
                    layer.set_weights(pre_gen.layers[1].layers[i].get_weights())

            pre_critic = load_model('./saved_model/wgangp16_disc_model.h5')
            for i, layer in enumerate(self.critic.layers[1].layers):
                j = i - len(self.critic.layers[1].layers)
                if j in [-i for i in range(int(math.log(self.img_rows / self.input_rows, 2)) * 2, 0, -2)]:
                    layer.set_weights(pre_critic.layers[1].layers[j].get_weights())
                    layer.trainable = False

        #-------------------------------
        # Construct Computational Graph
        #       for the Critic
        #-------------------------------
        # Freeze generator's layers while training critic
        self.generator.trainable = False
     
        real_samples = Input(shape=self.img_shape)
        generator_input_for_discriminator = Input(shape=(self.latent_dim,))
        generated_samples_for_discriminator = self.generator(generator_input_for_discriminator)
        discriminator_output_from_generator = self.critic(generated_samples_for_discriminator)
        discriminator_output_from_real_samples = self.critic(real_samples)

        averaged_samples = RandomWeightedAverage()([real_samples,
                                                    generated_samples_for_discriminator])
        averaged_samples_out = self.critic(averaged_samples)

        partial_gp_loss = partial(gradient_penalty_loss,
                                  averaged_samples=averaged_samples,
                                  gradient_penalty_weight=self.λ)
        # Functions need names or Keras will throw an error
        partial_gp_loss.__name__ = 'gradient_penalty'

        self.critic_model = Model(inputs=[real_samples, 
                                          generator_input_for_discriminator],
                                  outputs=[discriminator_output_from_real_samples, 
                                           discriminator_output_from_generator,
                                           averaged_samples_out])
        if gpu_count > 1:
            self.critic_model = multi_gpu_model(self.critic_model, gpus=gpu_count)
        self.critic_model.compile(optimizer=optimizer, 
                                  loss=[wasserstein_loss, 
                                        wasserstein_loss, 
                                        partial_gp_loss])
        
        print('Critic Summary:')
        self.critic.summary()       
        
        #-------------------------------
        # Construct Computational Graph
        #         for Generator
        #-------------------------------
        # For the generator we freeze the critic's layers
        self.critic.trainable = False
        self.generator.trainable = True
        for i, layer in enumerate(self.generator.layers[1].layers):
            if i in [i for i in range(1, int(math.log(self.img_rows / self.input_rows, 2)) * 2, 2)]:
                layer.trainable = False
                    
        generator_input = Input(shape=(self.latent_dim,))
        generator_layers = self.generator(generator_input)
        discriminator_layers_for_generator = self.critic(generator_layers)
        
        self.generator_model = Model(inputs=[generator_input], 
                                     outputs=[discriminator_layers_for_generator])
        if gpu_count > 1:
            self.generator_model = multi_gpu_model(self.generator_model, gpus=gpu_count)
        self.generator_model.compile(optimizer=optimizer,
                                     loss=wasserstein_loss)        

        print('Genarator Summary:')
        self.generator.summary()
   
    def build_generator(self):
        with tf.device("/cpu:0"):
            model = Sequential()
            model.add(Reshape((self.input_rows, self.input_cols, int(self.latent_dim / (self.input_rows * self.input_cols))), 
                              input_shape=(self.latent_dim,)
                             ))

            model.add(Conv2DTranspose(512, (3, 3), strides=1, padding='same',
                                     kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                                     ))
            model.add(LeakyReLU(alpha=0.2))      

            for _ in range(int(math.log(self.img_rows / self.input_rows, 2))):
                model.add(Conv2DTranspose(512, (3, 3), strides=2, padding='same', 
                                         kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                                         ))
                model.add(LeakyReLU(alpha=0.2))

            model.add(Conv2DTranspose(256, (3, 3), strides=1, padding='same', 
                                      kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                                      ))
            model.add(LeakyReLU(alpha=0.2))

            model.add(Conv2DTranspose(128, (3, 3), strides=1, padding='same', 
                                     kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                                     ))
            model.add(LeakyReLU(alpha=0.2))

            model.add(Conv2DTranspose(64, (3, 3), strides=1, padding='same', 
                                     kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                                     ))
            model.add(LeakyReLU(alpha=0.2))

            model.add(Conv2DTranspose(32, (3, 3), strides=1, padding='same', 
                                     kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                                     ))
            model.add(LeakyReLU(alpha=0.2))

            model.add(Conv2DTranspose(16, (3, 3), strides=1, padding='same', 
                                     kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                                     ))
            model.add(LeakyReLU(alpha=0.2))

            model.add(Conv2DTranspose(3, (3, 3), strides=1, padding='same', 
                                     kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                                     ))                
            model.add(Activation("tanh"))

            noise = Input(shape=(self.latent_dim,))
            img = model(noise)
            return Model(noise, img)
    
    def build_critic(self):
        with tf.device("/cpu:0"):
            model = Sequential()
            model.add(Conv2D(16, (1, 1), strides=1, input_shape=self.img_shape, padding="valid",
                             kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                            ))
            model.add(LeakyReLU(alpha=0.2))

            model.add(Conv2D(32, (3, 3), strides=1, padding="same",
                             kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                            ))
            model.add(LeakyReLU(alpha=0.2))

            model.add(Conv2D(64, (3, 3), strides=1, padding="same",
                             kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                            ))
            model.add(LeakyReLU(alpha=0.2))

            model.add(Conv2D(128, (3, 3), strides=1, padding="same",
                             kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                            ))
            model.add(LeakyReLU(alpha=0.2))

            model.add(Conv2D(256, (3, 3), strides=1, padding="same",
                             kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                            ))
            model.add(LeakyReLU(alpha=0.2))

            model.add(Conv2D(512, (3, 3), strides=1, padding="same",
                             kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                            ))
            model.add(LeakyReLU(alpha=0.2))

            for _ in range(int(math.log(self.img_rows / self.input_rows, 2))):
                model.add(Conv2D(512, (3, 3), strides=2, padding="same",
                                 kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                                ))
                model.add(LeakyReLU(alpha=0.2))

            model.add(Conv2D(1, (4, 4), strides=1, padding="valid",
                             kernel_initializer=keras.initializers.Orthogonal(gain=1.4, seed=None),
                            ))
            model.add(Flatten())

            img = Input(shape=self.img_shape)
            validity = model(img)
            return Model(img, validity)
    
    def train(self, epochs, batch_size, sample_interval=50):
        # ---------------------
        #  Load the dataset
        # ---------------------      
        # Original dataset
        X_train = np.load('../datasets/lfw32.npz')['arr_0']
        X_train = X_train / 127.5 - 1.0   # Rescale -1 to 1

        # Adversarial ground truths
        valid = np.ones((batch_size, 1), dtype=np.float32)
        fake = -np.ones((batch_size, 1), dtype=np.float32)
        dummy = np.zeros((batch_size, 1), dtype=np.float32)
        
        for epoch in tqdm(range(self.resume, self.resume + epochs + 1)):
            for _ in range(self.n_critic):
                # ---------------------
                #  Train Discriminator
                # ---------------------
                # Select a random batch of images
                idx = np.random.randint(0, X_train.shape[0], batch_size)
                imgs = X_train[idx]
                # Sample generator input
                noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
                # Train the critic
                d_loss = self.critic_model.train_on_batch([imgs, noise], 
                                                          [valid, fake, dummy])
                d_loss = d_loss[0]
                
            # ---------------------
            #  Train Generator
            # ---------------------
            g_loss = self.generator_model.train_on_batch(noise, valid)

            # ---------------------
            #  Log on TensorBoard
            # ---------------------
            # Backup Model
            # if epoch != 0 and epoch % 1000 == 0:
            #     self.critic.save('./saved_model/wgangp32_disc_model_'+str(epoch+self.resume)+'epoch.h5')
            #     self.generator.save('./saved_model/wgangp32_gen_model_'+str(epoch+self.resume)+'epoch.h5')
            
            # Save Loss & Histgram
            logs = {
                "Discriminator/loss": d_loss,
                "Generator/loss": g_loss,
            }

            histograms = {}
            for layer in self.critic.layers[1].layers:
                for i in range(len(layer.get_weights())):
                    if "conv" in layer.name or "dense" in layer.name:
                        name = layer.name + "/" + str(i)
                        value = layer.get_weights()[i]
                        histograms[name] = value
            self.logger.log(logs=logs, histograms=histograms, epoch=epoch+self.resume)
            
            # Save generated image samples
            if epoch+self.resume == 1000 or epoch+self.resume == 2000 or (epoch+self.resume) % sample_interval == 0:
                fig, name = self.sample_images(epoch+self.resume)
                images = {name: fig}
                self.logger.log(images=images, epoch=epoch+self.resume)
                print("%d [D loss: %f] [G loss: %f]" % (epoch, d_loss, g_loss))

    def sample_images(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = ((0.5 * gen_imgs + 0.5) * 255).astype(np.uint8)
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                if self.channels == 1:
                    axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap="gray")
                else:
                    axs[i, j].imshow(gen_imgs[cnt, :, :, :self.channels], cmap="gray")
                axs[i, j].axis("off")
                cnt += 1
        name = str(epoch) + ".png"
        return fig, name

In [4]:
wgan = WGANGP()



Critic Summary:
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 32, 32, 3)         0         
_________________________________________________________________
sequential_1 (Sequential)    (None, 1)                 8660001   
Total params: 8,660,001
Trainable params: 3,932,192
Non-trainable params: 4,727,809
_________________________________________________________________
Genarator Summary:
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         (None, 128)               0         
_________________________________________________________________
sequential_2 (Sequential)    (None, 32, 32, 3)         8689059   
Total params: 8,689,059
Trainable params: 3,932,067
Non-trainable params: 4,756,992
_________________________________________________________________


In [5]:
X_train = np.load('../datasets/lfw32.npz')['arr_0']
X_train = X_train / 127.5 - 1.0   # Rescale -1 to 1
idx = np.random.randint(0, X_train.shape[0], batch_size)
real_samples = X_train[idx]

In [9]:
generator_input_for_discriminator = np.random.normal(0, 1, (batch_size, 128))
generated_samples_for_discriminator = wgan.generator.predict(generator_input_for_discriminator)

In [11]:
averaged_samples = RandomWeightedAverage()([real_samples, generated_samples_for_discriminator])
# averaged_samples.shape

ValueError: Layer random_weighted_average_3 was called with an input that isn't a symbolic tensor. Received type: <class 'numpy.ndarray'>. Full input: [array([[[[ 0.34901961,  0.00392157, -0.30196078],
         [ 0.5372549 ,  0.12156863, -0.22352941],
         [ 0.64705882,  0.19215686, -0.12156863],
         ...,
         [-0.30980392, -0.58431373, -0.83529412],
         [-0.3254902 , -0.6       , -0.81960784],
         [-0.2       , -0.48235294, -0.73333333]],

        [[-0.00392157, -0.28627451, -0.5372549 ],
         [-0.30980392, -0.6       , -0.81176471],
         [ 0.23137255, -0.1372549 , -0.37254902],
         ...,
         [-0.3254902 , -0.58431373, -0.84313725],
         [-0.34901961, -0.61568627, -0.83529412],
         [-0.17647059, -0.48235294, -0.7254902 ]],

        [[-0.16862745, -0.4745098 , -0.71764706],
         [-0.60784314, -0.82745098, -0.92156863],
         [-0.34901961, -0.63921569, -0.78039216],
         ...,
         [-0.3254902 , -0.59215686, -0.81176471],
         [-0.35686275, -0.6       , -0.82745098],
         [-0.16078431, -0.48235294, -0.71764706]],

        ...,

        [[-0.41176471, -0.48235294, -0.5372549 ],
         [-0.41176471, -0.55294118, -0.6627451 ],
         [-0.34117647, -0.41960784, -0.41176471],
         ...,
         [ 0.10588235, -0.18431373, -0.39607843],
         [ 0.1372549 , -0.14509804, -0.31764706],
         [ 0.15294118, -0.09803922, -0.21568627]],

        [[-0.90588235, -0.98431373, -1.        ],
         [-0.56862745, -0.63137255, -0.54509804],
         [-0.78823529, -0.71764706, -0.4745098 ],
         ...,
         [ 0.0745098 , -0.2       , -0.35686275],
         [ 0.19215686, -0.12941176, -0.27058824],
         [ 0.12941176, -0.12156863, -0.22352941]],

        [[-0.34901961, -0.43529412, -0.40392157],
         [-0.91372549, -0.89019608, -0.77254902],
         [-0.95294118, -0.89803922, -0.75686275],
         ...,
         [ 0.0745098 , -0.17647059, -0.37254902],
         [-0.04313725, -0.23921569, -0.27843137],
         [-0.04313725, -0.24705882, -0.25490196]]],


       [[[-1.        , -1.        , -1.        ],
         [-0.96862745, -1.        , -1.        ],
         [-0.36470588, -0.49803922, -0.70196078],
         ...,
         [-0.0745098 , -0.00392157, -0.16862745],
         [-0.0745098 , -0.00392157, -0.16862745],
         [-0.09019608, -0.01960784, -0.18431373]],

        [[-1.        , -1.        , -1.        ],
         [-0.96862745, -1.        , -1.        ],
         [-0.38039216, -0.51372549, -0.71764706],
         ...,
         [-0.06666667,  0.00392157, -0.16078431],
         [-0.06666667,  0.00392157, -0.16078431],
         [-0.08235294, -0.01176471, -0.17647059]],

        [[-1.        , -1.        , -1.        ],
         [-0.97647059, -1.        , -1.        ],
         [-0.38039216, -0.51372549, -0.71764706],
         ...,
         [-0.06666667,  0.00392157, -0.16078431],
         [-0.05882353,  0.01176471, -0.15294118],
         [-0.06666667,  0.00392157, -0.16078431]],

        ...,

        [[-1.        , -1.        , -1.        ],
         [-1.        , -1.        , -0.98431373],
         [-0.78823529, -0.79607843, -0.75686275],
         ...,
         [ 0.22352941, -0.12156863, -0.14509804],
         [ 0.08235294, -0.14509804, -0.20784314],
         [-0.67843137, -0.70196078, -0.74117647]],

        [[-1.        , -1.        , -1.        ],
         [-1.        , -1.        , -0.98431373],
         [-0.78823529, -0.79607843, -0.75686275],
         ...,
         [ 0.34901961,  0.2       ,  0.09019608],
         [-0.48235294, -0.5372549 , -0.61568627],
         [-0.74901961, -0.73333333, -0.75686275]],

        [[-1.        , -1.        , -1.        ],
         [-0.99215686, -0.99215686, -0.97647059],
         [-0.80392157, -0.81176471, -0.77254902],
         ...,
         [ 0.05882353, -0.00392157, -0.09019608],
         [-0.85882353, -0.85882353, -0.8745098 ],
         [-0.8745098 , -0.8745098 , -0.8745098 ]]],


       [[[-1.        , -1.        , -1.        ],
         [-1.        , -1.        , -1.        ],
         [-1.        , -1.        , -1.        ],
         ...,
         [-1.        , -1.        , -1.        ],
         [-1.        , -1.        , -1.        ],
         [-1.        , -1.        , -1.        ]],

        [[-1.        , -1.        , -1.        ],
         [-1.        , -1.        , -1.        ],
         [-1.        , -1.        , -1.        ],
         ...,
         [-1.        , -1.        , -1.        ],
         [-1.        , -1.        , -1.        ],
         [-1.        , -1.        , -1.        ]],

        [[ 0.51372549,  0.55294118,  0.37254902],
         [ 0.57647059,  0.61568627,  0.45882353],
         [ 0.63921569,  0.67843137,  0.50588235],
         ...,
         [ 0.80392157,  0.85882353,  0.73333333],
         [ 0.78039216,  0.83529412,  0.70980392],
         [ 0.77254902,  0.82745098,  0.70196078]],

        ...,

        [[ 0.58431373,  0.19215686,  0.2       ],
         [ 0.67058824,  0.63921569,  0.44313725],
         [-0.63137255, -0.62352941, -0.68627451],
         ...,
         [-0.74901961, -0.78039216, -0.80392157],
         [-0.75686275, -0.78039216, -0.81960784],
         [-0.74901961, -0.77254902, -0.81176471]],

        [[ 0.00392157, -0.50588235, -0.5372549 ],
         [ 0.31764706, -0.05882353, -0.1372549 ],
         [-0.70196078, -0.73333333, -0.75686275],
         ...,
         [-0.77254902, -0.76470588, -0.80392157],
         [-0.79607843, -0.79607843, -0.81176471],
         [-0.80392157, -0.80392157, -0.81960784]],

        [[-0.14509804, -0.59215686, -0.6627451 ],
         [-0.40392157, -0.42745098, -0.49803922],
         [-0.78039216, -0.77254902, -0.81960784],
         ...,
         [-0.81176471, -0.79607843, -0.81960784],
         [-0.80392157, -0.78823529, -0.79607843],
         [-0.78823529, -0.78039216, -0.76470588]]],


       ...,


       [[[-1.        , -1.        , -1.        ],
         [-1.        , -1.        , -1.        ],
         [-1.        , -1.        , -1.        ],
         ...,
         [-1.        , -1.        , -1.        ],
         [-1.        , -1.        , -1.        ],
         [-1.        , -1.        , -1.        ]],

        [[-1.        , -1.        , -1.        ],
         [-1.        , -1.        , -1.        ],
         [-1.        , -1.        , -1.        ],
         ...,
         [-1.        , -1.        , -1.        ],
         [-1.        , -1.        , -1.        ],
         [-1.        , -1.        , -1.        ]],

        [[-0.98431373, -1.        , -0.97647059],
         [-0.98431373, -1.        , -0.97647059],
         [-0.97647059, -1.        , -0.98431373],
         ...,
         [-0.96862745, -1.        , -1.        ],
         [-0.95294118, -1.        , -1.        ],
         [-0.96862745, -1.        , -1.        ]],

        ...,

        [[-0.49803922, -0.45098039, -0.46666667],
         [-0.52941176, -0.48235294, -0.49803922],
         [-0.27058824, -0.27058824, -0.34901961],
         ...,
         [ 0.86666667,  0.4745098 ,  0.2       ],
         [ 0.98431373,  0.62352941,  0.44313725],
         [ 0.70980392,  0.18431373,  0.23921569]],

        [[-0.38039216, -0.34901961, -0.3254902 ],
         [-0.34901961, -0.35686275, -0.39607843],
         [ 0.54509804,  0.35686275,  0.15294118],
         ...,
         [ 0.96078431,  0.52156863,  0.24705882],
         [ 0.92941176,  0.58431373,  0.37254902],
         [ 0.99215686,  0.70196078,  0.56078431]],

        [[-0.2627451 , -0.27843137, -0.37254902],
         [ 0.51372549,  0.28627451,  0.01960784],
         [ 0.52941176,  0.12941176, -0.22352941],
         ...,
         [ 0.49019608, -0.00392157, -0.27843137],
         [ 0.62352941,  0.14509804, -0.09803922],
         [ 0.76470588,  0.28627451,  0.05882353]]],


       [[[-0.9372549 , -1.        , -1.        ],
         [-0.96862745, -1.        , -1.        ],
         [-0.97647059, -1.        , -1.        ],
         ...,
         [-0.92941176, -0.99215686, -1.        ],
         [-0.94509804, -1.        , -1.        ],
         [-1.        , -1.        , -0.98431373]],

        [[-0.79607843, -0.98431373, -0.98431373],
         [-0.88235294, -0.96862745, -0.98431373],
         [-0.45098039, -0.50588235, -0.58431373],
         ...,
         [-0.7254902 , -0.78823529, -0.70196078],
         [-0.57647059, -0.69411765, -0.65490196],
         [-0.92941176, -0.96862745, -0.91372549]],

        [[-0.74901961, -0.91372549, -0.9372549 ],
         [-0.78039216, -0.92156863, -0.9372549 ],
         [-0.6       , -0.63921569, -0.6627451 ],
         ...,
         [-0.76470588, -0.81176471, -0.71764706],
         [-0.73333333, -0.82745098, -0.71764706],
         [-0.81960784, -0.8745098 , -0.74901961]],

        ...,

        [[-0.80392157, -0.83529412, -0.82745098],
         [-0.85882353, -0.89019608, -0.88235294],
         [-0.82745098, -0.85882353, -0.85098039],
         ...,
         [-0.63921569, -0.85882353, -0.94509804],
         [-0.42745098, -0.54509804, -0.88235294],
         [-0.48235294, -0.56862745, -0.92941176]],

        [[-0.82745098, -0.85098039, -0.92156863],
         [-0.8745098 , -0.90588235, -0.91372549],
         [-0.86666667, -0.89803922, -0.90588235],
         ...,
         [-0.62352941, -0.79607843, -0.90588235],
         [-0.63137255, -0.71764706, -0.9372549 ],
         [-0.75686275, -0.81176471, -0.9372549 ]],

        [[-0.79607843, -0.82745098, -0.85098039],
         [-0.86666667, -0.89803922, -0.90588235],
         [-0.88235294, -0.91372549, -0.92156863],
         ...,
         [-0.70980392, -0.92941176, -0.96078431],
         [-0.68627451, -0.88235294, -0.85882353],
         [-0.63137255, -0.75686275, -0.65490196]]],


       [[[-0.42745098, -0.41960784, -0.46666667],
         [-0.24705882, -0.27058824, -0.34117647],
         [ 0.6       ,  0.59215686,  0.56078431],
         ...,
         [ 0.80392157,  0.81176471,  0.82745098],
         [ 0.80392157,  0.81176471,  0.84313725],
         [ 0.78823529,  0.81960784,  0.82745098]],

        [[-0.54509804, -0.58431373, -0.63137255],
         [-0.41176471, -0.41176471, -0.4745098 ],
         [-0.2       , -0.23137255, -0.3254902 ],
         ...,
         [ 0.78039216,  0.81176471,  0.81960784],
         [ 0.70196078,  0.73333333,  0.75686275],
         [ 0.51372549,  0.55294118,  0.57647059]],

        [[-0.27058824, -0.44313725, -0.62352941],
         [-0.37254902, -0.41960784, -0.51372549],
         [-0.36470588, -0.41960784, -0.49803922],
         ...,
         [ 0.81176471,  0.81960784,  0.83529412],
         [ 0.80392157,  0.79607843,  0.83529412],
         [ 0.70196078,  0.73333333,  0.75686275]],

        ...,

        [[-0.74901961, -0.71764706, -0.64705882],
         [-0.80392157, -0.74901961, -0.67058824],
         [-0.79607843, -0.74901961, -0.65490196],
         ...,
         [-0.52156863, -0.6627451 , -0.77254902],
         [ 0.39607843,  0.24705882,  0.11372549],
         [ 0.71764706,  0.60784314,  0.52156863]],

        [[-0.78039216, -0.71764706, -0.63137255],
         [-0.70980392, -0.65490196, -0.57647059],
         [-0.81176471, -0.76470588, -0.65490196],
         ...,
         [ 0.29411765,  0.12156863,  0.01960784],
         [ 0.59215686,  0.48235294,  0.41176471],
         [ 0.7254902 ,  0.61568627,  0.59215686]],

        [[-0.69411765, -0.63921569, -0.56078431],
         [-0.70196078, -0.70196078, -0.60784314],
         [-0.74117647, -0.71764706, -0.64705882],
         ...,
         [ 0.51372549,  0.40392157,  0.31764706],
         [ 0.67843137,  0.61568627,  0.6       ],
         [ 0.76470588,  0.65490196,  0.64705882]]]]), array([[[[ 0.28005204, -0.01295621,  0.5360961 ],
         [ 0.7138143 , -0.32918018,  0.08042552],
         [ 0.6813825 ,  0.42994848,  0.17496546],
         ...,
         [ 0.75120044, -0.2074037 ,  0.73411494],
         [ 0.6202267 , -0.15057443, -0.02427469],
         [ 0.37704888,  0.13028482,  0.39609212]],

        [[ 0.798608  , -0.0048413 ,  0.30336818],
         [ 0.9139696 ,  0.30283973,  0.15080553],
         [ 0.39215103,  0.28539187,  0.37781233],
         ...,
         [ 0.90145683, -0.09665363,  0.3289767 ],
         [-0.01507211,  0.59585404,  0.67706436],
         [-0.52627254,  0.4035525 ,  0.5719068 ]],

        [[ 0.7555575 , -0.52795196,  0.43682513],
         [ 0.41943383, -0.00275957,  0.28202432],
         [ 0.8449016 ,  0.4020259 , -0.32986307],
         ...,
         [ 0.1052077 ,  0.94883096,  0.76286906],
         [-0.04619666,  0.7481997 , -0.7266152 ],
         [-0.17623495,  0.56028736,  0.59310627]],

        ...,

        [[ 0.2909751 , -0.5842813 ,  0.3946403 ],
         [-0.24082695, -0.08199424,  0.5840604 ],
         [ 0.9111845 ,  0.6912433 ,  0.92412055],
         ...,
         [ 0.91500014,  0.91200167, -0.8868953 ],
         [ 0.86651576,  0.41030794, -0.14312866],
         [-0.18599217,  0.7101319 ,  0.01032009]],

        [[ 0.5036033 , -0.32457975, -0.62310004],
         [ 0.07955931,  0.71538   , -0.27250555],
         [-0.3282044 ,  0.7913524 ,  0.6508543 ],
         ...,
         [ 0.43469596,  0.99647766, -0.57596296],
         [ 0.82176423,  0.9422403 , -0.71396595],
         [ 0.5203975 ,  0.76501244,  0.41422945]],

        [[ 0.5494541 ,  0.606268  , -0.5817314 ],
         [ 0.2197637 ,  0.05814296, -0.1381551 ],
         [-0.6139494 ,  0.2679585 , -0.4904744 ],
         ...,
         [ 0.4727358 ,  0.46447867, -0.4251555 ],
         [ 0.04693783,  0.5672575 , -0.56019217],
         [ 0.13168946,  0.2460086 , -0.58633715]]],


       [[[ 0.18709867, -0.13919784,  0.21326135],
         [ 0.38496312, -0.06465629,  0.29490995],
         [ 0.3686448 ,  0.30889472,  0.2318039 ],
         ...,
         [ 0.862891  ,  0.6628613 ,  0.42536083],
         [ 0.3052695 ,  0.36013907, -0.22544855],
         [ 0.04473574,  0.5282923 ,  0.09166476]],

        [[ 0.27202237, -0.12902533,  0.2405799 ],
         [ 0.2931427 , -0.40337607,  0.39754632],
         [ 0.35103995,  0.40132898,  0.45857042],
         ...,
         [ 0.9444176 ,  0.59867996,  0.65695447],
         [ 0.3239954 ,  0.58770716,  0.26058698],
         [ 0.01199465,  0.79352945,  0.4366563 ]],

        [[ 0.42462772, -0.20005746,  0.02098982],
         [ 0.44561335,  0.37549317,  0.22094901],
         [ 0.68862754,  0.04966352,  0.19684973],
         ...,
         [ 0.36331964,  0.9463984 , -0.01874124],
         [ 0.82969874,  0.7728404 ,  0.60578847],
         [ 0.52958775,  0.3288805 ,  0.46693346]],

        ...,

        [[-0.04922634, -0.50676066,  0.16763727],
         [ 0.7654162 ,  0.09383211, -0.16851352],
         [ 0.81068856,  0.21837115, -0.11404665],
         ...,
         [ 0.81920767,  0.9626284 ,  0.19121765],
         [ 0.91598725,  0.8242136 ,  0.675939  ],
         [ 0.98958766,  0.8504024 , -0.13849683]],

        [[ 0.42106298, -0.23145925, -0.47753114],
         [-0.09348758,  0.34995797,  0.5689166 ],
         [-0.13490617,  0.38598004, -0.0739463 ],
         ...,
         [-0.872576  ,  0.99970996, -0.6376413 ],
         [-0.5643231 ,  0.9686425 ,  0.4598599 ],
         [ 0.17398652,  0.4451401 , -0.6272094 ]],

        [[ 0.26290858,  0.1357264 , -0.08694506],
         [ 0.09538344,  0.36150476, -0.49037078],
         [ 0.4393836 ,  0.66015387, -0.56874573],
         ...,
         [-0.95418566,  0.73919576, -0.89956546],
         [-0.41213933,  0.85994285, -0.58959854],
         [-0.45036927,  0.8905244 , -0.53529286]]],


       [[[ 0.50854367, -0.15723808,  0.3460986 ],
         [ 0.5923668 , -0.05765957, -0.06700322],
         [ 0.7864806 ,  0.20196047,  0.3031795 ],
         ...,
         [ 0.99599874, -0.36235088,  0.972334  ],
         [ 0.8438084 , -0.8619196 ,  0.37582415],
         [ 0.73434734,  0.05059353,  0.68843234]],

        [[ 0.5494253 ,  0.06016738,  0.53668207],
         [ 0.67744964, -0.024435  ,  0.68129265],
         [ 0.8252342 , -0.47557768,  0.24468392],
         ...,
         [ 0.96790326,  0.94796735,  0.849233  ],
         [ 0.04041001,  0.92726094,  0.5186056 ],
         [-0.29125884,  0.85609055,  0.903522  ]],

        [[ 0.24386355,  0.09603879,  0.638444  ],
         [ 0.5438569 , -0.05604295, -0.08032707],
         [ 0.9669115 ,  0.32108423,  0.02866204],
         ...,
         [ 0.9953327 ,  0.9973665 ,  0.27985686],
         [-0.7513156 ,  0.97620666, -0.59840184],
         [ 0.52101797,  0.7611095 ,  0.88859135]],

        ...,

        [[ 0.94218403,  0.15800168, -0.3558232 ],
         [ 0.5634679 ,  0.1634108 ,  0.8828248 ],
         [ 0.98365676,  0.91767776, -0.37972066],
         ...,
         [ 0.99197435,  0.8894995 , -0.33795202],
         [ 0.57987833,  0.98631364, -0.87175894],
         [ 0.7506098 ,  0.7640584 ,  0.89557385]],

        [[-0.01933168, -0.21847302,  0.4099324 ],
         [-0.6521086 ,  0.9223896 ,  0.8712159 ],
         [ 0.5165878 ,  0.9419717 ,  0.5121213 ],
         ...,
         [-0.34982458,  0.9981403 ,  0.8152248 ],
         [ 0.17366739,  0.6005548 ,  0.02385741],
         [ 0.8690081 ,  0.95914674,  0.3823066 ]],

        [[ 0.670771  ,  0.46017122, -0.23274584],
         [-0.03414054,  0.74885744,  0.06870717],
         [-0.34645772,  0.9222028 , -0.845058  ],
         ...,
         [-0.17068623,  0.87573826, -0.55411583],
         [ 0.7688842 ,  0.58232135, -0.08363857],
         [ 0.33503023,  0.7668581 , -0.2618821 ]]],


       ...,


       [[[ 0.23417884,  0.17876536,  0.11649466],
         [ 0.5914266 , -0.40924248,  0.17160852],
         [ 0.63532114,  0.34961477,  0.41503245],
         ...,
         [ 0.7065354 ,  0.46054253,  0.6750122 ],
         [ 0.14220487,  0.09347253,  0.7391535 ],
         [-0.11980761,  0.32697633,  0.13327156]],

        [[ 0.71638817, -0.21592437,  0.43343046],
         [ 0.48729447, -0.56404185,  0.47115207],
         [ 0.9165136 ,  0.59082544,  0.63809204],
         ...,
         [ 0.5665324 ,  0.47289944,  0.37342313],
         [ 0.46962252, -0.14107609, -0.0352383 ],
         [ 0.47494397,  0.149785  ,  0.6033565 ]],

        [[ 0.654126  ,  0.06903381,  0.27928805],
         [ 0.23318654,  0.15353401,  0.27010903],
         [ 0.653843  ,  0.7067269 ,  0.60990196],
         ...,
         [ 0.94115245,  0.8625983 , -0.3548162 ],
         [ 0.3179231 ,  0.33286607, -0.4719782 ],
         [ 0.40788606, -0.1885545 ,  0.14062953]],

        ...,

        [[ 0.2956994 , -0.4715173 ,  0.77048314],
         [ 0.66679734,  0.553596  ,  0.5519909 ],
         [ 0.96092623,  0.347186  ,  0.6526715 ],
         ...,
         [ 0.99772215, -0.572089  ,  0.07590809],
         [ 0.96562964,  0.7960083 , -0.8240127 ],
         [ 0.05156404,  0.83744776,  0.13072442]],

        [[ 0.5037143 , -0.58291775, -0.42885998],
         [ 0.4919658 ,  0.35213837, -0.7591936 ],
         [-0.92565376,  0.88493574, -0.6352117 ],
         ...,
         [ 0.9987292 ,  0.9712032 ,  0.5909334 ],
         [ 0.94052464,  0.10441297,  0.76670766],
         [ 0.17932929,  0.9813099 ,  0.95210403]],

        [[ 0.7935247 ,  0.47642842, -0.6121167 ],
         [ 0.81518096, -0.23811792, -0.8078477 ],
         [ 0.23053178, -0.24941792, -0.45480022],
         ...,
         [ 0.7165835 , -0.44795972, -0.98817915],
         [-0.01567555,  0.39961734, -0.7695575 ],
         [ 0.07669327,  0.28870076, -0.27438408]]],


       [[[-0.07916274,  0.06127489,  0.06721335],
         [ 0.2663672 , -0.36294428,  0.2730812 ],
         [ 0.4274954 , -0.199237  ,  0.3858946 ],
         ...,
         [ 0.66138583,  0.57767266,  0.6823033 ],
         [-0.15571557, -0.05171663,  0.5929194 ],
         [ 0.20335785,  0.08930219,  0.5612966 ]],

        [[ 0.32803303, -0.14009355,  0.12747727],
         [ 0.13388939, -0.12808745,  0.31751183],
         [ 0.31059846,  0.2476053 ,  0.06148256],
         ...,
         [-0.3071546 ,  0.6127426 , -0.3971068 ],
         [ 0.15058222,  0.6484544 , -0.64954734],
         [ 0.6301742 ,  0.19707787,  0.6394374 ]],

        [[ 0.29541293, -0.25610512,  0.02741918],
         [ 0.51532173,  0.36917076,  0.2740527 ],
         [ 0.507323  ,  0.30632663,  0.48969314],
         ...,
         [ 0.91772085,  0.8694348 ,  0.27336812],
         [-0.07765196,  0.8874428 , -0.36993268],
         [ 0.27321702,  0.7862604 ,  0.7157703 ]],

        ...,

        [[ 0.3052201 , -0.2849969 ,  0.02883673],
         [ 0.9027616 , -0.62882227,  0.14880867],
         [ 0.80590624,  0.5455357 ,  0.01459164],
         ...,
         [ 0.9555146 ,  0.9661299 , -0.12157329],
         [ 0.11750849,  0.6591238 , -0.8975571 ],
         [ 0.86323035,  0.87395483,  0.09352628]],

        [[ 0.3224021 , -0.704681  , -0.27405554],
         [ 0.31818643,  0.50933903,  0.06042401],
         [ 0.2807724 ,  0.5371993 , -0.17293894],
         ...,
         [ 0.1832646 ,  0.9845777 , -0.65311337],
         [-0.2699855 ,  0.4876237 ,  0.55397874],
         [ 0.3970491 ,  0.93286735, -0.18568991]],

        [[ 0.2721977 , -0.23631458, -0.2288712 ],
         [-0.00543356,  0.46390823, -0.67486215],
         [ 0.4907456 ,  0.50095206, -0.65681934],
         ...,
         [ 0.18105246,  0.883262  , -0.7136467 ],
         [-0.70929563,  0.91674685,  0.01274434],
         [ 0.55528104,  0.6824069 , -0.77079993]]],


       [[[ 0.29154435,  0.03771598,  0.38924667],
         [ 0.16212763,  0.52359515,  0.2615225 ],
         [ 0.7452874 ,  0.18767999,  0.1486834 ],
         ...,
         [ 0.73134184,  0.14331108,  0.974814  ],
         [ 0.70718825, -0.53215635,  0.3624317 ],
         [-0.22797051,  0.03587913,  0.5811585 ]],

        [[ 0.5703225 , -0.23076014,  0.21538258],
         [ 0.40888655, -0.06397503, -0.04663591],
         [ 0.76388633,  0.30449575,  0.51072127],
         ...,
         [ 0.6275543 ,  0.6621061 ,  0.14226295],
         [ 0.16916338,  0.78098434,  0.7759447 ],
         [-0.38196898,  0.29494008,  0.6400463 ]],

        [[ 0.35482153,  0.2507629 ,  0.10497419],
         [ 0.5569031 ,  0.16839586, -0.54338104],
         [ 0.5724411 ,  0.43847224,  0.06045865],
         ...,
         [ 0.9120543 ,  0.7269112 , -0.07969406],
         [ 0.47315556,  0.89779645,  0.2635537 ],
         [ 0.3811015 ,  0.5823222 ,  0.91262597]],

        ...,

        [[ 0.43041262,  0.49742815, -0.3106841 ],
         [ 0.8152806 , -0.3664783 ,  0.6685815 ],
         [ 0.897424  ,  0.3917484 ,  0.31995434],
         ...,
         [ 0.9984833 ,  0.997961  ,  0.46868393],
         [ 0.9185591 ,  0.53729343, -0.6477249 ],
         [ 0.68503755,  0.26151663, -0.8242519 ]],

        [[ 0.18841839, -0.01515698,  0.26325986],
         [-0.02830816,  0.60306543,  0.3835285 ],
         [ 0.14363052,  0.6274838 , -0.3897818 ],
         ...,
         [-0.43897757,  0.97890043,  0.27643386],
         [ 0.12157195,  0.93375087, -0.4293472 ],
         [ 0.60540456,  0.77511334,  0.16037741]],

        [[ 0.50782406, -0.0038881 ,  0.00762904],
         [-0.04364663,  0.15338175, -0.10123171],
         [ 0.2676707 ,  0.70861197, -0.46737352],
         ...,
         [ 0.8645493 ,  0.722528  , -0.60077596],
         [-0.6732556 ,  0.5519895 , -0.6607604 ],
         [ 0.06130024,  0.156236  , -0.6213455 ]]]], dtype=float32)]. All inputs to the layer should be tensors.

In [None]:
gradients = K.gradients(y_pred, averaged_samples)[0]
gradients = compute_gradients(y_pred, [averaged_samples])[0]

In [None]:
gan.generator.save('./saved_model/wgangp32_gen_model.h5')
gan.discriminator.save('./saved_model/wgangp32_critic_model.h5')
# gan.combined.save('./saved_model/dcgan_wloss4_combined.h5')

In [20]:
for layer in wgan.critic.layers[1].layers:
    print(layer.name)

conv2d_34
leaky_re_lu_61
conv2d_35
leaky_re_lu_62
conv2d_36
leaky_re_lu_63
conv2d_37
leaky_re_lu_64
conv2d_38
leaky_re_lu_65
conv2d_39
leaky_re_lu_66
conv2d_40
leaky_re_lu_67
conv2d_41
leaky_re_lu_68
conv2d_42
leaky_re_lu_69
conv2d_43
leaky_re_lu_70
conv2d_44
flatten_4


In [None]:
range(int(math.log(self.img_rows / self.input_rows, 2)))

In [33]:
list = [i for i in range(1, int(math.log(64 / 4, 2)) * 2, 2)]
print(list)

[1, 3, 5, 7]


In [34]:
list = [-i for i in range(int(math.log(64 / 4, 2)) * 2, 0, -2)]
print(list)

[-8, -6, -4, -2]


In [None]:
[i for i in range(1, int(math.log(self.img_rows / self.input_rows, 2)) * 2, 2)]
[-i for i in range(int(math.log(self.img_rows / self.input_rows, 2)) * 2, 0, -2)]