In [None]:
import sys
import random
import ipywidgets
from io import BytesIO
import matplotlib.pyplot as plt
import skimage
import numpy as np
import tensorflow as tf
import keras
import keras.backend as K

tf.logging.set_verbosity(tf.logging.ERROR)

def update_progress(msg, progress):
    barLength = 32
    status = ""
    block = int(round(barLength*progress))
    text = "\r{0}: [{1}] {2:.2%} {3}".format(msg, "="*(block-1) + ">" + "-"*(barLength-block), progress, status)
    sys.stdout.write(text)
    sys.stdout.flush()

In [None]:
(X_train, Y_train), (X_test, Y_test) = keras.datasets.mnist.load_data()
X_train = np.append(X_train, X_test, axis=0)
Y_train = np.append(Y_train, Y_test, axis=0)
X_train = 1. - X_train / 127.5
X_train = np.expand_dims(X_train, axis=3)

r, c = 6, 6
idx = random.sample(range(len(X_train)), r*c)
plt.figure(figsize=(16,16))
cnt=0
for i in range(r):
    for j in range(c):
        plt.subplot(r,c,cnt+1)
        plt.imshow(0.5*X_train[idx[cnt], :,:,0]+0.5, cmap='gray')
        plt.axis('off')
        cnt += 1
plt.show()

