In [1]:
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras.utils import np_utils
import tensorflow as tf
from tensorflow.compat.v1.keras import backend as K

import matplotlib.pyplot as plt
import os
import cv2
import numpy as np


In [2]:
config = tf.compat.v1.ConfigProto(gpu_options=tf.compat.v1.GPUOptions(allow_growth=True))
session = tf.compat.v1.Session(config=config)
K.set_session(session)

#dcgan_v3.pyが保存されているディレクトリのフルパス
root_dir = "/Users/user/Desktop/m31_expt/m31_datasets/"
#参照する画像フォルダ
input_img_dir = "human_w1_resize"
#出力する画像フォルダ
save_dir = "dcgan_v3_human_w1_img/"


In [3]:
class DCGAN():
    def __init__(self):

        self.class_names = os.listdir(root_dir)

        self.shape = (128, 128, 3)
        self.z_dim = 100

        optimizer = Adam(lr=0.0002, beta_1=0.5)

        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])

        self.generator = self.build_generator()
        # self.generator.compile(loss='binary_crossentropy', optimizer=optimizer)

        z = Input(shape=(self.z_dim,))
        img = self.generator(z)

        self.discriminator.trainable = False

        valid = self.discriminator(img)

        self.combined = Model(z, valid)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

    def build_generator(self):
        noise_shape = (self.z_dim,)
        model = Sequential()

        model.add(Dense(128 * 32 * 32, activation="relu", input_shape=noise_shape))
        model.add(Reshape((32, 32, 128)))
        model.add(BatchNormalization(momentum=0.8))
        model.add(UpSampling2D())
        model.add(Conv2D(128, kernel_size=3, padding="same"))
        model.add(Activation("relu"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(UpSampling2D())
        model.add(Conv2D(64, kernel_size=3, padding="same"))
        model.add(Activation("relu"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Conv2D(3, kernel_size=3, padding="same"))
        model.add(Activation("tanh"))

        model.summary()

        noise = Input(shape=noise_shape)
        img = model(noise)

        return Model(noise, img)

    def build_discriminator(self):
        img_shape = self.shape
        model = Sequential()

        model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=img_shape, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
        model.add(ZeroPadding2D(padding=((0, 1), (0, 1))))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))

        model.add(Flatten())
        model.add(Dense(1, activation='sigmoid'))

        model.summary()

        img = Input(shape=img_shape)
        validity = model(img)

        return Model(img, validity)

    def build_combined(self):
        self.discriminator.trainable = False
        model = Sequential([self.generator, self.discriminator])

        return model

    def train(self, iterations, batch_size=128, save_interval=50, model_interval=10000, check_noise=None, r=5, c=5):

        X_train, labels = self.load_imgs()

        half_batch = int(batch_size / 2)

        X_train = (X_train.astype(np.float32) - 127.5) / 127.5

        for iteration in range(iterations):

            # ------------------
            # Training Discriminator
            # -----------------
            idx = np.random.randint(0, X_train.shape[0], half_batch)

            imgs = X_train[idx]

            noise = np.random.uniform(-1, 1, (half_batch, self.z_dim))

            gen_imgs = self.generator.predict(noise)

            d_loss_real = self.discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))

            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # -----------------
            # Training Generator
            # -----------------

            noise = np.random.uniform(-1, 1, (batch_size, self.z_dim))

            g_loss = self.combined.train_on_batch(noise, np.ones((batch_size, 1)))

            print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (iteration, d_loss[0], 100 * d_loss[1], g_loss))

            if iteration % save_interval == 0:
                self.save_imgs(iteration, check_noise, r, c)
                start = np.expand_dims(check_noise[0], axis=0)
                end = np.expand_dims(check_noise[1], axis=0)
                resultImage = self.visualizeInterpolation(start=start, end=end)
                cv2.imwrite(save_dir + "latent_{}.png".format(iteration), resultImage)
                if iteration % model_interval == 0:
                    self.generator.save("mb_dcgan-{}-iter.h5".format(iteration))

    def save_imgs(self, iteration, check_noise, r, c):
        noise = check_noise
        gen_imgs = self.generator.predict(noise)

        # 0-1 rescale
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i, j].imshow(gen_imgs[cnt, :, :, :])
                axs[i, j].axis('off')
                cnt += 1
        fig.savefig(save_dir + '%d.png' % iteration)
        # fig.savefig('images/gen_imgs/kill_me_%d.png' % iteration)

        plt.close()

    def load_imgs(self):

        img_paths = []
        labels = []
        images = []

    #print(input_img_dir)
    #print(self.class_names)

        for cl_name in self.class_names:
            if cl_name == input_img_dir:
                img_names = os.listdir(os.path.join(root_dir, cl_name))



                for img_name in img_names:
                    img_paths.append(os.path.abspath(os.path.join(root_dir, cl_name, img_name)))
                    hot_cl_name = self.get_class_one_hot(cl_name)
                    labels.append(hot_cl_name)

        for img_path in img_paths:
            img = cv2.imread(img_path)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            images.append(img)

        images = np.array(images)

        return (np.array(images), np.array(labels))

    def get_class_one_hot(self, class_str):
        label_encoded = self.class_names.index(class_str)

        label_hot = np_utils.to_categorical(label_encoded, len(self.class_names))
        label_hot = label_hot

        return label_hot

    def visualizeInterpolation(self, start, end, save=True, nbSteps=10):
        print("Generating interpolations...")

        steps = nbSteps
        latentStart = start
        latentEnd = end

        startImg = self.generator.predict(latentStart)
        endImg = self.generator.predict(latentEnd)

        vectors = []

        alphaValues = np.linspace(0, 1, steps)
        for alpha in alphaValues:
            vector = latentStart * (1 - alpha) + latentEnd * alpha
            vectors.append(vector)

        vectors = np.array(vectors)

        resultLatent = None
        resultImage = None

        for i, vec in enumerate(vectors):
            gen_img = np.squeeze(self.generator.predict(vec), axis=0)
            gen_img = (0.5 * gen_img + 0.5) * 255
            interpolatedImage = cv2.cvtColor(gen_img, cv2.COLOR_RGB2BGR)
            interpolatedImage = interpolatedImage.astype(np.uint8)
            resultImage = interpolatedImage if resultImage is None else np.hstack([resultImage, interpolatedImage])

        return resultImage



