In [None]:
import sys
import io
import logging

import ipywidgets
import PIL

import tensorflow as tf
import tensorflow.keras as keras
import numpy as np
logger = tf.get_logger()
logger.setLevel(logging.ERROR)

print(tf.__version__)

def update_progress(msg, progress):
    barLength, status = 32, ""
    block = int(round(barLength*progress))
    text = "\r{0}: [{1}] {2:.2%} {3}".format(msg, "="*(block-1) + ">" + "-"*(barLength-block), progress, status)
    sys.stdout.write(text)
    sys.stdout.flush()

In [None]:
class Sampling(keras.layers.Layer):
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

In [None]:
def build_encoder(latent_dim):
    encoder_inputs = keras.Input(shape=(28, 28, 1))
    x = keras.layers.Conv2D(32, 3, activation='relu', strides=2, padding='same')(encoder_inputs)
    x = keras.layers.Conv2D(64, 3, activation='relu', strides=2, padding='same')(x)
    x = keras.layers.Flatten()(x)
    x = keras.layers.Dense(16, activation='relu')(x)
    z_mean = keras.layers.Dense(latent_dim, name='z_mean')(x)
    z_log_var = keras.layers.Dense(latent_dim, name='z_log_var')(x)
    z = Sampling()([z_mean, z_log_var])
    encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name='encoder')
    return encoder

In [None]:
def build_decoder(latent_dim):
    latent_inputs = keras.Input(shape=(latent_dim,))
    x = keras.layers.Dense(7 * 7 * 64, activation='relu')(latent_inputs)
    x = keras.layers.Reshape((7, 7, 64))(x)
    x = keras.layers.Conv2DTranspose(64, 3, activation='relu', strides=2, padding='same')(x)
    x = keras.layers.Conv2DTranspose(32, 3, activation='relu', strides=2, padding='same')(x)
    decoder_outputs = keras.layers.Conv2DTranspose(1, 3, activation='sigmoid', padding='same')(x)
    decoder = keras.Model(latent_inputs, decoder_outputs, name='decoder')
    return decoder

In [None]:
class VAE(keras.Model):

    def __init__(self, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.latent_dim = 2
        self.n = 15
        self.digit_size = 28
        self.z_sample = [[xi,yi] for xi in np.linspace(-4, 4, self.n) for yi in np.linspace(-4, 4, self.n)[::-1]]
        
    def init(self, batch_size=128):
        self.batch_size = batch_size
        (x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
        mnist_digits = np.concatenate([x_train, x_test], axis=0)
        mnist_digits = np.expand_dims(mnist_digits, -1).astype('float32') / 255
        self.dataset = tf.data.Dataset.from_tensor_slices(mnist_digits).batch(self.batch_size)
        
        self.encoder = build_encoder(self.latent_dim)
        self.decoder = build_decoder(self.latent_dim)
        
        self.optimizer = keras.optimizers.Adam()
        
        self.digit_box = ipywidgets.Image()
        self.digit_box.value = self.plot_images()
        return self.digit_box
    
    @tf.function
    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            reconstruction_loss = tf.reduce_mean(keras.losses.binary_crossentropy(data, reconstruction))
            reconstruction_loss *= 28 * 28
            kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
            kl_loss = tf.reduce_mean(kl_loss)
            kl_loss *= -0.5
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        return {'loss': total_loss,
                'reconstruction_loss': reconstruction_loss,
                'kl_loss': kl_loss}

    def train(self, epochs, save_interval=10):
        for e in range(epochs):
            total_loss, recon_loss, kl_loss = [], [], []
            counter = 0
            for batch in self.dataset:
                losses = self.train_step(batch)
                total_loss.append(np.mean(losses['loss']))
                recon_loss.append(np.mean(losses['reconstruction_loss']))
                kl_loss.append(np.mean(losses['kl_loss']))
                update_progress("Epoch {: 5d} | Losses: Total {: 6.2f} Reconstruction {: 6.2f} KL {: 6.4f}".format(e,
                    np.mean(total_loss),
                    np.mean(recon_loss),
                    np.mean(kl_loss)), self.batch_size * float(counter)/70000)
                if counter % save_interval == 0:
                    self.digit_box.value = self.plot_images()
                counter += 1
            print()
    
    def plot_images(self):
        canvas = PIL.Image.new('RGB', (self.n*self.digit_size, self.n*self.digit_size), color='white')
        x_decoded = 255 * self.decoder.predict(self.z_sample)[:,:,:,0]
        for i,d in enumerate(x_decoded):
            dimg = PIL.Image.fromarray(d.astype('uint8')).resize((self.digit_size, self.digit_size), resample=PIL.Image.NEAREST)
            canvas.paste(dimg, box=(self.digit_size*int(i/self.n), self.digit_size*int(i%self.n)))

        buf = io.BytesIO()
        canvas.save(buf, 'gif')
        return buf.getvalue()
    

In [None]:
vae = VAE()

In [None]:
digit_box = vae.init(batch_size=256)
display(digit_box)
vae.train(epochs=30)