In [1]:
import tensorflow as tf

class VanillaVAE(tf.keras.Model):
    def __init__(self, input_dim, output_dim, latent_dim=32):
        super(VanillaVAE, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.latent_dim = latent_dim
        
        # Encoder
        self.encoder = tf.keras.Sequential([
            tf.keras.layers.InputLayer(input_shape=input_dim),
            tf.keras.layers.Conv2D(32, 3, strides=2, activation='relu'),
            tf.keras.layers.Conv2D(64, 3, strides=2, activation='relu'),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(latent_dim * 2)
        ])
        
        # Decoder
        self.decoder = tf.keras.Sequential([
            tf.keras.layers.InputLayer(input_shape=(latent_dim,)),
            tf.keras.layers.Dense(input_dim[0]//4 * input_dim[1]//4 * 64, activation='relu'),
            tf.keras.layers.Reshape((input_dim[0]//4, input_dim[1]//4, 64)),
            tf.keras.layers.Conv2DTranspose(64, 3, strides=2, activation='relu', padding='same'),
            tf.keras.layers.Conv2DTranspose(32, 3, strides=2, activation='relu', padding='same'),
            tf.keras.layers.Conv2D(output_dim[-1], 3, activation='sigmoid', padding='same')
        ])
    
    def encode(self, x):
        mean, logvar = tf.split(self.encoder(x), num_or_size_splits=2, axis=1)
        return mean, logvar
    
    def reparameterize(self, mean, logvar):
        eps = tf.random.normal(shape=mean.shape)
        return eps * tf.exp(logvar * .5) + mean
    
    def decode(self, z):
        return self.decoder(z)
    
    def call(self, inputs):
        mean, logvar = self.encode(inputs)
        z = self.reparameterize(mean, logvar)
        return self.decode(z), mean, logvar

2024-07-12 18:55:30.491661: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-12 18:55:30.491838: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-12 18:55:30.661428: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
import tensorflow as tf

class VAEEncoderDecoder(tf.keras.Model):
    def __init__(self, input_dim, output_dim, latent_dim=32):
        super(VAEEncoderDecoder, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.latent_dim = latent_dim
        
        # Encoder with encoder-decoder architecture
        self.encoder = tf.keras.Sequential([
            tf.keras.layers.InputLayer(input_shape=input_dim),
            # Downsampling
            tf.keras.layers.Conv2D(32, 3, strides=2, activation='relu', padding='same'),
            tf.keras.layers.Conv2D(64, 3, strides=2, activation='relu', padding='same'),
            # Upsampling
            tf.keras.layers.Conv2DTranspose(64, 3, strides=2, activation='relu', padding='same'),
            tf.keras.layers.Conv2DTranspose(32, 3, strides=2, activation='relu', padding='same'),
            tf.keras.layers.Conv2D(32, 3, activation='relu', padding='same'),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(latent_dim * 2)
        ])
        
        # Decoder (same as in VanillaVAE)
        self.decoder = tf.keras.Sequential([
            tf.keras.layers.InputLayer(input_shape=(latent_dim,)),
            tf.keras.layers.Dense(input_dim[0]//4 * input_dim[1]//4 * 64, activation='relu'),
            tf.keras.layers.Reshape((input_dim[0]//4, input_dim[1]//4, 64)),
            tf.keras.layers.Conv2DTranspose(64, 3, strides=2, activation='relu', padding='same'),
            tf.keras.layers.Conv2DTranspose(32, 3, strides=2, activation='relu', padding='same'),
            tf.keras.layers.Conv2D(output_dim[-1], 3, activation='sigmoid', padding='same')
        ])
    
    def encode(self, x):
        mean, logvar = tf.split(self.encoder(x), num_or_size_splits=2, axis=1)
        return mean, logvar
    
    def reparameterize(self, mean, logvar):
        eps = tf.random.normal(shape=mean.shape)
        return eps * tf.exp(logvar * .5) + mean
    
    def decode(self, z):
        return self.decoder(z)
    
    def call(self, inputs):
        mean, logvar = self.encode(inputs)
        z = self.reparameterize(mean, logvar)
        return self.decode(z), mean, logvar

In [3]:
import tensorflow as tf

class VAEDecoderEncoderDecoder(tf.keras.Model):
    def __init__(self, input_dim, output_dim, latent_dim=32):
        super(VAEDecoderEncoderDecoder, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.latent_dim = latent_dim
        
        # Encoder (same as in VanillaVAE)
        self.encoder = tf.keras.Sequential([
            tf.keras.layers.InputLayer(input_shape=input_dim),
            tf.keras.layers.Conv2D(32, 3, strides=2, activation='relu'),
            tf.keras.layers.Conv2D(64, 3, strides=2, activation='relu'),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(latent_dim * 2)
        ])
        
        # Decoder with encoder-decoder architecture
        self.decoder = tf.keras.Sequential([
            tf.keras.layers.InputLayer(input_shape=(latent_dim,)),
            tf.keras.layers.Dense(input_dim[0]//4 * input_dim[1]//4 * 64, activation='relu'),
            tf.keras.layers.Reshape((input_dim[0]//4, input_dim[1]//4, 64)),
            # Upsampling
            tf.keras.layers.Conv2DTranspose(64, 3, strides=2, activation='relu', padding='same'),
            tf.keras.layers.Conv2DTranspose(32, 3, strides=2, activation='relu', padding='same'),
            # Downsampling
            tf.keras.layers.Conv2D(32, 3, strides=2, activation='relu', padding='same'),
            tf.keras.layers.Conv2D(64, 3, strides=2, activation='relu', padding='same'),
            # Final upsampling
            tf.keras.layers.Conv2DTranspose(32, 3, strides=2, activation='relu', padding='same'),
            tf.keras.layers.Conv2DTranspose(16, 3, strides=2, activation='relu', padding='same'),
            tf.keras.layers.Conv2D(output_dim[-1], 3, activation='sigmoid', padding='same')
        ])
    
    def encode(self, x):
        mean, logvar = tf.split(self.encoder(x), num_or_size_splits=2, axis=1)
        return mean, logvar
    
    def reparameterize(self, mean, logvar):
        eps = tf.random.normal(shape=mean.shape)
        return eps * tf.exp(logvar * .5) + mean
    
    def decode(self, z):
        return self.decoder(z)
    
    def call(self, inputs):
        mean, logvar = self.encode(inputs)
        z = self.reparameterize(mean, logvar)
        return self.decode(z), mean, logvar

In [4]:
import tensorflow as tf

class VAEDoubleEncoderDecoder(tf.keras.Model):
    def __init__(self, input_dim, output_dim, latent_dim=32):
        super(VAEDoubleEncoderDecoder, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.latent_dim = latent_dim
        
        # Encoder with encoder-decoder architecture
        self.encoder = tf.keras.Sequential([
            tf.keras.layers.InputLayer(input_shape=input_dim),
            # Downsampling
            tf.keras.layers.Conv2D(32, 3, strides=2, activation='relu', padding='same'),
            tf.keras.layers.Conv2D(64, 3, strides=2, activation='relu', padding='same'),
            # Upsampling
            tf.keras.layers.Conv2DTranspose(64, 3, strides=2, activation='relu', padding='same'),
            tf.keras.layers.Conv2DTranspose(32, 3, strides=2, activation='relu', padding='same'),
            tf.keras.layers.Conv2D(32, 3, activation='relu', padding='same'),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(latent_dim * 2)
        ])
        
        # Decoder with encoder-decoder architecture
        self.decoder = tf.keras.Sequential([
            tf.keras.layers.InputLayer(input_shape=(latent_dim,)),
            tf.keras.layers.Dense(input_dim[0]//4 * input_dim[1]//4 * 64, activation='relu'),
            tf.keras.layers.Reshape((input_dim[0]//4, input_dim[1]//4, 64)),
            # Upsampling
            tf.keras.layers.Conv2DTranspose(64, 3, strides=2, activation='relu', padding='same'),
            tf.keras.layers.Conv2DTranspose(32, 3, strides=2, activation='relu', padding='same'),
            # Downsampling
            tf.keras.layers.Conv2D(32, 3, strides=2, activation='relu', padding='same'),
            tf.keras.layers.Conv2D(64, 3, strides=2, activation='relu', padding='same'),
            # Final upsampling
            tf.keras.layers.Conv2DTranspose(32, 3, strides=2, activation='relu', padding='same'),
            tf.keras.layers.Conv2DTranspose(16, 3, strides=2, activation='relu', padding='same'),
            tf.keras.layers.Conv2D(output_dim[-1], 3, activation='sigmoid', padding='same')
        ])
    
    def encode(self, x):
        mean, logvar = tf.split(self.encoder(x), num_or_size_splits=2, axis=1)
        return mean, logvar
    
    def reparameterize(self, mean, logvar):
        eps = tf.random.normal(shape=mean.shape)
        return eps * tf.exp(logvar * .5) + mean
    
    def decode(self, z):
        return self.decoder(z)
    
    def call(self, inputs):
        mean, logvar = self.encode(inputs)
        z = self.reparameterize(mean, logvar)
        return self.decode(z), mean, logvar

In [5]:
import tensorflow as tf
from tensorflow.keras.datasets import mnist
import numpy as np

# Load and preprocess MNIST data
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = np.expand_dims(x_train, axis=-1)
x_test = np.expand_dims(x_test, axis=-1)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


In [19]:
%matplotlib inline
import matplotlib.pyplot as plt
def train_and_evaluate_vae(model, epochs=30, batch_size=128):
    optimizer = tf.keras.optimizers.Adam(1e-4)
    
    @tf.function
    def train_step(x):
        with tf.GradientTape() as tape:
            reconstruction, mean, logvar = model(x)
            reconstruction_loss = tf.reduce_mean(
                tf.keras.losses.binary_crossentropy(x, reconstruction)
            )
            kl_loss = -0.5 * tf.reduce_mean(1 + logvar - tf.square(mean) - tf.exp(logvar))
            total_loss = reconstruction_loss + kl_loss
        
        grads = tape.gradient(total_loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        return total_loss, reconstruction_loss, kl_loss
    
    test_losses = []
    
    for epoch in range(epochs):
        epoch_loss = 0
        epoch_reconstruction_loss = 0
        epoch_kl_loss = 0
        num_batches = 0
        
        for batch in tf.data.Dataset.from_tensor_slices(x_train).batch(batch_size):
            total_loss, reconstruction_loss, kl_loss = train_step(batch)
            epoch_loss += total_loss
            epoch_reconstruction_loss += reconstruction_loss
            epoch_kl_loss += kl_loss
            num_batches += 1
        
        epoch_loss /= num_batches
        epoch_reconstruction_loss /= num_batches
        epoch_kl_loss /= num_batches
        
        # Evaluate the model
        test_reconstruction, _, _ = model(x_test)
        test_loss = tf.reduce_mean(
            tf.keras.losses.binary_crossentropy(x_test, test_reconstruction)
        )
        test_losses.append(test_loss.numpy())
        
        print(f'Epoch {epoch + 1}, Train Loss: {epoch_loss:.4f}, '
              f'Reconstruction Loss: {epoch_reconstruction_loss:.4f}, '
              f'KL Loss: {epoch_kl_loss:.4f}, Test Loss: {test_loss:.4f}')
    
    # Generate and plot reconstructions for 3 sample inputs
    sample_images = x_test[:3]
    sample_reconstructions, _, _ = model(sample_images)
    
    plt.figure(figsize=(12, 4))
    for i in range(3):
        # Original
        plt.subplot(2, 3, i+1)
        plt.imshow(sample_images[i, :, :, 0], cmap='gray')
        plt.title(f'Original {i+1}')
        plt.axis('off')
        
        # Reconstruction
        plt.subplot(2, 3, i+4)
        plt.imshow(sample_reconstructions[i, :, :, 0], cmap='gray')
        plt.title(f'Reconstructed {i+1}')
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig(f'{model.__class__.__name__}_reconstructions.png')
    plt.close()
    
    return test_losses[-1], test_losses

In [20]:
%matplotlib inline
import matplotlib.pyplot as plt
# Train and evaluate all models
models = {
    'VanillaVAE': VanillaVAE(input_dim=(28, 28, 1), output_dim=(28, 28, 1)),
    'VAEEncoderDecoder': VAEEncoderDecoder(input_dim=(28, 28, 1), output_dim=(28, 28, 1)),
    'VAEDecoderEncoderDecoder': VAEDecoderEncoderDecoder(input_dim=(28, 28, 1), output_dim=(28, 28, 1)),
    'VAEDoubleEncoderDecoder': VAEDoubleEncoderDecoder(input_dim=(28, 28, 1), output_dim=(28, 28, 1))
}

model_losses = {}

In [21]:
# Function to plot test losses for all models
def plot_test_losses(model_losses):
    plt.figure(figsize=(10, 6))
    for model_name, losses in model_losses.items():
        plt.plot(range(1, len(losses) + 1), losses, label=model_name)
    plt.xlabel('Epoch')
    plt.ylabel('Test Loss')
    plt.title('Test Loss per Epoch for Different VAE Models')
    plt.legend()
    plt.grid(True)
    plt.savefig('vae_models_test_losses.png')
    plt.close()

In [22]:
for model_name, model in models.items():
    print(f"\nTraining {model_name}...")
    final_loss, epoch_losses = train_and_evaluate_vae(model)
    model_losses[model_name] = epoch_losses
    print(f'{model_name} Final Test Loss: {final_loss:.4f}')

# Plot test losses for all models
plot_test_losses(model_losses)

# Find the best performing model
best_model = min(model_losses, key=lambda x: model_losses[x][-1])
print(f"\nBest performing model: {best_model}")


Training VanillaVAE...
Epoch 1, Train Loss: 0.3927, Reconstruction Loss: 0.3899, KL Loss: 0.0028, Test Loss: 0.2742
Epoch 2, Train Loss: 0.2706, Reconstruction Loss: 0.2688, KL Loss: 0.0018, Test Loss: 0.2661
Epoch 3, Train Loss: 0.2674, Reconstruction Loss: 0.2656, KL Loss: 0.0018, Test Loss: 0.2648
Epoch 4, Train Loss: 0.2661, Reconstruction Loss: 0.2643, KL Loss: 0.0018, Test Loss: 0.2636
Epoch 5, Train Loss: 0.2653, Reconstruction Loss: 0.2634, KL Loss: 0.0018, Test Loss: 0.2626
Epoch 6, Train Loss: 0.2647, Reconstruction Loss: 0.2628, KL Loss: 0.0019, Test Loss: 0.2622
Epoch 7, Train Loss: 0.2643, Reconstruction Loss: 0.2623, KL Loss: 0.0019, Test Loss: 0.2617
Epoch 8, Train Loss: 0.2641, Reconstruction Loss: 0.2621, KL Loss: 0.0020, Test Loss: 0.2616
Epoch 9, Train Loss: 0.2639, Reconstruction Loss: 0.2618, KL Loss: 0.0021, Test Loss: 0.2615
Epoch 10, Train Loss: 0.2638, Reconstruction Loss: 0.2617, KL Loss: 0.0021, Test Loss: 0.2612
Epoch 11, Train Loss: 0.2636, Reconstruction 