In [6]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, Input, Flatten, Dense, Lambda, Reshape, MaxPool2D, BatchNormalization ,concatenate
from sklearn.preprocessing import OneHotEncoder

In [3]:
(f_tra, l_tra), (f_tes, l_tes) = keras.datasets.mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [4]:
img_width, img_height = f_tra.shape[1], f_tra.shape[2]
batch_size = 128
no_epochs = 50
validation_split = 0.2
verbosity = 1
latent_dim = 2
num_channels = 1
num_class = 10

In [31]:
f_tra = (f_tra.astype('float32') / 255).reshape(f_tra.shape[0], img_height, img_width, num_channels)
f_tes = (f_tes.astype('float32') / 255).reshape(f_tes.shape[0], img_height, img_width, num_channels)

In [12]:
class Sampling(layers.Layer):
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

In [58]:
i      = Input(shape=(28*28 + 10), name='encoder_input') #(28, 28, 1)
# ii      = Input(shape=(10), name='label_input') #(28, 28, 1)
# x      = Conv2D(filters=4, kernel_size=5, strides=1, activation='relu', name='layer1')(i) #(24, 24, 4)
# x      = BatchNormalization()(x)

# x      = Conv2D(filters=8, kernel_size=5, strides=1, activation='relu')(x) #(20, 20, 8)
# x      = BatchNormalization()(x)

# x      = Conv2D(filters=16, kernel_size=5, strides=1, activation='relu')(x) #(16, 16, 16)
# x      = BatchNormalization()(x)

# x      = Flatten()(x) #(16*16*16, )
# x      = concatenate([x,ii])
x      = Dense(150)(i)
x      = Dense(120)(x) #(16*16, )
x      = Dense(50)(x)
# x      = BatchNormalization()(x)

mu     = Dense(latent_dim, name='z_mu')(x) #(2, )
sigma  = Dense(latent_dim, name='z_log_sigma')(x) #(2, )
z      = Sampling()([mu, sigma])

encoder = keras.Model([i,ii], [mu, sigma, z], name="encoder")
encoder.summary()

Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
encoder_input (InputLayer)      [(None, 794)]        0                                            
__________________________________________________________________________________________________
dense_18 (Dense)                (None, 150)          119250      encoder_input[0][0]              
__________________________________________________________________________________________________
dense_19 (Dense)                (None, 120)          18120       dense_18[0][0]                   
__________________________________________________________________________________________________
dense_20 (Dense)                (None, 50)           6050        dense_19[0][0]                   
____________________________________________________________________________________________

In [53]:
latent_i = keras.Input(shape=(latent_dim+10,))
# ii      = Input(shape=(10), name='label_input') #(28, 28, 1)
# x        = concatenate([latent_i,ii])
x        = layers.Dense(16*16, activation="relu")(x)
# x        = BatchNormalization()(x)

x        = layers.Dense(16*16*16, activation="relu")(latent_i)
x        = layers.Reshape((16, 16, 16))(x)
# x        = BatchNormalization()(x)

x        = layers.Conv2DTranspose(8, 5, activation="relu", strides=1)(x)
# x        = BatchNormalization()(x)
x        = layers.Conv2DTranspose(4, 5, activation="relu", strides=1)(x)
# x        = BatchNormalization()(x)
o        = layers.Conv2DTranspose(1, 5, activation="sigmoid", strides=1)(x)

decoder = keras.Model([latent_i,ii], o, name="decoder")
decoder.summary()

