* Conditional GAN (CGAN)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# カレントディレクトリの読み込みとカレントディレクトリへの移動
import sys
sys.path.append(f'/content/drive/My Drive/system/')
import os
os.chdir(f'/content/drive/My Drive/system/myanswer')

In [None]:
!pip install scipy==1.1.0



In [None]:
from __future__ import print_function, division
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
from glob import glob
import pickle
import datetime
import matplotlib.pyplot as plt
import numpy as np
import scipy
import scipy.misc

In [None]:
class DataLoader():
    def __init__(self, num_classes=10, img_res=(128, 128)):
        self.num_classes = num_classes
        self.img_res = img_res

    def load_data(self, is_testing=False):
        if os.path.exists("../pickle/foodimg128_tensor_{}classes.pickle".format(self.num_classes)) and os.path.exists("../pickle/foodimg128_label_{}classes.pickle".format(self.num_classes)):
            with open('../pickle/foodimg128_tensor_{}classes.pickle'.format(self.num_classes), 'rb') as p:
                imgs = pickle.load(p)
            with open('../pickle/foodimg128_label_{}classes.pickle'.format(self.num_classes), 'rb') as p:
                labels = pickle.load(p)
        else:
            imgs = [] # テンソル化した 画像を格納するリスト
            labels = [] # 食べ物画像のラベルを格納するリスト

            food_dict = {0: "bibimba", 1:"chahan", 2:"chikenrice", 3:"curry", 4:"ebichill", 
                        5:"gratin", 6:"gyudon", 7:"hiyachu", 8:"kaisendon", 9:"katsudon", 
                        10:"meatspa", 11:"omelet", 12:"omurice", 13:"oyakodon", 
                        14:"pilaf", 15:"pizza", 16:"ramen", 17:"rice", 18:"soba",
                        19:"steak"}

            for label, food_name in food_dict.items():
                img_pathes = glob('../figure/foodimg128/{}/*.jpg'.format(food_name))
                print("current food image {} labeled {}".format(food_name, label))
                for img_path in img_pathes:
                    each_img = self.imread(img_path)
                    if not is_testing:
                        each_img = scipy.misc.imresize(each_img, self.img_res)
                        if np.random.random() > 0.5:
                            each_img = np.fliplr(each_img)
                    else:
                        each_img = scipy.misc.imresize(each_img, self.img_res)       
                    imgs.append(each_img)
                    labels.append(label)
                print("finished generating dataset {} labeled {}".format(food_name, label))
            
            with open('../pickle/foodimg128_tensor_{}classes.pickle'.format(self.num_classes), 'wb') as p:
                pickle.dump(imgs , p)
            with open('../pickle/foodimg128_label_{}classes.pickle'.format(self.num_classes), 'wb') as p:
                pickle.dump(labels, p)
            print("finished generating all of datasets")            

        return np.array(imgs), np.array(labels)

    def imread(self, path):
        return scipy.misc.imread(path, mode='RGB').astype(np.float)

