# Beta-VAE on MNIST: Latent Space Analysis

This notebook demonstrates training a Beta-VAE (without convolutional layers) on the MNIST dataset, and analyzing the learned latent space through visualization, clustering, and t-SNE.

In [None]:
# Imports
import numpy as np
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.metrics import confusion_matrix
from sklearn.cluster import KMeans
import seaborn as sns
from model import NonConv_VAE
from utils import plot_latent_distribution, plot_generated_images, train_vae, compute_total_loss, relabel_clusters

In [None]:
# Reproducibility
seed = 20
keras.backend.clear_session()
tf.random.set_seed(seed)
np.random.seed(seed)

In [None]:
# Hyperparameters
num_epochs = 30
latent_dim = 10
LOSS_TYPE = 'bce'
BETA_SCHEDULE = 'linear'
LEARNING_RATE = 1e-3
MAX_BETA = 1.0

In [None]:
# Load and preprocess MNIST
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
image_size = x_train.shape[1] * x_train.shape[2]
x_train = x_train.reshape(-1, image_size).astype("float32") / 255
x_test = x_test.reshape(-1, image_size).astype("float32") / 255

In [None]:
# Create dataloaders
batch_size = 32
x_train_dl = tf.data.Dataset.from_tensor_slices(x_train).shuffle(1000).batch(batch_size)
x_test_dl = tf.data.Dataset.from_tensor_slices(x_test).batch(batch_size)

In [None]:
# Initialize model and optimizer
vae_model = NonConv_VAE(latent_dim=latent_dim)
optimizer = keras.optimizers.Adam(LEARNING_RATE)

## Latent Space (Before Training)

In [None]:
plot_latent_distribution(vae_model, x_test, y_test, batch_size=100, title='Before Training')

## Train Beta-VAE

In [None]:
recon_losses, kl_losses, grad_norms = train_vae(
    vae_model, optimizer, x_train_dl, num_epochs, compute_total_loss,
    free_bits=0.5, max_beta=MAX_BETA, beta_schedule=BETA_SCHEDULE, loss_type=LOSS_TYPE
)

In [None]:
plt.plot(recon_losses, label='Reconstruction Loss')
plt.plot(kl_losses, label='KL Divergence')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.title("Training Loss Curves")
plt.show()

## Clustering in Latent Space

In [None]:
_, z_mean, _, _ = vae_model.predict(x_test, batch_size=100, verbose=0)

kmeans = KMeans(n_clusters=10, random_state=0, n_init='auto')
y_pred = kmeans.fit_predict(z_mean)
y_pred_relabel = relabel_clusters(y_pred, y_test)

In [None]:
# Latent space clusters
plt.figure(figsize=(10, 6))
plt.scatter(z_mean[:, 0], z_mean[:, 1], c=y_pred_relabel, cmap="tab10", s=5)
plt.title("K-means Clustering in Latent Space (Relabeled)")
plt.xlabel("z[0]")
plt.ylabel("z[1]")
plt.grid(True)
plt.show()

In [None]:
# Confusion Matrix
conf_mat = confusion_matrix(y_test, y_pred_relabel)
plt.figure(figsize=(8, 6))
sns.heatmap(conf_mat, annot=True, fmt="d", cmap="Blues")
plt.title("Confusion Matrix: True vs. K-means Clusters")
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.show()

## t-SNE Projection of Latent Space

In [None]:
tsne = TSNE(n_components=2, random_state=0, perplexity=30)
z_tsne = tsne.fit_transform(z_mean)

plt.figure(figsize=(10, 6))
plt.scatter(z_tsne[:, 0], z_tsne[:, 1], c=y_test, cmap="tab10", s=5)
plt.title("t-SNE Projection of Latent Space")
plt.xlabel("Component 1")
plt.ylabel("Component 2")
plt.grid(True)
plt.show()

## Latent Space (After Training)

In [None]:
plot_latent_distribution(vae_model, x_test, y_test, batch_size=100, title='After Training')

## Visualize Generated Digits from Latent Space

In [None]:
plot_generated_images(vae_model.decode, latent_dim=latent_dim)