In [None]:
from tensorflow.keras import datasets
(x_train,y_train), (x_test,y_test) = datasets.fashion_mnist.load_data()

In [None]:
def preprocess(imgs):
    imgs = imgs.astype("float32") / 255.0
    imgs = np.pad(imgs, ((0, 0), (2,2), (2, 2)), constant_values = 0.0)
    imgs = np.expand_dims(imgs, -1)
    return imgs

x_train = preprocess(x_train)
x_test = preprocess(x_test)

In [None]:
encoder_input = layers.Input(
    shape = (32, 32, 1), name = "encoder_input"
)

x = layers.Conv2D(32, (3, 3), strides = 2, activation = 'relu', padding = "same")(encoder_input)
x = layers.Conv2D(64, (3, 3), strides = 2, activation = 'relu', padding = "same")(x)
x = layers.Conv2D(128, (3, 3), strides 2, activation = 'relu', padding = "same")(x)
shape_before_flattening = K.int_shape(x)[1:]

x = layers.Flatten()(x)
encoder_output = layers.Dense(2, name = "encoder_output")(x)

encoder = models.Model(encoder_input, encoder_output)

In [None]:
decoder_input = layers.Input(shape = (2,), name = "decoder_input")
x = layers.Dense(np.prod(shape_before_flattening))(decoder_input)
x = layers.Reshape(shape_before_flattening)(x)
x = layers.Conv2DTranspose(
    128, (3, 3), strides = 2, activation = 'relu', padding = "same"
)(x)
x = layers.Conv2DTranspose(
    64, (3, 3), strides = 2, activation = 'relu', padding = "same"
)(x)
x = layers.Conv2DTranspose(
    32, (3, 3), strides = 2, activation = 'relu', padding = "same"
)(x)
decoder_output = layers.Conv2D(
    1, (3, 3), strides = 1, activation = "sigmoid". padding = "same", name = "decoder_output"
)(x)

decoder = models.Model(decoder_input, decoder_output)

In [None]:
autoencoder = Model(encoder_input, decoder(encoder_output))

autoencoder.compile(optimizer = "adam", loss = "binary_crossentropy")

In [None]:
autoencoder.fit(
    x_train,
    x_train,
    epochs = 5,
    batch_size = 100,
    shuffle = True,
    validation_data = (x_test, x_test),
)

In [None]:
example_images = x_test[:5000]
predictions = autoencoder.predict(example_images)

In [None]:
embeddings = encoder.predict(example_images)

plt.figure(figsize = (8, 8))
plt.scatter(embeddings[:, 0], embeddings[:, 1], c = "black", alpha = 0.5, s = 3)
plt.show()

In [None]:
mins, maxs = np.min(embeddings, axis = 0), np.max(embeddings, axis = 0)
sample = np.random.uniform(mins, maxs, size = (18, 2))
reconstructions = decoder.predict(sample)

In [None]:
class Sampling(layers.Layer):
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = K.random_normal(shape = (batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

In [None]:
encoder_input = layers.Input(
    shape = (32, 32, 1), name = "encoder_input"
)

x = layers.Conv2D(32, (3, 3), strides = 2, activation = 'relu', padding = "same")(encoder_input)
x = layers.Conv2D(64, (3, 3), strides = 2, activation = 'relu', padding = "same")(x)
x = layers.Conv2D(128, (3, 3), strides 2, activation = 'relu', padding = "same")(x)
shape_before_flattening = K.int_shape(x)[1:]

x = layers.Flatten()(x)
z_mean = layers.Dense(2, name = "z_mean")(x)
z_log_var = layers.Dense(2, name = "z_log_var")(x)
z = Sampling()([z_mean, z_log_var])

encoder = models.Model(encoder_input, [z_mean, z_log_var, z], name = "encoder")

In [None]:
class VAE(models.Model):
    def __init__(self, encoder, decoderm, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = metrics.Mean(name = "total_loss")
        self.reconstruction_loss_tracker = metrics.Mean(name = "reconstruction_loss")
        self.kl_loss_tracker = metrics,Mean(name = "kl_loss")

        @property
        def metrics(self):
            return[
                self.total_loss_tracker,
                self.reconstruction_loss_tracker,
                self.kl_loss_tracker
            ]
        
        def call(self, inputs):
            z_mean, z_log_var, z = encoder(inputs)
            reconstruction = decoder(z)
            return z_mean, z_log_var, reconstruction
        
        def train_step(self, data):
            with tf.GradientTape() as tape:
                z_mean, z_log_var, reconstruction = self(data)
                reconstruction_loss = tf.reduce_mean(
                    500
                    * losses.binary_crossentropy(
                        data, reconstruction, axis = (1, 2, 3)
                    )
                )
                kl_loss = tf.reduce_mean(
                    reduce_sum(
                        - 0.5
                        * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)),
                        axis = 1,
                    )
                )
                total_loss = reconstruction_loss + kl_loss

            grads = tape.gradient(total_loss, self.trainable_weights)
            self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

            self.total_loss_tracker.update_state(total_loss)
            self.reconstruction_loss_tracker.update_state(reconstruction_loss)
            self.kl_loss_tracker.update_state(kl_loss)

            return {m.name: m.result() for m in self.metrics}
        
        vae = VAE(encoder, decoder)
        vae.compile(optimizer = "adam")
        vae.fit(
            train, epochs = 5, batch_size =100
        )

In [None]:
grid_width, grid_height = (10, 3)
z_sample = np.random.normal(size = (grid_width * grid_height, 200))

reconstructions = decoder.predict(z_sample)

fig = plt.figure(figsize = (18, 5))
fig.subplogs_adjust(hspace = 0.4, wspace = 0.4)
for i in range(grid_width * grid_height):
    ax = fig.add_subplot(grid_height, grid_width, i + 1)
    ax.axis("off")
    ax.imshow(reconstructions[i, :, :])