* WGAN使ってグラタンに似た偽画像を生成する

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
# カレントディレクトリの読み込みとカレントディレクトリへの移動
import sys
sys.path.append(f'/content/drive/My Drive/system/')
import os
os.chdir(f'/content/drive/My Drive/system/myanswer')

In [3]:
!pip install scipy==1.1.0

Collecting scipy==1.1.0
[?25l  Downloading https://files.pythonhosted.org/packages/a8/0b/f163da98d3a01b3e0ef1cab8dd2123c34aee2bafbb1c5bffa354cc8a1730/scipy-1.1.0-cp36-cp36m-manylinux1_x86_64.whl (31.2MB)
[K     |████████████████████████████████| 31.2MB 108kB/s 
[31mERROR: umap-learn 0.4.6 has requirement scipy>=1.3.1, but you'll have scipy 1.1.0 which is incompatible.[0m
[31mERROR: tensorflow 2.3.0 has requirement scipy==1.4.1, but you'll have scipy 1.1.0 which is incompatible.[0m
[31mERROR: plotnine 0.6.0 has requirement scipy>=1.2.0, but you'll have scipy 1.1.0 which is incompatible.[0m
[31mERROR: albumentations 0.1.12 has requirement imgaug<0.2.7,>=0.2.5, but you'll have imgaug 0.2.9 which is incompatible.[0m
Installing collected packages: scipy
  Found existing installation: scipy 1.4.1
    Uninstalling scipy-1.4.1:
      Successfully uninstalled scipy-1.4.1
Successfully installed scipy-1.1.0


In [4]:
from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import RMSprop
from functools import partial
from glob import glob
import keras.backend as K
import tensorflow as tf
import scipy
import scipy.misc
import matplotlib.pyplot as plt
import numpy as np
import pickle

In [5]:
class DataLoader():
    def __init__(self, dataset_name, img_res=(128, 128)):
        self.dataset_name = dataset_name
        self.img_res = img_res

    def load_data(self, is_testing=False):
        if os.path.exists("../pickle/{}_tensor.pickle".format(self.dataset_name)):
            with open("../pickle/{}_tensor.pickle".format(self.dataset_name), 'rb') as p:
                imgs = pickle.load(p)
        else:
            img_pathes = glob('../figure/foodimg128/%s/*.jpg' % (self.dataset_name))
            imgs = []
            for img_path in img_pathes:
                img = self.imread(img_path)
                if not is_testing:
                    img = scipy.misc.imresize(img, self.img_res)
                    if np.random.random() > 0.5:
                        img = np.fliplr(img)
                else:
                    img = scipy.misc.imresize(img, self.img_res)
                imgs.append(img)
            with open('../pickle/{}_tensor.pickle'.format(self.dataset_name), 'wb') as p:
                pickle.dump(imgs , p)

        return np.array(imgs)

    def imread(self, path):
        return scipy.misc.imread(path, mode="RGB").astype(np.float)

In [6]:
class WGAN():
    def __init__(self, dataset_name="mnist"):
        self.img_rows = 28
        self.img_cols = 28
        # 変換させたい画像のデータセットの名前を指定
        self.dataset_name = dataset_name
        if self.dataset_name == "mnist":
            self.channels = 1
        else:
            self.channels = 3
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100

        self.data_loader = DataLoader(dataset_name=self.dataset_name,
                                      img_res=(self.img_rows, self.img_cols))


        # Following parameter and optimizer set as recommended in paper
        self.n_critic = 5
        self.clip_value = 0.01
        optimizer = RMSprop(lr=0.00005)

        # Build and compile the critic
        self.critic = self.build_critic()
        self.critic.compile(loss=self.wasserstein_loss,
                            optimizer=optimizer,
                            metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()

        # The generator takes noise as input and generated imgs
        z = Input(shape=(self.latent_dim,))
        img = self.generator(z)

        # For the combined model we will only train the generator
        self.critic.trainable = False

        # The critic takes generated images as input and determines validity
        valid = self.critic(img)

        # The combined model  (stacked generator and critic)
        self.combined = Model(z, valid)
        self.combined.compile(loss=self.wasserstein_loss,
                              optimizer=optimizer,
                              metrics=['accuracy'])
    

    def wasserstein_loss(self, y_true, y_pred):
        return K.mean(y_true * y_pred)

    def build_generator(self):
        model = Sequential()
        model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim))
        model.add(Reshape((7, 7, 128)))
        model.add(UpSampling2D())
        model.add(Conv2D(128, kernel_size=4, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Activation("relu"))
        model.add(UpSampling2D())
        model.add(Conv2D(64, kernel_size=4, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Activation("relu"))
        model.add(Conv2D(self.channels, kernel_size=4, padding="same"))
        model.add(Activation("tanh"))

        # model.summary()

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(noise, img)

    def build_critic(self):
        model = Sequential()
        model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Conv2D(32, kernel_size=3, strides=2, padding="same"))
        model.add(ZeroPadding2D(padding=((0,1),(0,1))))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Conv2D(128, kernel_size=3, strides=1, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Flatten())
        model.add(Dense(1))

        # model.summary()

        img = Input(shape=self.img_shape)
        validity = model(img)

        return Model(img, validity)
    
    def train(self, epochs, batch_size, sample_interval=50):

        if self.dataset_name == "mnist":
            (X_train, _), (_, _) = mnist.load_data()
            # Rescale -1 to 1
            X_train = (X_train.astype(np.float32) - 127.5) / 127.5
            X_train = np.expand_dims(X_train, axis=3)
        else:
            X_train = self.data_loader.load_data()
            # Rescale -1 to 1
            X_train = (X_train.astype(np.float32) - 127.5) / 127.5

        # Adversarial ground truths
        valid = -np.ones((batch_size, 1))
        fake =  np.ones((batch_size, 1))
        
        for epoch in range(epochs):
            for _ in range(self.n_critic):
                # ---------------------
                #  Train Discriminator
                # ---------------------

                # Select a random batch of images
                idx = np.random.randint(0, X_train.shape[0], batch_size)
                imgs = X_train[idx]
                
                # Sample noise as generator input
                noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

                # Generate a batch of new images
                gen_imgs = self.generator.predict(noise)

                # Train the critic
                d_loss_real = self.critic.train_on_batch(imgs, valid)
                d_loss_fake = self.critic.train_on_batch(gen_imgs, fake)
                d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)

                # Clip critic weights
                for l in self.critic.layers:
                    weights = l.get_weights()
                    weights = [np.clip(w, -self.clip_value, self.clip_value) for w in weights]
                    l.set_weights(weights)


            # ---------------------
            #  Train Generator
            # ---------------------
            g_loss = self.combined.train_on_batch(noise, valid)

           
            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                # Plot the progress
                print("epoch %d [D loss: %f] [G loss: %f]" % (epoch+1, 1 - d_loss[0], 1 - g_loss[0]))
                self.sample_images(epoch)

    def sample_images(self, epoch):
        os.makedirs('../result/%s/wgan' % self.dataset_name, exist_ok=True)
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                if self.dataset_name == "mnist":
                    axs[i,j].imshow(gen_imgs[cnt,:,:,0], cmap='gray')
                else:
                    axs[i,j].imshow(gen_imgs[cnt,:,:,:])
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("../result/{}/wgan/epoch{}.png".format(self.dataset_name, epoch),
                    transparent=True, dpi=300, bbox_inches="tight", pad_inches=0.0)
        plt.close()

In [7]:
 wgan = WGAN(dataset_name="gratin")
 # wgan = WGAN(dataset_name="mnist") 
 wgan.train(epochs=50000, batch_size=32, sample_interval=5000)

epoch 1 [D loss: 0.999906] [G loss: 1.000275]
epoch 5001 [D loss: 0.999967] [G loss: 1.000062]
epoch 10001 [D loss: 0.999953] [G loss: 1.000053]
epoch 15001 [D loss: 0.999976] [G loss: 1.000058]
epoch 20001 [D loss: 0.999978] [G loss: 1.000077]
epoch 25001 [D loss: 0.999975] [G loss: 1.000057]
epoch 30001 [D loss: 0.999963] [G loss: 1.000068]
epoch 35001 [D loss: 0.999965] [G loss: 1.000067]
epoch 40001 [D loss: 0.999967] [G loss: 1.000067]
epoch 45001 [D loss: 0.999974] [G loss: 1.000072]
