# Understanding Variational Autoencoders (VAEs) and Transposed Convolutions"  
 This notebook provides an in-depth walkthrough of autoencoders and variational autoencoders (VAEs) using PyTorch.
 It is designed as a teaching tool for HCC students with detailed explanations and visualizations.
 ## Topics covered:
- Simple Autoencoders
- Variational Autoencoders (VAEs)
- Convolutional VAE reconstructions
- Latent space interpolation
- Visualizations of the reparameterization trick, loss functions, and network architectures

## Introduction
Autoencoders learn compressed representations of data and then reconstruct the input.
Variational Autoencoders (VAEs) add a probabilistic twist by mapping inputs to a latent distribution (mean and log-variance).
This continuous latent space allows for generating new samples and smooth interpolation between examples.
Key components:
- Encoder: Maps input data to a latent probability distribution.
- Latent Space: A structured, continuous representation.
- Decoder: Reconstructs the input from a sampled latent vector.
The reparameterization trick (z = mean + exp(0.5 * logvar) * epsilon, where epsilon ~ N(0,1)) allows gradients to flow during training.

In [None]:
#Install required packages if needed:
!pip install torch torchvision matplotlib numpy

## What are Autoencoders? (A Gentle Introduction)

Imagine you have a picture, and you want to create a smaller, compressed version of it, but still be able to recreate the original image from that smaller version. That's essentially what an autoencoder does!

Autoencoders are a type of neural network that learn how to compress data (like images) into a lower-dimensional "code," and then learn how to "decode" it back to the original. They're like a fancy compression algorithm that learns what parts of the data are most important.

They have two main parts:
- **Encoder:** This part takes the input data (e.g., an image) and squeezes it down to a more compact form (the latent space).
- **Decoder:** This part takes the compressed data and tries to reconstruct the original input as closely as possible.

In [1]:
#install the necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

#For visualization architectures
from matplotlib.patches import Rectangle, FancyArrowPatch

# Ensure plots show inline in a Jupyter Notebook
%matplotlib inline

## Simple Autoencoder for MNIST
The MNIST images (28x28) are flattened into a 784-dimensional vector.
The encoder compresses this into a latent space, and the decoder reconstructs the image.
ReLU activations introduce non-linearity and Sigmoid ensures outputs are in the range [0,1].

### Training and Visualization of the Autoencoder
The following functions train the autoencoder on the MNIST dataset and visualize original versus reconstructed images.

## Variational Autoencoder (VAE)
In a VAE the encoder outputs two vectors: one for the mean and one for the log-variance.
The latent vector is sampled using the reparameterization trick:
z = mean + exp(0.5 * logvar) * epsilon (where epsilon is drawn from a standard normal distribution)
The loss function is a sum of the reconstruction loss (binary cross-entropy) and the KL divergence.


## Diving into Simple Autoencoders

In this section, we will start with a very basic autoencoder to get a feel for how they work.  We'll use the MNIST dataset, which contains images of handwritten digits (0-9). Each image is 28x28 pixels.

Here's the basic idea:
1. We'll take each 28x28 image and flatten it out into a 784-dimensional vector (28 * 28 = 784).
2. The encoder will take this 784-dimensional vector, and compress it into something smaller.
3.  The decoder will take that smaller compressed vector and try to reconstruct the original 784-dimensional vector (and thus the original image).

This will give us a "code" for the image. The better the autoencoder learns, the closer the reconstructed image will be to the original.

In [6]:
# %% Cell 2: Simple Autoencoder Class
class SimpleAutoencoder(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=128, latent_dim=32):
        """
        A basic autoencoder for MNIST digits.
        - input_dim: Size of the flattened MNIST image (28*28 = 784)
        - hidden_dim: Number of neurons in the hidden layer
        - latent_dim: Size of the latent representation
        """
        super(SimpleAutoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()
        )

    def forward(self, x):
        if x.dim() > 2:  # Flatten image if needed
            x = x.view(x.size(0), -1)
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded


