In [None]:
import sys
import random
import io

import ipywidgets
import PIL
import bqplot.pyplot

import numpy as np
import tensorflow as tf
import keras
import keras.backend as K

tf.get_logger().setLevel('ERROR')

def update_progress(msg, progress):
    barLength = 32
    status = ""
    block = int(round(barLength*progress))
    text = "\r{0}: [{1}] {2:.2%} {3}".format(msg, "="*(block-1) + ">" + "-"*(barLength-block), progress, status)
    sys.stdout.write(text)
    sys.stdout.flush()

***
# MNIST dataset
***

In [None]:
(x_tr, y_tr), (x_te, y_te) = keras.datasets.mnist.load_data()
x_tr = np.append(x_tr, x_te, axis=0)
y_tr = np.append(y_tr, y_te, axis=0)
x_tr = x_tr[y_tr.flatten() == 3]

In [None]:
r, c, s = 6, 10, 96
digits = x_tr[np.random.randint(x_tr.shape[0], size=r*c), :, :]
canvas = PIL.Image.new('RGB', (c*s+2, r*s+2), color='white')
for i,d in enumerate(digits):
    dimg = PIL.Image.fromarray(255-d).resize((s-8, s-8), resample=PIL.Image.NEAREST)
    canvas.paste(dimg, box=(s*int(i/r), s*(i%r)))

buf = io.BytesIO()
canvas.save(buf, 'gif')
img = ipywidgets.Image(value=buf.getvalue())
display(img)

***
# Build the GAN
***

In [None]:
def build_generator(latent_dim, channels):

    noise = keras.layers.Input(shape=(latent_dim,))

    x = keras.layers.Dense(128 * 3 * 3)(noise)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ELU()(x)
    x = keras.layers.Reshape((3, 3, 128))(x)
    
    x = keras.layers.Conv2DTranspose(64, 3, strides=2, padding='valid')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ELU()(x)

    x = keras.layers.Conv2DTranspose(32, 3, strides=2, padding='same')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ELU()(x)

    x = keras.layers.Conv2DTranspose(16, 3, strides=2, padding='same')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ELU()(x)

    x = keras.layers.Conv2D(channels, kernel_size=3, padding="same")(x)
    img = keras.layers.Activation("tanh")(x)

    return noise, img

