## Training a Variational Autoencoder (VAE) on the MNIST Dataset

In this blog post, we'll explore how to train a Variational Autoencoder (VAE) to generate synthetic data using the MNIST dataset. We'll guide you through setting up the environment, defining and training the VAE model, and generating new images. These synthetic images will be used in addition to the original MNIST data to enhance a classification task. We'll also visualize the learned latent space and evaluate the quality of the generated images using the Fréchet Inception Distance (FID) score. 

Variational Autoencoders (VAEs) are generative models that learn to represent data in a lower-dimensional latent space while being able to reconstruct the original data from this space. Unlike traditional autoencoders, VAEs introduce a probabilistic approach by learning a distribution over the latent variables, allowing for the generation of new, similar data by sampling from this distribution. This makes VAEs particularly useful for tasks such as data synthesis, anomaly detection, and creating smooth interpolations between different data points. Their probabilistic nature provides a principled way to generate realistic and diverse synthetic data.


#### Setting Up the Environment and Loading the Dataset
First, we need to set up our environment and load the MNIST dataset. We will use PyTorch for this task.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Hyperparameters
batch_size = 256
learning_rate = 1e-3
num_epochs = 100
latent_dim = 30

# MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False)

#### Defining the VAE Model
Next, we define the VAE model. The VAE consists of an encoder and a decoder. The encoder maps the input data to a latent space, and the decoder reconstructs the data from the latent space.

In [None]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(28*28, 400)
        self.fc2_mean = nn.Linear(400, latent_dim)
        self.fc2_logvar = nn.Linear(400, latent_dim)
        self.fc3 = nn.Linear(latent_dim, 400)
        self.fc4 = nn.Linear(400, 28*28)
    
    def encode(self, x):
        h1 = torch.relu(self.fc1(x))
        return self.fc2_mean(h1), self.fc2_logvar(h1)
    
    def reparameterize(self, mean, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mean + eps * std
    
    def decode(self, z):
        h3 = torch.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))
    
    def forward(self, x):
        mean, logvar = self.encode(x.view(-1, 28*28))
        z = self.reparameterize(mean, logvar)
        return self.decode(z), mean, logvar

#### Training the VAE
We will now train the VAE using the MNIST dataset. The training loop involves forward and backward passes, and we will save checkpoints at regular intervals.

In [None]:
def loss_function(recon_x, x, mean, logvar):
    BCE = nn.functional.binary_cross_entropy(recon_x, x.view(-1, 28*28), reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
    return BCE + KLD

model = VAE()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

model.train()
for epoch in range(num_epochs):
    train_loss = 0
    for data, _ in train_loader:
        data = data.view(-1, 28*28)
        optimizer.zero_grad()
        recon_batch, mean, logvar = model(data)
        loss = loss_function(recon_batch, data, mean, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss/len(train_loader.dataset):.4f}')

#### Generating Images Using the Trained VAE
After training the VAE, we can generate new images by sampling from the latent space.

In [None]:
import matplotlib.pyplot as plt

def generate_images(model, num_images=20):
    model.eval()
    with torch.no_grad():
        z = torch.randn(num_images, model.fc2_mean.in_features)
        generated = model.decode(z).view(-1, 1, 28, 28).cpu().numpy()
        fig, axes = plt.subplots(2, 10, figsize=(10, 2))
        for i, ax in enumerate(axes.flat):
            ax.imshow(generated[i][0], cmap='gray')
            ax.axis('off')
        plt.tight_layout()
    return plt

plt = generate_images(model)
plt.show()

![Synthetic images](../images/fake_images.png)

#### Visualizing the Latent Space
We can also visualize the latent space using t-SNE to reduce the dimensionality to 2D. 

Visualizing the latent space of a Variational Autoencoder (VAE) provides insights into how the model organizes and represents different features of the data. By mapping similar images to nearby points in the latent space, we can observe clusters that correspond to distinct characteristics, such as digit shapes in the MNIST dataset. The smooth transitions between points in the latent space suggest that the model has learned meaningful features, enabling it to generate realistic interpolations. This visualization can reveal how well the VAE captures the underlying data distribution, highlighting areas where the model performs well or struggles to differentiate certain features.


t-Distributed Stochastic Neighbor Embedding (t-SNE) is a powerful dimensionality reduction technique commonly used for visualizing high-dimensional data, especially in the context of machine learning. Here are some key reasons to use t-SNE:

* t-SNE is particularly effective for visualizing high-dimensional data in two or three dimensions. It preserves local structure, meaning that similar data points remain close together in the lower-dimensional representation.
* Unlike linear techniques such as PCA (Principal Component Analysis), t-SNE is adept at capturing complex, non-linear relationships in the data, making it well-suited for data with intricate structures.
* t-SNE can help reveal clusters or groupings within the data, allowing researchers to identify distinct classes or patterns that may not be apparent in the original high-dimensional space.
* The resulting 2D or 3D visualizations from t-SNE are often more interpretable, enabling easier communication of insights to non-technical stakeholders.

In [None]:
from sklearn.manifold import TSNE

def get_latent_space(model, data_loader):
    latent_vectors = []
    model.eval()
    with torch.no_grad():
        for inputs, _ in data_loader:
            mu, log_var = model.encode(inputs.view(-1, 28*28))
            z = model.reparameterize(mu, log_var)
            latent_vectors.append(z)
    return torch.cat(latent_vectors, dim=0).cpu().numpy()

def visualize_latent_space(latent_vectors):
    tsne = TSNE(n_components=2, random_state=42)
    latent_2d = tsne.fit_transform(latent_vectors)
    plt.scatter(latent_2d[:, 0], latent_2d[:, 1], alpha=0.7)
    plt.title("2D t-SNE visualization of VAE latent space")
    plt.xlabel("Latent Dimension 1")
    plt.ylabel("Latent Dimension 2")
    plt.grid(True)
    plt.show()

latent_vectors = get_latent_space(model, val_loader)
visualize_latent_space(latent_vectors)

![Image Description](../images/latent_space.png)

#### Calculating the FID Score
The Fréchet Inception Distance (FID) score is a metric used to evaluate the quality of generated images by comparing the distribution of generated images to the distribution of real images. A lower FID score indicates that the generated images are more similar to the real images. Achieved FID of 3.58 (on 1024 generated images) and FID score: 4.94783878326416 on MNIST dataset.

#### Why FID?

* FID compares the distribution of generated images to the distribution of real images in feature space, rather than just pixel values. This allows for a more meaningful evaluation of visual quality.
*  FID has been shown to correlate well with human judgment on the quality of generated images, making it a more perceptually relevant metric than pixel-wise comparisons like Mean Squared Error (MSE).
* FID measures both the quality and diversity of generated images. A lower FID score indicates that the generated images are not only similar to real images but also cover a diverse range of styles and classes.

#### Interpreting a FID Score of 3.58
FID scores typically range from 0 (perfect similarity) to higher values indicating worse performance. In practice, FID scores below 10 are often considered to indicate good image quality, while scores above 20 may indicate poorer performance.

A FID score of 3.58 is quite low and suggests that the generated images are of high quality and closely resemble the real images from the dataset. It indicates that the model has effectively learned the underlying distribution of the data and can generate diverse and realistic images.

#### Context Matters
The interpretation of FID scores can depend on the dataset being used. For example, achieving a low FID score on a simple dataset like MNIST might be easier than on more complex datasets like CIFAR-10 or ImageNet.

When comparing models, it’s important to look at relative FID scores (e.g., Model A vs. Model B) rather than absolute scores, as the significance of the score can vary with different datasets and architectures.

In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchmetrics.image.fid import FrechetInceptionDistance
from modeldef import ConvVAE

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
latent_dim = 30  # Dimension of the latent space
batch_size = 256
# load the checkpoint
ckpt = 'vae_checkpoint.pth_epoch_81.pth'

checkpoint = torch.load(f'./checkpoints/{ckpt}')
model = ConvVAE(latent_dim=latent_dim).to(device)

# Transformations for MNIST dataset
transform = transforms.Compose([
    transforms.Grayscale(3), # Convert to 3 channels
    transforms.Resize(299),  # Resize to 299x299 (Inception-v3 input size)
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.to(torch.uint8))
#    transforms.Lambda(lambda x: x.expand(3, -1, -1)),
#    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # Normalize to [-1, 1] range,
#    transforms.Lambda(lambda x: (x * 255).clamp(0, 255).to(torch.uint8))
])

