This is a college assignment project demonstrating a Wasserstein Generative Adversarial Network (WGAN) trained to generate images resembling the CIFAR-10 dataset (cars, planes, animals, etc.). It includes the model architecture, a Jupyter notebook for training, and a Streamlit web application to easily visualize the generated images.
model.py: Contains the PyTorch implementation of the WGAN Generator architecture (ConvTranspose2dblocks).streamlit_app.py: A Streamlit web application that provides an interactive UI to generate and display random images using the pretrained generator.wgan-code.ipynb: A Jupyter Notebook containing the full training pipeline, including dataset preparation, Generator and Critic (Discriminator) architectures, and the custom WGAN loss with weight clipping or gradient penalty.requirements.txt: Lists all Python dependencies required to run the project.weights/: Directory containing the saved model parameters (generator.pth) after training.
- Clone or download this repository.
- Ensure you have Python installed. It is recommended to use a virtual environment:
python -m venv venv source venv/bin/activate # On Windows: venv\Scripts\activate
- Install the required dependencies:
pip install -r requirements.txt
To start the interactive interface and generate images:
streamlit run streamlit_app.pyThis will open a local web server (usually at http://localhost:8501) where you can:
- Select the grid size for image generation (4, 9, 16, or 25 images).
- Click the "Generate Images" button.
- View the resulting 32x32 RGB images generated from random noise.
If you wish to train the model from scratch or explore the training process, open the Jupyter Notebook:
jupyter notebook wgan-code.ipynbThe notebook contains code to load the CIFAR-10 dataset, define the training loops, and save the resulting weights into the weights/ directory.
A random latent vector z (size 100) is sampled from a standard normal distribution and passed through the trained Generator network. The Generator uses a series of transposed convolutions to upscale this noise into a 3 x 32 x 32 RGB image, producing new images that visually mimic the distribution of the original training data.
- Python 3.8+
- PyTorch
- Streamlit
- Pillow
- NumPy
- Torchvision
(See requirements.txt for exact version requirements)