FactorVAE implementation using Pytorch, Pytorch Lightning, Pipenv and Hydra. Currently, this implementation supports only the DSprites dataset. This implementation follows as much as possible the specifications contained in Disentangling by Factorising (Kim & Mnih, 2018) https://arxiv.org/pdf/1802.05983.pdf
- Clone the repository
git clone https://github.com/Michedev/FactorVAE.git
-
Install anaconda
-
Run a terminal with the anaconda environment
## Train
To train the model, run the following command:
```bash
anaconda-project run python train.py
or alternatively, to train single GPU:
anaconda-project run train-gpu
├── data # Data folder
├── deep_learning_template # source code
│ ├── config
│ │ ├── dataset # Dataset config
│ │ ├── model # Model config
│ │ ├── model_dataset # model and dataset specific config
│ │ ├── test.yaml # testing configuration
│ │ └── train.yaml # training configuration
│ ├── dataset # Dataset definition
│ ├── model # Model definition
│ ├── utils
│ │ ├── experiment_tools.py # Iterate over experiments
│ │ └── paths.py # common paths
│ ├── train.py # Entrypoint point for training
│ └── test.py # Entrypoint point for testing
├── pyproject.toml # Project configuration
├── saved_models # where models are saved
└── readme.md # This file
Once trained a model, generate the image through the script generate.py via the following command
anaconda-project run python generate.py checkpoint_path=saved_models/{model_folder}
Then, inside {model_folder} there will be the file generated.png containing the batch of generated images.
Add disentanglement evaluationAdd generation procedure