In [1]:
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras.utils import np_utils
import tensorflow as tf
from tensorflow.compat.v1.keras import backend as K

import matplotlib.pyplot as plt
import os
import cv2
import numpy as np


In [2]:
np.random.seed(0)
np.random.RandomState(0)
tf.compat.v1.set_random_seed(0)

config = tf.compat.v1.ConfigProto(gpu_options=tf.compat.v1.GPUOptions(allow_growth=True))
session = tf.compat.v1.Session(config=config)
K.set_session(session)

# root_dir = "/home/takusub/PycharmProjects/Samples/dcgan/kill_me_baby_datasets/"
#keras_dcgan.pyが保存されているディレクトリのフルパス
root_dir = "/Users/user/Desktop/IMG/"
input_img_dir = "icon_resize"
save_dir = "icon_dcgan_v3/"


In [3]:
class DCGAN():
    def __init__(self):
        
        self.class_names = os.listdir(root_dir)
        
        self.shape = (128, 128, 3)
        self.z_dim = 100
        
        optimizer = Adam(lr=0.0002, beta_1=0.5)
        
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
        
        self.generator = self.build_generator()
        # self.generator.compile(loss='binary_crossentropy', optimizer=optimizer)
        
        z = Input(shape=(self.z_dim,))
        img = self.generator(z)
        
        self.discriminator.trainable = False
        
        valid = self.discriminator(img)
        
        self.combined = Model(z, valid)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
    
    def build_generator(self):
        noise_shape = (self.z_dim,)
        model = Sequential()
        
        model.add(Dense(128 * 32 * 32, activation="relu", input_shape=noise_shape))
        model.add(Reshape((32, 32, 128)))
        model.add(BatchNormalization(momentum=0.8))
        model.add(UpSampling2D())
        model.add(Conv2D(128, kernel_size=3, padding="same"))
        model.add(Activation("relu"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(UpSampling2D())
        model.add(Conv2D(64, kernel_size=3, padding="same"))
        model.add(Activation("relu"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Conv2D(3, kernel_size=3, padding="same"))
        model.add(Activation("tanh"))
        
        model.summary()
        
        noise = Input(shape=noise_shape)
        img = model(noise)
        
        return Model(noise, img)
    
    def build_discriminator(self):
        img_shape = self.shape
        model = Sequential()
        
        model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=img_shape, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
        model.add(ZeroPadding2D(padding=((0, 1), (0, 1))))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        
        model.add(Flatten())
        model.add(Dense(1, activation='sigmoid'))
        
        model.summary()
        
        img = Input(shape=img_shape)
        validity = model(img)
        
        return Model(img, validity)
    
    def build_combined(self):
        self.discriminator.trainable = False
        model = Sequential([self.generator, self.discriminator])
        
        return model
    
    def train(self, iterations, batch_size=128, save_interval=50, model_interval=10000, check_noise=None, r=5, c=5):
        
        X_train, labels = self.load_imgs()
        
        half_batch = int(batch_size / 2)
        
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5

        for iteration in range(iterations):
            
            # ------------------
            # Training Discriminator
            # -----------------
            idx = np.random.randint(0, X_train.shape[0], half_batch)
            
            imgs = X_train[idx]
            
            noise = np.random.uniform(-1, 1, (half_batch, self.z_dim))
            
            gen_imgs = self.generator.predict(noise)
            
            d_loss_real = self.discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))
            
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
            
            # -----------------
            # Training Generator
            # -----------------
            
            noise = np.random.uniform(-1, 1, (batch_size, self.z_dim))
            
            g_loss = self.combined.train_on_batch(noise, np.ones((batch_size, 1)))
            
            print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (iteration, d_loss[0], 100 * d_loss[1], g_loss))
            
            if iteration % save_interval == 0:
                self.save_imgs(iteration, check_noise, r, c)
                start = np.expand_dims(check_noise[0], axis=0)
                end = np.expand_dims(check_noise[1], axis=0)
                resultImage = self.visualizeInterpolation(start=start, end=end)
                # cv2.imwrite("images/latent/" + "latent_{}.png".format(iteration), resultImage)
                cv2.imwrite(save_dir + "latent_{}.png".format(iteration), resultImage)
                if iteration % model_interval == 0:
                    # self.generator.save("ganmodels/dcgan-{}-iter.h5".format(iteration))
                    self.generator.save("mb_dcgan-{}-iter.h5".format(iteration))

    def save_imgs(self, iteration, check_noise, r, c):
        noise = check_noise
        gen_imgs = self.generator.predict(noise)
        
        # 0-1 rescale
        gen_imgs = 0.5 * gen_imgs + 0.5
        
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i, j].imshow(gen_imgs[cnt, :, :, :])
                axs[i, j].axis('off')
                cnt += 1
        fig.savefig(save_dir + '%d.png' % iteration)
        # fig.savefig('images/gen_imgs/kill_me_%d.png' % iteration)
        
        plt.close()

    def load_imgs(self):
    
        img_paths = []
        labels = []
        images = []
    # for cl_name in self.class_names:
    #     img_names = os.listdir(os.path.join(root_dir, cl_name))
    #     for img_name in img_names:
    #         img_paths.append(os.path.abspath(os.path.join(root_dir, cl_name, img_name)))
    #         hot_cl_name = self.get_class_one_hot(cl_name)
    #         labels.append(hot_cl_name)
    
    #print(input_img_dir)
    #print(self.class_names)
    
        for cl_name in self.class_names:
            if cl_name == input_img_dir:
                img_names = os.listdir(os.path.join(root_dir, cl_name))



                for img_name in img_names:
                    img_paths.append(os.path.abspath(os.path.join(root_dir, cl_name, img_name)))
                    hot_cl_name = self.get_class_one_hot(cl_name)
                    labels.append(hot_cl_name)
    
        for img_path in img_paths:
            img = cv2.imread(img_path)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            images.append(img)

        images = np.array(images)
        
        return (np.array(images), np.array(labels))

    def get_class_one_hot(self, class_str):
        label_encoded = self.class_names.index(class_str)
    
        label_hot = np_utils.to_categorical(label_encoded, len(self.class_names))
        label_hot = label_hot
        
        return label_hot
    
    def visualizeInterpolation(self, start, end, save=True, nbSteps=10):
        print("Generating interpolations...")
        
        steps = nbSteps
        latentStart = start
        latentEnd = end
        
        startImg = self.generator.predict(latentStart)
        endImg = self.generator.predict(latentEnd)
        
        vectors = []
        
        alphaValues = np.linspace(0, 1, steps)
        for alpha in alphaValues:
            vector = latentStart * (1 - alpha) + latentEnd * alpha
            vectors.append(vector)
        
        vectors = np.array(vectors)
        
        resultLatent = None
        resultImage = None
        
        for i, vec in enumerate(vectors):
            gen_img = np.squeeze(self.generator.predict(vec), axis=0)
            gen_img = (0.5 * gen_img + 0.5) * 255
            interpolatedImage = cv2.cvtColor(gen_img, cv2.COLOR_RGB2BGR)
            interpolatedImage = interpolatedImage.astype(np.uint8)
            resultImage = interpolatedImage if resultImage is None else np.hstack([resultImage, interpolatedImage])
            
        return resultImage


