In [None]:
import sys
import random
import io
import logging

import ipywidgets
import PIL
import bqplot.pyplot

import tensorflow as tf
import tensorflow.keras as keras
import numpy as np
logger = tf.get_logger()
logger.setLevel(logging.ERROR)

print(tf.__version__)

def update_progress(msg, progress):
    barLength = 32
    status = ""
    block = int(round(barLength*progress))
    text = "\r{0}: [{1}] {2:.2%} {3}".format(msg, "="*(block-1) + ">" + "-"*(barLength-block), progress, status)
    sys.stdout.write(text)
    sys.stdout.flush()

***
# Build the GAN
***

In [None]:
def build_generator(latent_dim):
    noise = keras.layers.Input(shape=(latent_dim,))

    x = keras.layers.Dense(128 * 7 * 7)(noise)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ELU()(x)
    x = keras.layers.Reshape((7, 7, 128))(x)

    x = keras.layers.Conv2DTranspose(64, 3, strides=2, padding='same')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ELU()(x)

    x = keras.layers.Conv2DTranspose(32, 3, strides=2, padding='same')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ELU()(x)

    x = keras.layers.Conv2D(1, kernel_size=5, padding='same')(x)
    img = keras.layers.Activation('tanh')(x)

    model = keras.models.Model(inputs=noise, outputs=img)

    return model

def build_discriminator(img_shape):
    img = keras.layers.Input(shape=img_shape)
    x = keras.layers.GaussianNoise(0.05)(img)

    x = keras.layers.Conv2D(32, kernel_size=5, strides=1, padding='same')(img)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ELU()(x)
    x = keras.layers.AveragePooling2D(2)(x)

    x = keras.layers.Conv2D(64, kernel_size=3, strides=1, padding='same')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ELU()(x)
    x = keras.layers.AveragePooling2D(2)(x)

    x = keras.layers.Conv2D(128, kernel_size=3, strides=1, padding='same')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ELU()(x)
    x = keras.layers.GlobalAveragePooling2D()(x)

    x = keras.layers.Dropout(0.25)(x)
    validity = keras.layers.Dense(1, activation='sigmoid')(x)

    model = keras.models.Model(inputs=img, outputs=validity)
    return model

