In [4]:
import tensorflow as tf
from keras import Model, layers, Sequential
import numpy as np

In [1]:
from keras import utils

train_data = utils.image_dataset_from_directory(
    "./celeba/img_align_celeba/",
    labels=None,
    color_mode="rgb",
    image_size=(64, 64),
    batch_size=(128),
    shuffle=True,
    seed=24,
    interpolation="bilinear"
)

Found 202599 files.


I0000 00:00:1729773436.065953   21707 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-10-24 21:37:16.381491: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2343] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


In [8]:
def preprocess(img):
    img = tf.cast(img, "float32") / 255.0
    return img

train = train_data.map(lambda x: preprocess(x))     

<class 'tensorflow.python.data.ops.map_op._MapDataset'>


In [19]:
import keras
class Sampling(layers.Layer):
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch  = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = keras.random.normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

In [34]:
encoder_input_layer = layers.Input(shape=(64, 64, 3), name="input_layer")
x = layers.Conv2D(128, (3, 3), strides=2, padding="same")(encoder_input_layer)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(128, (3, 3), strides=2, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(128, (3, 3), strides=2, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(128, (3, 3), strides=2, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(128, (3, 3), strides=2, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU()(x)
x = layers.Flatten()(x)

z_mean  = layers.Dense(200, name="z_mean")(x)
z_log_var  = layers.Dense(200, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])

encoder = Model(encoder_input_layer, [z_mean, z_log_var, z])
encoder.summary()

In [35]:
decoder_input_layer = layers.Input(shape=(200,))
x = layers.Dense(512)(decoder_input_layer)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU()(x)
x = layers.Reshape((2, 2, 128))(x)

x = layers.Conv2DTranspose((128), (3, 3), strides=2, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU()(x)

x = layers.Conv2DTranspose((128), (3, 3), strides=2, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU()(x)

x = layers.Conv2DTranspose((128), (3, 3), strides=2, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU()(x)

x = layers.Conv2DTranspose((128), (3, 3), strides=2, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU()(x)

decoder_output_layer = layers.Conv2DTranspose((3), (3, 3), strides=2, padding="same")(x)

decoder = Model(decoder_input_layer, decoder_output_layer)
decoder.summary()

In [36]:
class VAE(Model):
    def __init__ (self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstrution_loss_tracker = keras.metrics.Mean(name="reconstruction_loss")
        self.kl_loss_trcker = keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [self.total_loss_tracker, self.reconstrution_loss_tracker, self.kl_loss_trcker]

    def call(self, inputs):
        z_mean, z_log_var, z = encoder(inputs)
        reconstruction = decoder(z)
        return z_mean, z_log_var, reconstruction

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, reconstruction = self(data)
            reconstruction_loss = tf.reduce_mean(
                500 * keras.losses.binary_crossentropy(data, reconstruction, axis=(1, 2, 3))
            )
            kl_loss = tf.reduce_mean(
                tf.reduce_sum( -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)), axis=1)
            )
            total_loss =reconstruction_loss + kl_loss

            grads = tape.gradient(total_loss, self.trainable_variables)
            self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

            self.total_loss_tracker.update_state(total_loss)
            self.reconstrution_loss_tracker.update_state(reconstruction_loss)
            self.kl_loss_trcker.update_state(kl_loss)

            return {m.name: m.result() for m in self.metrics}

vae = VAE(encoder, decoder)
vae.compile(optimizer="adam")

In [37]:
vae.fit(train,epochs=3)

Epoch 1/3
[1m1583/1583[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m506s[0m 317ms/step - kl_loss: 7.0785 - reconstruction_loss: 333.2290 - total_loss: 340.3075
Epoch 2/3
[1m1583/1583[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m497s[0m 314ms/step - kl_loss: 9.0034 - reconstruction_loss: 284.1945 - total_loss: 293.1980
Epoch 3/3
[1m1583/1583[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m644s[0m 407ms/step - kl_loss: 10.0138 - reconstruction_loss: 277.7852 - total_loss: 287.7990


<keras.src.callbacks.history.History at 0x7f33a09475f0>

In [39]:
grid_width, grid_height = (10, 3)
z_sample = np.random.normal(size=(grid_width * grid_height, 200))
reconstrutions = decoder.predict(z_sample)

from matplotlib.pylab import plt
fig = plt.figure(figsize=(18, 5))
fig.subplots_adjust(hspace=0.4, wspace=0.4)
for i in range(grid_width * grid_height):
    ax = fig.add_subplot(grid_height, grid_width, i + 1)
    ax.axis("off")
    ax.imshow(reconstrutions[i, :, :])

ValueError: Input 0 of layer "functional_12" is incompatible with the layer: expected shape=(None, 200), found shape=(10, 3, 200)