In [None]:
# %% Cell 3: Create and Print Autoencoder
autoencoder = SimpleAutoencoder()
print("Simple Autoencoder Architecture:")
print(autoencoder)


In [11]:
# %% Cell 4: Train Autoencoder Function
def train_autoencoder(model, epochs=5, batch_size=128, learning_rate=1e-3):
    """
    Train the autoencoder on the MNIST dataset.
    Returns the trained model and a list of average losses per epoch.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    print("Training on device:", device)

    # Use the training dataset for training
    transform = transforms.ToTensor()
    train_dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

    criterion = nn.MSELoss()  # Mean Squared Error for reconstruction
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    losses = []
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for batch_idx, (data, _) in enumerate(train_loader):
            data = data.to(device)
            data_flat = data.view(data.size(0), -1)
            output = model(data_flat)
            loss = criterion(output, data_flat)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            if batch_idx % 100 == 0:
                print(f"Epoch {epoch+1}/{epochs} - Batch {batch_idx} Loss: {loss.item():.6f}")

        avg_loss = train_loss / len(train_loader)
        losses.append(avg_loss)
        print(f"Epoch {epoch+1} Average Loss: {avg_loss:.6f}")

    return model, losses


In [None]:
trained_autoencoder, training_losses = train_autoencoder(autoencoder, epochs=5)
print("Training Losses per Epoch:")
print(training_losses)


In [None]:
plt.figure(figsize=(10, 5))
plt.plot(training_losses, marker='o', linestyle='-')
plt.title("Training Losses Over Epochs")
plt.xlabel("Epoch")
plt.ylabel("Average Loss")
plt.grid(True)
plt.show()



In [14]:
# %% Cell 5: Autoencoder Reconstruction Visualization
def visualize_reconstructions(model, num_examples=10):
    """
    Visualize original and reconstructed images from the autoencoder.
    """
    model.eval()
    transform = transforms.ToTensor()
    test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=num_examples, shuffle=True)
    dataiter = iter(test_loader)
    images, _ = next(dataiter)

    with torch.no_grad():
        images_flat = images.view(images.size(0), -1)
        reconstructions = model(images_flat).view(images.size())

    plt.figure(figsize=(20, 4))
    for i in range(num_examples):
        plt.subplot(2, num_examples, i+1)
        plt.imshow(images[i].squeeze().numpy(), cmap="gray")
        plt.title("Original")
        plt.axis("off")

    for i in range(num_examples):
        plt.subplot(2, num_examples, i+num_examples+1)
        plt.imshow(reconstructions[i].squeeze().numpy(), cmap="gray")
        plt.title("Reconstructed")
        plt.axis("off")

    plt.tight_layout()
    plt.show()


## Understanding Limitations of Simple Autoencoders

Our simple autoencoder can reconstruct images, but it has limitations:
*   **No Control Over Latent Space:** The latent space (the compressed representation) might not be smooth or continuous. This means that small changes in the latent space don't always translate to smooth changes in the output, making it hard to create variations or interpolate between different inputs.
*  **Not Generative:** The autoencoder is not a generative model - we cannot sample from the latent space to create new images.

This is where Variational Autoencoders come in. VAEs introduce a probabilistic element to the latent space to address these limitations.

## Introduction to Variational Autoencoders (VAEs)

VAEs improve on simple autoencoders by adding a probabilistic twist. Instead of just compressing an input to a single code, a VAE maps an input to a distribution in the latent space.

Think of it this way:
*   **Simple Autoencoder:** Learns a single code (a point) in latent space for each input.
*   **VAE:** Learns a distribution (a "cloud" of points) in latent space for each input. This distribution is characterized by a mean (the center of the cloud) and a log-variance (how spread out the cloud is).

This difference is important. By mapping to a distribution, the VAE creates a smoother, more continuous latent space. This has several benefits.
1.  **Generation:** We can sample from this latent distribution to generate new data points that are similar to the training data.
2.  **Interpolation:** We can move through the latent space and generate a smooth transition of different outputs.

The key idea of the VAE is to make the latent space smooth, organized and allow for sampling new data from it.

In [16]:
# %% Cell 6: Simple VAE Class
class SimpleVAE(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=512, latent_dim=32):
        """
        A basic Variational Autoencoder for MNIST digits.
        - input_dim: Flattened input size.
        - hidden_dim: Number of neurons in hidden layers.
        - latent_dim: Dimensionality of the latent space.
        """
        super(SimpleVAE, self).__init__()
        # Encoder network
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU()
        )
        # Two separate linear layers for mean and log-variance
        self.fc_mean = nn.Linear(hidden_dim // 2, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim // 2, latent_dim)
        # Decoder network
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()
        )

    def encode(self, x):
        """Encode input into mean and log-variance."""
        hidden = self.encoder(x)
        mean = self.fc_mean(hidden)
        logvar = self.fc_logvar(hidden)
        return mean, logvar

    def reparameterize(self, mean, logvar):
        """
        Sample latent vector z using the reparameterization trick.
        z = mean + exp(0.5 * logvar) * epsilon, where epsilon ~ N(0, 1)
        """
        std = torch.exp(0.5 * logvar)
        epsilon = torch.randn_like(std)
        return mean + std * epsilon

    def decode(self, z):
        """Reconstruct input from latent vector z."""
        return self.decoder(z)

    def forward(self, x):
        if x.dim() > 2:
            x = x.view(x.size(0), -1)
        mean, logvar = self.encode(x)
        z = self.reparameterize(mean, logvar)
        reconstructed = self.decode(z)
        return reconstructed, mean, logvar


In [None]:
# %% Cell 7: Create and Print VAE
vae_model = SimpleVAE()
print("Simple VAE Architecture:")
print(vae_model)


### VAE Loss Function
The VAE loss function consists of:
- Reconstruction Loss: Binary cross-entropy between the input and its reconstruction.
- KL Divergence: Measures how much the latent distribution diverges from a standard normal distribution.
The total loss is given by:
Total Loss = Reconstruction Loss + beta * KL Divergence

## The Importance of the Reparameterization Trick

The VAE uses the "reparameterization trick" to make the training process possible.
The VAE wants to sample from the latent space to decode back to an image, but doing so is a non-differentiable operation.

The reparameterization trick rewrites the sampling in a way that allows gradients to flow during training. Instead of directly sampling from the latent distribution, the trick samples from a standard normal distribution (N(0,1)) and transforms the random number using the mean and the variance learned from the encoder.  This allows our network to update the encoder based on the loss.

The equation for the reparameterization is:
`z = mean + exp(0.5 * logvar) * epsilon`
where:
- `z` is the sampled latent vector
- `mean` and `logvar` are outputs of the encoder
-  `epsilon` is a random number drawn from N(0,1).

This ensures that while we introduce randomness, the randomness is separate from the learnable parameters.

In [None]:
# %% Cell 8: VAE Loss Function
def vae_loss(recon_x, x, mean, logvar, beta=1.0):
    """
    Compute the VAE loss.
    Total Loss = Reconstruction Loss + beta * KL Divergence.
    Reconstruction Loss: Binary cross-entropy.
    KL Divergence: Regularizes the latent distribution.
    """
    recon_loss = F.binary_cross_entropy(recon_x, x, reduction="sum")
    kl_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
    total_loss = recon_loss + beta * kl_loss
    return total_loss, recon_loss, kl_loss

    # Load a batch of MNIST images for testing
transform = transforms.ToTensor()
test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=10, shuffle=True)
images, labels = next(iter(test_loader))

# Flatten images (since our VAE expects a 784-length vector)
images = images.view(images.size(0), -1)

# Pass through the VAE model (assumes vae_model is defined in Cell 7)
recon, mean, logvar = vae_model(images)

# Compute the loss using the vae_loss function
total_loss, recon_loss, kl_loss = vae_loss(recon, images, mean, logvar)

print("Total Loss:", total_loss.item())
print("Reconstruction Loss:", recon_loss.item())
print("KL Divergence:", kl_loss.item())



### Convolutional VAE Visualization Functions
These functions visualize the reconstructions from a convolutional VAE and generate new images by sampling from its latent space.
(Note: Your convolutional VAE model should have a method called sample for image generation.)

## Why Convolutional VAEs?

Our previous VAE examples used fully connected layers, which process the data as a simple list of numbers, ignoring any spatial information. Convolutional layers, on the other hand, take into account the spatial structure in images by looking at small patches in the image at a time.

Convolutional VAEs are better suited for images because:
1.  They can extract features at multiple scales: this is key to having good image representations.
2.  They reduce the number of parameters by sharing weights across the image.
3.  They are more efficient with spatial data than fully connected layers.

Convolutional VAEs will use transposed convolutions for decoding. Transposed convolutions are sometimes called "deconvolutions" but the process is actually the opposite direction of a convolutional layer.  A convolutional layer takes an image and creates a set of feature maps. Transposed convolution takes a set of feature maps and creates an image or a set of feature maps that are larger.

In short, they allow us to build VAEs that deal with images much better.

In [None]:
# %% Cell 9: Convolutional VAE Visualization Functions

def visualize_conv_vae_reconstructions(model, num_examples=10):
    """
    Visualize original and reconstructed images from a convolutional VAE.
    (Assumes your model accepts images as input and returns reconstructions.)
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    # Load MNIST test dataset
    transform = transforms.ToTensor()
    test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=num_examples, shuffle=True)
    dataiter = iter(test_loader)
    images, _ = next(dataiter)
    images = images.to(device)

    with torch.no_grad():
        outputs = model(images)
        # If model returns a tuple (reconstructions, mean, logvar), use the first element.
        if isinstance(outputs, tuple):
            reconstructions = outputs[0]
        else:
            reconstructions = outputs

    images = images.cpu()
    reconstructions = reconstructions.cpu()

    plt.figure(figsize=(20, 4))
    for i in range(num_examples):
        plt.subplot(2, num_examples, i+1)
        # Original images: expect shape (1, 28, 28); squeeze to (28,28)
        plt.imshow(images[i].squeeze().numpy(), cmap="gray")
        plt.title("Original")
        plt.axis("off")

    for i in range(num_examples):
        plt.subplot(2, num_examples, i+num_examples+1)
        # For reconstructed images, if output is flattened (shape (784,)), reshape to (28,28)
        img = reconstructions[i].squeeze()
        if img.numel() == 784:
            img = img.view(28, 28)
        plt.imshow(img.numpy(), cmap="gray")
        plt.title("Reconstructed")
        plt.axis("off")

    plt.tight_layout()
    plt.show()


