This GitHub repository contains the code for the reproducible experiments presented in our paper Robust Kernel Hypothesis Testing under Data Corruption.
The code is written in JAX which can leverage the architecture of GPUs to provide considerable computational speedups.
In a chosen directory, clone the repository and change to its directory by executing
git clone git@github.com:antoninschrab/dckernel-paper.git
cd dckernel-paper
We then recommend creating a conda
environment with the required dependencies:
conda create -n dckernel-env
conda activate dckernel-env
# install JAX for GPU:
pip install -U "jax[cuda12]"
# or install JAX for CPU:
pip install -U "jax[cpu]"
# install all other dependencies
conda install numpy scipy scikit-learn matplotlib tqdm
To run only dcMMD and dcHSIC it is sufficient to only install JAX as explained in our dckernel repository.
The code to reproduce the experiments of the paper can be found in the experiments.ipynb notebook.
For the experiments, the results and figures are saved in the results and figures directories, respectively.
Our proposed dcMMD and dcHSIC tests are implemented in dctests.py.
To use our tests in practice, we recommend using our dckernel
package which is available on the dckernel repository.
It can be installed by running
pip install git+https://github.com/antoninschrab/dckernel.git
Installation instructions and example code are available on the dckernel repository.
We also illustrate how to use the tests in the demo section of the notebook experiments.ipynb.
- DP tests: repository, paper
- IMDb dataset: repository, paper
If you have any issues running our code, please do not hesitate to contact Antonin Schrab.
Centre for Artificial Intelligence, Department of Computer Science, University College London
Gatsby Computational Neuroscience Unit, University College London
Inria London
@unpublished{schrab2024robust,
title={Robust Kernel Hypothesis Testing under Data Corruption},
author={Antonin Schrab and Ilmun Kim},
year={2024},
url = {https://arxiv.org/abs/2405.19912},
eprint={2405.19912},
archivePrefix={arXiv},
primaryClass={stat.ML}
}
MIT License (see LICENSE).