In [1]:
from __future__ import print_function, division

from keras.datasets import cifar10
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply, GaussianNoise
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
from keras.layers import MaxPooling2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras import losses
from keras.utils import to_categorical
import keras.backend as K

import matplotlib.pyplot as plt

import numpy as np

In [11]:
img_rows = 32
img_cols = 32
mask_height = 8
mask_width = 8
channels = 3
num_classes = 2
img_shape = (img_rows, img_cols, channels)
missing_shape = (mask_height, mask_width, channels)

optimizer = Adam(0.0002, 0.5)

# Build and compile the discriminator
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy',
    optimizer=optimizer,
    metrics=['accuracy'])

# Build the generator
generator = build_generator()

# The generator takes noise as input and generates the missing
# part of the image
masked_img = Input(shape=img_shape)
gen_missing = generator(masked_img)

# For the combined model we will only train the generator
discriminator.trainable = False

# The discriminator takes generated images as input and determines
# if it is generated or if it is a real image
valid = discriminator(gen_missing)

# The combined model  (stacked generator and discriminator)
# Trains generator to fool discriminator
combined = Model(masked_img , [gen_missing, valid])
combined.compile(loss=['mse', 'binary_crossentropy'],
    loss_weights=[0.999, 0.001],
    optimizer=optimizer)

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 4, 4, 64)          1792      
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 4, 4, 64)          0         
_________________________________________________________________
batch_normalization (BatchNo (None, 4, 4, 64)          256       
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 2, 2, 128)         73856     
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 2, 2, 128)         0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 2, 2, 128)         512       
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 2, 2, 256)         2

In [3]:
def build_generator():
    model = Sequential()

    # Encoder
    model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=img_shape, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    model.add(Conv2D(512, kernel_size=1, strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.5))

    # Decoder
    model.add(UpSampling2D())
    model.add(Conv2D(128, kernel_size=3, padding="same"))
    model.add(Activation('relu'))
    model.add(BatchNormalization(momentum=0.8))
    model.add(UpSampling2D())
    model.add(Conv2D(64, kernel_size=3, padding="same"))
    model.add(Activation('relu'))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Conv2D(channels, kernel_size=3, padding="same"))
    model.add(Activation('tanh'))

    model.summary()

    masked_img = Input(shape=img_shape)
    gen_missing = model(masked_img)

    return Model(masked_img, gen_missing)