def generate_images_from_conv_vae(model, num_examples=10):
    """
    Generate new images by sampling from the latent space of a convolutional VAE.
    (Your model must implement a 'sample' method for this to work.)
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    with torch.no_grad():
        generated_images = model.sample(num_examples)
        generated_images = generated_images.cpu()

    plt.figure(figsize=(20, 3))
    for i in range(num_examples):
        plt.subplot(1, num_examples, i+1)
        img = generated_images[i].squeeze()
        if img.numel() == 784:
            img = img.view(28, 28)
        plt.imshow(img.numpy(), cmap="gray")
        plt.title("Generated")
        plt.axis("off")

    plt.tight_layout()
    plt.show()


# Example usage:
# For demonstration purposes, we'll use the simple VAE 'vae_model'
# Ensure that 'vae_model' is already defined in your notebook.
if __name__ == "__main__":
    test_model = vae_model  # Replace with your conv VAE if available
    visualize_conv_vae_reconstructions(test_model)
    # If your model has a 'sample' method, you can also run:
    # generate_images_from_conv_vae(test_model)


### Latent Space Interpolation
This function selects one example each of two specified digits from MNIST, encodes them, and linearly interpolates between their latent representations.
The resulting reconstructions show a smooth transition between the two digits.

In [23]:
# %% Cell 10: Latent Space Interpolation
def interpolate_digits(model, digit1=3, digit2=8, steps=10):
    """
    Interpolate between two digits in the VAE's latent space.
    Finds one example for each digit, encodes them, linearly interpolates between their latent vectors,
    and decodes the results to show a smooth transition.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    transform = transforms.ToTensor()
    test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=True)

    image1, image2 = None, None
    for img, label in test_loader:
        if label.item() == digit1 and image1 is None:
            image1 = img.to(device)
        if label.item() == digit2 and image2 is None:
            image2 = img.to(device)
        if image1 is not None and image2 is not None:
            break

    if image1 is None or image2 is None:
        print("Could not find images for the specified digits.")
        return

    with torch.no_grad():
        image1_flat = image1.view(1, -1)
        image2_flat = image2.view(1, -1)
        mean1, _ = model.encode(image1_flat)
        mean2, _ = model.encode(image2_flat)

    interpolated_images = []
    for alpha in np.linspace(0, 1, steps):
        z = mean1 * (1 - alpha) + mean2 * alpha
        with torch.no_grad():
            reconstruction = model.decode(z)
        interpolated_images.append(reconstruction.view(28, 28).cpu().numpy())

    plt.figure(figsize=(20, 3))
    for i, img in enumerate(interpolated_images):
        plt.subplot(1, steps, i+1)
        plt.imshow(img, cmap="gray")
        plt.axis("off")
    plt.suptitle(f"Interpolation between {digit1} and {digit2}", fontsize=16)
    plt.show()