Model: "decoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_5 (InputLayer)            [(None, 12)]         0                                            
__________________________________________________________________________________________________
dense_17 (Dense)                (None, 4096)         53248       input_5[0][0]                    
__________________________________________________________________________________________________
reshape_4 (Reshape)             (None, 16, 16, 16)   0           dense_17[0][0]                   
__________________________________________________________________________________________________
conv2d_transpose_9 (Conv2DTrans (None, 20, 20, 8)    3208        reshape_4[0][0]                  
____________________________________________________________________________________________

In [59]:
class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def train_step(self, data):
        if isinstance(data, tuple):
            data = data[0]
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = encoder(data)
            reconstruction = decoder(z)
            reconstruction_loss = tf.reduce_mean(
                keras.losses.binary_crossentropy(data, reconstruction)
            )
            reconstruction_loss *= 28 * 28
            kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
            kl_loss = tf.reduce_mean(kl_loss)
            kl_loss *= -0.5
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        return {
            "loss": total_loss,
            "reconstruction_loss": reconstruction_loss,
            "kl_loss": kl_loss,
        }

In [60]:
vae = VAE(encoder, decoder)
vae.compile(optimizer=keras.optimizers.Adam(learning_rate=0.002))

In [70]:
F = np.concatenate([f_tra, f_tes], axis=0)

F = F.reshape((F.shape[0],28*28*1))
F.shape

(70000, 784)

In [77]:
l_tra = (l_tra.astype('float32')).reshape(l_tra.shape[0], 1)
l_tes = (l_tes.astype('float32')).reshape(l_tes.shape[0], 1)

In [94]:
L = np.concatenate([l_tra, l_tes], axis=0)
L.shape



(70000, 1)

In [97]:
enc = OneHotEncoder()
L = enc.fit_transform(L)
L.reshape((L.shape[0],10))
F.reshape((L.shape[0],28*28))

TypeError: ignored

In [96]:
data = np.concatenate([F,L], axis=0)

ValueError: ignored

In [38]:
def viz_latent_space(epoch, verbos=False):
    cmap = plt.cm.get_cmap('jet', 10)
    mu, _, _ = encoder.predict(f_tra)
    plt.figure(figsize=(8, 8))
    for l in np.unique(l_tra):
        mask = (l_tra == l)
        z1, z2 = mu[mask, 0], mu[mask, 1]
        plt.scatter(z1, z2, color = cmap(l), alpha=0.25, label = l, edgecolors='black')
    #   plt.scatter(mu[:, 0], mu[:, 1], c=l_tra, alpha=0.3, label=l_tra)
    plt.xlabel('z - dim 1')
    plt.ylabel('z - dim 2')
    plt.legend()
    #   plt.colorbar()
    plt.grid()
    plt.title(f'latet space visualization: epoch {epoch}')

    x_l, x_u = min(mu[:, 0]), max(mu[:, 0])
    y_l, y_u = min(mu[:, 1]), max(mu[:, 1])

    filename1 = 'latent_%04d.png' % (epoch)
    plt.savefig(filename1)
    plt.show()
    return x_l, x_u, y_l, y_u

def plot_latent(epoch, verbos=False):
    # display a n*n 2D manifold of digits
    x_l, x_u, y_l, y_u = viz_latent_space(epoch, verbos)
    n = 20
    digit_size = 28
    figsize = 8
    figure = np.zeros((digit_size * n, digit_size * n))
    # linearly spaced coordinates corresponding to the 2D plot
    # of digit classes in the latent space
    grid_x = np.linspace(x_l, x_u, n)
    grid_y = np.linspace(y_l, y_u, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = decoder.predict(z_sample)
            digit = x_decoded[0].reshape(digit_size, digit_size)
            figure[
                i * digit_size : (i + 1) * digit_size,
                j * digit_size : (j + 1) * digit_size,
            ] = digit

    plt.figure(figsize=(figsize, figsize))
    start_range = digit_size // 2
    end_range = n * digit_size + start_range + 1
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(figure, cmap="Greys_r")
    plt.title(f'image grid - epoch{epoch}')
    filename1 = 'grid_%04d.png' % (epoch)
    plt.savefig(filename1, bbox_inches='tight')
    plt.show()

In [39]:
class GANMonitor(keras.callbacks.Callback):
    def __init__(self, period=40):
        self.period = period

    def on_epoch_end(self, epoch, logs=None):
       if epoch % self.period == 0:
           plot_latent(epoch+1, True)

In [40]:
call_back = GANMonitor(10)
epochs = 50

In [50]:
history = vae.fit([F,L], epochs=epochs, batch_size=128, callbacks = [call_back])

Epoch 1/50


ValueError: ignored