In [None]:
class DCGAN():
    def __init__(self):
        # Input shape
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 4
        self.r, self.c = 2,3

        self.cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
        self.accuracy = tf.keras.metrics.BinaryAccuracy()

        self.optimizer = tf.keras.optimizers.Adam(lr=0.0002, beta_1=0.5, beta_2=0.999)
        
        self.generator = build_generator(self.latent_dim)
        self.discriminator = build_discriminator(self.img_shape)

        self.generator.summary()
        self.discriminator.summary()

    def discriminator_loss(self, real_output, fake_output):
        real_loss = self.cross_entropy(tf.ones_like(real_output), real_output)
        fake_loss = self.cross_entropy(tf.zeros_like(fake_output), fake_output)
        total_loss = (real_loss + fake_loss)/2.0
        return total_loss

    def generator_loss(self, fake_output):
        return self.cross_entropy(tf.ones_like(fake_output), fake_output)

    def discriminator_accuracy(self, real_output, fake_output):
        real_acc = self.accuracy(tf.ones_like(real_output), real_output)
        fake_acc = self.accuracy(tf.zeros_like(fake_output), fake_output)
        total_acc = (real_acc + fake_acc)/2.0
        return total_acc

                                                 
    @tf.function
    def train_step(self, images):
        noise = tf.random.normal([self.batch_size, self.latent_dim], stddev=2.0)

        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            generated_images = self.generator(noise, training=True)

            real_output = self.discriminator(images, training=True)
            fake_output = self.discriminator(generated_images, training=True)

            gen_loss = self.generator_loss(fake_output)
            disc_loss = self.discriminator_loss(real_output, fake_output)
            disc_acc = self.discriminator_accuracy(real_output, fake_output)

        gradients_of_generator = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
        gradients_of_discriminator = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)

        self.optimizer.apply_gradients(zip(gradients_of_generator, self.generator.trainable_variables))
        self.optimizer.apply_gradients(zip(gradients_of_discriminator, self.discriminator.trainable_variables))

        return gen_loss, disc_loss, disc_acc
                                                 
    def init(self, batch_size=128, cls=3):
        self.cls = cls
        self.batch_size = batch_size
                                            
        # Load the dataset
        (X_train, Y_train), (X_test, Y_test) = keras.datasets.mnist.load_data()
        X_train = np.append(X_train, X_test, axis=0)
        Y_train = np.append(Y_train, Y_test, axis=0)
        
        X_train = X_train[Y_train.flatten() == cls]

        # Rescale -1 to 1
        X_train = 1. - X_train / 127.5
        self.X_train = np.expand_dims(X_train, axis=3)
        self.Y_train = Y_train
        
        self.noise = np.random.normal(0, 0.5, (self.r * self.c, self.latent_dim))
        self.noise[0] = np.zeros(self.latent_dim)
        
        self.loss_generator = []
        self.loss_discriminator = []
        self.acc_discriminator = []
        self.epoch = 0
        
        # Initialize plots
        
        self.fake_plt = ipywidgets.Image()
        self.real_plt = ipywidgets.Image()
        
        real_imgs = self.X_train[np.random.randint(self.X_train.shape[0], size=self.r*self.c),:,:,0]
        fake_imgs = self.generator.predict(self.noise)[:,:,:,0]

        self.fake_plt.value = self.plot_images(fake_imgs)
        self.real_plt.value = self.plot_images(real_imgs)
        
        
        axes_loss = {'x': {'label': 'Epochs'}, 
                     'y': {'label': 'Losses', 
                           'label_offset': '50px',
                           'tick_style': {'font-size': 10}
                          }
                    }
        axes_acc = {'x': {'label': 'Epochs'}, 
                    'y': {'label': 'Accuracy', 
                          'label_offset': '50px',
                          'tick_style': {'font-size': 10}
                         }
                   }
        
        self.loss_plt = bqplot.pyplot.figure()
        self.loss_plt.layout.height = '300px'
        self.loss_plt.layout.width = '400px'
        bqplot.pyplot.plot([0,1],[0.0,0.0], axes_options=axes_loss)
        bqplot.pyplot.plot([0,1],[1.0,1.0], colors=['orange'])
        
        self.acc_plt  = bqplot.pyplot.figure()
        self.acc_plt.layout.height = '300px'
        self.acc_plt.layout.width = '400px'
        bqplot.pyplot.scales(scales={'y': bqplot.scales.LinearScale(min=0.0,max=1.0)})
        bqplot.pyplot.plot([0,1],[0.0,1.0], axes_options=axes_acc)

        return self.fake_plt, self.real_plt, self.loss_plt, self.acc_plt
                                     
    def train(self, epochs, save_interval=50):

        # History arrays
        g_loss, d_loss, d_accs = [],[],[]
        start_epoch = self.epoch
        
        for self.epoch in range(self.epoch, start_epoch+epochs+1):
            idx = np.random.randint(0, self.X_train.shape[0], self.batch_size)
            imgs = self.X_train[idx]
            g_loss_step, d_loss_step, d_acc_step = self.train_step(imgs)
            g_loss.append(g_loss_step)
            d_loss.append(d_loss_step)
            d_accs.append(d_acc_step)

                                                 
            # If at save interval => save generated image samples
            if self.epoch % save_interval == 0:
                real_imgs = self.X_train[np.random.randint(self.X_train.shape[0], size=self.r*self.c),:,:,0]
                fake_imgs = self.generator.predict(self.noise)[:,:,:,0]
                
                self.fake_plt.value = self.plot_images(fake_imgs)
                self.real_plt.value = self.plot_images(real_imgs)
                
                update_progress("Epoch {: 5d} | D loss {:2.2f} D acc {:2.2f} | G loss {:2.2f}".format(
                    self.epoch, np.mean(d_loss), np.mean(d_accs), np.mean(g_loss)), (self.epoch-start_epoch)/epochs)

                if self.epoch != 0:
                    # Save to history
                    self.loss_generator.append(float(np.mean(g_loss)))
                    self.loss_discriminator.append(float(np.mean(d_loss)))
                    self.acc_discriminator.append(float(np.mean(d_accs)))
                    g_loss, d_loss, d_accs = [],[],[]
                    self.plot_hist(save_interval, 
                                   self.loss_discriminator, 
                                   self.loss_generator,
                                   self.acc_discriminator)
        print()
        return {'d_loss': self.loss_discriminator,
                'd_acc': self.acc_discriminator,
                'g_loss': self.loss_generator}

    def plot_images(self, images):
        r, c, s = self.r, self.c, 96
        canvas = PIL.Image.new('RGB', (c*s+2, r*s+2), color='white')
        for i,d in enumerate(images):
            darr = ((d+1.0)*127).astype('uint8')
            dimg = PIL.Image.fromarray(darr.astype('uint8')).resize((s-8, s-8), resample=PIL.Image.NEAREST)
            canvas.paste(dimg, box=(s*int(i/r), s*(i%r)))

        buf = io.BytesIO()
        canvas.save(buf, 'gif')
        return buf.getvalue()
    
    def plot_hist(self, si, d_loss, g_loss, d_acc):
        x_data = range(0, si*len(d_loss), si)
        if len(x_data) < 2: return
        
        self.loss_plt.marks[0].x = np.asarray(x_data)
        self.loss_plt.marks[1].x = np.asarray(x_data)
        x_data = range(0, si*len(d_acc), si)
        self.acc_plt.marks[0].x = np.asarray(x_data)
        
        self.loss_plt.marks[0].y = np.asarray(d_loss)
        self.loss_plt.marks[1].y = np.asarray(g_loss)
        self.acc_plt.marks[0].y = np.asarray(d_acc)

In [None]:
dcgan = DCGAN()

In [None]:
fake, real, loss_plt, acc_plt = dcgan.init(batch_size=16, cls=3)
digit_box = ipywidgets.VBox([fake, real])
plot_box  = ipywidgets.VBox([loss_plt, acc_plt])
digit_box.layout.justify_content = 'space-around'
plot_box.layout.justify_content  = 'space-around'
display(ipywidgets.HBox([plot_box, digit_box]))

In [None]:
%%time
history = dcgan.train(epochs=3000, save_interval=5)

In [None]:
sl = [ipywidgets.FloatSlider(value=0.0, min=-1.0, max=1.0, step=0.1) 
      for _ in range(dcgan.latent_dim)]
gimg = ipywidgets.Image()

def plot_sample(latent_vector):
    g_img = dcgan.generator.predict(latent_vector)[0,:,:,0]
    imbuf = io.BytesIO()
    darr = ((g_img+1.0)*127).astype('uint8')
    dimg = PIL.Image.fromarray(darr.astype('uint8')).resize((256, 256), resample=PIL.Image.NEAREST)
    dimg.save(imbuf, 'gif')
    gimg.value = imbuf.getvalue()
    
def slider_changed(msg):
    l_space = np.asarray([[s.value for s in sl]])
    plot_sample(l_space)

for s in sl:
    s.observe(slider_changed, names='value')

sliderbox = ipywidgets.VBox(sl)
display(ipywidgets.HBox([sliderbox, gimg]))
plot_sample(np.asarray([[s.value for s in sl]]))