Skip to content

Adaptation of PyTorch implementation for Score-Based Generative Modeling through Stochastic Differential Equations for emulating high resolution climate models

License

henryaddison/mlde

 
 

Repository files navigation

ML Downscaling Emulator

Forked from PyTorch implementation for the paper Score-Based Generative Modeling through Stochastic Differential Equations

by Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole

Dependencies

  1. Create conda environment: conda env create -f environment.lock.yml
  2. Clone and install https://github.com/henryaddison/mlde_utils into the environment: e.g. pip install -e ../mlde_utils
  3. Install ml_downscaling_emulator locally: pip install -e .
  4. Install unet code: git clone --depth 1 git@github.com:henryaddison/Pytorch-UNet src/ml_downscaling_emulator/unet
  5. Configure necessary environment variables: DERVIED_DATA and KK_SLACK_WH_URL

Usage

Smoke test

bin/local-test-train

Training

Train models through main.py.

main.py:
  --config: Training configuration.
    (default: 'None')
  --mode: <train|eval>: Running mode: train or eval
  --workdir: Working directory
  • config is the path to the config file. Our prescribed config files are provided in configs/. They are formatted according to ml_collections and should be quite self-explanatory.

    Naming conventions of config files: the path of a config file is a combination of the following dimensions:

    • dataset: One of cifar10, celeba, celebahq, celebahq_256, ffhq_256, celebahq, ffhq.
    • model: One of ncsn, ncsnv2, ncsnpp, ddpm, ddpmpp.
    • continuous: train the model with continuously sampled time steps.
  • workdir is the path that stores all artifacts of one experiment, like checkpoints, samples, and evaluation results.

  • mode is "train". When set to "train", it starts the training of a new model, or resumes the training of an old model if its meta-checkpoints (for resuming running after pre-emption in a cloud environment) exist in workdir/checkpoints-meta .

    These functionalities can be configured through config files, or more conveniently, through the command-line support of the ml_collections package. For example, to generate samples and evaluate sample quality, supply the --config.eval.enable_sampling flag; to compute log-likelihoods, supply the --config.eval.enable_bpd flag, and specify --config.eval.dataset=train/test to indicate whether to compute the likelihoods on the training or test dataset.

Sampling

TODO

How to extend the code

  • New SDEs: inherent the sde_lib.SDE abstract class and implement all abstract methods. The discretize() method is optional and the default is Euler-Maruyama discretization. Existing sampling methods and likelihood computation will automatically work for this new SDE.
  • New predictors: inherent the sampling.Predictor abstract class, implement the update_fn abstract method, and register its name with @register_predictor. The new predictor can be directly used in sampling.get_pc_sampler for Predictor-Corrector sampling, and all other controllable generation methods in controllable_generation.py.
  • New correctors: inherent the sampling.Corrector abstract class, implement the update_fn abstract method, and register its name with @register_corrector. The new corrector can be directly used in sampling.get_pc_sampler, and all other controllable generation methods in controllable_generation.py.

Tips

  • When using the JAX codebase, you can jit multiple training steps together to improve training speed at the cost of more memory usage. This can be set via config.training.n_jitted_steps. For CIFAR-10, we recommend using config.training.n_jitted_steps=5 when your GPU/TPU has sufficient memory; otherwise we recommend using config.training.n_jitted_steps=1. Our current implementation requires config.training.log_freq to be dividable by n_jitted_steps for logging and checkpointing to work normally.
  • The snr (signal-to-noise ratio) parameter of LangevinCorrector somewhat behaves like a temperature parameter. Larger snr typically results in smoother samples, while smaller snr gives more diverse but lower quality samples. Typical values of snr is 0.05 - 0.2, and it requires tuning to strike the sweet spot.
  • For VE SDEs, we recommend choosing config.model.sigma_max to be the maximum pairwise distance between data samples in the training dataset.

References

This code based on the following work:

@inproceedings{
  song2021scorebased,
  title={Score-Based Generative Modeling through Stochastic Differential Equations},
  author={Yang Song and Jascha Sohl-Dickstein and Diederik P Kingma and Abhishek Kumar and Stefano Ermon and Ben Poole},
  booktitle={International Conference on Learning Representations},
  year={2021},
  url={https://openreview.net/forum?id=PxTIG12RRHS}
}

This work is built upon some previous papers which might also interest you:

  • Song, Yang, and Stefano Ermon. "Generative Modeling by Estimating Gradients of the Data Distribution." Proceedings of the 33rd Annual Conference on Neural Information Processing Systems. 2019.
  • Song, Yang, and Stefano Ermon. "Improved techniques for training score-based generative models." Proceedings of the 34th Annual Conference on Neural Information Processing Systems. 2020.
  • Ho, Jonathan, Ajay Jain, and Pieter Abbeel. "Denoising diffusion probabilistic models." Proceedings of the 34th Annual Conference on Neural Information Processing Systems. 2020.

About

Adaptation of PyTorch implementation for Score-Based Generative Modeling through Stochastic Differential Equations for emulating high resolution climate models

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 95.1%
  • Cuda 3.7%
  • Other 1.2%