### Reparameterization Trick Visualization
This visualization shows how the encoder produces the mean (μ) and log-variance (log σ²), how epsilon is sampled from a standard normal distribution, and how these combine to form the latent vector z.
The diagram emphasizes that the randomness (epsilon) is separate from the network's learnable parameters.

In [None]:
# %% Cell 11: Reparameterization Trick Visualization
def visualize_reparameterization_trick():
    """
    Visualize the reparameterization trick:
    - Shows how the encoder outputs the mean (μ) and log-variance (log σ²),
      how a random epsilon is sampled, and how they combine to form z.
    """
    fig, ax = plt.subplots(figsize=(12, 8))

    # Define block positions: [x, y, width, height]
    encoder_box = [0.1, 0.4, 0.2, 0.3]
    mean_box = [0.4, 0.5, 0.1, 0.1]
    logvar_box = [0.4, 0.3, 0.1, 0.1]
    reparam_box = [0.6, 0.4, 0.15, 0.15]
    decoder_box = [0.8, 0.4, 0.2, 0.3]
    random_box = [0.6, 0.2, 0.1, 0.1]

    # Draw boxes for each component
    ax.add_patch(Rectangle(encoder_box[:2], encoder_box[2], encoder_box[3],
                             facecolor="lightgreen", alpha=0.7, edgecolor="black"))
    ax.add_patch(Rectangle(mean_box[:2], mean_box[2], mean_box[3],
                             facecolor="coral", alpha=0.7, edgecolor="black"))
    ax.add_patch(Rectangle(logvar_box[:2], logvar_box[2], logvar_box[3],
                             facecolor="coral", alpha=0.7, edgecolor="black"))
    ax.add_patch(Rectangle(reparam_box[:2], reparam_box[2], reparam_box[3],
                             facecolor="lightskyblue", alpha=0.7, edgecolor="black"))
    ax.add_patch(Rectangle(decoder_box[:2], decoder_box[2], decoder_box[3],
                             facecolor="lightgreen", alpha=0.7, edgecolor="black"))
    ax.add_patch(Rectangle(random_box[:2], random_box[2], random_box[3],
                             facecolor="gold", alpha=0.7, edgecolor="black"))

    # Add labels inside boxes
    ax.text(encoder_box[0] + encoder_box[2]/2, encoder_box[1] + encoder_box[3]/2,
            "Encoder\nNeural Network", ha="center", va="center", fontsize=12)
    ax.text(mean_box[0] + mean_box[2]/2, mean_box[1] + mean_box[3]/2,
            "μ", ha="center", va="center", fontsize=14)
    ax.text(logvar_box[0] + logvar_box[2]/2, logvar_box[1] + logvar_box[3]/2,
            "log σ²", ha="center", va="center", fontsize=14)
    ax.text(reparam_box[0] + reparam_box[2]/2, reparam_box[1] + reparam_box[3]/2,
            "z = μ + σ × ε", ha="center", va="center", fontsize=12)
    ax.text(decoder_box[0] + decoder_box[2]/2, decoder_box[1] + decoder_box[3]/2,
            "Decoder\nNeural Network", ha="center", va="center", fontsize=12)
    ax.text(random_box[0] + random_box[2]/2, random_box[1] + random_box[3]/2,
            "ε ~ N(0,1)", ha="center", va="center", fontsize=12)

    # Draw arrows indicating the flow
    arrow_kwargs = dict(arrowstyle="->", lw=2, color="black")
    ax.annotate("", xy=(encoder_box[0], encoder_box[1] + encoder_box[3]/2),
                xytext=(encoder_box[0] - 0.1, encoder_box[1] + encoder_box[3]/2),
                arrowprops=arrow_kwargs)
    ax.text(encoder_box[0] - 0.15, encoder_box[1] + encoder_box[3]/2, "Input", ha="center", va="center", fontsize=12)
    ax.annotate("", xy=(mean_box[0], mean_box[1] + mean_box[3]/2),
                xytext=(encoder_box[0] + encoder_box[2], encoder_box[1] + encoder_box[3]/2),
                arrowprops=arrow_kwargs)
    ax.annotate("", xy=(logvar_box[0], logvar_box[1] + logvar_box[3]/2),
                xytext=(encoder_box[0] + encoder_box[2], encoder_box[1] + encoder_box[3]/2),
                arrowprops=arrow_kwargs)
    ax.annotate("", xy=(reparam_box[0], reparam_box[1] + reparam_box[3]/2),
                xytext=(mean_box[0] + mean_box[2], mean_box[1] + mean_box[3]/2),
                arrowprops=arrow_kwargs)
    ax.annotate("", xy=(reparam_box[0], reparam_box[1] + reparam_box[3]/2),
                xytext=(logvar_box[0] + logvar_box[2], logvar_box[1] + logvar_box[3]/2),
                arrowprops=arrow_kwargs)
    ax.annotate("", xy=(reparam_box[0] + reparam_box[2]/2, reparam_box[1]),
                xytext=(random_box[0] + random_box[2]/2, random_box[1] + random_box[3]),
                arrowprops=arrow_kwargs)
    ax.annotate("", xy=(decoder_box[0], decoder_box[1] + decoder_box[3]/2),
                xytext=(reparam_box[0] + reparam_box[2], reparam_box[1] + reparam_box[3]/2),
                arrowprops=arrow_kwargs)
    ax.annotate("", xy=(decoder_box[0] + decoder_box[2] + 0.1, decoder_box[1] + decoder_box[3]/2),
                xytext=(decoder_box[0] + decoder_box[2], decoder_box[1] + decoder_box[3]/2),
                arrowprops=arrow_kwargs)
    ax.text(decoder_box[0] + decoder_box[2] + 0.15, decoder_box[1] + decoder_box[3]/2,
            "Output", ha="center", va="center", fontsize=12)

    explanation = (
        "Key Points:\n"
        "1. Encoder outputs μ (mean) and log σ² (log-variance).\n"
        "2. Convert log σ² to σ using exp(0.5 * log σ²).\n"
        "3. Sample ε from N(0,1) and compute z = μ + σ × ε.\n"
        "4. Gradients flow through μ and σ during backpropagation."
    )
    ax.text(0.5, 0.85, "The Reparameterization Trick", ha="center", fontsize=16, weight="bold")
    ax.text(0.5, 0.02, explanation, ha="center", va="bottom", fontsize=11,
            bbox=dict(facecolor="white", alpha=0.7, boxstyle="round,pad=0.5"))

    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.axis("off")
    plt.tight_layout()
    plt.show()
