<a href="https://colab.research.google.com/github/Akiyoshi-Yagi/GANs/blob/master/CycleGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import scipy as sp
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import datetime
import matplotlib.pyplot as plt
import sys
import numpy as np
import os

Using TensorFlow backend.


In [None]:
!pip install git+https://www.github.com/keras-team/keras-contrib.git

Collecting git+https://www.github.com/keras-team/keras-contrib.git
  Cloning https://www.github.com/keras-team/keras-contrib.git to /tmp/pip-req-build-9odn4rui
  Running command git clone -q https://www.github.com/keras-team/keras-contrib.git /tmp/pip-req-build-9odn4rui
Building wheels for collected packages: keras-contrib
  Building wheel for keras-contrib (setup.py) ... [?25l[?25hdone
  Created wheel for keras-contrib: filename=keras_contrib-2.0.8-cp36-none-any.whl size=101064 sha256=05d6bdb6aa05b9e0f9d63fa81d69078ed1f25e275d71b9b7f83f3b37d73cd608
  Stored in directory: /tmp/pip-ephem-wheel-cache-m6v5o1_z/wheels/11/27/c8/4ed56de7b55f4f61244e2dc6ef3cdbaff2692527a2ce6502ba
Successfully built keras-contrib
Installing collected packages: keras-contrib
Successfully installed keras-contrib-2.0.8


In [None]:
#@title
import scipy
from glob import glob
import numpy as np

class DataLoader():
    def __init__(self, dataset_name, img_res=(128, 128)):
        self.dataset_name = dataset_name
        self.img_res = img_res

    def load_data(self, domain, batch_size=1, is_testing=False):
        data_type = "train%s" % domain if not is_testing else "test%s" % domain
        path = glob('./datasets/%s/%s/*' % (self.dataset_name, data_type))

        batch_images = np.random.choice(path, size=batch_size)

        imgs = []
        for img_path in batch_images:
            img = self.imread(img_path)
            if not is_testing:
                img = scipy.misc.imresize(img, self.img_res)

                if np.random.random() > 0.5:
                    img = np.fliplr(img)
            else:
                img = scipy.misc.imresize(img, self.img_res)
            imgs.append(img)

        imgs = np.array(imgs)/127.5 - 1.

        return imgs

    def load_batch(self, batch_size=1, is_testing=False):
        data_type = "train" if not is_testing else "val"
        path_A = glob('./datasets/%s/%sA/*' % (self.dataset_name, data_type))
        path_B = glob('./datasets/%s/%sB/*' % (self.dataset_name, data_type))

        self.n_batches = int(min(len(path_A), len(path_B)) / batch_size)
        total_samples = self.n_batches * batch_size

        # Sample n_batches * batch_size from each path list so that model sees all
        # samples from both domains
        path_A = np.random.choice(path_A, total_samples, replace=False)
        path_B = np.random.choice(path_B, total_samples, replace=False)

        for i in range(self.n_batches-1):
            batch_A = path_A[i*batch_size:(i+1)*batch_size]
            batch_B = path_B[i*batch_size:(i+1)*batch_size]
            imgs_A, imgs_B = [], []
            for img_A, img_B in zip(batch_A, batch_B):
                img_A = self.imread(img_A)
                img_B = self.imread(img_B)

                img_A = scipy.misc.imresize(img_A, self.img_res)
                img_B = scipy.misc.imresize(img_B, self.img_res)

                if not is_testing and np.random.random() > 0.5:
                        img_A = np.fliplr(img_A)
                        img_B = np.fliplr(img_B)

                imgs_A.append(img_A)
                imgs_B.append(img_B)

            imgs_A = np.array(imgs_A)/127.5 - 1.
            imgs_B = np.array(imgs_B)/127.5 - 1.

            yield imgs_A, imgs_B

    def imread(self, path):
        return scipy.misc.imread(path, mode='RGB').astype(np.float)

In [11]:
class CycleGAN():
  def __init__(self):
    self.img_rows = 128
    self.img_cols = 128
    self.chanenels = 3
    self.img_shape = (self.img_rows, self.img_cols, self.channels)

    self.dataset_name = "apple2orange"
    self.data_loader = DataLoader(dataset_name=self.dataset_name, img_res=(self.img_rows, self.img_cols))

    patch = int(self.img_rows / 2 ** 4)
    self.disc_patch = (patch, patch, 1)

    self.gf =  32
    self.df =  64

    #ここのハイパーパラメータは実験によって最適な数値を求める！
    #cycle consistnecy lossの重み
    self.lambda_cycle = 10.0
    #identity lossの重み
    self.lambda_id = 0.9 * self.lambda_cycle

    optimizer = Adam(0.0002, 0.5)

    #ネットワーク構築

    #識別器のコンパイル
    self.d_A = self.build_discriminator()
    self.d_B = self.build_discriminator()
    self.d_A.compile(loss="mse", optimizer=optimizer, metrics=["accuracy"])
    self.d_B.compile(loss="mse", optimizer=optimizer, metrics=["accuracy"])

    #生成器のコンパイル
    self.g_AB = self.build_genereator()
    self.g_BA = self.build_genereator()

    img_A = Input(shape=self.img_shape)
    img_B = Input(shape=self.img_shape)

    fake_B = self.g_AB(img_A)
    fake_A = self.g_BA(img_B)
    
    reconstr_A = self.g_BA(fake_B)
    reconstr_B = self.g_AB(fake_A)
    
    img_A_id = self.g_BA(img_A)
    img_B_id = self.g_AB(img_B)

    self.d_A.trainable = False
    self.d_B.trainable = False

    valid_A = self.d_A(fake_A)
    valid_B = self.d_B(fake_B)

    self.combined = Model(inputs=[img_A, img_B],
                          outputs=[valid_A, valid_B,
                                    reconstr_A, reconstr_B,
                                    img_A_id, img_B_id])
    self.combined.compile(loss=['mse', 'mse','mae', 'mae', 'mae', 'mae'],
                          loss_weights=[1, 1,
                                        self.lambda_cycle, self.lambda_cycle,
                                        self.lambda_id, self.lambda_id],
                          optimizer=optimizer)

    

  #生成器
  def build_genereator(self):

    def conv2d(layer_input, filters, f_size=4, noramlization=True):
      d = Conv2D(filters, kernel_size=f_size, strides=2, padding="same")(layer_input)
      d = LeakyReLU(alpha=0.2)(d)
      if nomalization:
        d = InstanceNormalization()(d)
        return d

    @staticmethod
    def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
      u = UpSampling2D(size=2)(layer_input)
      u = Conv2D(filters, kernel_size=f_size, strides=1, padding="same", activation="relu")(u)
      if dropout_rate:
        u = Dropout(dropout_rate)(u)
      u = InstanceNormalization()(u)
      u = Concatenate()([u, skip_input])
      return u

      #ダウンサンプリング
      d0 = Input(shape=self.img_shape)
      d1 = self.conv2d(d0, self.gf)
      d2 = self.conv2d(d1, self.gf*2)
      d3 = self.conv2d(d2, self.gf*4)
      d4 = self.conv2d(d3, self.gf*8)

      #upsampling
      u1 = self.deconv2d(d4, d3, self.gf * 4)
      u2 = self.deconv2d(u1, d2, self.gf * 2)
      u3 = self.deconv2d(u2, d1, self.gf * 1)
      u4 = UpSampling2D(size=2)(u3)

      output_img = Conv2D(self.channels, kernel_size=4, strides=1, padding="same", activation="tanh")(u4)

      return Model( d0, output_img)
  
  #識別器
  def build_discriminator(self):
    img = Input(shape=self.img_shape)

    d1 = self.conv2d(img, self.df, normalization=False)
    d2 = selfconv2d(d1, self.df*2)
    d3 = selfconv2d(d2, self.df*4)
    d4 = selfconv2d(d3, self.df*8)

    validity = Conv2D(1, kernel_size=4, strides=1, padding="same")(d4)

    return Model(img, validity)

  def train(self, epochs, batch_size=1, sample_interval=50):

    valid = np.ones((batch_size,) + disc_patch)
    fake =  np.zeros((batch_size, ) + self.disc_patch)

    for epoch in range(epochs):
      for batch_i, (imgs_A, imgs_B) in enumerate(self.data_loader.load_batch(batch_size)):

        #識別器の訓練
        fake_B = self.g_AB.predict(imgs_A)
        fake_A = self.g_BA.predict(imgs_B)

        dA_loss_real = self.d_A.train_on_batch(imgs_A, valid)
        dA_loss_fake = self.d_A.train_on_batch(fake_A, fake)
        dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)

        dB_loss_real = self.d_B.train_on_batch(imgs_B, valid)
        dB_loss_fake = self.d_B.train_on_batch(fake_B, fake)
        dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)

        #識別器全体の誤差
        d_loss = 0.5 * np.add(dA_loss, dB_loss)

        #生成器の訓練
        g_loss = self.combined.train_on_batch([imgs_A, imgs_B], [valid, valid, imgs_A, imgs_B, imgs_A, imgs_B])
        g_loss = self.combined.train_on_batch([imgs_A, imgs_B],
                                                      [valid, valid,
                                                       imgs_A, imgs_B,
                                                       imgs_A, imgs_B])
                # If at save interval => plot the generated image samples
        if batch_i % sample_interval == 0:
                    self.sample_images(epoch, batch_i)




In [None]:
cycle_gan = CycleGAN()
cycle_gan.train(epochs=100, batch_size=64, sample_interval=10)