In [35]:
import numpy as np
from tensorflow.keras import datasets
import matplotlib.pyplot as plt

(x_train, y_train), (x_test, y_test) = datasets.fashion_mnist.load_data()

x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

x_train = x_train[..., np.newaxis]
x_test = x_test[..., np.newaxis]

x_train.shape







(60000, 28, 28, 1)

In [36]:

import tensorflow as tf
from tensorflow.keras import layers, models, metrics, backend as K
from tensorflow.keras.losses import binary_crossentropy

class Sampling(layers.Layer):
    def call(self, inputs):
        z_mean, z_log_var = inputs
        epsilon = tf.random.normal(shape=tf.shape(z_mean))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon


In [37]:
# Encoder
encoder_input = layers.Input(shape=(28, 28, 1), name="encoder_input")
x = layers.Conv2D(32, (3, 3), strides=2, activation="relu", padding="same")(encoder_input)
x = layers.Conv2D(64, (3, 3), strides=2, activation="relu", padding="same")(x)
x = layers.Conv2D(128, (3, 3), strides=2, activation="relu", padding="same")(x)
shape_before_flattening = K.int_shape(x)[1:]
x = layers.Flatten()(x)
x = layers.Dense(256, activation="relu")(x)
z_mean = layers.Dense(2, name="z_mean")(x)
z_log_var = layers.Dense(2, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = models.Model(encoder_input, [z_mean, z_log_var, z], name="encoder")

In [38]:
# Decoder
decoder_input = layers.Input(shape=(2,), name="decoder_input")
x = layers.Dense(14*14*128)(decoder_input)
x = layers.Reshape((14, 14, 128))(x)
x = layers.Conv2DTranspose(64, (3, 3), strides=2, activation='relu', padding="same")(x)
x = layers.Conv2DTranspose(1, (3, 3), activation='sigmoid', padding="same")(x)
decoder = models.Model(decoder_input, x)

In [42]:

class VAE(models.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = metrics.Mean(name="reconstruction_loss")
        self.kl_loss_tracker = metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def call(self, inputs):
        z_mean, z_log_var, z = self.encoder(inputs)
        reconstruction = self.decoder(z)
        return z_mean, z_log_var, reconstruction

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, reconstruction = self(data)
            reconstruction_loss = 500 * tf.reduce_mean(
                binary_crossentropy(data, reconstruction))
            kl_loss = tf.reduce_mean(
                tf.reduce_sum(-0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)), axis=1)
            )
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {m.name: m.result() for m in self.metrics}

In [46]:
# Create and train the VAE model
vae = VAE(encoder, decoder)
vae.compile(optimizer="adam")
vae.fit(
    x_train,
    epochs=3,
    batch_size=100
)

Epoch 1/3
[1m600/600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m188s[0m 307ms/step - kl_loss: 0.4402 - reconstruction_loss: 0.4693 - total_loss: 0.5134
Epoch 2/3
[1m600/600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m196s[0m 296ms/step - kl_loss: 0.0996 - reconstruction_loss: 0.4804 - total_loss: 0.4903
Epoch 3/3
[1m600/600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m199s[0m 292ms/step - kl_loss: 0.0908 - reconstruction_loss: 0.4813 - total_loss: 0.4904


<keras.src.callbacks.history.History at 0x7beb8e2dc0d0>

In [45]:
# prompt: reduce the KL loss and overall total loss

# ... (previous code)

class VAE(models.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = metrics.Mean(name="reconstruction_loss")
        self.kl_loss_tracker = metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def call(self, inputs):
        z_mean, z_log_var, z = self.encoder(inputs)
        reconstruction = self.decoder(z)
        return z_mean, z_log_var, reconstruction

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, reconstruction = self(data)
            reconstruction_loss = tf.reduce_mean(
                binary_crossentropy(data, reconstruction)
            )
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + 0.1 * kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {m.name: m.result() for m in self.metrics}

# ... (rest of the code)


In [None]:
#updated part

In [41]:
from tensorflow.keras.layers import DepthwiseConv2D, Conv2D

# Encoder
encoder_input = layers.Input(shape=(28, 28, 1), name="encoder_input")
x = DepthwiseConv2D((3, 3), strides=2, activation="relu", padding="same")(encoder_input)
x = Conv2D(32, (1, 1), activation="relu")(x)
x = DepthwiseConv2D((3, 3), strides=2, activation="relu", padding="same")(x)
x = Conv2D(64, (1, 1), activation="relu")(x)
x = DepthwiseConv2D((3, 3), strides=2, activation="relu", padding="same")(x)
x = Conv2D(128, (1, 1), activation="relu")(x)
shape_before_flattening = K.int_shape(x)[1:]
x = layers.Flatten()(x)
x = layers.Dense(256, activation="relu")(x)
z_mean = layers.Dense(2, name="z_mean")(x)
z_log_var = layers.Dense(2, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = models.Model(encoder_input, [z_mean, z_log_var, z], name="encoder")

# Decoder
decoder_input = layers.Input(shape=(2,), name="decoder_input")
x = layers.Dense(14*14*128)(decoder_input)
x = layers.Reshape((14, 14, 128))(x)
x = DepthwiseConv2D((3, 3), strides=1, activation='relu', padding="same")(x)
x = Conv2D(64, (1, 1), activation='relu')(x)
x = layers.Conv2DTranspose(64, (3, 3), strides=2, activation='relu', padding="same")(x)
x = layers.Conv2DTranspose(1, (3, 3), activation='sigmoid', padding="same")(x)
decoder = models.Model(decoder_input, x)
