# Pix2Pix

In [1]:
from __future__ import print_function, division
import scipy

from keras.datasets import mnist
from keras_contrib.layers.normalization import InstanceNormalization
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import datetime
import matplotlib.pyplot as plt
import sys
from data_loader import DataLoader
import numpy as np
import os
import glob
import skimage
import imageio
from tqdm import tqdm_notebook as tqdm

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.
  return f(*args, **kwds)


In [2]:
class DataLoader():
    """
    supposed
    ./datasets/(dataset_name)/train/source/0.jpg
    ./datasets/(dataset_name)/train/target/0.jpg
    ./datasets/(dataset_name)/test/source/0.jpg
    ./datasets/(dataset_name)/test/target/0.jpg
    ...
    """
    
    def __init__(self, dataset_name, img_res=(128, 128)):
        self.dataset_name = dataset_name
        self.img_res = img_res

    def load_data(self, batch_size=1, is_testing=False):
        data_type = "train" if not is_testing else "test"
        img_source_dir = './datasets/%s/%s/source/*' % (self.dataset_name, data_type)
        img_source_paths = glob.glob(img_source_dir)
        img_source_batch = np.random.choice(img_source_paths, size=batch_size)

        imgs_source, imgs_target = [], []
        for img_source_path in img_source_batch:
            img_source = self.imread(img_source_path)
            img_target = self.imread(img_source_path.replace('source', 'target'))

            img_source = scipy.misc.imresize(img_source, self.img_res)
            img_target = scipy.misc.imresize(img_target, self.img_res)

            if not is_testing and np.random.random() > 0.5:
                    img_source = np.fliplr(img_source)
                    img_target = np.fliplr(img_target)

            imgs_source.append(img_source)
            imgs_target.append(img_target)

        imgs_source = np.array(imgs_source) / 255
        imgs_target = np.array(imgs_target) / 255

        return imgs_source, imgs_target

    def load_batch(self, batch_size=1, is_testing=False):
        data_type = "train" if not is_testing else "test"
        img_source_dir = './datasets/%s/%s/source/*' % (self.dataset_name, data_type)
        img_source_paths = glob.glob(img_source_dir)[::6]

        self.n_batches = len(img_source_paths) // batch_size

        for i in range(self.n_batches-1):
            img_source_batch = img_source_paths[i*batch_size:(i+1)*batch_size]
            imgs_source, imgs_target = [], []
            for img_source_path in img_source_batch:
                img_source = self.imread(img_source_path)
                img_target = self.imread(img_source_path.replace('source', 'target'))

                img_source = scipy.misc.imresize(img_source, self.img_res)
                img_target = scipy.misc.imresize(img_target, self.img_res)

                if not is_testing and np.random.random() > 0.5:
                        img_source = np.fliplr(img_source)
                        img_target = np.fliplr(img_target)

                imgs_source.append(img_source)
                imgs_target.append(img_target)

            imgs_source = np.array(imgs_source) / 255
            imgs_target = np.array(imgs_target) / 255

            yield imgs_source, imgs_target


    def imread(self, path):
        return scipy.misc.imread(path, mode='RGB').astype(np.float)

