Skip to content

Toy datasets to evaluate algorithms for domain generalization and invariance learning.

License

Notifications You must be signed in to change notification settings

Slowika/InvarianceUnitTests

 
 

Repository files navigation

Linear unit-tests for invariance discovery - Code

Official code for the paper Linear unit-tests for invariance discovery, presented as a spotlight talk at the NeurIPS 2020 Workshop Causal Discovery & Causality-Inspired Machine Learning.

Installing requirements

conda create -n invariance python=3.8
conda activate invariance
python3.8 -m pip install -U -r requirements.txt

Running a single experiment

python3.8 scripts/main.py \
    --model ERM --dataset Example1 --n_envs 3 \
    --num_iterations 10000 --dim_inv 5 --dim_spu 5 \
    --hparams '{"lr":1e-3, "wd":1e-4}' --output_dir results/

Running the experiments and printing results

python3.8 scripts/sweep.py --num_iterations 10000 --num_data_seeds 1 --num_model_seed 1 --output_dir results/
python3.8 scripts/collect_results.py results/COMMIT

Reproducing the figures

bash reproduce_plots.sh

Reproducing the results (requires a cluster)

Be careful, this script launches 630 000 jobs for the hyper-parameter search.

bash reproduce_results.sh test

Deactivating and removing the env

conda deactivate
conda remove --name invariance --all

License

This source code is released under the MIT license, included here.

Reference

If you make use of our suite of tasks in your research, please cite the following in your manuscript:

@article{aubin2021linear,
  title={Linear unit-tests for invariance discovery},
  author={Aubin, Benjamin and S{\l}owik, Agnieszka and Arjovsky, Martin and Bottou, Leon and Lopez-Paz, David},
  journal={arXiv preprint arXiv:2102.10867},
  year={2021}
}

About

Toy datasets to evaluate algorithms for domain generalization and invariance learning.

Resources

License

Code of conduct

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 93.7%
  • Shell 6.3%