In [None]:
class CGAN():
    def __init__(self, dataset_name="mnist"):
        # Input shape
        self.img_rows = 28
        self.img_cols = 28
        self.dataset_name = dataset_name
        if self.dataset_name == "mnist":
            self.channels = 1
        else:
            self.channels = 3
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        if self.dataset_name == "mnist":
            self.num_classes = 10
        elif self.dataset_name == "foodimg":
            self.num_classes = 20
        self.latent_dim = 100

        self.data_loader = DataLoader(num_classes=self.num_classes,
                                      img_res=(self.img_rows, self.img_cols))

        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss=['binary_crossentropy'],
                                   optimizer=optimizer,
                                   metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()

        # The generator takes noise and the target label as input
        # and generates the corresponding digit of that label
        noise = Input(shape=(self.latent_dim,))
        label = Input(shape=(1,))
        img = self.generator([noise, label])

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The discriminator takes generated image as input and determines validity
        # and the label of that image
        valid = self.discriminator([img, label])

        # The combined model  (stacked generator and discriminator)
        # Trains generator to fool discriminator
        self.combined = Model([noise, label], valid)
        self.combined.compile(loss=['binary_crossentropy'],
            optimizer=optimizer)

    def build_generator(self):

        model = Sequential()

        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))

        # model.summary()

        noise = Input(shape=(self.latent_dim,))
        label = Input(shape=(1,), dtype='int32')
        label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label))

        model_input = multiply([noise, label_embedding])
        img = model(model_input)

        return Model([noise, label], img)

    def build_discriminator(self):

        model = Sequential()

        model.add(Dense(512, input_dim=np.prod(self.img_shape)))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.4))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.4))
        model.add(Dense(1, activation='sigmoid'))
        # model.summary()

        img = Input(shape=self.img_shape)
        label = Input(shape=(1,), dtype='int32')

        label_embedding = Flatten()(Embedding(self.num_classes, np.prod(self.img_shape))(label))
        flat_img = Flatten()(img)

        model_input = multiply([flat_img, label_embedding])

        validity = model(model_input)

        return Model([img, label], validity)

    def train(self, epochs, batch_size=128, sample_interval=50):

        if self.dataset_name == "mnist":
            (X_train, y_train), (_, _) = mnist.load_data()
            # Configure input
            X_train = (X_train.astype(np.float32) - 127.5) / 127.5
            X_train = np.expand_dims(X_train, axis=3)
        else:
            X_train, y_train = self.data_loader.load_data()
            X_train = (X_train.astype(np.float32) - 127.5) / 127.5
      
        y_train = y_train.reshape(-1, 1)

        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random half batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs, labels = X_train[idx], y_train[idx]

            # Sample noise as generator input
            noise = np.random.normal(0, 1, (batch_size, 100))

            # Generate a half batch of new images
            gen_imgs = self.generator.predict([noise, labels])

            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch([imgs, labels], valid)
            d_loss_fake = self.discriminator.train_on_batch([gen_imgs, labels], fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            # Condition on labels 
            sampled_labels = np.random.randint(0, self.num_classes, batch_size).reshape(-1, 1)

            # Train the generator
            g_loss = self.combined.train_on_batch([noise, sampled_labels], valid)

            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
               # Plot the progress
                print("epoch %d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch+1, d_loss[0], 100*d_loss[1], g_loss))
                self.sample_images(epoch)

    def sample_images(self, epoch):
        os.makedirs('../result/%s/cgan' % self.dataset_name, exist_ok=True)

        food_list = ["bibimba", "chahan", "chikenrice", "curry", "ebichill", 
                      "gratin", "gyudon", "hiyachu", "kaisendon", "katsudon", 
                      "meatspa", "omelet", "omurice", "oyakodon", "pilaf", "pizza", 
                      "ramen", "rice", "soba", "steak"]

        if self.dataset_name == "mnist":
            r, c = 2, 5
            noise = np.random.normal(0, 1, (r * c, 100))
        else:
            r, c = 4, 5
            noise = np.random.normal(0, 1, (r * c, 100))
        
        sampled_labels = np.arange(0, self.num_classes).reshape(-1, 1)
        gen_imgs = self.generator.predict([noise, sampled_labels])
        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c, figsize=(10, 10))
        cnt = 0
        for i in range(r):
            for j in range(c):
                if self.dataset_name == "mnist":
                    axs[i,j].imshow(gen_imgs[cnt,:,:,0], cmap='gray')
                    axs[i,j].set_title("digit : %d" % sampled_labels[cnt])
                elif self.dataset_name == "foodimg":
                    axs[i,j].imshow(gen_imgs[cnt,:,:,:])
                    axs[i,j].set_title("{}".format(food_list[cnt]))
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("../result/{}/cgan/epoch{}.png".format(self.dataset_name, epoch),
                    transparent=True, dpi=300, bbox_inches="tight", pad_inches=0.0)
        plt.close()

In [None]:
cgan = CGAN(dataset_name="foodimg")
#cgan = CGAN(dataset_name="mnist")
cgan.train(epochs=500000, batch_size=32, sample_interval=10000)

epoch 1 [D loss: 0.693727, acc.: 31.25%] [G loss: 0.676793]
epoch 10001 [D loss: 0.672554, acc.: 46.88%] [G loss: 0.707860]
epoch 20001 [D loss: 0.717961, acc.: 45.31%] [G loss: 0.805275]
epoch 30001 [D loss: 0.692158, acc.: 46.88%] [G loss: 0.851080]
epoch 40001 [D loss: 0.631641, acc.: 62.50%] [G loss: 0.808809]
epoch 50001 [D loss: 0.654302, acc.: 59.38%] [G loss: 0.854447]
epoch 60001 [D loss: 0.614031, acc.: 62.50%] [G loss: 0.910921]
epoch 70001 [D loss: 0.654013, acc.: 53.12%] [G loss: 0.961875]
epoch 80001 [D loss: 0.504565, acc.: 81.25%] [G loss: 0.895734]
epoch 90001 [D loss: 0.638280, acc.: 64.06%] [G loss: 0.989370]
epoch 100001 [D loss: 0.628344, acc.: 60.94%] [G loss: 0.975713]
epoch 110001 [D loss: 0.517963, acc.: 76.56%] [G loss: 1.074201]
epoch 120001 [D loss: 0.620235, acc.: 60.94%] [G loss: 0.937835]
epoch 130001 [D loss: 0.599234, acc.: 64.06%] [G loss: 1.104290]
epoch 140001 [D loss: 0.524435, acc.: 73.44%] [G loss: 1.224781]
epoch 150001 [D loss: 0.630569, acc.: 6

KeyboardInterrupt: ignored