In [3]:
class Pix2Pix():
    def __init__(self, dataset_name='facades'):
        # Input shape
        self.img_rows = 256
        self.img_cols = 256
        self.channels = 3
        self.img_shape = (self.img_rows, self.img_cols, self.channels)

        # Configure data loader
        self.dataset_name = dataset_name
        self.data_loader = DataLoader(dataset_name=self.dataset_name,
                                      img_res=(self.img_rows, self.img_cols))

        # Calculate output shape of D (PatchGAN)
        patch = int(self.img_rows / 2**4)
        self.disc_patch = (patch, patch, 1)

        # Number of filters in the first layer of G and D
        self.gf = 64
        self.df = 64

        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])

        #-------------------------
        # Build generator
        #-------------------------

        # Build the generator
        self.generator = self.build_generator()
        img_source = Input(shape=self.img_shape)
        img_fake = self.generator(img_source)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # Discriminators determines validity of translated images / condition pairs
        validity = self.discriminator([img_fake, img_source])

        self.combined = Model(inputs=img_source, outputs=[validity, img_fake])
        self.combined.compile(loss=['mse', 'mae'],
                              loss_weights=[1, 100],
                              optimizer=optimizer)

    def build_generator(self):
        """U-Net Generator"""

        def conv2d(layer_input, filters, f_size=4, bn=True):
            """Layers used during downsampling"""
            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if bn:
                d = BatchNormalization(momentum=0.8)(d)
            return d

        def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
            """Layers used during upsampling"""
            u = UpSampling2D(size=2)(layer_input)
            u = Conv2D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(u)
            if dropout_rate:
                u = Dropout(dropout_rate)(u)
            u = BatchNormalization(momentum=0.8)(u)
            u = Concatenate()([u, skip_input])
            return u

        # Image input
        d0 = Input(shape=self.img_shape)

        # Downsampling
        d1 = conv2d(d0, self.gf, bn=False)
        d2 = conv2d(d1, self.gf*2)
        d3 = conv2d(d2, self.gf*4)
        d4 = conv2d(d3, self.gf*8)
        d5 = conv2d(d4, self.gf*8)
        d6 = conv2d(d5, self.gf*8)
        d7 = conv2d(d6, self.gf*8)

        # Upsampling
        u1 = deconv2d(d7, d6, self.gf*8)
        u2 = deconv2d(u1, d5, self.gf*8)
        u3 = deconv2d(u2, d4, self.gf*8)
        u4 = deconv2d(u3, d3, self.gf*4)
        u5 = deconv2d(u4, d2, self.gf*2)
        u6 = deconv2d(u5, d1, self.gf)

        u7 = UpSampling2D(size=2)(u6)
        output_img = Conv2D(self.channels, kernel_size=4, strides=1, padding='same', activation='tanh')(u7)

        return Model(d0, output_img)

    def build_discriminator(self):

        def d_layer(layer_input, filters, f_size=4, bn=True):
            """Discriminator layer"""
            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if bn:
                d = BatchNormalization(momentum=0.8)(d)
            return d

        img_A = Input(shape=self.img_shape)
        img_B = Input(shape=self.img_shape)

        # Concatenate image and conditioning image by channels to produce input
        combined_imgs = Concatenate(axis=-1)([img_A, img_B])

        d1 = d_layer(combined_imgs, self.df, bn=False)
        d2 = d_layer(d1, self.df*2)
        d3 = d_layer(d2, self.df*4)
        d4 = d_layer(d3, self.df*8)

        validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)

        return Model([img_A, img_B], validity)

    def train(self, epochs, batch_size=1):

        start_time = datetime.datetime.now()

        # Adversarial loss ground truths
        valid = np.ones((batch_size,) + self.disc_patch)
        fake = np.zeros((batch_size,) + self.disc_patch)

        print('Train started at', start_time)
        for epoch in tqdm(range(epochs)):
            for batch_i, (imgs_source, imgs_target) in tqdm(enumerate(self.data_loader.load_batch(batch_size))):

                imgs_fake = self.generator.predict(imgs_source)

                # Train the discriminator
                d_loss_real = self.discriminator.train_on_batch([imgs_target, imgs_source], valid)
                d_loss_fake = self.discriminator.train_on_batch([imgs_fake,   imgs_source], fake)
                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

                # Train the generators. update weights so that imgs_fake is judged as valid.
                g_loss = self.combined.train_on_batch(imgs_source, [valid, imgs_target])
                
                if batch_i == 200:
                    break

            elapsed_time = datetime.datetime.now() - start_time
            # Plot the progress
            print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %f] time: %s" % (
                epoch, epochs, batch_i, self.data_loader.n_batches, d_loss[0], 100*d_loss[1], g_loss[0], elapsed_time)
                  )
            self.sample_images(epoch)

    def sample_images(self, epoch, samples=3):
        os.makedirs('images/%s' % self.dataset_name, exist_ok=True)
        r, c = 3, samples

        imgs_source, imgs_target= self.data_loader.load_data(batch_size=samples, is_testing=True)
        imgs_fake = self.generator.predict(imgs_source)

        gen_imgs = np.concatenate([imgs_source, imgs_fake, imgs_target])
        # Rescale images 0 - 1
        gen_imgs = np.clip(gen_imgs, 0, 1)

        titles = ['Source', 'Generated', 'Target']
        fig, axs = plt.subplots(r, c)
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[c * i + j])
                axs[i,j].set_title(titles[i])
                axs[i,j].axis('off')
        fig.savefig("images/%s/%d.png" % (self.dataset_name, epoch))
        plt.close()

