Skip to content

Manjarly/PyTorch-Variational-Autoencoder-VAE

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

3 Commits
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

PyTorch Variational Autoencoder (VAE) for MNIST

This repository contains a Jupyter Notebook (Variational_autoencoder.ipynb) that implements a Variational Autoencoder (VAE) from scratch using PyTorch. The model is trained on the MNIST dataset to learn a probabilistic latent space and generate new, synthetic images of handwritten digits.

πŸš€ Project Overview

This project demonstrates the complete workflow for building and training a VAE:

  1. Dataset: The MNIST dataset is loaded and pre-processed.
  2. Model Definition: A VAE class is defined with three main components:
    • An Encoder that compresses images into a probability distribution (mean and log-variance).
    • A Reparameterization Trick to sample from this distribution.
    • A Decoder that reconstructs images from the sampled latent vectors.
  3. Loss Function: A custom loss function is implemented that combines:
    • Reconstruction Loss: Binary Cross Entropy (BCE) to make the output image look like the input.
    • Kullback-Leibler (KL) Divergence: A regularization term that forces the latent space to resemble a standard normal distribution.
  4. Training: The model is trained to minimize this combined loss.
  5. Sampling: After training, new images are generated by sampling random vectors from the latent space and passing them through the decoder.

🧠 Model Architecture

The VAE is built using simple Multi-Layer Perceptrons (MLPs).

Encoder

The Encoder takes a $28 \times 28$ image (flattened to 784 dimensions) and maps it to a 20-dimensional probabilistic latent space.

  • Input: $784$-dim vector (flattened image)
  • Layer 1: Linear (784 -> 512) + ReLU
  • Layer 2: Linear (512 -> 256) + ReLU
  • Output: Two parallel Linear layers (256 -> 20) to output the mean (mu) and log-variance (log_var) of the latent distribution.

Reparameterization Trick

To allow for backpropagation, we sample from the latent distribution $N(\mu, \sigma^2)$ using the reparameterization trick: $z = \mu + \epsilon \times \sigma$ where $\sigma = \exp(0.5 \times \log_var)$ and $\epsilon \sim N(0, 1)$.

Decoder

The Decoder takes a 20-dimensional latent vector ($z$) and reconstructs a $28 \times 28$ image.

  • Input: $20$-dim latent vector ($z$)
  • Layer 1: Linear (20 -> 256) + ReLU
  • Layer 2: Linear (256 -> 512) + ReLU
  • Output Layer: Linear (512 -> 784) + Sigmoid (to scale output pixels between 0 and 1).

πŸ“‰ Custom Loss Function

The total loss is the sum of the Reconstruction Loss and the KL Divergence:

  1. Reconstruction Loss: Binary Cross Entropy (BCE) is used to measure the difference between the original and reconstructed images.
  2. KL Divergence: This term acts as a regularizer, pushing the learned latent distribution to be close to a standard normal distribution. It is calculated as: $KLD = -0.5 \times \sum(1 + \log(\sigma^2) - \mu^2 - \sigma^2)$.

πŸ“Š Dataset: MNIST

The project uses the standard MNIST dataset.

  • Images are transformed into PyTorch Tensors.
  • Data is loaded in batches using DataLoader.

πŸ› οΈ Training & Sampling

  • Optimizer: Adam with a learning rate of 1e-3.
  • Epochs: The model is trained for 20 epochs.
  • Sampling: After training, $64$ random vectors are sampled from a standard normal distribution ($N(0, 1)$) and passed through the decoder to generate new images.
  • Output: The generated samples are saved as samples.png.

🏎️ How to Run

  1. Ensure you have the required libraries installed:
    pip install torch torchvision numpy matplotlib
  2. Launch Jupyter Notebook:
    jupyter notebook
  3. Open the Variational_autoencoder.ipynb file.
  4. Run the cells sequentially from top to bottom. The notebook will:
    • Download the MNIST dataset.
    • Initialize the VAE model, optimizer, and loss function.
    • Run the training loop for 20 epochs, printing the loss.
    • Generate a $8 \times 8$ grid of new digit images and save it as samples.png.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published