Skip to content

Reproducibility code for Efficient Aggregated Kernel Tests using Incomplete U-statistics, by Schrab, Kim, Guedj and Gretton: https://arxiv.org/abs/2206.09194 NeurIPS 2022

License

Notifications You must be signed in to change notification settings

antoninschrab/agginc-paper

Repository files navigation

Reproducibility code for AggInc: Efficient Aggregated Kernel Tests using Incomplete U-statistics

This GitHub repository contains the code for the reproducibility of the experiments in our paper Efficient Aggregated Kernel Tests using Incomplete U-statistics.

To use our MMDAggInc, HSICAggInc and KSDAggInc tests in practice, we recommend using our agginc package, more details available on the agginc repository.

The code for reproducibility of the experiments of our paper, and for generating the figures in figures, is presented in the notebook experiments.ipynb. The outputs of all the experiments are saved in results.

Requirements

  • python 3.9

The packages in requirements.txt are required to run our tests and the ones we compare against. The numpy package version <= 1.21 is only needed for compatibility with the theano package.

Additionally, the jax and jaxlib packages are required to run the Jax implementation of AggInc in agginc/jax.py.

Installation

In a chosen directory, clone the repository and change to its directory by executing

git clone git@github.com:antoninschrab/agginc-paper.git
cd agginc-paper

We then recommend creating and activating a virtual environment by either

  • using venv:
    python3 -m venv agginc-env
    source agginc-env/bin/activate
    # can be deactivated by running:
    # deactivate
    
  • or using conda:
    conda create --name agginc-env python=3.9
    conda activate agginc-env
    # can be deactivated by running:
    # conda deactivate
    

The required packages can then be installed in the virtual environment by running

python -m pip install -r requirements.txt

For using the Jax implementation of our tests, Jax needs to be installed, for which we recommend using conda. This can be done by running

  • for GPU:
    conda install -c conda-forge -c nvidia pip cuda-nvcc "jaxlib=0.4.1=*cuda*" jax
  • or, for CPU:
    conda install -c conda-forge -c nvidia pip jaxlib=0.4.1 jax

How to use MMDAggInc, HSICAggInc and KSDAggInc in practice?

The MMDAggInc, HSICAggInc and KSDAggInc tests are implemented as the function agginc in agginc/np.py for the Numpy version and in agginc/jax.py for the Jax version.

For the Numpy implementation of our AggInc tests, we only require the numpy, scipy and psutil packages.

For the Jax implementation of our AggInc tests, we only require the jax, jaxlib and psutil packages.

To use our tests in practice, we recommend using our agginc package which is available on the agginc repository. It can be installed by running

pip install git+https://github.com/antoninschrab/agginc.git

Installation instructions and example code are available on the agginc repository.

We also provide some code showing how to use our AggInc tests in demo.ipynb.

In practice, we recommend using the Jax implementation as it runs considerably faster.

Speed comparison

We recommend using our Jax implementation in agginc/jax.py over our Numpy implementation in agginc/np.py as it runs more than 100 times faster after compilation, as can be seen from the results in the notebook speed.ipynb which are reported below.

Speed in ms Numpy (CPU) Jax (CPU) Jax (GPU)
MMDAggInc 4490 844 23
HSICAggInc 2820 539 18
KSDAggInc 3770 590 22

References

In our experiments, we compare MMDAggInc to

and to

We compare HSICAggInc to

We compare KSDAggInc to

and to

and to

Contact

If you have any issues running our code, please do not hesitate to contact Antonin Schrab.

Affiliations

Centre for Artificial Intelligence, Department of Computer Science, University College London

Gatsby Computational Neuroscience Unit, University College London

Inria London

Bibtex

@inproceedings{schrab2022efficient,
  author    = {Antonin Schrab and Ilmun Kim and Benjamin Guedj and Arthur Gretton},
  title     = {Efficient Aggregated Kernel Tests using Incomplete {$U$}-statistics},
  booktitle = {Advances in Neural Information Processing Systems 35: Annual Conference
               on Neural Information Processing Systems 2022, NeurIPS 2022},
  editor    = {Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho},
  year      = {2022},
}

License

MIT License (see LICENSE.md).

Related tests

  • mmdagg: MMD Aggregated MMDAgg test
  • ksdagg: KSD Aggregated KSDAgg test
  • mmdfuse: MMD-Fuse test
  • dpkernel: Differentially private dpMMD dpHSIC tests
  • dckernel: Robust to Data Corruption dcMMD dcHSIC tests

About

Reproducibility code for Efficient Aggregated Kernel Tests using Incomplete U-statistics, by Schrab, Kim, Guedj and Gretton: https://arxiv.org/abs/2206.09194 NeurIPS 2022

Topics

Resources

License

Stars

Watchers

Forks