# Imports

In [None]:
import os
import shutil
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

# Preparing the Data Folders

In [None]:
model_ckpt = '/content/drive/MyDrive/VAE/BAYC/Models/'
train_path = '/content/drive/MyDrive/VAE/BAYC/Train_data/'
test_path = '/content/drive/MyDrive/VAE/BAYC/Test_data/'

# Encoder and Decoder Definition

In [None]:
latent_dim = 16
image_size = 512

# Encoder
encoder = tf.keras.Sequential(
    [
        tf.keras.layers.InputLayer(input_shape=(image_size, image_size, 3)),
        tf.keras.layers.Conv2D(
            filters=32, kernel_size=3, strides=(2, 2), activation='relu'),
        tf.keras.layers.Conv2D(
            filters=64, kernel_size=3, strides=(2, 2), activation='relu'),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(latent_dim + latent_dim, kernel_initializer="zeros"),
    ]
)

# Decoder
latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(image_size//4 * image_size//4 * 64, activation="relu")(latent_inputs)
x = layers.Reshape((image_size//4, image_size//4, 64))(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
decoder_outputs = layers.Conv2DTranspose(3, 3, activation="sigmoid", padding="same")(x)
decoder = keras.Model(latent_inputs, decoder_outputs)

# VAE Definition

In [None]:
class VAE(tf.keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def encode(self, x):
        mean_logvar = self.encoder(x)
        mean, logvar = tf.split(mean_logvar, num_or_size_splits=2, axis=1)
        return mean, logvar

    def reparameterize(self, mean, logvar):
        eps = tf.random.normal(shape=mean.shape)
        return eps * tf.exp(logvar * 0.5) + mean

    def decode(self, z):
        return self.decoder(z)

    def call(self, x):
        mean, logvar = self.encode(x)
        z = self.reparameterize(mean, logvar)
        return self.decode(z)

vae = VAE(encoder, decoder)

optimizer = tf.keras.optimizers.Adam(1e-5)

def log_normal_pdf(sample, mean, logvar, raxis=1):
    log2pi = tf.math.log(2. * np.pi)
    epsilon = 1e-8
    return tf.reduce_sum(
        -.5 * ((sample - mean) ** 2. * tf.exp(-logvar + epsilon) + logvar + log2pi),
        axis=raxis)


def compute_loss(model, x):
    mean, logvar = model.encode(x)
    z = model.reparameterize(mean, logvar)
    x_logit = model.decode(z)
    cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x)
    logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3])
    logpz = log_normal_pdf(z, 0., 0.)
    logqz_x = log_normal_pdf(z, mean, logvar)
    return -tf.reduce_mean(logpx_z + logpz - logqz_x)

@tf.function
def train_step(model, x, optimizer):
    with tf.GradientTape() as tape:
        loss = compute_loss(model, x)
    gradients = tape.gradient(loss, model.trainable_variables)
    clipped_gradients = [tf.clip_by_value(grad, -1.0, 1.0) for grad in gradients]
    optimizer.apply_gradients(zip(clipped_gradients, model.trainable_variables))

# Data Loaders

In [None]:
train_ds = tf.keras.utils.image_dataset_from_directory(
    directory=train_path,
    labels='inferred',
    label_mode='categorical',
    batch_size=80,
    image_size=(image_size, image_size))

valid_ds = tf.keras.utils.image_dataset_from_directory(
    directory=test_path,
    labels='inferred',
    label_mode='categorical',
    batch_size=80,
    image_size=(image_size, image_size))

normalization_layer = layers.experimental.preprocessing.Rescaling(1./255)

train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))#.cache()
valid_ds = valid_ds.map(lambda x, y: (normalization_layer(x), y))#.cache()

Found 7914 files belonging to 1 classes.
Found 1988 files belonging to 1 classes.


# Training Data

In [None]:
ckpt = tf.train.Checkpoint(optimizer=optimizer, encoder=encoder, decoder=decoder, step=tf.Variable(0))
manager = tf.train.CheckpointManager(ckpt, model_ckpt, max_to_keep=1)

In [None]:
if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
    ckpt.restore(manager.latest_checkpoint)
else:
    print("Initializing from scratch.")

Restored from /content/drive/MyDrive/VAE/BAYC/Models/ckpt-10


NotFoundError: ignored

In [None]:
epochs = 50

for epoch in range(epochs):
    for train_x, _ in train_ds:
        train_step(vae, train_x, optimizer)
        ckpt.step.assign_add(1)
        if int(ckpt.step) % 200 == 0:
            save_path = manager.save()
            print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))

    train_loss = tf.keras.metrics.Mean()
    valid_loss = tf.keras.metrics.Mean()

    for train_x, _ in train_ds:
        train_loss(compute_loss(vae, train_x))
    for valid_x, _ in valid_ds:
        valid_loss(compute_loss(vae, valid_x))
    print(f"Epoch {epoch + 1}, Train loss: {train_loss.result()}, Validation loss: {valid_loss.result()}")

# Testing and Validation Plots

In [None]:
test_batch = next(iter(valid_ds))
test_image = test_batch[0][0:1]

plt.figure()
plt.imshow(test_image[0])
plt.title("Original Image")
plt.show()

mean_logvar = encoder(test_image)
mean, logvar = tf.split(mean_logvar, num_or_size_splits=2, axis=1)
z = vae.reparameterize(mean, logvar)
decoded_image = decoder(z)

plt.figure()
plt.imshow(decoded_image[0])
plt.title("Decoder Output")
plt.show()