# Environment mmdfuse-env

In [None]:
import jax
import jax.numpy as jnp
from jax import random
import numpy as np
from tqdm.auto import tqdm
from pathlib import Path
Path("results").mkdir(exist_ok=True)
%load_ext autoreload
%autoreload 2

In [5]:
from all_tests import mmdfuse_test, mmd_median_test, mmd_split_test
from all_tests import mmdagg_test, mmdagg_test_permutation, mmdagginc_test, deep_mmd_test
from all_tests import met_test, scf_test
from all_tests import ctt_test, actt_test

In [6]:
def sampler_normal(key, N, d):
    key, subkey = random.split(key)
    subkeys = random.split(subkey, num=2)
    X = jax.random.multivariate_normal(subkeys[0], jnp.zeros((d, )), jnp.eye(d), shape=(N,)) 
    Y = jax.random.multivariate_normal(subkeys[1], 1.1 * jnp.zeros((d, )), jnp.eye(d), shape=(N,)) 
    return X, Y

In [4]:
# compile jax functions
mmdfuse_test(X, Y, subkey, seed)
mmd_median_test(X, Y, subkey, seed)
mmd_split_test(X, Y, subkey, seed)
mmdagg_test_permutation(X, Y, subkey, seed)
mmdagginc_test(X, Y, subkey, seed)

0

## Time vary sample size

In [1]:
sample_sizes = (250, 500, 1000, 2000, 4000)
d = 10

tests = (mmdfuse_test, mmd_median_test, mmd_split_test, mmdagg_test_permutation, mmdagginc_test, deep_mmd_test, met_test, scf_test, ctt_test, actt_test)
time_mean = np.zeros((len(tests), len(sample_sizes)))
time_std = np.zeros((len(tests), len(sample_sizes)))
key = random.PRNGKey(42)
seed = 42
for s in tqdm(range(len(sample_sizes))):
    sample_size = sample_sizes[s]
    key, subkey = random.split(key)
    X, Y = sampler_normal(subkey, N=sample_size, d=d)
    key, subkey = random.split(key)
    seed += 1
    for t in range(len(tests)):
        test = tests[t]
        compiled = test(X - 1, Y + 1, subkey, seed)
        time_f = %timeit -o -r 10 -n 1 test(X, Y, subkey, seed)
        time_mean[t,s] = np.mean(time_f.timings)
        time_std[t,s] = np.std(time_f.timings)  

jnp.save("results/runtimes_vary_n_mean.npy", time_mean)
jnp.save("results/runtimes_vary_n_std.npy", time_std)
jnp.save("results/runtimes_vary_n_x_axis.npy", sample_sizes)

print("sample sizes :", sample_sizes)
print("d :", d)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(time_mean[t])



sample sizes : (250, 500, 1000, 2000, 4000)
d : 10

<function mmdfuse_test at 0x7f775811a8b0>
[0.01525387 0.05450232 0.20295158 0.82106686 4.123863  ]
 
<function mmd_median_test at 0x7f7c131cea60>
[0.00291616 0.00795466 0.02829301 0.10554314 0.4574714 ]
 
<function mmd_split_test at 0x7f7c131ceaf0>
[0.00322919 0.00586797 0.01951194 0.05306372 0.2013358 ]
 
<function mmdagg_test_permutation at 0x7f78f832ddc0>
[0.02957679 0.0913114  0.32629034 1.4267617  7.1457143 ]
 
<function mmdagginc_test at 0x7f7c131cec10>
[0.00362434 0.00749975 0.0132229  0.02457804 0.04737148]
 
<function deep_mmd_test at 0x7f7c131ceca0>
[15.788608 15.728316 16.335041 15.397292 32.340313]
 
<function met_test at 0x7f7c131ced30>
[1.848104  1.8826377 1.9914436 2.2622426 2.6341183]
 
<function scf_test at 0x7f7c131cedc0>
[1.011565  1.0351776 1.0738761 1.1150266 1.1129541]
 
<function ctt_test at 0x7f7c131cee50>
[0.0055323  0.01416196 0.06024303 0.2431935  0.9558423 ]
 
