In [32]:
import tensorflow as tf 
import pandas as pd 
import numpy as np 
import matplotlib.pyplot as plt 
import tensorflow.keras.backend as K

tf.config.list_physical_devices()

[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
 PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [33]:
no_of_class = 102
IMG_SIZE = (224,224)
EMBED_DIM = 256

train = tf.keras.utils.image_dataset_from_directory('../102_flowers_dataset/train/',image_size=IMG_SIZE)
test = tf.keras.utils.image_dataset_from_directory('../102_flowers_dataset/valid/',image_size=IMG_SIZE)

Found 6552 files belonging to 102 classes.
Found 818 files belonging to 102 classes.


In [34]:
class Sampling(tf.keras.layers.Layer):
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = K.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

In [35]:
encoder_input = tf.keras.layers.Input(shape=(IMG_SIZE[0],IMG_SIZE[1],3), name="encoder_input")
x = tf.keras.layers.Rescaling(scale=1.0/255,name='rescale')(encoder_input) 
x = tf.keras.layers.Conv2D(16, (3, 3), strides=2, activation="relu", padding="same")(x)
x = tf.keras.layers.Conv2D(32, (3, 3), strides=2, activation="relu", padding="same")(x)
x = tf.keras.layers.Conv2D(64, (3, 3), strides=2, activation="relu", padding="same")(x)
x = tf.keras.layers.Conv2D(64, (3, 3), strides=2, activation="relu", padding="same")(x)

shape_before_flattening = K.int_shape(x)[1:]  # the decoder will need this!

x = tf.keras.layers.Flatten()(x)

z_mean = tf.keras.layers.Dense(EMBED_DIM, name="z_mean")(x)
z_log_var = tf.keras.layers.Dense(EMBED_DIM, name="z_log_var")(x)

z = Sampling()([z_mean, z_log_var])


encoder = tf.keras.models.Model(encoder_input, [z_mean, z_log_var, z], name="encoder")
encoder.summary()

In [36]:
decoder_input = tf.keras.layers.Input(shape=(EMBED_DIM,), name="decoder_input")
x = tf.keras.layers.Dense(np.prod(shape_before_flattening))(decoder_input)
x = tf.keras.layers.Reshape(shape_before_flattening)(x)

x = tf.keras.layers.Conv2DTranspose(256, (3, 3), strides=2, activation="relu", padding="same")(x)
x = tf.keras.layers.Conv2DTranspose(128, (3, 3), strides=2, activation="relu", padding="same")(x)
x = tf.keras.layers.Conv2DTranspose(64, (3, 3), strides=2, activation="relu", padding="same")(x)
x = tf.keras.layers.Conv2DTranspose(32, (3, 3), strides=2, activation="relu", padding="same")(x)

decoder_output = tf.keras.layers.Conv2DTranspose(3,(3, 3),strides=1,activation="sigmoid",padding="same",name="decoder_output")(x)

decoder = tf.keras.models.Model(decoder_input, decoder_output,name="decoder")
decoder.summary()

In [45]:
class VAE(tf.keras.models.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = tf.keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def call(self, inputs):
        z_mean, z_log_var, z = encoder(inputs[0])
        reconstruction = decoder(z)
        return z_mean, z_log_var, reconstruction

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, reconstruction = self(data)
            beta = 500
            reconstruction_loss = tf.reduce_mean(
                beta
                * tf.keras.losses.binary_crossentropy(
                    data[0], reconstruction, axis=(1, 2, 3)
                )
            )
            kl_loss = tf.reduce_mean(
                tf.reduce_sum(
                    -0.5
                    * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)),
                    axis=1,
                )
            )
            total_loss = reconstruction_loss + kl_loss

        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)

        return {m.name: m.result() for m in self.metrics}

    def test_step(self, data):
        """Step run during validation."""

        z_mean, z_log_var, reconstruction = self(data)
        beta = 500
        reconstruction_loss = tf.reduce_mean(
            beta
            * tf.keras.losses.binary_crossentropy(data[0], reconstruction, axis=(1, 2, 3))
        )
        kl_loss = tf.reduce_mean(
            tf.reduce_sum(
                -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)),
                axis=1,
            )
        )
        total_loss = reconstruction_loss + kl_loss

        return {
            "loss": total_loss,
            "reconstruction_loss": reconstruction_loss,
            "kl_loss": kl_loss,
        }
        
model = VAE(encoder, decoder)


In [46]:
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.00001))

In [47]:
model.fit(train,epochs=5,batch_size=8)

Epoch 1/5
[1m205/205[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m24s[0m 92ms/step - kl_loss: nan - reconstruction_loss: nan - total_loss: nan
Epoch 2/5
[1m205/205[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 76ms/step - kl_loss: nan - reconstruction_loss: nan - total_loss: nan
Epoch 3/5
[1m205/205[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 77ms/step - kl_loss: nan - reconstruction_loss: nan - total_loss: nan
Epoch 4/5
[1m205/205[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 78ms/step - kl_loss: nan - reconstruction_loss: nan - total_loss: nan
Epoch 5/5
[1m205/205[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 83ms/step - kl_loss: nan - reconstruction_loss: nan - total_loss: nan


<keras.src.callbacks.history.History at 0x7f701840f4d0>

: 