In [4]:
if __name__ == '__main__':
    dcgan = DCGAN()
    r, c = 5, 5
    check_noise = np.random.uniform(-1, 1, (r * c, 100))
    dcgan.train(
        iterations=1000,
        batch_size=8,
        # save_interval=1000,
        save_interval=10, ### epoch回数が50の倍数になったときに、generator生成画像を保存
        model_interval=100,
        check_noise=check_noise,
        r=r,
        c=c
    )


Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 64, 64, 32)        896       
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 64, 64, 32)        0         
_________________________________________________________________
dropout (Dropout)            (None, 64, 64, 32)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 32, 32, 64)        18496     
_________________________________________________________________
zero_padding2d (ZeroPadding2 (None, 33, 33, 64)        0         
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 33, 33, 64)        0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 33, 33, 64)        0

Generating interpolations...
61 [D loss: 0.012725, acc.: 100.00%] [G loss: 0.150450]
62 [D loss: 0.028491, acc.: 100.00%] [G loss: 0.201613]
63 [D loss: 0.010048, acc.: 100.00%] [G loss: 0.310643]
64 [D loss: 0.024513, acc.: 100.00%] [G loss: 0.163032]
65 [D loss: 0.000997, acc.: 100.00%] [G loss: 0.443419]
66 [D loss: 0.005283, acc.: 100.00%] [G loss: 0.268486]
67 [D loss: 0.008571, acc.: 100.00%] [G loss: 0.274589]
68 [D loss: 0.005664, acc.: 100.00%] [G loss: 0.172754]
69 [D loss: 0.015604, acc.: 100.00%] [G loss: 0.146555]
70 [D loss: 0.018694, acc.: 100.00%] [G loss: 0.526618]
Generating interpolations...
71 [D loss: 0.003451, acc.: 100.00%] [G loss: 0.379915]
72 [D loss: 0.022867, acc.: 100.00%] [G loss: 0.604242]
73 [D loss: 0.001319, acc.: 100.00%] [G loss: 0.357857]
74 [D loss: 0.004619, acc.: 100.00%] [G loss: 0.308098]
75 [D loss: 0.013851, acc.: 100.00%] [G loss: 0.277598]
76 [D loss: 0.001559, acc.: 100.00%] [G loss: 0.392924]
77 [D loss: 0.002923, acc.: 100.00%] [G loss: 

199 [D loss: 0.000147, acc.: 100.00%] [G loss: 0.025697]
200 [D loss: 0.000508, acc.: 100.00%] [G loss: 0.035377]
Generating interpolations...
201 [D loss: 0.000396, acc.: 100.00%] [G loss: 0.070394]
202 [D loss: 0.000196, acc.: 100.00%] [G loss: 0.048351]
203 [D loss: 0.000382, acc.: 100.00%] [G loss: 0.086231]
204 [D loss: 0.000707, acc.: 100.00%] [G loss: 0.127567]
205 [D loss: 0.000133, acc.: 100.00%] [G loss: 0.015563]
206 [D loss: 0.000560, acc.: 100.00%] [G loss: 0.053881]
207 [D loss: 0.000306, acc.: 100.00%] [G loss: 0.078410]
208 [D loss: 0.005455, acc.: 100.00%] [G loss: 0.093253]
209 [D loss: 0.000305, acc.: 100.00%] [G loss: 0.060911]
210 [D loss: 0.000345, acc.: 100.00%] [G loss: 0.012326]
Generating interpolations...
211 [D loss: 0.000349, acc.: 100.00%] [G loss: 0.010952]
212 [D loss: 0.000233, acc.: 100.00%] [G loss: 0.009987]
213 [D loss: 0.000219, acc.: 100.00%] [G loss: 0.015304]
214 [D loss: 0.000184, acc.: 100.00%] [G loss: 0.013159]
215 [D loss: 0.000167, acc.: 1

336 [D loss: 0.000040, acc.: 100.00%] [G loss: 0.007130]
337 [D loss: 0.000116, acc.: 100.00%] [G loss: 0.010353]
338 [D loss: 0.000088, acc.: 100.00%] [G loss: 0.003627]
339 [D loss: 0.000151, acc.: 100.00%] [G loss: 0.021026]
340 [D loss: 0.000095, acc.: 100.00%] [G loss: 0.009341]
Generating interpolations...
341 [D loss: 0.000235, acc.: 100.00%] [G loss: 0.002752]
342 [D loss: 0.000057, acc.: 100.00%] [G loss: 0.007987]
343 [D loss: 0.000012, acc.: 100.00%] [G loss: 0.004207]
344 [D loss: 0.000113, acc.: 100.00%] [G loss: 0.003067]
345 [D loss: 0.000084, acc.: 100.00%] [G loss: 0.002763]
346 [D loss: 0.000204, acc.: 100.00%] [G loss: 0.002971]
347 [D loss: 0.000046, acc.: 100.00%] [G loss: 0.004176]
348 [D loss: 0.000040, acc.: 100.00%] [G loss: 0.005137]
349 [D loss: 0.000124, acc.: 100.00%] [G loss: 0.004239]
350 [D loss: 0.000082, acc.: 100.00%] [G loss: 0.009404]
Generating interpolations...
351 [D loss: 0.000169, acc.: 100.00%] [G loss: 0.014995]
352 [D loss: 0.000160, acc.: 1

473 [D loss: 0.000028, acc.: 100.00%] [G loss: 0.019480]
474 [D loss: 0.000315, acc.: 100.00%] [G loss: 0.007196]
475 [D loss: 0.001592, acc.: 100.00%] [G loss: 0.098767]
476 [D loss: 0.000016, acc.: 100.00%] [G loss: 0.027015]
477 [D loss: 0.000024, acc.: 100.00%] [G loss: 0.004385]
478 [D loss: 0.000049, acc.: 100.00%] [G loss: 0.003402]
479 [D loss: 0.000036, acc.: 100.00%] [G loss: 0.004511]
480 [D loss: 0.000047, acc.: 100.00%] [G loss: 0.009543]
Generating interpolations...
481 [D loss: 0.000033, acc.: 100.00%] [G loss: 0.003508]
482 [D loss: 0.000027, acc.: 100.00%] [G loss: 0.001623]
483 [D loss: 0.000169, acc.: 100.00%] [G loss: 0.005642]
484 [D loss: 0.000054, acc.: 100.00%] [G loss: 0.009183]
485 [D loss: 0.000049, acc.: 100.00%] [G loss: 0.006498]
486 [D loss: 0.000059, acc.: 100.00%] [G loss: 0.038807]
487 [D loss: 0.000013, acc.: 100.00%] [G loss: 0.006871]
488 [D loss: 0.000061, acc.: 100.00%] [G loss: 0.009102]
489 [D loss: 0.000080, acc.: 100.00%] [G loss: 0.061971]
49

Generating interpolations...
611 [D loss: 0.000083, acc.: 100.00%] [G loss: 0.009937]
612 [D loss: 0.000021, acc.: 100.00%] [G loss: 0.010392]
613 [D loss: 0.000032, acc.: 100.00%] [G loss: 0.008813]
614 [D loss: 0.000028, acc.: 100.00%] [G loss: 0.001813]
615 [D loss: 0.000026, acc.: 100.00%] [G loss: 0.001184]
616 [D loss: 0.000078, acc.: 100.00%] [G loss: 0.011522]
617 [D loss: 0.000010, acc.: 100.00%] [G loss: 0.003045]
618 [D loss: 0.000108, acc.: 100.00%] [G loss: 0.002435]
619 [D loss: 0.000019, acc.: 100.00%] [G loss: 0.001548]
620 [D loss: 0.000052, acc.: 100.00%] [G loss: 0.002504]
Generating interpolations...
621 [D loss: 0.000082, acc.: 100.00%] [G loss: 0.016257]
622 [D loss: 0.000019, acc.: 100.00%] [G loss: 0.001175]
623 [D loss: 0.000030, acc.: 100.00%] [G loss: 0.003895]
624 [D loss: 0.000076, acc.: 100.00%] [G loss: 0.007521]
625 [D loss: 0.000075, acc.: 100.00%] [G loss: 0.004658]
626 [D loss: 0.000098, acc.: 100.00%] [G loss: 0.010347]
627 [D loss: 0.000015, acc.: 1

748 [D loss: 0.000054, acc.: 100.00%] [G loss: 0.002625]
749 [D loss: 0.000118, acc.: 100.00%] [G loss: 0.024595]
750 [D loss: 0.000032, acc.: 100.00%] [G loss: 0.003439]
Generating interpolations...
751 [D loss: 0.000007, acc.: 100.00%] [G loss: 0.012141]
752 [D loss: 0.000018, acc.: 100.00%] [G loss: 0.003557]
753 [D loss: 0.000025, acc.: 100.00%] [G loss: 0.003157]
754 [D loss: 0.000006, acc.: 100.00%] [G loss: 0.001329]
755 [D loss: 0.000033, acc.: 100.00%] [G loss: 0.002481]
756 [D loss: 0.000016, acc.: 100.00%] [G loss: 0.004417]
757 [D loss: 0.000015, acc.: 100.00%] [G loss: 0.003708]
758 [D loss: 0.000003, acc.: 100.00%] [G loss: 0.001067]
759 [D loss: 0.000011, acc.: 100.00%] [G loss: 0.001054]
760 [D loss: 0.000030, acc.: 100.00%] [G loss: 0.001454]
Generating interpolations...
761 [D loss: 0.000008, acc.: 100.00%] [G loss: 0.002406]
762 [D loss: 0.000007, acc.: 100.00%] [G loss: 0.008231]
763 [D loss: 0.000008, acc.: 100.00%] [G loss: 0.002501]
764 [D loss: 0.000029, acc.: 1

885 [D loss: 0.000018, acc.: 100.00%] [G loss: 0.001772]
886 [D loss: 0.000013, acc.: 100.00%] [G loss: 0.003783]
887 [D loss: 0.000070, acc.: 100.00%] [G loss: 0.004438]
888 [D loss: 0.000024, acc.: 100.00%] [G loss: 0.018611]
889 [D loss: 0.000003, acc.: 100.00%] [G loss: 0.003887]
890 [D loss: 0.000010, acc.: 100.00%] [G loss: 0.008594]
Generating interpolations...
891 [D loss: 0.000009, acc.: 100.00%] [G loss: 0.002890]
892 [D loss: 0.000009, acc.: 100.00%] [G loss: 0.002576]
893 [D loss: 0.000068, acc.: 100.00%] [G loss: 0.007496]
894 [D loss: 0.000007, acc.: 100.00%] [G loss: 0.004566]
895 [D loss: 0.000010, acc.: 100.00%] [G loss: 0.004003]
896 [D loss: 0.000004, acc.: 100.00%] [G loss: 0.002943]
897 [D loss: 0.000015, acc.: 100.00%] [G loss: 0.003341]
898 [D loss: 0.000010, acc.: 100.00%] [G loss: 0.002949]
899 [D loss: 0.000003, acc.: 100.00%] [G loss: 0.004385]
900 [D loss: 0.000049, acc.: 100.00%] [G loss: 0.008132]
Generating interpolations...
901 [D loss: 0.000061, acc.: 1