In [None]:
from __future__ import print_function, division

from keras.layers import Input, Dense, Flatten, Dropout, Reshape, Concatenate
from keras.layers import BatchNormalization, Activation, Conv2D, Conv2DTranspose
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Model
from keras.optimizers import Adam

from keras.datasets import cifar10
import keras.backend as K

import matplotlib.pyplot as plt

import sys
import numpy as np

%pylab inline

In [None]:
def get_generator():
    z = Input(shape=(100,)) # noize
    c = Input(shape=(10,)) # condition. In this case, c is supposed to be a one-hot vector of CIFAR-10
    merged_input = Concatenate()([z, c])

    hid = Dense(128 * 8 * 8, activation='relu')(merged_input)    
    hid = BatchNormalization(momentum=0.9)(hid)
    hid = LeakyReLU(alpha=0.2)(hid)
    hid = Reshape((8, 8, 128))(hid)
    
    hid = Conv2D(128, kernel_size=4, strides=1,padding='same')(hid)
    hid = BatchNormalization(momentum=0.9)(hid)    
    hid = LeakyReLU(alpha=0.2)(hid)

    hid = Conv2DTranspose(128, 4, strides=2, padding='same')(hid)
    hid = BatchNormalization(momentum=0.9)(hid)
    hid = LeakyReLU(alpha=0.2)(hid)

    hid = Conv2D(128, kernel_size=5, strides=1,padding='same')(hid)
    hid = BatchNormalization(momentum=0.9)(hid)    
    hid = LeakyReLU(alpha=0.2)(hid)

    hid = Conv2DTranspose(128, 4, strides=2, padding='same')(hid)
    hid = BatchNormalization(momentum=0.9)(hid)
    hid = LeakyReLU(alpha=0.2)(hid)

    hid = Conv2D(128, kernel_size=5, strides=1, padding='same')(hid)
    hid = BatchNormalization(momentum=0.9)(hid)
    hid = LeakyReLU(alpha=0.2)(hid)

    hid = Conv2D(128, kernel_size=5, strides=1, padding='same')(hid)
    hid = BatchNormalization(momentum=0.9)(hid)
    hid = LeakyReLU(alpha=0.2)(hid)

    hid = Conv2D(3, kernel_size=5, strides=1, padding="same")(hid) # 32, 32, 3
    out = Activation("tanh")(hid)


    model = Model(inputs=[z, c], outputs=out)
    model.summary()

    return model, out

In [None]:
def get_discriminator():
    x = Input(shape=(32, 32, 3))
    c = Input(shape=(10,))
    hid = Conv2D(128, kernel_size=3, strides=1, padding='same')(x)
    hid = BatchNormalization(momentum=0.9)(hid)
    hid = LeakyReLU(alpha=0.2)(hid)

    hid = Conv2D(128, kernel_size=4, strides=2, padding='same')(hid)
    hid = BatchNormalization(momentum=0.9)(hid)
    hid = LeakyReLU(alpha=0.2)(hid)

    hid = Conv2D(128, kernel_size=4, strides=2, padding='same')(hid)
    hid = BatchNormalization(momentum=0.9)(hid)
    hid = LeakyReLU(alpha=0.2)(hid)

    hid = Conv2D(128, kernel_size=4, strides=2, padding='same')(hid)
    hid = BatchNormalization(momentum=0.9)(hid)
    hid = LeakyReLU(alpha=0.2)(hid)

    hid = Flatten()(hid)

    c = Input(shape=(10,))
    merged_layer = Concatenate()([hid, c])
    hid = Dense(512, activation='relu')(merged_layer)
    hid = LeakyReLU(alpha=0.2)(hid)
    hid = Dropout(0.4)(hid)
    out = Dense(1, activation='sigmoid')(hid)

    model = Model(inputs=[x, c], outputs=out)
    model.summary()

    return model, out

In [None]:
# --------
# Prepare some utilities
# --------