In [None]:
def build_discriminator(img_shape):

    img = keras.layers.Input(shape=img_shape)
    x = keras.layers.GaussianNoise(0.05)(img)
    
    x = keras.layers.Conv2D(16, kernel_size=3, strides=1, padding="same")(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ELU()(x)
    x = keras.layers.AveragePooling2D(2)(x)
    
    x = keras.layers.Conv2D(32, kernel_size=3, strides=1, padding="same")(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ELU()(x)
    x = keras.layers.AveragePooling2D(2)(x)

    x = keras.layers.Conv2D(64, kernel_size=3, strides=1, padding="same")(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ELU()(x)
    x = keras.layers.AveragePooling2D(2)(x)
    x = keras.layers.SpatialDropout2D(0.25)(x)

    x = keras.layers.Conv2D(128, kernel_size=3, strides=1, padding="same")(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ELU()(x)
    x = keras.layers.GlobalAveragePooling2D()(x)

    validity = keras.layers.Dense(1, activation='sigmoid')(x)

    return img, validity

In [None]:
def rounded_binary_accuracy(y_true, y_pred):
    return K.mean(K.equal(K.round(y_true), K.round(y_pred)), axis=-1)

class DCGAN():
    def __init__(self):
        # Input shape
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 4
        self.r, self.c = 2,3
        
        optimizer = keras.optimizers.Adam(lr=0.0002, beta_1=0.5, beta_2=0.999)
        
        # Build and compile the discriminator
        inp_discriminator, out_discriminator = build_discriminator(self.img_shape)
        self.discriminator = keras.models.Model(
            inp_discriminator, 
            out_discriminator, 
            name='discriminator')
        self.discriminator.compile(
            loss='binary_crossentropy', 
            optimizer=optimizer, 
            metrics=[rounded_binary_accuracy])

        # Build the frozen discriminator copy
        self.frozen_discriminator = keras.engine.network.Network(
            inp_discriminator, 
            out_discriminator, 
            name='frozen_discriminator')
        self.frozen_discriminator.trainable = False

        # Build the generator
        inp_generator, out_generator = build_generator(self.latent_dim, self.channels)
        self.generator = keras.models.Model(inp_generator, out_generator, name='generator')

        # The combined model  (stacked generator and discriminator)
        # Trains the generator to fool the discriminator

        self.combined = keras.models.Sequential()
        self.combined.add(self.generator)
        self.combined.add(self.frozen_discriminator)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

        self.generator.summary()
        self.discriminator.summary()
        
    def init(self, batch_size=128, cls=3):
        self.cls = cls
        self.batch_size = batch_size
        
        # Load the dataset
        (X_train, Y_train), (X_test, Y_test) = keras.datasets.mnist.load_data()
        X_train = np.append(X_train, X_test, axis=0)
        Y_train = np.append(Y_train, Y_test, axis=0)
        
        X_train = X_train[Y_train.flatten() == cls]

        # Rescale -1 to 1
        X_train = 1. - X_train / 127.5
        self.X_train = np.expand_dims(X_train, axis=3)
        self.Y_train = Y_train
        
        self.noise = np.random.normal(0, 0.25, (self.r * self.c, self.latent_dim))
        self.noise[0] = np.zeros(self.latent_dim)
        
        self.loss_discriminator = []
        self.acc_discriminator = []
        self.acc_real_discriminator = []
        self.acc_fake_discriminator = []
        self.loss_generator = []
        self.epoch = 0
        
        # Initialize plots
        
        self.fake_plt = ipywidgets.Image()
        self.real_plt = ipywidgets.Image()
        
        real_imgs = self.X_train[np.random.randint(self.X_train.shape[0], size=self.r*self.c),:,:,0]
        fake_imgs = self.generator.predict(self.noise)[:,:,:,0]

        self.fake_plt.value = self.plot_images(fake_imgs)
        self.real_plt.value = self.plot_images(real_imgs)
        
        
        axes_loss = {'x': {'label': 'Epochs'}, 
                     'y': {'label': 'Losses', 
                           'label_offset': '50px',
                           'tick_style': {'font-size': 10}
                          }
                    }
        axes_acc = {'x': {'label': 'Epochs'}, 
                    'y': {'label': 'Accuracy', 
                          'label_offset': '50px',
                          'tick_style': {'font-size': 10}
                         }
                   }
        
        self.loss_plt = bqplot.pyplot.figure()
        self.loss_plt.layout.height = '300px'
        self.loss_plt.layout.width = '400px'
        bqplot.pyplot.plot([0,1],[0.0,0.0], axes_options=axes_loss)
        bqplot.pyplot.plot([0,1],[1.0,1.0], colors=['orange'])
        
        self.acc_plt  = bqplot.pyplot.figure()
        self.acc_plt.layout.height = '300px'
        self.acc_plt.layout.width = '400px'
        bqplot.pyplot.scales(scales={'y': bqplot.scales.LinearScale(min=0.0,max=1.0)})
        bqplot.pyplot.plot([0,1],[0.0,1.0], axes_options=axes_acc)

        return self.fake_plt, self.real_plt, self.loss_plt, self.acc_plt
        
    def train(self, epochs, save_interval=50):

        # History arrays
        d_losses = []
        d_accs = []
        d_accs_real = []
        d_accs_fake = []
        g_losses = []              
        start_epoch = self.epoch
        
        for self.epoch in range(self.epoch, start_epoch+epochs+1):

            # Adversarial ground truths
            valid = np.ones((self.batch_size, 1)) # - np.random.uniform(low=0.0, high=0.25, size=(self.batch_size, 1))
            fake = np.zeros((self.batch_size, 1)) # + np.random.uniform(low=0.0, high=0.1, size=(self.batch_size, 1))
            valid_fake = np.ones((self.batch_size, 1))

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random half of images
            idx = np.random.randint(0, self.X_train.shape[0], self.batch_size)
            imgs = self.X_train[idx]

            # Sample noise and generate a batch of new images
            noise = np.random.normal(0, 1, (self.batch_size, self.latent_dim))
            gen_imgs = self.generator.predict(noise)

            # Train the discriminator (real classified as ones and generated as zeros)
            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
            
            d_losses.append(d_loss[0])
            d_accs.append(d_loss[1])
            d_accs_real.append(d_loss_real[1])
            d_accs_fake.append(d_loss_fake[1])            
            
            # ---------------------
            #  Train Generator
            # ---------------------

            # Train the generator (wants discriminator to mistake images as real)
            noise = np.random.normal(0, 1, (self.batch_size, self.latent_dim))
            g_loss = self.combined.train_on_batch(noise, valid_fake)

            # Plot the progress
            g_losses.append(g_loss)
                
            # If at save interval => save generated image samples
            if self.epoch % save_interval == 0:
                real_imgs = self.X_train[np.random.randint(self.X_train.shape[0], size=self.r*self.c),:,:,0]
                fake_imgs = self.generator.predict(self.noise)[:,:,:,0]
                
                self.fake_plt.value = self.plot_images(fake_imgs)
                self.real_plt.value = self.plot_images(real_imgs)
                
                update_progress("Epoch {: 5d} | D loss {:2.2f} D acc {:2.2f} | G loss {:2.2f}".format(
                    self.epoch, np.mean(d_losses), np.mean(d_accs), np.mean(g_losses)), (self.epoch-start_epoch)/epochs)

                if self.epoch != 0:
                    # Save to history
                    self.loss_discriminator.append(float(np.mean(d_losses)))
                    self.acc_discriminator.append(float(np.mean(d_accs)))
                    self.acc_real_discriminator.append(float(np.mean(d_accs_real)))
                    self.acc_fake_discriminator.append(float(np.mean(d_accs_fake)))
                    self.loss_generator.append(float(np.mean(g_losses)))
                    d_losses = []
                    d_accs = []
                    d_accs_real = []
                    d_accs_fake = []
                    g_losses = []
                    self.plot_hist(save_interval, 
                                   self.loss_discriminator, 
                                   self.loss_generator,
                                   self.acc_discriminator)
        print()
        return {'d_loss': self.loss_discriminator,
                'd_acc': self.acc_discriminator,
                'g_loss': self.loss_generator}

    def plot_images(self, images):
        r, c, s = self.r, self.c, 96
        canvas = PIL.Image.new('RGB', (c*s+2, r*s+2), color='white')
        for i,d in enumerate(images):
            darr = ((d+1.0)*127).astype('uint8')
            dimg = PIL.Image.fromarray(darr.astype('uint8')).resize((s-8, s-8), resample=PIL.Image.NEAREST)
            canvas.paste(dimg, box=(s*int(i/r), s*(i%r)))

        buf = io.BytesIO()
        canvas.save(buf, 'gif')
        return buf.getvalue()
    
    def plot_hist(self, si, d_loss, g_loss, d_acc):
        x_data = range(0, si*len(d_loss), si)
        if len(x_data) < 2: return
        
        self.loss_plt.marks[0].x = np.asarray(x_data)
        self.loss_plt.marks[1].x = np.asarray(x_data)
        x_data = range(0, si*len(d_acc), si)
        self.acc_plt.marks[0].x = np.asarray(x_data)
        
        self.loss_plt.marks[0].y = np.asarray(d_loss)
        self.loss_plt.marks[1].y = np.asarray(g_loss)
        self.acc_plt.marks[0].y = np.asarray(d_acc)

***
# Training
***

In [None]:
dcgan = DCGAN()

In [None]:
fake, real, loss_plt, acc_plt = dcgan.init(batch_size=16, cls=3)
digit_box = ipywidgets.VBox([fake, real])
plot_box  = ipywidgets.VBox([loss_plt, acc_plt])
digit_box.layout.justify_content = 'space-around'
plot_box.layout.justify_content  = 'space-around'
display(ipywidgets.HBox([plot_box, digit_box]))

In [None]:
%%time
history = dcgan.train(epochs=2000, save_interval=5)

In [None]:
import json
dcgan.generator.save("mnist-dcgan_generator.h5")
dcgan.discriminator.save("mnist-dcgan_discriminator.h5")
with open('mnist-dcgan.hist','w') as fp:
    json.dump(history, fp)

***
# Generate numbers
***

In [None]:
generator = keras.models.load_model("mnist-dcgan_generator.h5",
    custom_objects={'rounded_binary_accuracy': rounded_binary_accuracy})


In [None]:
sl = [ipywidgets.FloatSlider(value=0.0, min=-1.0, max=1.0, step=0.1) 
      for _ in range(dcgan.latent_dim)]
gimg = ipywidgets.Image()

def plot_sample(latent_vector):
    g_img = dcgan.generator.predict(latent_vector)[0,:,:,0]
    imbuf = io.BytesIO()
    darr = ((g_img+1.0)*127).astype('uint8')
    dimg = PIL.Image.fromarray(darr.astype('uint8')).resize((256, 256), resample=PIL.Image.NEAREST)
    dimg.save(imbuf, 'gif')
    gimg.value = imbuf.getvalue()
    
def slider_changed(msg):
    l_space = np.asarray([[s.value for s in sl]])
    plot_sample(l_space)

for s in sl:
    s.observe(slider_changed, names='value')

sliderbox = ipywidgets.VBox(sl)
display(ipywidgets.HBox([sliderbox, gimg]))
plot_sample(np.asarray([[s.value for s in sl]]))