In [None]:
import sys, os
import threading
import collections
import random
import ipywidgets
from io import BytesIO
import matplotlib.pyplot as plt
import skimage
import numpy as np
import tensorflow as tf
import keras
import keras.backend as K

tf.logging.set_verbosity(tf.logging.ERROR)

def update_progress(msg, progress):
    barLength = 20
    status = ""
    block = int(round(barLength*progress))
    text = "\r{0}: [{1}] {2:.2%} {3}".format(msg, "="*(block-1) + ">" + "-"*(barLength-block), progress, status)
    sys.stdout.write(text)
    sys.stdout.flush()

In [None]:
def build_generator(latent_dim, channels):

    noise = keras.layers.Input(shape=(latent_dim,))
    
    x = keras.layers.Dense(256 * 4 * 4)(noise)
    x = keras.layers.BatchNormalization(momentum=0.8)(x)
    x = keras.layers.ELU(alpha=0.1)(x)
    x = keras.layers.Reshape((4, 4, 256))(x)
        
    x = keras.layers.Conv2DTranspose(128, 3, strides=2, padding='same')(x)
    x = keras.layers.BatchNormalization(momentum=0.8)(x)
    x = keras.layers.ELU(alpha=0.1)(x)

    x = keras.layers.Conv2DTranspose(64, 3, strides=2, padding='same')(x)
    x = keras.layers.BatchNormalization(momentum=0.8)(x)
    x = keras.layers.ELU(alpha=0.1)(x)
    
    x = keras.layers.Conv2DTranspose(32, 3, strides=2, padding='same')(x)
    x = keras.layers.BatchNormalization(momentum=0.8)(x)
    x = keras.layers.ELU(alpha=0.1)(x)
    
    x = keras.layers.Conv2D(channels, kernel_size=5, padding="same")(x)
    img = keras.layers.Activation("tanh")(x)

    return noise, img


In [None]:
def build_discriminator(img_shape):

    img = keras.layers.Input(shape=img_shape)
    x = keras.layers.GaussianNoise(0.1)(img)
    
    x = keras.layers.Conv2D(32, kernel_size=3, strides=1, padding="same")(x)
    x = keras.layers.BatchNormalization(momentum=0.8)(x)
    x = keras.layers.ELU(alpha=0.1)(x)
    x = keras.layers.AveragePooling2D(2)(x)

    x = keras.layers.Conv2D(64, kernel_size=3, strides=1, padding="same")(x)
    x = keras.layers.BatchNormalization(momentum=0.8)(x)
    x = keras.layers.ELU(alpha=0.1)(x)
    x = keras.layers.AveragePooling2D(2)(x)
    
    x = keras.layers.Conv2D(128, kernel_size=3, strides=1, padding="same")(x)
    x = keras.layers.BatchNormalization(momentum=0.8)(x)
    x = keras.layers.ELU(alpha=0.1)(x)
    x = keras.layers.AveragePooling2D(2)(x)

    x = keras.layers.Conv2D(256, kernel_size=3, strides=1, padding="same")(x)
    x = keras.layers.BatchNormalization(momentum=0.8)(x)
    x = keras.layers.ELU(alpha=0.1)(x)
    x = keras.layers.AveragePooling2D(2)(x)

    x = keras.layers.Dropout(0.4)(x)
    x = keras.layers.Flatten()(x)
    validity = keras.layers.Dense(1, activation='sigmoid')(x)

    return img, validity

In [None]:
def rounded_binary_accuracy(y_true, y_pred):
    return K.mean(K.equal(K.round(y_true), K.round(y_pred)), axis=-1)