from keras.preprocessing import image

def one_hot_encode(y):
    z = np.zeros((len(y), 10))
    idx = np.arange(len(y))
    z[idx, y] = 1
    return z

def generate_noise(n_samples, noise_dim):
    X = np.random.normal(0, 1, size=(n_samples, noise_dim))
    return X

def generate_random_labels(n):
    y = np.random.choice(10, n)
    y = one_hot_encode(y)
    return y

tags = ['Airplane', 'Automobile', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']
  
def show_samples():
    fig, axs = plt.subplots(5, 6, figsize=(10,6))
    plt.subplots_adjust(hspace=0.3, wspace=0.1)
    #fig, axs = plt.subplots(5, 6)
    #fig.tight_layout()
    for classlabel in range(10):
        row = int(classlabel / 2)
        coloffset = (classlabel % 2) * 3
        lbls = one_hot_encode([classlabel] * 3)
        noise = generate_noise(3, 100)
        gen_imgs = generator.predict([noise, lbls])

        for i in range(3):
            # Dont scale the images back, let keras handle it
            img = image.array_to_img(gen_imgs[i], scale=True)
            axs[row,i+coloffset].imshow(img)
            axs[row,i+coloffset].axis('off')
            if i ==1:
                axs[row,i+coloffset].set_title(tags[classlabel])
    plt.show()
    plt.close()  

In [None]:
# -------
# Compile discriminator
# -------
discriminator, disc_out = get_discriminator()
discriminator.compile(optimizer=Adam(0.0002, 0.5), loss='binary_crossentropy', metrics=['accuracy'])


# --------
# Compile combined model with untrainable discriminator
# --------
discriminator.trainable = False

combined_input = Input(shape=(100,))
combined_condition = Input(shape=(10,))
generator, gen_out = get_generator()
x = generator([combined_input, combined_condition])
combined_out = discriminator([x, combined_condition])
combined = Model(inputs=[combined_input, combined_condition], output=combined_out)
combined.summary()

combined.compile(optimizer=Adam(0.0002, 0.5), loss='binary_crossentropy')

In [None]:
# --------
# Prepare data
# -------

BATCH_SIZE = 32

# # Get training images
(X_train, c_train), (X_test, c_test) = cifar10.load_data()

# Normalize data
X_train = (X_train - 127.5) / 127.5
X_test = (X_test - 127.5) / 127.5
# 1hot encode labels
c_train = one_hot_encode(c_train[:,0])
c_test = one_hot_encode(c_test[:,0])

print ("Training shape: {}".format(X_train.shape))
print ("Test     shape: {}".format(X_test.shape))
 
num_train_batches = X_train.shape[0] // BATCH_SIZE
num_test_batches = X_test.shape[0] // BATCH_SIZE

In [None]:
import datetime
JST = datetime.timezone(datetime.timedelta(hours=+9), 'JST')
start_time = datetime.datetime.now(JST)
print(start_time)

In [None]:
# --------
# Train and test
# --------

EPOCHS = 100
d_losses = []
d_accs   = []
g_losses = []

# Prepare labels as ground truth of discrimination output
true_labels = np.ones((BATCH_SIZE, 1))
fake_labels = np.zeros((BATCH_SIZE, 1))

print("Training started.")
for epoch in range(EPOCHS):
    
    # shuffle all data on starting epoch
    s = np.arange(X_train.shape[0])
    X_train = X_train[s]
    c_train = c_train[s]
    
    # Train
    for batch_idx_train in range(num_train_batches):
        # Get the set of images and labels (here used as conditions) from CIFAR-10 dataset
        true_images      = X_train[batch_idx_train * BATCH_SIZE : (batch_idx_train + 1) * BATCH_SIZE]
        true_conditions  = c_train[batch_idx_train * BATCH_SIZE : (batch_idx_train + 1) * BATCH_SIZE]

        noise_data        = generate_noise(BATCH_SIZE, 100)
        random_conditions = generate_random_labels(BATCH_SIZE)
        generated_images  = generator.predict([noise_data, random_conditions])

        # Train discriminator on real data whose label is one
        d_loss_true_train = discriminator.train_on_batch([true_images, true_conditions], true_labels)

        # Train discriminator on generated data whose label is zero
        d_loss_fake       = discriminator.train_on_batch([generated_images, random_conditions], fake_labels)

        # Train generator. Generator tries to cheat discriminator so a label need set to be one
        # Learning goes on so that a network produces labels
        g_loss_train      = combined.train_on_batch([noise_data, random_conditions], true_labels)
    
    # Test at the end of every epoch
    # store discriminator's and generator's loss at every batch and then calculate average at the end of epoch 
    d_losses_in_epoch = []
    d_accs_in_epoch = []
    g_losses_in_epoch = []
    for batch_idx_test in range(num_test_batches):
        test_images     = X_test[batch_idx_test * BATCH_SIZE : (batch_idx_test + 1) * BATCH_SIZE]
        test_conditions = c_test[batch_idx_test * BATCH_SIZE : (batch_idx_test + 1) * BATCH_SIZE]

        d_loss_test_batch = discriminator.test_on_batch([test_images, test_conditions], true_labels)
        g_loss_test_batch = combined.test_on_batch([noise_data, random_conditions], true_labels)

        d_losses_in_epoch.append(d_loss_test_batch[0])
        d_accs_in_epoch.append(d_loss_test_batch[1])
        g_losses_in_epoch.append(g_loss_test_batch)
        
    d_loss = np.average(d_losses_in_epoch)
    d_acc  = np.average(d_accs_in_epoch)
    g_loss = np.average(g_losses_in_epoch)

    d_losses.append(d_loss)
    d_accs.append(d_acc)
    g_losses.append(g_loss)
    print('Epoch: {}; Generator Loss: {:.4f}; Discriminator Loss: {:.4f}, Acc: {:.4f}'.format(epoch + 1, g_loss, d_loss, d_acc))
    show_samples()


In [None]:
for classlabel in range(10):
    lbls = one_hot_encode([classlabel] * 9)
    noise = generate_noise(9, 100)
    gen_imgs = generator.predict([noise, lbls])

    fig, axs = plt.subplots(3, 3)
    plt.subplots_adjust(hspace=0.05, wspace=0.05)
    count = 0
    for i in range(3):
        for j in range(3):
            # Dont scale the images back, let keras handle it
            img = image.array_to_img(gen_imgs[count], scale=True)
            axs[i,j].imshow(img)
            axs[i,j].axis('off')
            plt.suptitle('Label: ' + str(classlabel))
            count += 1
    plt.show()
    plt.close()  

In [None]:
# --------
# Plot generator and discriminator accuracy and loss all
# --------
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
x = [i + 1 for i in range(len(d_losses))]
plt.plot(x, d_losses, label="d_loss")
plt.plot(x, d_accs,   label="d_acc")
plt.plot(x, g_losses, label="g_loss")
plt.xlabel('Epochs')
plt.ylabel('a.u.')
plt.legend()
plt.show()

In [None]:
# --------
# Plot generator and discriminator accuracy and loss by 1000 epoochs
# --------
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
x = [i + 1 for i in range(len(d_losses))]
plt.plot(x, d_losses, label="d_loss")
plt.plot(x, d_accs,   label="d_acc")
plt.plot(x, g_losses, label="g_loss")
plt.xlabel('Epochs')
plt.ylabel('a.u.')
plt.xlim((0,80))
plt.ylim((0, 3))
ax = plt.gca()
ax.set_xticks(ticks = [i for i in range(0, 80, 5)])
plt.legend()
plt.show()

In [None]:
finished_time = datetime.datetime.now(JST)
print(finished_time)
print('One epoch took {:.0f} seconds'.format((finished_time - start_time).total_seconds() // 351))