transform_gen = transforms.Compose([
    transforms.Lambda(lambda x: x.expand(3, -1, -1)),
    transforms.Resize(299),
    transforms.Lambda(lambda x: x.to(torch.uint8))
#    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
#    transforms.Lambda(lambda x: (x * 255).clamp(0, 255).to(torch.uint8))
])

# Load the MNIST dataset
mnist_data = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(mnist_data, batch_size=batch_size, shuffle=False)

fid = FrechetInceptionDistance(feature=2048).to(device)

# Extract features from real images and random noise (as "fake" images)
for real_images, _ in dataloader:
    real_images = real_images.to(device)
    with torch.no_grad():
        z = torch.randn(batch_size, model.fc2.in_features)  # Sample from the latent space
        generated = model.decode(z).view(-1, 1, 28, 28).cpu()  # Generate images from the latent vectors
        fake_images = torch.stack([transform_gen(img) for img in generated])
    # Update the FID metric with real and fake images
    fid.update(real_images, real=True)
    fid.update(fake_images, real=False)

# Compute the FID score
fid_score = fid.compute().item()
print(f"FID score: {fid_score}")

In this blog post, we have covered the process of training a VAE on the MNIST dataset, generating new images, visualizing the latent space, and calculating the FID score to evaluate the quality of generated images. The generated images can be used in addition to the MNIST data for various classification tasks. The VAE model provides a powerful way to learn meaningful representations of the data, which can be useful in many machine learning applications.

#### Common Gotchas Leading to High FID Scores

Initially I got a very high FID score (334) while working through this example. Even though synthetic images looked OK. The issue was with the way preprocessing transformations were applied before FID calculation.

* Ensure that the preprocessing applied to the training data is the same as that applied to the generated images. Differences in scaling, normalization, or augmentation can lead to discrepancies that inflate the FID score.

* When converting images to the correct format, ensure that they are normalized appropriately. For example, if the training images are scaled to the range [0, 1], the generated images should be treated similarly. Mismatched normalization can create significant differences in the data distributions.

* Applying different transformations (e.g., resizing, cropping, or color adjustments) to real and generated images can distort their distributions. For example, if you use Grayscale(3) for generated images but not for the training images, it could lead to higher FID scores.

* If the VAE's latent space is not well-learned, it may produce unrealistic samples that don't capture the diversity of the training data. This could lead to a high FID score despite visually similar images.


#### Additional reading/references

* Foundations and Trends in Machine Learning - An Introduction to Variational Autoencoders https://arxiv.org/pdf/1906.02691
* Working Code - [Visit Github](https://github.com/Kunal627/kunal627.github.io/tree/main/code)