Skip to content

YuvanJain/WGAN

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

WGAN Image Generator (CIFAR-10)

This is a college assignment project demonstrating a Wasserstein Generative Adversarial Network (WGAN) trained to generate images resembling the CIFAR-10 dataset (cars, planes, animals, etc.). It includes the model architecture, a Jupyter notebook for training, and a Streamlit web application to easily visualize the generated images.

Project Structure

  • model.py: Contains the PyTorch implementation of the WGAN Generator architecture (ConvTranspose2d blocks).
  • streamlit_app.py: A Streamlit web application that provides an interactive UI to generate and display random images using the pretrained generator.
  • wgan-code.ipynb: A Jupyter Notebook containing the full training pipeline, including dataset preparation, Generator and Critic (Discriminator) architectures, and the custom WGAN loss with weight clipping or gradient penalty.
  • requirements.txt: Lists all Python dependencies required to run the project.
  • weights/: Directory containing the saved model parameters (generator.pth) after training.

Installation

  1. Clone or download this repository.
  2. Ensure you have Python installed. It is recommended to use a virtual environment:
    python -m venv venv
    source venv/bin/activate  # On Windows: venv\Scripts\activate
  3. Install the required dependencies:
    pip install -r requirements.txt

Usage

Running the Web Application

To start the interactive interface and generate images:

streamlit run streamlit_app.py

This will open a local web server (usually at http://localhost:8501) where you can:

  • Select the grid size for image generation (4, 9, 16, or 25 images).
  • Click the "Generate Images" button.
  • View the resulting 32x32 RGB images generated from random noise.

Training the Model

If you wish to train the model from scratch or explore the training process, open the Jupyter Notebook:

jupyter notebook wgan-code.ipynb

The notebook contains code to load the CIFAR-10 dataset, define the training loops, and save the resulting weights into the weights/ directory.

How it Works

A random latent vector z (size 100) is sampled from a standard normal distribution and passed through the trained Generator network. The Generator uses a series of transposed convolutions to upscale this noise into a 3 x 32 x 32 RGB image, producing new images that visually mimic the distribution of the original training data.

Requirements

  • Python 3.8+
  • PyTorch
  • Streamlit
  • Pillow
  • NumPy
  • Torchvision

(See requirements.txt for exact version requirements)

About

A web-based implementation of a Wasserstein Generative Adversarial Network (WGAN) built with PyTorch and Flask, designed to generate realistic images through stable training using the Wasserstein loss. The project demonstrates how advanced deep learning models can be integrated into an interactive web application, making it easy to experiment with.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors