In [None]:
import sys
import random
import io

import ipywidgets
import PIL
import bqplot.pyplot

import numpy as np
import tensorflow as tf
import keras
import keras.backend as K

tf.logging.set_verbosity(tf.logging.ERROR)

def update_progress(msg, progress):
    barLength = 32
    status = ""
    block = int(round(barLength*progress))
    text = "\r{0}: [{1}] {2:.2%} {3}".format(msg, "="*(block-1) + ">" + "-"*(barLength-block), progress, status)
    sys.stdout.write(text)
    sys.stdout.flush()

***
# CIFAR dataset

### CIFAR 100 fine dataset

|Class|N|Class|N|Class|N|Class|N|Class|N|
|--------|-|----------|-|----|-|---|-|----|-|
|Apple|0|Aquarium fish|1|Baby|2|Bear|3|Beaver|4|
|Bed|5|Bee|6|Beetle|7|Bicycle|8|Bottle|9|
|Bowl|10|Boy|11|Bridge|12|Bus|13|Butterfly|14|
|Camel|15|Can|16|Castle|17|Caterpillar|18|Cattle|19|
|Chair|20|Chimpanzee|21|Clock|22|Cloud|23|Cockroach|24|
|Couch|25|Crab|26|Crocodile|27|Cups|28|Dinosaur|29|
|Dolphin|30|Elephant|31|Flatfish|32|Forest|33|Fox|34|
|Girl|35|Hamster|36|House|37|Kangaroo|38|Computer keyboard|39|
|Lamp|40|Lawn-mower|41|Leopard|42|Lion|43|Lizard|44|
|Lobster|45|Man|46|Maple|47|Motorcycle|48|Mountain|49|
|Mouse|50|Mushrooms|51|Oak|52|Oranges|53|Orchids|54|
|Otter|55|Palm|56|Pears|57|Pickup truck|58|Pine|59|
|Plain|60|Plates|61|Poppies|62|Porcupine|63|Possum|64|
|Rabbit|65|Raccoon|66|Ray|67|Road|68|Rocket|69|
|Roses|70|Sea|71|Seal|72|Shark|73|Shrew|74|
|Skunk|75|Skyscraper|76|Snail|77|Snake|78|Spider|79|
|Squirrel|80|Streetcar|81|Sunflowers|82|Sweet peppers|83|Table|84|
|Tank|85|Telephone|86|Television|87|Tiger|88|Tractor|89|
|Train|90|Trout|91|Tulips|92|Turtle|93|Wardrobe|94|
|Whale|95|Willow|96|Wolf|97|Woman|98|Worm|99|

### CIFAR 100 coarse dataset

|Class|N|Class|N|Class|N|Class|N|
|--------|-|----------|-|----|-|---|-|
|Aquatic mammal|0|Fish|1|Flower|2|Food container|3|
|Fruit or vegetable|4|Household electrical device|5|Household furniture|6|Insect|7|
|Large carnivore|8|Large man-made outdoor thing|9|Large natural outdoor scene|10|Large omnivore or herbivore|11|
|Medium-sized mammal|12|Non-insect invertebrate|13|People|14|Reptile|15|
|Small mammal|16|Tree|17|Vehicles Set 1|18|Vehicles Set 2|19|

### CIFAR 10 dataset

|Class|N|Class|N|Class|N|Class|N|Class|N|
|--------|-|----------|-|----|-|---|-|----|-|
|Airplane|0|Automobile|1|Bird|2|Cat|3|Deer|4|
|Dog|5|Frog|6|Horse|7|Ship|8|Truck|9|


***

In [None]:
#(x_tr, y_tr), (x_te, y_te) = keras.datasets.cifar10.load_data()
(x_tr, y_tr), (x_te, y_te) = keras.datasets.cifar100.load_data(label_mode='coarse')
x_tr = np.append(x_tr, x_te, axis=0)
y_tr = np.append(y_tr, y_te, axis=0)
x_tr = x_tr[y_tr.flatten() == 2]

In [None]:
r, c, s = 6, 8, 72
pics = x_tr[np.random.randint(x_tr.shape[0], size=r*c)]
canvas = PIL.Image.new('RGB', (c*s+2, r*s+2), color='white')
for i,d in enumerate(pics):
    dimg = PIL.Image.fromarray(d).resize((s-8, s-8))
    canvas.paste(dimg, box=(s*int(i/r), s*(i%r)))

buf = io.BytesIO()
canvas.save(buf, 'gif')
img = ipywidgets.Image(value=buf.getvalue())
display(img)

***
# Build the GAN
***

In [None]:
def build_discriminator(img_shape):

    img = keras.layers.Input(shape=img_shape)
    x = keras.layers.GaussianNoise(0.05)(img)
    
    x = keras.layers.Conv2D(32, kernel_size=5, strides=1, padding="same")(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ELU()(x)
    x = keras.layers.AveragePooling2D(2)(x)

    x = keras.layers.Conv2D(64, kernel_size=3, strides=1, padding="same")(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ELU()(x)
    x = keras.layers.AveragePooling2D(2)(x)

    x = keras.layers.Conv2D(128, kernel_size=3, strides=1, padding="same")(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ELU()(x)
    x = keras.layers.GlobalAveragePooling2D()(x)

    x = keras.layers.Dropout(0.25)(x)
    validity = keras.layers.Dense(1, activation='sigmoid')(x)

    return img, validity

In [None]:
def build_generator(latent_dim, channels):

    noise = keras.layers.Input(shape=(latent_dim,))
    inpsize = 8
    
    x = keras.layers.Dense(128 * inpsize * inpsize)(noise)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ELU()(x)
    x = keras.layers.Reshape((inpsize, inpsize, 128))(x)
        
    x = keras.layers.Conv2DTranspose(64, 3, strides=2, padding='same')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ELU()(x)

    x = keras.layers.Conv2DTranspose(32, 3, strides=2, padding='same')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ELU()(x)

    x = keras.layers.Conv2D(channels, kernel_size=5, padding="same")(x)
    img = keras.layers.Activation("tanh")(x)

    return noise, img

In [None]:
def rounded_binary_accuracy(y_true, y_pred):
    return K.mean(K.equal(K.round(y_true), K.round(y_pred)), axis=-1)

class DCGAN():
    def __init__(self):
        # Input shape
        self.img_rows = 32
        self.img_cols = 32
        self.channels = 3
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 16
        self.r, self.c = 3, 4
        self.i_discriminator = 1
        
        optimizer = keras.optimizers.Adam(lr=0.0002, beta_1=0.5, beta_2=0.999)
        
        # Build and compile the discriminator
        inp_discriminator, out_discriminator = build_discriminator(self.img_shape)
        self.discriminator = keras.models.Model(
            inp_discriminator, 
            out_discriminator, 
            name='discriminator')
        self.discriminator.compile(
            loss='binary_crossentropy', 
            optimizer=optimizer, 
            metrics=[rounded_binary_accuracy])

        # Build the frozen discriminator copy
        self.frozen_discriminator = keras.engine.network.Network(
            inp_discriminator, 
            out_discriminator, 
            name='frozen_discriminator')
        self.frozen_discriminator.trainable = False

        # Build the generator
        inp_generator, out_generator = build_generator(self.latent_dim, self.channels)
        self.generator = keras.models.Model(inp_generator, out_generator, name='generator')

        # The combined model  (stacked generator and discriminator)
        # Trains the generator to fool the discriminator

        self.combined = keras.models.Sequential()
        self.combined.add(self.generator)
        self.combined.add(self.frozen_discriminator)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

        self.generator.summary()
        self.discriminator.summary()
        

    def init(self, batch_size=128, cls=1, dataset='cifar10', mode='fine'):
        self.cls = cls
        self.batch_size = batch_size

        # Load the dataset
        
        if dataset == 'cifar100':
            (X_train, Y_train), (X_test, Y_test) = keras.datasets.cifar100.load_data(label_mode=mode)
        else:
            (X_train, Y_train), (X_test, Y_test) = keras.datasets.cifar10.load_data()
        
        X_train = np.append(X_train, X_test, axis=0)
        self.Y_train = np.append(Y_train, Y_test, axis=0)
        self.X_train = X_train[self.Y_train.flatten() == self.cls]
        
        self.real_generator = keras.preprocessing.image.ImageDataGenerator(
            preprocessing_function = lambda i:i/127.5-1.0,
            #fill_mode='constant',
            #cval=255.0,
            #zoom_range=0.25,
            horizontal_flip=True
        )
        self.real_flow = self.real_generator.flow(self.X_train, batch_size=batch_size)
                
        self.noise = np.random.normal(0, 0.5, (self.r * self.c, self.latent_dim))
        self.noise[0] = np.zeros(self.latent_dim)
        
        self.loss_discriminator = []
        self.acc_discriminator = []
        self.loss_generator = []
        self.epoch = 0
        
        # Initialize plots
        
        self.fake_plt = ipywidgets.Image()
        self.real_plt = ipywidgets.Image()
        
        #real_imgs = self.X_train[np.random.randint(self.X_train.shape[0], size=self.r*self.c),:,:,:]
        real_imgs = next(self.real_flow)[:self.r*self.c]
        fake_imgs = self.generator.predict(self.noise)[:,:,:,:]

        self.fake_plt.value = self.plot_images(fake_imgs)
        self.real_plt.value = self.plot_images(real_imgs)
        
        
        axes_loss = {'x': {'label': 'Epochs'}, 
                     'y': {'label': 'Losses', 
                          'label_offset': '50px',
                          'tick_style': {'font-size': 10}}}
        axes_acc = {'x': {'label': 'Epochs'}, 
                    'y': {'label': 'Accuracy', 
                         'label_offset': '50px',
                          'tick_style': {'font-size': 10}}}
        
        self.loss_plt = bqplot.pyplot.figure(min_aspect_ratio=4/3, max_aspect_ratio=4/3)
        bqplot.pyplot.plot([0,1],[0.5,0.5], axes_options=axes_loss)
        bqplot.pyplot.plot([0,1],[0.75,0.75], colors=['orange'])
        self.acc_plt  = bqplot.pyplot.figure(min_aspect_ratio=4/3, max_aspect_ratio=4/3)
        bqplot.pyplot.scales(scales={'y': bqplot.scales.LinearScale(min=0.0,max=1.0)})
        bqplot.pyplot.plot([0,1],[0.5,0.5], axes_options=axes_acc)

        return self.fake_plt, self.real_plt, self.loss_plt, self.acc_plt
        
    def train(self, epochs, save_interval=50):

        # History arrays
        d_losses = []
        d_accs = []
        g_losses = []              
        start_epoch = self.epoch
        
        for self.epoch in range(self.epoch, start_epoch+epochs+1):

            # Adversarial ground truths
            fake = np.zeros((self.batch_size, 1)) # + np.random.uniform(low=0.0, high=0.1, size=(self.batch_size, 1))
            valid_fake = np.ones((self.batch_size, 1))

            # ---------------------
            #  Train Discriminator
            # ---------------------

            for _ in range(self.i_discriminator):
                
                # Select a random half of images
                #idx = np.random.randint(0, self.X_train.shape[0], self.batch_size)
                #imgs = self.X_train[idx]
                imgs = next(self.real_flow)
                valid = np.ones((len(imgs), 1)) - np.random.uniform(low=0.0, high=0.25, size=(len(imgs), 1))
            
                # Sample noise and generate a batch of new images
                noise = np.random.normal(0, 1, (self.batch_size, self.latent_dim))
                gen_imgs = self.generator.predict(noise)

                # Train the discriminator (real classified as ones and generated as zeros)
                d_loss_real = self.discriminator.train_on_batch(imgs, valid)
                d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
            
                d_losses.append(d_loss[0])
                d_accs.append(d_loss[1])
            
            # ---------------------
            #  Train Generator
            # ---------------------

            # Train the generator (wants discriminator to mistake images as real)
            noise = np.random.normal(0, 1, (self.batch_size, self.latent_dim))
            g_loss = self.combined.train_on_batch(noise, valid_fake)

            # Plot the progress
            g_losses.append(g_loss)
                
            # If at save interval => save generated image samples
            if self.epoch % save_interval == 0:
                real_imgs = next(self.real_flow)[:self.r*self.c] #self.X_train[np.random.randint(self.X_train.shape[0], size=self.r*self.c)]
                fake_imgs = self.generator.predict(self.noise)
                
                self.fake_plt.value = self.plot_images(fake_imgs)
                self.real_plt.value = self.plot_images(real_imgs)
                
                update_progress("Epoch {: 5d} | D loss {:2.2f} D acc {:2.2f} | G loss {:2.2f}".format(
                    self.epoch, np.mean(d_losses), np.mean(d_accs), np.mean(g_losses)), (self.epoch-start_epoch)/epochs)

                if self.epoch != 0:
                    # Save to history
                    self.loss_discriminator.append(float(np.mean(d_losses)))
                    self.acc_discriminator.append(float(np.mean(d_accs)))
                    self.loss_generator.append(float(np.mean(g_losses)))
                    d_losses = []
                    d_accs = []
                    g_losses = []
                    self.plot_hist(save_interval, self.loss_discriminator, self.loss_generator, self.acc_discriminator)
        print()
        return {'d_loss': self.loss_discriminator,
                'd_acc': self.acc_discriminator,
                'g_loss': self.loss_generator}

    def plot_images(self, images):
        s = 72
        canvas = PIL.Image.new('RGB', (self.c*s+2, self.r*s+2), color='white')
        for i,d in enumerate(images):
            darr = ((d+1.0)*127).astype('uint8')
            dimg = PIL.Image.fromarray(darr.astype('uint8')).resize((s-8, s-8))
            canvas.paste(dimg, box=(s*int(i/self.r), s*(i%self.r)))

        buf = io.BytesIO()
        canvas.save(buf, 'gif')
        return buf.getvalue()
    
    def plot_hist(self, si, d_loss, g_loss, acc):
        x_data = range(1, si*len(d_loss)+1, si)
        self.loss_plt.marks[0].x = x_data
        self.loss_plt.marks[0].y = d_loss
        self.loss_plt.marks[1].x = x_data
        self.loss_plt.marks[1].y = g_loss
        
        x_data = range(1, si*len(acc)+1, si)
        self.acc_plt.marks[0].x = x_data
        self.acc_plt.marks[0].y = acc
        

***
# Training
***

In [None]:
dcgan = DCGAN()

In [None]:
fake, real, loss_plt, acc_plt = dcgan.init(batch_size=32, cls=7, dataset='cifar10', mode='coarse')
digit_box = ipywidgets.HBox([fake, real])
plot_box  = ipywidgets.HBox([loss_plt, acc_plt])
digit_box.layout.justify_content = 'space-around'
plot_box.layout.justify_content  = 'space-around'
display(ipywidgets.VBox([digit_box, plot_box]))

In [None]:
%%time
history = dcgan.train(epochs=10000, save_interval=5)

In [None]:
import json
dcgan.generator.save("cifar-dcgan_generator.h5")
dcgan.discriminator.save("cifar-dcgan_discriminator.h5")
with open('cifar-dcgan.hist','w') as fp:
    json.dump(history, fp)

***
# Generate pictures
***

In [None]:
generator = keras.models.load_model("cifar-dcgan_generator.h5",
    custom_objects={'rounded_binary_accuracy': rounded_binary_accuracy})


In [None]:
sl = [ipywidgets.FloatSlider(value=0.0, min=-1.0, max=1.0, step=0.1) 
      for _ in range(dcgan.latent_dim)]
gimg = ipywidgets.Image()

def plot_sample(latent_vector):
    g_img = dcgan.generator.predict(latent_vector)[0,:,:,:]
    imbuf = io.BytesIO()
    darr = ((g_img+1.0)*127).astype('uint8')
    dimg = PIL.Image.fromarray(darr.astype('uint8')).resize((256, 256))
    dimg.save(imbuf, 'gif')
    gimg.value = imbuf.getvalue()
    
def slider_changed(msg):
    l_space = np.asarray([[s.value for s in sl]])
    plot_sample(l_space)

for s in sl:
    s.observe(slider_changed, names='value')

sliderbox = ipywidgets.VBox(sl)
picbox = ipywidgets.VBox([gimg])
display(ipywidgets.HBox([sliderbox, picbox]))
plot_sample(np.asarray([[s.value for s in sl]]))