%matplotlib inline
interpolate_digits(vae_model, digit1=3, digit2=8, steps=10)
visualize_reparameterization_trick()

### VAE Loss Function Visualization
This diagram explains the two main components of the VAE loss:
- Reconstruction Loss: How closely the output matches the input.
- KL Divergence: How far the latent distribution is from a standard normal distribution.

## Visualizing the VAE Training Process

When training the VAE, we are trying to minimize the total loss, which is made of the reconstruction loss and the KL divergence. The reconstruction loss is essentially like a reconstruction error. The KL divergence term regularizes the latent space by penalizing distributions that are very different from a standard normal distribution, making the latent space continuous and nicely structured.

During the training of the VAE, it's expected that the reconstruction loss will go down over time and hopefully that the KL divergence also goes down as the model learns to map the input data to the latent space.

In practice, you will see that the reconstruction loss usually reaches a lower point than KL divergence, which can be a little harder to minimize.

In [25]:
# %% Cell 12: VAE Loss Function Visualization
def visualize_vae_loss_function():
    """
    Visualize the two components of the VAE loss:
      - Reconstruction Loss: How well the output matches the input.
      - KL Divergence: How far the latent distribution is from a standard normal distribution.
    """
    fig, ax = plt.subplots(figsize=(12, 7))

    ax.add_patch(Rectangle((0.05, 0.5), 0.4, 0.35, facecolor="lightblue", alpha=0.3))
    ax.add_patch(Rectangle((0.05, 0.15), 0.4, 0.35, facecolor="lightgreen", alpha=0.3))
    ax.add_patch(Rectangle((0.55, 0.15), 0.4, 0.7, facecolor="lightsalmon", alpha=0.3))

    ax.text(0.25, 0.9, "Reconstruction Loss", ha="center", fontsize=14, weight="bold")
    ax.text(0.25, 0.55, "Measures how well we reconstruct the input", ha="center", fontsize=12)

    ax.text(0.25, 0.5, "KL Divergence", ha="center", fontsize=14, weight="bold")
    ax.text(0.25, 0.2, "Measures divergence from a standard normal distribution", ha="center", fontsize=12)

    ax.text(0.75, 0.9, "Total VAE Loss", ha="center", fontsize=14, weight="bold")
    ax.text(0.75, 0.83, "Reconstruction Loss + β × KL Divergence", ha="center", fontsize=12)

    recon_details = (
        "Binary Cross-Entropy:\n"
        "- Compares each pixel of the original and reconstruction\n"
        "- High when differences are large, low when similar"
    )
    ax.text(0.25, 0.7, recon_details, ha="center", va="center", fontsize=10,
            bbox=dict(facecolor="white", alpha=0.7, boxstyle="round,pad=0.3"))

    kl_details = (
        "KL(N(μ, σ²) || N(0, 1)) = -0.5 × sum(1 + log(σ²) - μ² - σ²)\n\n"
        "- Penalizes divergence from N(0,1)\n"
        "- Encourages μ near 0 and σ near 1\n"
        "- Regularizes the latent space"
    )
    ax.text(0.25, 0.35, kl_details, ha="center", va="center", fontsize=10,
            bbox=dict(facecolor="white", alpha=0.7, boxstyle="round,pad=0.3"))

    plt.tight_layout()
    plt.show()


