A zoo of GAN implementations.
pip install -r requirements.txt
- currently uses
torchvision.datasets.ImageFolder
Dataset. Changedata_dir
parameter inconfig.yml
to your custom dataset path. - Only 64x64 images supported.
-
to start model training, run
python train.py <model-name> --config-dir path/to/config.yml
-
supported models:
wgan
: Wasserstein GAN with gradient clipping.wgan-gp
: WGAN with gradient penalty.dcgan
: DCGAN
-
for help, run
python train.py --help
-
for model specific help, run
python train.py <model_name> --help
- Config file controls the model behaviour.
- Can be extended to have more fields as required by the model.
name: <str> model/config name
device: <str> [cuda|cpu] device to load models to.
data_dir: <str> path to data dir.
seed: <int> seed to control randomness.
z_dim: <int> latent dimension for generator noise input.
imsize: <int> input/output image size.
img_ch: <int> number of channels in image.
w_gp: <number> Gradient Penalty weight.
n_critic: <int> number of critic iterations.
batch_size: <int> batch size for training.
epochs: <int> number of epochs to train for.
viz_freq: <int> image vizualisation frequency (in steps).
lr:
g: <float> learning rate for generator.
d: <float> learning rate for discriminator/critic.
NOTE: Models not trained to convergence!!
- DCGAN
- WGAN
- WGAN-GP
- Model Checkpointing (Save/Load).
- Flexible Image Sizes.
- Other GANs
- More Datasets
- MNIST
- CIFAR