The purpose of this project is generation cats images. To do so, I used a GAN model, specificaly a DCGAN model. The DCGAN generator takes a vector z from the latent space and uses transpose convolution to map the features and output an image of size 3x64x64. The discriminator uses the same procedure but inverted.
I used a public dataset available here: https://www.kaggle.com/datasets/crawford/cat-dataset
The model was trained on 225 epochs with a dataset of size around 9000 images.
Modification were added in the training and dataset to avoid collapse mode:
- Preprocessing of the data: Center around the face of the cat and crop the dataset to 64x64 size (Highly helped the model to converge)
- Change of loss function from log(1-D(G(z))) to -log(D(G(z))) for the generator part (thanks to https://github.com/soumith/ganhacks)
pip install -r requirements.txt
To run the training
python ./src/train_dcgan.py
To generate 64 cats from dcgan
In config.yaml, choose a checkpoint for the generator (default is "./saved_model/generator_checkpoint_225.pt") and run
python ./src/generate_fake.py
The generation plot should be saved in ./plots folder as sample_from_generator.png
Results can be found in the ./plots folder
For training: https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
For preprocessing: https://github.com/AlexiaJM/Deep-learning-with-cats
Mathieu Nalpon