In [10]:
def build_discriminator():

    model = Sequential()

    model.add(Conv2D(64, kernel_size=3, strides=2, input_shape=missing_shape, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Conv2D(256, kernel_size=3, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Flatten())
    model.add(Dense(1, activation='sigmoid'))
    model.summary()

    img = Input(shape=missing_shape)
    validity = model(img)

    return Model(img, validity)

In [13]:
def mask_randomly(imgs):
    y1 = np.random.randint(0, img_rows - mask_height, imgs.shape[0])
    y2 = y1 + mask_height
    x1 = np.random.randint(0, img_rows - mask_width, imgs.shape[0])
    x2 = x1 + mask_width

    masked_imgs = np.empty_like(imgs)
    missing_parts = np.empty((imgs.shape[0], mask_height, mask_width, channels))
    for i, img in enumerate(imgs):
        masked_img = img.copy()
        _y1, _y2, _x1, _x2 = y1[i], y2[i], x1[i], x2[i]
        missing_parts[i] = masked_img[_y1:_y2, _x1:_x2, :].copy()
        masked_img[_y1:_y2, _x1:_x2, :] = 0
        masked_imgs[i] = masked_img

    return masked_imgs, missing_parts, (y1, y2, x1, x2)

In [6]:
def train(epochs, batch_size=128, sample_interval=50):

    # Load the dataset
    (X_train, y_train), (_, _) = cifar10.load_data()

    # Extract dogs and cats
    X_cats = X_train[(y_train == 3).flatten()]
    X_dogs = X_train[(y_train == 5).flatten()]
    X_train = np.vstack((X_cats, X_dogs))

    # Rescale -1 to 1
    X_train = X_train / 127.5 - 1.
    y_train = y_train.reshape(-1, 1)

    # Adversarial ground truths
    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))

    for epoch in range(epochs):

        # ---------------------
        #  Train Discriminator
        # ---------------------

        # Select a random batch of images
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        imgs = X_train[idx]

        masked_imgs, missing_parts, _ = mask_randomly(imgs)

        # Generate a batch of new images
        gen_missing = generator.predict(masked_imgs)

        # Train the discriminator
        d_loss_real = discriminator.train_on_batch(missing_parts, valid)
        d_loss_fake = discriminator.train_on_batch(gen_missing, fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # ---------------------
        #  Train Generator
        # ---------------------

        g_loss = combined.train_on_batch(masked_imgs, [missing_parts, valid])

        # Plot the progress
        print ("%d [D loss: %f, acc: %.2f%%] [G loss: %f, mse: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss[0], g_loss[1]))

        # If at save interval => save generated image samples
        if epoch % sample_interval == 0:
            idx = np.random.randint(0, X_train.shape[0], 6)
            imgs = X_train[idx]
            sample_images(epoch, imgs)

In [16]:
def sample_images(epoch, imgs):
    r, c = 3, 6

    masked_imgs, missing_parts, (y1, y2, x1, x2) = mask_randomly(imgs)
    gen_missing = generator.predict(masked_imgs)

    imgs = 0.5 * imgs + 0.5
    masked_imgs = 0.5 * masked_imgs + 0.5
    gen_missing = 0.5 * gen_missing + 0.5

    fig, axs = plt.subplots(r, c)
    for i in range(c):
        axs[0,i].imshow(imgs[i, :,:])
        axs[0,i].axis('off')
        axs[1,i].imshow(masked_imgs[i, :,:])
        axs[1,i].axis('off')
        filled_in = imgs[i].copy()
        filled_in[y1[i]:y2[i], x1[i]:x2[i], :] = gen_missing[i]
        axs[2,i].imshow(filled_in)
        axs[2,i].axis('off')
    fig.savefig("images_CEGAN/%d.png" % epoch)
    plt.close()

In [20]:
def save_model():

    def save(model, model_name):
        model_path = "saved_model_CEGAN/%s.json" % model_name
        weights_path = "saved_model_CEGAN/%s_weights.hdf5" % model_name
        options = {"file_arch": model_path,
                    "file_weight": weights_path}
        json_string = model.to_json()
        open(options['file_arch'], 'w').write(json_string)
        model.save_weights(options['file_weight'])

    save(generator, "generator")
    save(discriminator, "discriminator")

In [21]:
train(epochs=20, batch_size=64, sample_interval=50)
# save_model()

0 [D loss: 0.001411, acc: 100.00%] [G loss: 0.194543, mse: 0.193353]
1 [D loss: 0.005126, acc: 100.00%] [G loss: 0.205027, mse: 0.203551]
2 [D loss: 0.002228, acc: 100.00%] [G loss: 0.222549, mse: 0.221371]
3 [D loss: 0.000886, acc: 100.00%] [G loss: 0.178576, mse: 0.177018]
4 [D loss: 0.002542, acc: 100.00%] [G loss: 0.215597, mse: 0.214322]
5 [D loss: 0.002276, acc: 100.00%] [G loss: 0.227550, mse: 0.226648]
6 [D loss: 0.000483, acc: 100.00%] [G loss: 0.226616, mse: 0.225600]
7 [D loss: 0.005713, acc: 100.00%] [G loss: 0.184953, mse: 0.183766]
8 [D loss: 0.001061, acc: 100.00%] [G loss: 0.207902, mse: 0.206813]
9 [D loss: 0.002968, acc: 100.00%] [G loss: 0.183129, mse: 0.181746]
10 [D loss: 0.001071, acc: 100.00%] [G loss: 0.207207, mse: 0.206139]
11 [D loss: 0.001114, acc: 100.00%] [G loss: 0.212929, mse: 0.211825]
12 [D loss: 0.003148, acc: 100.00%] [G loss: 0.223852, mse: 0.222994]
13 [D loss: 0.004356, acc: 100.00%] [G loss: 0.222071, mse: 0.221468]
14 [D loss: 0.001306, acc: 100