Navigation Menu

Skip to content

antoninschrab/mmdagg-paper

Repository files navigation

Reproducibility code for MMDAgg: MMD Aggregated Two-Sample Test

This GitHub repository contains the code for the reproducible experiments presented in our paper MMD Aggregated Two-Sample Test.

We provide the code to run the experiments to generate Figures 1-10 and Table 2 from our paper, those can be found in media. The code for the Failing Loudly experiment (with results reported in Table 1) can be found on the FL-MMDAgg repository.

To use our MMDAgg test in practice, we recommend using our mmdagg package, more details available on the mmdagg repository.

Our implementation uses two quantile estimation methods (wild bootstrap and permutations). The MMDAgg test aggregates over different types of kernels (e.g. Gaussian, Laplace, Inverse Multi-Quadric (IMQ), Matérn (with various parameters) kernels), each with several bandwidths. In practice, we recommend aggregating over both Gaussian and Laplace kernels, each with 10 bandwidths.

Requirements

  • python 3.9

The packages in requirements.txt are required to run our tests and the ones we compare against.

Additionally, the jax and jaxlib packages are required to run the Jax implementation of MMDAgg in mmdagg/jax.py.

Installation

In a chosen directory, clone the repository and change to its directory by executing

git clone git@github.com:antoninschrab/mmdagg-paper.git
cd mmdagg-paper

We then recommend creating and activating a virtual environment by either

  • using venv:
    python3 -m venv mmdagg-env
    source mmdagg-env/bin/activate
    # can be deactivated by running:
    # deactivate
    
  • or using conda:
    conda create --name mmdagg-env python=3.9
    conda activate mmdagg-env
    # can be deactivated by running:
    # conda deactivate
    

The packages required for reproducibility of the experiments can then be installed in the virtual environment by running

python -m pip install -r requirements.txt

For using the Jax implementation of MMDAgg, Jax needs to be installed (instructions). For example, this can be done by running

  • for GPU:
    pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
    # conda install -c conda-forge -c nvidia pip numpy scipy cuda-nvcc "jaxlib=0.4.1=*cuda*" jax
  • or, for CPU:
    conda install -c conda-forge -c nvidia pip jaxlib=0.4.1 jax

Reproducing the experiments of the paper

To run the experiments, the following command can be executed

python experiments.py

This command saves the results in dedicated .csv and .pkl files in a new directory user/raw. The output of this command is already provided in paper/raw. The results of the rest of the experiments, saved in the results directory, can be obtained by running the Computations_mmdagg.ipynb notebook and the Computations_autotst.ipynb notebook which uses the autotst package introduced in the AutoML Two-Sample Test paper.

The actual figures of the paper can be obtained from the saved results by running the code in the figures.ipynb notebook.

All the experiments are comprised of 'embarrassingly parallel for loops', significant speed up can be obtained by using parallel computing libraries such as joblib or dask.

Data

Half of the experiments uses a down-sampled version of the MNIST dataset which is created as a .data file in a new directory mnist_dataset when running the script experiments.py. This dataset can also be generated on its own by executing

python mnist.py

The other half of the experiments uses samples drawn from a perturbed uniform density (Eq. 17). A rejection sampler f_theta_sampler for this density is implemented in sampling.py.

How to use MMDAgg in practice?

The MMDAgg test is implemented as the function mmdagg in mmdagg/np.py for the Numpy version and in mmdagg/jax.py for the Jax version.

For the Numpy implementation of our MMDAgg test, we only require the numpy and scipy packages.

For the Jax implementation of our MMDAgg test, we only require the jax and jaxlib packages.

To use our tests in practice, we recommend using our mmdagg package which is available on the mmdagg repository. It can be installed by running

pip install git+https://github.com/antoninschrab/mmdagg.git

Installation instructions and example code are available on the mmdagg repository.

We also provide some code showing how to use our MMDAgg test in the demo_speed.ipynb notebook which also contains speed comparisons between the Jax and Numpy implementations, as reported below.

Speed in s Numpy (CPU) Jax (CPU) Jax (GPU)
MMDAgg 43.1 14.9 0.495

In practice, we recommend using the Jax implementation as it runs considerably faster (100 times faster in the above table, see notebook demo_speed.ipynb).

References

Failing Loudly: An Empirical Study of Methods for Detecting Dataset Shift. Stephan Rabanser, Stephan Günnemann, Zachary C. Lipto. (paper, code)

Learning Kernel Tests Without Data Splitting. Jonas M. Kübler, Wittawat Jitkrittum, Bernhard Schölkopf, Krikamol Muandet. (paper, code)

AutoML Two-Sample Test. Jonas M. Kübler, Vincent Stimper, Simon Buchholz, Krikamol Muandet, Bernhard Schölkopf. (paper, code)

MMDAggInc

For a computationally efficient version of MMDAgg which can run in linear time, check out our paper Efficient Aggregated Kernel Tests using Incomplete U-statistics with reproducible experiments in the agginc-paper repository and a package in the agginc repository.

Contact

If you have any issues running our code, please do not hesitate to contact Antonin Schrab.

Affiliations

Centre for Artificial Intelligence, Department of Computer Science, University College London

Gatsby Computational Neuroscience Unit, University College London

Inria London

Bibtex

@article{schrab2021mmd,
  author  = {Antonin Schrab and Ilmun Kim and M{\'e}lisande Albert and B{\'e}atrice Laurent and Benjamin Guedj and Arthur Gretton},
  title   = {{MMD} Aggregated Two-Sample Test},
  journal = {Journal of Machine Learning Research},
  year    = {2023},
  volume  = {24},
  number  = {194},
  pages   = {1--81},
  url     = {http://jmlr.org/papers/v24/21-1289.html}
}

License

MIT License (see LICENSE.md).

About

Reproducibility code for MMD Aggregated Two-Sample Test, by Schrab, Kim, Albert, Laurent, Guedj and Gretton: https://arxiv.org/abs/2110.15073

Topics

Resources

License

Stars

Watchers

Forks