<function actt_test at 0x7f7c131ceee0>
[0.0786

## Time vary dimension

In [13]:
sample_size = 500
ds = (1, 10, 100, 1000, 10000)

tests = (mmdfuse_test, mmd_median_test, mmd_split_test, mmdagg_test_permutation, mmdagginc_test, deep_mmd_test, met_test, scf_test, ctt_test, actt_test)
time_mean = np.zeros((len(tests), len(ds)))
time_std = np.zeros((len(tests), len(ds)))
key = random.PRNGKey(42)
seed = 42
for s in tqdm(range(len(ds))):
    d = ds[s]
    key, subkey = random.split(key)
    X, Y = sampler_normal(subkey, N=sample_size, d=d)
    key, subkey = random.split(key)
    seed += 1
    for t in range(len(tests)):
        if (t, s) != (6, 4): 
            test = tests[t]
            compiled = test(X - 1, Y + 1, subkey, seed)
            time_f = %timeit -o -r 10 -n 1 test(X, Y, subkey, seed)
            time_mean[t,s] = np.mean(time_f.timings)
            time_std[t,s] = np.std(time_f.timings)  

if save:
    jnp.save("results/runtimes_vary_d_mean.npy", time_mean)
    jnp.save("results/runtimes_vary_d_std.npy", time_std)
    jnp.save("results/runtimes_vary_d_x_axis.npy", ds)

print("dimensions :", ds)
print("sample size :", sample_size)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(time_mean[t])


dimensions : (1, 10, 100, 1000, 10000)
sample size : 500
 
<function mmdfuse_test at 0x7f775811a8b0>
[0.04927933 0.05408762 0.05455407 0.06881471 0.46823305]
 
<function mmd_median_test at 0x7f7c131cea60>
[0.00755315 0.00757084 0.00923714 0.02361597 0.40707476]
 
<function mmd_split_test at 0x7f7c131ceaf0>
[0.00543419 0.00547784 0.00657382 0.01416614 0.11493775]
 
<function mmdagg_test_permutation at 0x7f78f832ddc0>
[0.09055484 0.09227856 0.09387398 0.11022529 0.56206424]
 
<function mmdagginc_test at 0x7f7c131cec10>
[0.00718643 0.00717156 0.00775331 0.01344394 0.10263746]
 
<function deep_mmd_test at 0x7f7c131ceca0>
[17.29203677 16.16944732 17.47588174 18.67277506 16.82258415]
 
<function met_test at 0x7f7c131ced30>
[1.84779273 1.84727186 2.13086594 6.05938222 0.        ]
 
<function scf_test at 0x7f7c131cedc0>
[0.99328488 1.05924308 1.05488316 1.29074975 1.96614258]
 
<function ctt_test at 0x7f7c131cee50>
[0.01220438 0.01762149 0.07314361 0.52497685 5.52379876]
 
<function actt_test 

# Environment autogluon-env

In [5]:
import jax
import jax.numpy as jnp
from jax import random
import numpy as np
from tqdm.auto import tqdm
from pathlib import Path
Path("results").mkdir(exist_ok=True)

In [10]:
import autotst
from utils import HiddenPrints

def autotst_test(X, Y, key, seed, time=60):
    with HiddenPrints():
        tst = autotst.AutoTST(X, Y, split_ratio=0.5, model=autotst.model.AutoGluonTabularPredictor)
        tst.split_data()
        tst.fit_witness(time_limit=time)  # time limit adjustable to your needs (in seconds)
        p_value = tst.p_value_evaluate(permutations=10000)  # control number of permutations in the estimation
    return int(p_value <= 0.05)

In [3]:
def sampler_normal(key, N, d):
    key, subkey = random.split(key)
    subkeys = random.split(subkey, num=2)
    X = jax.random.multivariate_normal(subkeys[0], jnp.zeros((d, )), jnp.eye(d), shape=(N,)) 
    Y = jax.random.multivariate_normal(subkeys[1], 1.1 * jnp.zeros((d, )), jnp.eye(d), shape=(N,)) 
    return X, Y

## Time vary sample size

In [2]:
sample_sizes = (250, 500, 1000, 2000, 4000)
d = 10

tests = (autotst_test, )
time_mean = np.zeros((len(tests), len(sample_sizes)))
time_std = np.zeros((len(tests), len(sample_sizes)))
key = random.PRNGKey(42)
seed = 42
for s in tqdm(range(len(sample_sizes))):
    sample_size = sample_sizes[s]
    key, subkey = random.split(key)
    X, Y = sampler_normal(subkey, N=sample_size, d=d)
    key, subkey = random.split(key)
    seed += 1
    for t in range(len(tests)):
        test = tests[t]
        time_f = %timeit -o -r 10 -n 1 test(X, Y, subkey, seed)
        time_mean[t,s] = np.mean(time_f.timings)
        time_std[t,s] = np.std(time_f.timings)  

if save:
    jnp.save("results/runtimes_vary_n_autotst_mean.npy", time_mean)
    jnp.save("results/runtimes_vary_n_autotst_std.npy", time_std)

print("sample_sizes :", sample_sizes)
print("d :", d)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(time_mean[t])



sample_sizes : (250, 500, 1000, 2000, 4000)
d : 10
 
<function autotst_test at 0x7fe293f28ee0>
[39.0804517  44.53311026 68.31917908 70.4661544  74.73801817]



## Time vary dimension

In [3]:
sample_size = 500
ds = (1, 10, 100, 1000, 10000)
tests = (autotst_test, )
time_mean = np.zeros((len(tests), len(ds)))
time_std = np.zeros((len(tests), len(ds)))
key = random.PRNGKey(42)
seed = 42
for s in tqdm(range(len(ds))):
    d = ds[s]
    key, subkey = random.split(key)
    X, Y = sampler_normal(subkey, N=sample_size, d=d)
    key, subkey = random.split(key)
    seed += 1
    for t in range(len(tests)):
        if (t, s) != (6, 4): 
            test = tests[t]
            time_f = %timeit -o -r 10 -n 1 test(X, Y, subkey, seed)
            time_mean[t,s] = np.mean(time_f.timings)
            time_std[t,s] = np.std(time_f.timings)  

if save:
    jnp.save("results/runtimes_vary_d_autotst_mean.npy", time_mean)
    jnp.save("results/runtimes_vary_d_autotst_std.npy", time_std)

print("dimensions :", ds)
print("sample size :", sample_size)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(time_mean[t])


dimensions : (1, 10, 100, 1000, 10000)
sample size : 500
 
<function autotst_test at 0x7fe293f28ee0>
[39.19415824 39.18103299 51.87275588 63.75943972 73.03921813]