In [4]:
pix2pix = Pix2Pix(dataset_name='sakura')
pix2pix.generator.summary()
pix2pix.discriminator.summary()
pix2pix.combined.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            (None, 256, 256, 3)  0                                            
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 128, 128, 64) 3136        input_3[0][0]                    
__________________________________________________________________________________________________
leaky_re_lu_5 (LeakyReLU)       (None, 128, 128, 64) 0           conv2d_6[0][0]                   
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 64, 64, 128)  131200      leaky_re_lu_5[0][0]              
__________________________________________________________________________________________________
leaky_re_l

  'Discrepancy between trainable weights and collected trainable'


In [5]:
pix2pix.train(epochs=200, batch_size=1)

Train started at 2018-11-24 03:36:34.027224


`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imread`` instead.
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``skimage.transform.resize`` instead.
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``skimage.transform.resize`` instead.
  'Discrepancy between trainable weights and collected trainable'


[Epoch 0/200] [Batch 200/5334] [D loss: 0.296335, acc:  80%] [G loss: 8.921474] time: 0:10:57.763858


`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``skimage.transform.resize`` instead.
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``skimage.transform.resize`` instead.


Exception in thread Thread-4:
Traceback (most recent call last):
  File "/Users/n/.pyenv/versions/3.6.1/lib/python3.6/threading.py", line 916, in _bootstrap_inner
    self.run()
  File "/Users/n/.pyenv/versions/3.6.1/envs/ssd/lib/python3.6/site-packages/tqdm/_monitor.py", line 63, in run
    for instance in self.tqdm_cls._instances:
  File "/Users/n/.pyenv/versions/3.6.1/lib/python3.6/_weakrefset.py", line 60, in __iter__
    for itemref in self.data:
RuntimeError: Set changed size during iteration



[Epoch 1/200] [Batch 200/5334] [D loss: 0.089542, acc:  92%] [G loss: 6.649175] time: 0:21:38.279963


[Epoch 2/200] [Batch 200/5334] [D loss: 0.093448, acc:  94%] [G loss: 4.183104] time: 0:32:19.924494


[Epoch 3/200] [Batch 200/5334] [D loss: 0.148926, acc:  87%] [G loss: 3.394277] time: 0:42:51.096082


[Epoch 4/200] [Batch 200/5334] [D loss: 0.322532, acc:  51%] [G loss: 4.050559] time: 0:53:33.442499





KeyboardInterrupt: 

In [None]:
# Plot generator and discriminator accuracy and loss all
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
x = [i + 1 for i in range(30)]
plt.plot(x, pix2pix.d_val_losses, label="d_val_loss")
plt.plot(x, pix2pix.d_val_accs,   label="d_val_acc")
plt.plot(x, pix2pix.g_val_losses, label="g_val_loss")
plt.plot(x, pix2pix.d_losses,     label='d_loss')
plt.plot(x, pix2pix.d_accs,       label="d_acc")
plt.plot(x, pix2pix.g_losses,     label="g_loss")
plt.xlabel('Epochs')
plt.ylabel('a.u.')
plt.legend()
plt.yscale('log')
plt.title('Scores of pix2pix')
plt.show()