class DCGAN():
    def __init__(self, img, real, loss_plt, acc_pls):
        # Input shape
        self.img_rows = 32
        self.img_cols = 32
        self.channels = 3
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 16

        self.img = img
        self.real = real
        self.loss_plt = loss_plt
        self.acc_plt = acc_plt
        
        optimizer = keras.optimizers.Adam(0.0002, 0.5, clipvalue=1.0, decay=1e-8)
        #optimizer = keras.optimizers.SGD(0.002, momentum=0.8, nesterov=True, clipvalue=1.0, decay=1e-7)
        
        # Build and compile the discriminator
        inp_discriminator, out_discriminator = build_discriminator(self.img_shape)
        self.discriminator = keras.models.Model(inp_discriminator, out_discriminator, name='discriminator')
        self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=[rounded_binary_accuracy])

        # Build the frozen discriminator copy
        self.frozen_discriminator = keras.engine.network.Network(inp_discriminator, out_discriminator, name='frozen_discriminator')
        self.frozen_discriminator.trainable = False

        # Build the generator
        inp_generator, out_generator = build_generator(self.latent_dim, self.channels)
        self.generator = keras.models.Model(inp_generator, out_generator, name='generator')

        # The combined model  (stacked generator and discriminator)
        # Trains the generator to fool the discriminator
        #optimizer = keras.optimizers.Adam(0.0002, 0.5) #, clipvalue=1.0, decay=1e-8)
        
        self.combined = keras.models.Sequential()
        self.combined.add(self.generator)
        self.combined.add(self.frozen_discriminator)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

        #self.discriminator.summary()
        self.generator.summary()
        self.combined.summary()
        
    def train(self, epochs, batch_size=128, save_interval=50, cls=1):

        # Load the dataset
        (X_train, Y_train), (X_test, Y_test) = keras.datasets.cifar10.load_data() #(label_mode='fine')
        
        X_train = np.append(X_train, X_test, axis=0)
        Y_train = np.append(Y_train, Y_test, axis=0)
        
        # 'Airplane', 'Automobile', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck'
        X_train = X_train[Y_train.flatten() == cls]

        # Rescale -1 to 1
        X_train = X_train / 127.5 - 1.0
        
        # Adversarial ground truths
        valid = np.ones((batch_size, 1)) - np.random.uniform(low=0.0, high=0.1, size=(batch_size, 1))
        fake = np.zeros((batch_size, 1)) + np.random.uniform(low=0.0, high=0.1, size=(batch_size, 1))
        valid_fake = np.ones((batch_size, 1))
        
        # History arrays
        loss_discriminator = []
        acc_discriminator = []
        loss_generator = []
        d_losses = []
        d_accs = []
        g_losses = []
        
        self.noise = np.random.normal(0, 1, (2 * 3, self.latent_dim))        

        for epoch in range(0, epochs+1):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random half of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]

            # Sample noise and generate a batch of new images
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            gen_imgs = self.generator.predict(noise)

            # Train the discriminator (real classified as ones and generated as zeros)
            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            # Train the generator (wants discriminator to mistake images as real)
            g_loss = self.combined.train_on_batch(noise, valid_fake)

            # Plot the progress
            d_losses.append(d_loss[0])
            d_accs.append(d_loss[1])
            g_losses.append(g_loss)
                
            # If at save interval => save generated image samples
            if epoch % save_interval == 0:
                self.plot_imgs(epoch)
                self.plot_real(X_train)
                update_progress("Epoch {: 5d} | D loss {:2.2f} D acc {:2.2f} | G loss {:2.2f}".format(
                    epoch, np.mean(d_losses), np.mean(d_accs), np.mean(g_losses)), epoch/epochs)

                if epoch != 0:
                    # Save to history
                    loss_discriminator.append(float(np.mean(d_losses)))
                    acc_discriminator.append(float(np.mean(d_accs)))
                    loss_generator.append(float(np.mean(g_losses)))
                    d_losses = []
                    d_accs = []
                    g_losses = []
                    self.plot_hist(loss_discriminator, loss_generator, acc_discriminator)
        print()
        return {'d_loss': loss_discriminator,
                'd_acc': acc_discriminator,
                'g_loss': loss_generator}

    def plot_real(self, X_train):
        r, c = 2, 3
        imbuf = BytesIO()
        idx = random.sample(range(len(X_train)), r*c)
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                img = X_train[idx[cnt], :,:,:]
                img = 0.5 * img + 0.5
                axs[i,j].imshow(img)
                axs[i,j].axis('off')
                cnt += 1
        #fig.savefig("images/mnist_{:05d}.png".format(epoch))
        fig.savefig(imbuf)
        imbuf.seek(0)           
        self.real.value = imbuf.getvalue()
        plt.close()
            
    def plot_imgs(self, epoch):
        r, c = 2, 3
        #noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        noise = np.asarray([[i/(r*c-1)]*self.latent_dim for i in range(r*c)])
        
        gen_imgs = self.generator.predict(self.noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        imbuf = BytesIO()
        
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,:])
                axs[i,j].axis('off')
                cnt += 1
        #fig.savefig("images/mnist_{:05d}.png".format(epoch))
        fig.savefig(imbuf)
        imbuf.seek(0)           
        self.img.value = imbuf.getvalue()
        plt.close()

    def plot_hist(self, d_loss, g_loss, acc):
        imbuf = BytesIO()

        plt.plot(range(1, len(d_loss)+1), d_loss, linewidth=2)
        plt.plot(range(1, len(g_loss)+1), g_loss, linewidth=2)
        plt.ylabel('D/G loss')
        plt.xlabel('Steps')
        plt.savefig(imbuf)
        imbuf.seek(0)
        self.loss_plt.value = imbuf.getvalue()
        plt.close()
        
        plt.plot(range(1, len(acc)+1), acc, linewidth=2)
        plt.ylabel('Discriminator accuracy')
        plt.xlabel('Steps')
        plt.savefig(imbuf)
        imbuf.seek(0)
        self.acc_plt.value = imbuf.getvalue()
        plt.close()
        

In [None]:
img = ipywidgets.Image()
real = ipywidgets.Image()
loss_plt = ipywidgets.Image()
acc_plt = ipywidgets.Image()
dcgan = DCGAN(img, real, loss_plt, acc_plt)
display(ipywidgets.HBox([ipywidgets.VBox([img, real]), ipywidgets.VBox([loss_plt, acc_plt])]))

In [None]:
# 0: 'Airplane', 1: 'Automobile', 2: 'Bird',  3: 'Cat',  4: 'Deer'
# 5: 'Dog',      6: 'Frog',       7: 'Horse', 8: 'Ship', 9: 'Truck'
history = dcgan.train(epochs=10000, batch_size=32, save_interval=10, cls=3)