Code for Multivariate soft rank via entropic optimal transport: sample efficiency and generative modeling
This repository provides two applications of novel multivariate soft rank energy (sRE) and soft rank mmd (sRMMD). (a) Developing a generative model using sRE and sRMMD as the loss functions to produce MNIST-digits, (b) utilizing sRMMD as the loss in a deep generative model to produce valid knockoffs in order to select statistically significant features.
- python=3.6.5
- numpy=1.14.0
- scipy=1.0.0
- pytorch=0.4.1
- cvxpy=1.0.10
- cvxopt=1.2.0
- pandas=0.23.4
-
To reproduce the MNIST results from the paper:
- Figure 1(b)- run 'mnist_figures_geneartion.py'
- Figure 1(a)- use lossType = 'mmd' and run 'mnist_figures_geneartion.py'
- Figure 1(c)- use lossType = 'sRMMD' and run 'mnist_figures_geneartion.py'
- Figure 1(b)- run 'mnist_figures_geneartion.py'
-
To reproduce knockoff figures from the paper
- Extra package dependencies for other benchmarks
- DDLK : install the package from https://github.com/rajesh-lab/ddlk
- KnockoffGAN : install Tensorflow v2 and use code from https://bitbucket.org/mvdschaar/mlforhealthlabpub/src/master/
- Reproducing Figure 2(c)- run 'knockoff_figures_geneartion.py
- Figure 2(a)- use distType = 'GaussianAR1' and run 'knockoff_figures_geneartion.py
- Figure 2(b)- use distType = 'GaussianMixtureAR1' and run 'knockoff_figures_geneartion.py
- Figure 2(d)- use distType = 'SparseGaussian' and run 'knockoff_figures_geneartion.py
- Extra package dependencies for other benchmarks
-
To reproduce Table 1 from the paper
- run real_dataset.py
N:B: In case of any error regarding package dependices while running 'mnist_figures_geneartion.py' and 'real_dataset.py', run each method separately.
These notebooks provide an overall view how sRMMD-knockoff filter works on synthetic and real data
- Examples/knockoff_synthetic_settings.ipynb code to generate valid knockoffs using sRMMD.
- Examples/knockoff_real_data.ipynb metabolites selection using sRMMD knockoffs on the real data set available in dataset/Real dataset