### Autoencoder Architecture Diagram
This diagram shows the flow from input to latent space and back to output in a basic autoencoder.

In [26]:
# %% Cell 13: Autoencoder Architecture Diagram
def plot_autoencoder_architecture():
    """
    Display a diagram showing the flow in a basic autoencoder:
    from input to latent space and back to output.
    """
    fig, ax = plt.subplots(figsize=(12, 6))

    input_rect = Rectangle((0, 2), 1, 4, color="skyblue", alpha=0.7)
    ax.add_patch(input_rect)
    ax.text(0.5, 6.5, "Input (784 neurons for 28x28 image)", ha="center", va="center", fontsize=10)

    encoder_rect = Rectangle((2, 2.5), 1, 3, color="lightgreen", alpha=0.7)
    ax.add_patch(encoder_rect)
    ax.text(2.5, 6.5, "Encoder Hidden Layer (256 neurons)", ha="center", va="center", fontsize=10)

    latent_rect = Rectangle((4, 3), 1, 2, color="coral", alpha=0.7)
    ax.add_patch(latent_rect)
    ax.text(4.5, 6.5, "Latent Space (32 neurons)", ha="center", va="center", fontsize=10)

    decoder_rect = Rectangle((6, 2.5), 1, 3, color="lightgreen", alpha=0.7)
    ax.add_patch(decoder_rect)
    ax.text(6.5, 6.5, "Decoder Hidden Layer (256 neurons)", ha="center", va="center", fontsize=10)

    output_rect = Rectangle((8, 2), 1, 4, color="lightsalmon", alpha=0.7)
    ax.add_patch(output_rect)
    ax.text(8.5, 6.5, "Output (784 neurons)", ha="center", va="center", fontsize=10)

    arrow1 = FancyArrowPatch((1, 4), (2, 4), arrowstyle="->", mutation_scale=20, color="black")
    arrow2 = FancyArrowPatch((3, 4), (4, 4), arrowstyle="->", mutation_scale=20, color="black")
    arrow3 = FancyArrowPatch((5, 4), (6, 4), arrowstyle="->", mutation_scale=20, color="black")
    arrow4 = FancyArrowPatch((7, 4), (8, 4), arrowstyle="->", mutation_scale=20, color="black")
    ax.add_patch(arrow1)
    ax.add_patch(arrow2)
    ax.add_patch(arrow3)
    ax.add_patch(arrow4)

    ax.text(1.5, 4.5, "Compress", ha="center", fontsize=10)
    ax.text(3.5, 4.5, "Compress", ha="center", fontsize=10)
    ax.text(5.5, 4.5, "Reconstruct", ha="center", fontsize=10)
    ax.text(7.5, 4.5, "Reconstruct", ha="center", fontsize=10)

    ax.set_xlim(-1, 10)
    ax.set_ylim(0, 8)
    ax.axis("off")
    ax.set_title("Basic Autoencoder Architecture", fontsize=14)

    plt.tight_layout()
    plt.show()


In [None]:
# %% Cell 14: Test / Run desired functions
# To train the simple autoencoder:
# trained_autoencoder, training_losses = train_autoencoder(autoencoder, epochs=5)
# visualize_reconstructions(trained_autoencoder)

# To test the VAE:
# Use the vae_loss function during training, or visualize reconstructions similarly.

# To visualize latent space interpolation:
interpolate_digits(vae_model, digit1=3, digit2=8, steps=10)

# To visualize the reparameterization trick:
visualize_reparameterization_trick()

# To visualize the VAE loss function:
visualize_vae_loss_function()

# To view the autoencoder architecture diagram:
plot_autoencoder_architecture()


# **Conclusion**
In this notebook we explored simple autoencoders and variational autoencoders using PyTorch.
We discussed the reparameterization trick, examined the VAE loss function, and visualized both network architectures and latent space interpolations.
Experiment with the code, adjust hyperparameters, and add further visualizations to deepen your understanding.