In [4]:
if __name__ == '__main__':
    dcgan = DCGAN()
    r, c = 5, 5
    check_noise = np.random.uniform(-1, 1, (r * c, 100))
    dcgan.train(
        iterations=200000,
        batch_size=32,
        # save_interval=1000,
        save_interval=50, ### epoch回数が50の倍数になったときに、generator生成画像を保存
        model_interval=5000,
        check_noise=check_noise,
        r=r,
        c=c
    )


Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 64, 64, 32)        896       
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 64, 64, 32)        0         
_________________________________________________________________
dropout (Dropout)            (None, 64, 64, 32)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 32, 32, 64)        18496     
_________________________________________________________________
zero_padding2d (ZeroPadding2 (None, 33, 33, 64)        0         
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 33, 33, 64)        0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 33, 33, 64)        0

63 [D loss: 0.017330, acc.: 100.00%] [G loss: 0.764356]
64 [D loss: 0.015066, acc.: 100.00%] [G loss: 0.798455]
65 [D loss: 0.039729, acc.: 100.00%] [G loss: 0.485776]
66 [D loss: 0.027144, acc.: 100.00%] [G loss: 0.483694]
67 [D loss: 0.043833, acc.: 96.88%] [G loss: 0.915500]
68 [D loss: 0.009299, acc.: 100.00%] [G loss: 1.804410]
69 [D loss: 0.062964, acc.: 100.00%] [G loss: 0.341795]
70 [D loss: 0.018482, acc.: 100.00%] [G loss: 0.524709]
71 [D loss: 0.017857, acc.: 100.00%] [G loss: 0.494170]
72 [D loss: 0.002797, acc.: 100.00%] [G loss: 0.900965]
73 [D loss: 0.006360, acc.: 100.00%] [G loss: 0.696527]
74 [D loss: 0.028181, acc.: 100.00%] [G loss: 0.320790]
75 [D loss: 0.001504, acc.: 100.00%] [G loss: 0.290117]
76 [D loss: 0.009201, acc.: 100.00%] [G loss: 0.201150]
77 [D loss: 0.007366, acc.: 100.00%] [G loss: 0.341701]
78 [D loss: 0.011624, acc.: 100.00%] [G loss: 0.221798]
79 [D loss: 0.008285, acc.: 100.00%] [G loss: 0.460667]
80 [D loss: 0.021151, acc.: 100.00%] [G loss: 0.1

207 [D loss: 0.001701, acc.: 100.00%] [G loss: 0.026423]
208 [D loss: 0.000996, acc.: 100.00%] [G loss: 0.073286]
209 [D loss: 0.006339, acc.: 100.00%] [G loss: 0.038605]
210 [D loss: 0.008079, acc.: 100.00%] [G loss: 0.064572]
211 [D loss: 0.000256, acc.: 100.00%] [G loss: 0.108567]
212 [D loss: 0.000872, acc.: 100.00%] [G loss: 0.044599]
213 [D loss: 0.000320, acc.: 100.00%] [G loss: 0.084168]
214 [D loss: 0.000324, acc.: 100.00%] [G loss: 0.015773]
215 [D loss: 0.000243, acc.: 100.00%] [G loss: 0.031833]
216 [D loss: 0.002416, acc.: 100.00%] [G loss: 0.045668]
217 [D loss: 0.000204, acc.: 100.00%] [G loss: 0.045484]
218 [D loss: 0.000644, acc.: 100.00%] [G loss: 0.045510]
219 [D loss: 0.001267, acc.: 100.00%] [G loss: 0.024425]
220 [D loss: 0.003565, acc.: 100.00%] [G loss: 0.092482]
221 [D loss: 0.000746, acc.: 100.00%] [G loss: 0.100813]
222 [D loss: 0.000813, acc.: 100.00%] [G loss: 0.035519]
223 [D loss: 0.000381, acc.: 100.00%] [G loss: 0.043020]
224 [D loss: 0.001048, acc.: 10

350 [D loss: 0.000295, acc.: 100.00%] [G loss: 0.008667]
Generating interpolations...
351 [D loss: 0.000417, acc.: 100.00%] [G loss: 0.004841]
352 [D loss: 0.000226, acc.: 100.00%] [G loss: 0.012774]
353 [D loss: 0.000214, acc.: 100.00%] [G loss: 0.002594]
354 [D loss: 0.000587, acc.: 100.00%] [G loss: 0.018602]
355 [D loss: 0.000194, acc.: 100.00%] [G loss: 0.015714]
356 [D loss: 0.000681, acc.: 100.00%] [G loss: 0.006267]
357 [D loss: 0.000308, acc.: 100.00%] [G loss: 0.006781]
358 [D loss: 0.000266, acc.: 100.00%] [G loss: 0.011611]
359 [D loss: 0.000569, acc.: 100.00%] [G loss: 0.019251]
360 [D loss: 0.000191, acc.: 100.00%] [G loss: 0.005472]
361 [D loss: 0.000181, acc.: 100.00%] [G loss: 0.018310]
362 [D loss: 0.000438, acc.: 100.00%] [G loss: 0.012773]
363 [D loss: 0.000401, acc.: 100.00%] [G loss: 0.011930]
364 [D loss: 0.000578, acc.: 100.00%] [G loss: 0.008992]
365 [D loss: 0.000563, acc.: 100.00%] [G loss: 0.010943]
366 [D loss: 0.000681, acc.: 100.00%] [G loss: 0.013031]
36

493 [D loss: 0.000667, acc.: 100.00%] [G loss: 0.014207]
494 [D loss: 0.000067, acc.: 100.00%] [G loss: 0.006085]
495 [D loss: 0.000104, acc.: 100.00%] [G loss: 0.003195]
496 [D loss: 0.000777, acc.: 100.00%] [G loss: 0.013100]
497 [D loss: 0.000215, acc.: 100.00%] [G loss: 0.013206]
498 [D loss: 0.000034, acc.: 100.00%] [G loss: 0.012449]
499 [D loss: 0.000093, acc.: 100.00%] [G loss: 0.014420]
500 [D loss: 0.000127, acc.: 100.00%] [G loss: 0.007163]
Generating interpolations...
501 [D loss: 0.000119, acc.: 100.00%] [G loss: 0.006650]
502 [D loss: 0.000200, acc.: 100.00%] [G loss: 0.003754]
503 [D loss: 0.000078, acc.: 100.00%] [G loss: 0.011326]
504 [D loss: 0.000138, acc.: 100.00%] [G loss: 0.008070]
505 [D loss: 0.000716, acc.: 100.00%] [G loss: 0.008656]
506 [D loss: 0.000110, acc.: 100.00%] [G loss: 0.007202]
507 [D loss: 0.000062, acc.: 100.00%] [G loss: 0.013572]
508 [D loss: 0.000181, acc.: 100.00%] [G loss: 0.006720]
509 [D loss: 0.000070, acc.: 100.00%] [G loss: 0.020189]
51

636 [D loss: 0.000115, acc.: 100.00%] [G loss: 0.001971]
637 [D loss: 0.000110, acc.: 100.00%] [G loss: 0.002198]
638 [D loss: 0.000351, acc.: 100.00%] [G loss: 0.008999]
639 [D loss: 0.000228, acc.: 100.00%] [G loss: 0.006320]
640 [D loss: 0.000049, acc.: 100.00%] [G loss: 0.007282]
641 [D loss: 0.000018, acc.: 100.00%] [G loss: 0.008530]
642 [D loss: 0.000091, acc.: 100.00%] [G loss: 0.003861]
643 [D loss: 0.000041, acc.: 100.00%] [G loss: 0.005135]
644 [D loss: 0.000068, acc.: 100.00%] [G loss: 0.005642]
645 [D loss: 0.000126, acc.: 100.00%] [G loss: 0.013575]
646 [D loss: 0.000041, acc.: 100.00%] [G loss: 0.005654]
647 [D loss: 0.000021, acc.: 100.00%] [G loss: 0.006136]
648 [D loss: 0.000104, acc.: 100.00%] [G loss: 0.017105]
649 [D loss: 0.000035, acc.: 100.00%] [G loss: 0.011238]
650 [D loss: 0.000052, acc.: 100.00%] [G loss: 0.010027]
Generating interpolations...
651 [D loss: 0.000140, acc.: 100.00%] [G loss: 0.012990]
652 [D loss: 0.000036, acc.: 100.00%] [G loss: 0.006859]
65

779 [D loss: 0.000057, acc.: 100.00%] [G loss: 0.005312]
780 [D loss: 0.000406, acc.: 100.00%] [G loss: 0.015357]
781 [D loss: 0.000025, acc.: 100.00%] [G loss: 0.007568]
782 [D loss: 0.000084, acc.: 100.00%] [G loss: 0.006563]
783 [D loss: 0.000045, acc.: 100.00%] [G loss: 0.010393]
784 [D loss: 0.000050, acc.: 100.00%] [G loss: 0.009012]
785 [D loss: 0.000040, acc.: 100.00%] [G loss: 0.004913]
786 [D loss: 0.000053, acc.: 100.00%] [G loss: 0.007018]
787 [D loss: 0.000041, acc.: 100.00%] [G loss: 0.005495]
788 [D loss: 0.000022, acc.: 100.00%] [G loss: 0.001459]
789 [D loss: 0.000029, acc.: 100.00%] [G loss: 0.004276]
790 [D loss: 0.000042, acc.: 100.00%] [G loss: 0.007789]
791 [D loss: 0.000171, acc.: 100.00%] [G loss: 0.004156]
792 [D loss: 0.000006, acc.: 100.00%] [G loss: 0.004664]
793 [D loss: 0.000016, acc.: 100.00%] [G loss: 0.005124]
794 [D loss: 0.000039, acc.: 100.00%] [G loss: 0.006694]
795 [D loss: 0.000028, acc.: 100.00%] [G loss: 0.004346]
796 [D loss: 0.000017, acc.: 10

922 [D loss: 0.000111, acc.: 100.00%] [G loss: 0.004447]
923 [D loss: 0.000005, acc.: 100.00%] [G loss: 0.006520]
924 [D loss: 0.000133, acc.: 100.00%] [G loss: 0.006886]
925 [D loss: 0.000025, acc.: 100.00%] [G loss: 0.001367]
926 [D loss: 0.000574, acc.: 100.00%] [G loss: 0.004090]
927 [D loss: 0.000020, acc.: 100.00%] [G loss: 0.008319]
928 [D loss: 0.000024, acc.: 100.00%] [G loss: 0.002718]
929 [D loss: 0.000035, acc.: 100.00%] [G loss: 0.004509]
930 [D loss: 0.000066, acc.: 100.00%] [G loss: 0.005441]
931 [D loss: 0.000032, acc.: 100.00%] [G loss: 0.003493]
932 [D loss: 0.000055, acc.: 100.00%] [G loss: 0.002659]
933 [D loss: 0.000059, acc.: 100.00%] [G loss: 0.002755]
934 [D loss: 0.000023, acc.: 100.00%] [G loss: 0.035641]
935 [D loss: 0.000083, acc.: 100.00%] [G loss: 0.002792]
936 [D loss: 0.000023, acc.: 100.00%] [G loss: 0.008220]
937 [D loss: 0.000038, acc.: 100.00%] [G loss: 0.001520]
938 [D loss: 0.000016, acc.: 100.00%] [G loss: 0.002085]
939 [D loss: 0.000031, acc.: 10

KeyboardInterrupt: 