In [None]:
def build_discriminator(img_shape):

    img = keras.layers.Input(shape=img_shape)
    x = keras.layers.GaussianNoise(0.05)(img)
    
    x = keras.layers.Conv2D(32, kernel_size=3, strides=1, padding="same")(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ELU()(x)
    x = keras.layers.AveragePooling2D(2)(x)

    x = keras.layers.Conv2D(64, kernel_size=3, strides=1, padding="same")(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ELU()(x)
    x = keras.layers.AveragePooling2D(2)(x)

    x = keras.layers.Conv2D(128, kernel_size=3, strides=1, padding="same")(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ELU()(x)
    x = keras.layers.AveragePooling2D(2)(x)

    x = keras.layers.Dropout(0.25)(x)
    x = keras.layers.Flatten()(x)
    validity = keras.layers.Dense(1, activation='sigmoid')(x)

    return img, validity

In [None]:
def build_generator(latent_dim, channels):

    noise = keras.layers.Input(shape=(latent_dim,))

    x = keras.layers.Dense(128 * 7 * 7)(noise)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ELU()(x)
    x = keras.layers.Reshape((7, 7, 128))(x)
    
    x = keras.layers.Conv2DTranspose(64, 3, strides=2, padding='same')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ELU()(x)

    x = keras.layers.Conv2DTranspose(32, 3, strides=2, padding='same')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ELU()(x)

    x = keras.layers.Conv2D(channels, kernel_size=5, padding="same")(x)
    img = keras.layers.Activation("tanh")(x)

    return noise, img


In [None]:
def rounded_binary_accuracy(y_true, y_pred):
    return K.mean(K.equal(K.round(y_true), K.round(y_pred)), axis=-1)

class DCGAN():
    def __init__(self, img, real, loss_plt, acc_pls):
        # Input shape
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 4

        self.img = img
        self.real = real
        self.loss_plt = loss_plt
        self.acc_plt = acc_plt
        
        #optimizer = keras.optimizers.Adam(0.0002, 0.5, clipvalue=1.0, decay=1e-8)
        optimizer = keras.optimizers.Adam(lr=0.00025, beta_1=0.5, decay=1e-8)
        # Build and compile the discriminator
        inp_discriminator, out_discriminator = build_discriminator(self.img_shape)
        self.discriminator = keras.models.Model(inp_discriminator, out_discriminator, name='discriminator')
        self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=[rounded_binary_accuracy])

        # Build the frozen discriminator copy
        self.frozen_discriminator = keras.engine.network.Network(inp_discriminator, out_discriminator, name='frozen_discriminator')
        self.frozen_discriminator.trainable = False

        # Build the generator
        inp_generator, out_generator = build_generator(self.latent_dim, self.channels)
        self.generator = keras.models.Model(inp_generator, out_generator, name='generator')

        # The combined model  (stacked generator and discriminator)
        # Trains the generator to fool the discriminator

        self.combined = keras.models.Sequential()
        self.combined.add(self.generator)
        self.combined.add(self.frozen_discriminator)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

        #self.generator.summary()
        #self.discriminator.summary()
        #self.combined.summary()
        
    def init(self, batch_size=128, cls=3):
        self.cls = cls
        self.batch_size = batch_size
        
        # Load the dataset
        (X_train, Y_train), (X_test, Y_test) = keras.datasets.mnist.load_data()
        X_train = np.append(X_train, X_test, axis=0)
        Y_train = np.append(Y_train, Y_test, axis=0)
        
        X_train = X_train[Y_train.flatten() == cls]

        # Rescale -1 to 1
        X_train = 1. - X_train / 127.5
        self.X_train = np.expand_dims(X_train, axis=3)
        self.Y_train = Y_train
        
        self.noise = np.random.normal(0, 1, (2 * 3, self.latent_dim))

        self.loss_discriminator = []
        self.acc_discriminator = []
        self.loss_generator = []
        self.epoch = 0
        
        self.plot_imgs()
        self.plot_real(self.X_train)
        self.plot_hist([0.5,0.5], [0.5,0.5], [0.5,0.5])
        plt.close('all')

    def train(self, epochs, save_interval=50):

        # Adversarial ground truths
        valid = np.ones((self.batch_size, 1)) - np.random.uniform(low=0.0, high=0.1, size=(self.batch_size, 1))
        fake = np.zeros((self.batch_size, 1)) + np.random.uniform(low=0.0, high=0.1, size=(self.batch_size, 1))
        valid_fake = np.ones((self.batch_size, 1))

        # History arrays
        d_losses = []
        d_accs = []
        g_losses = []              
        start_epoch = self.epoch
        
        for self.epoch in range(self.epoch, start_epoch+epochs+1):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random half of images
            idx = np.random.randint(0, self.X_train.shape[0], self.batch_size)
            imgs = self.X_train[idx]

            # Sample noise and generate a batch of new images
            noise = np.random.normal(0, 1, (self.batch_size, self.latent_dim))
            gen_imgs = self.generator.predict(noise)

            # Train the discriminator (real classified as ones and generated as zeros)
            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            # Train the generator (wants discriminator to mistake images as real)
            g_loss = self.combined.train_on_batch(noise, valid_fake)

            # Plot the progress
            d_losses.append(d_loss[0])
            d_accs.append(d_loss[1])
            g_losses.append(g_loss)
                
            # If at save interval => save generated image samples
            if self.epoch % save_interval == 0:
                self.plot_imgs()
                self.plot_real(self.X_train)
                update_progress("Epoch {: 5d} | D loss {:2.2f} D acc {:2.2f} | G loss {:2.2f}".format(
                    self.epoch, np.mean(d_losses), np.mean(d_accs), np.mean(g_losses)), (self.epoch-start_epoch)/epochs)

                if self.epoch != 0:
                    # Save to history
                    self.loss_discriminator.append(float(np.mean(d_losses)))
                    self.acc_discriminator.append(float(np.mean(d_accs)))
                    self.loss_generator.append(float(np.mean(g_losses)))
                    d_losses = []
                    d_accs = []
                    g_losses = []
                    self.plot_hist(self.loss_discriminator, self.loss_generator, self.acc_discriminator)
                plt.close('all')
        print()
        return {'d_loss': self.loss_discriminator,
                'd_acc': self.acc_discriminator,
                'g_loss': self.loss_generator}

    def plot_real(self, X_train):
        r, c = 2, 3
        imbuf = BytesIO()
        idx = random.sample(range(len(X_train)), r*c)
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                img = X_train[idx[cnt], :,:,0]
                img = 0.5 * img + 0.5
                axs[i,j].imshow(img, cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        #fig.savefig("images/mnist_{:05d}.png".format(epoch))
        fig.savefig(imbuf)
        imbuf.seek(0)           
        self.real.value = imbuf.getvalue()
        plt.clf()
        plt.cla()
            
    def plot_imgs(self):
        r, c = 2, 3
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        #noise = np.asarray([[i/(r*c-1)]*self.latent_dim for i in range(r*c)])
        
        gen_imgs = self.generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        imbuf = BytesIO()
        
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        #fig.savefig("images/mnist_{:05d}.png".format(epoch))
        fig.savefig(imbuf)
        imbuf.seek(0)           
        self.img.value = imbuf.getvalue()
        plt.clf()
        plt.cla()

    def plot_hist(self, d_loss, g_loss, acc):
        imbuf = BytesIO()

        plt.plot(range(1, len(d_loss)+1), d_loss, linewidth=2)
        plt.plot(range(1, len(g_loss)+1), g_loss, linewidth=2)
        plt.ylabel('D/G loss')
        plt.xlabel('Steps')
        plt.savefig(imbuf)
        imbuf.seek(0)
        self.loss_plt.value = imbuf.getvalue()
        plt.clf()
        plt.cla()
        
        plt.plot(range(1, len(acc)+1), acc, linewidth=2)
        plt.ylabel('Discriminator accuracy')
        plt.xlabel('Steps')
        plt.savefig(imbuf)
        imbuf.seek(0)
        self.acc_plt.value = imbuf.getvalue()
        plt.clf()
        plt.cla()
        

In [None]:
img = ipywidgets.Image()
real = ipywidgets.Image()
loss_plt = ipywidgets.Image()
acc_plt = ipywidgets.Image()
dcgan = DCGAN(img, real, loss_plt, acc_plt)

In [None]:
display(ipywidgets.HBox([ipywidgets.VBox([img, real]), ipywidgets.VBox([loss_plt, acc_plt])]))
dcgan.init(batch_size=64, cls=3)

In [None]:
history = dcgan.train(epochs=1000, save_interval=5)

In [None]:
sl1 = ipywidgets.FloatSlider(value=0.5, min=0.0, max=1.0, step=0.05)
sl2 = ipywidgets.FloatSlider(value=0.5, min=0.0, max=1.0, step=0.05)
sl3 = ipywidgets.FloatSlider(value=0.5, min=0.0, max=1.0, step=0.05)
sl4 = ipywidgets.FloatSlider(value=0.5, min=0.0, max=1.0, step=0.05)
gimg = ipywidgets.Image(width=320)

def plot_sample(latent_vector):
    g_img = dcgan.generator.predict(latent_vector)
    imbuf = BytesIO()
    plt.figure(figsize=(6,6))
    plt.imshow(g_img[0,:,:,0], cmap='gray')
    plt.savefig(imbuf)
    imbuf.seek(0)
    gimg.value = imbuf.getvalue()
    plt.clf()
    plt.cla()
    plt.close()
    
def slider_changed(msg):
    l_space = np.asarray([[sl1.value, sl2.value, sl3.value, sl4.value]])
    plot_sample(l_space)
    
sl1.observe(slider_changed, names='value')
sl2.observe(slider_changed, names='value')
sl3.observe(slider_changed, names='value')
sl4.observe(slider_changed, names='value')

display(ipywidgets.HBox([ipywidgets.VBox([sl1,sl2,sl3,sl4]), gimg]))
plot_sample(np.asarray([[sl1.value, sl2.value, sl3.value, sl4.value]]))