# Environment mmdfuse-env

In [None]:
from sampler_perturbations import sampler_perturbations
import numpy as np
import jax
import jax.numpy as jnp
from jax import random
from tqdm.auto import tqdm
from pathlib import Path
Path("results").mkdir(exist_ok=True)
%load_ext autoreload
%autoreload 2

In [2]:
from all_tests import mmdfuse_test, mmd_median_test, mmd_split_test
from all_tests import mmdagg_test, mmdagginc_test, deep_mmd_test
from all_tests import met_test, scf_test
from all_tests import ctt_test, actt_test

## Vary difficulty d = 1

In [16]:
repetitions = 200
scales = (0, 0.1, 0.2, 0.3, 0.4, 0.5)
number_perturbations = 2
sample_size = 500
d = 1

tests = (mmdfuse_test, mmd_median_test, mmd_split_test, mmdagg_test, mmdagginc_test, deep_mmd_test, met_test, scf_test, ctt_test, actt_test)
outputs = jnp.zeros((len(tests), len(scales), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in (range(len(scales))):
    scale = scales[s]
    for i in (range(repetitions)):
        key, subkey = random.split(key)
        X, Y = sampler_perturbations(m=sample_size, n=sample_size, d=d, scale=scale, number_perturbations=number_perturbations, seed=seed)
        key, subkey = random.split(key)
        seed += 1
        for t in range(len(tests)):
            test = tests[t]
                outputs[t][s][i] = test(
                    X,
                    Y,
                    subkey,
                    seed,
                )

output = jnp.mean(jnp.array(outputs), -1)

jnp.save("results/perturbations_vary_dif_d1.npy", output)
jnp.save("results/perturbations_vary_dif_d1_x_axis.npy", scales)

print("scales :", scales)
print("sample size :", sample_size)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])

scales : (0, 0.1, 0.2, 0.3, 0.4, 0.5)
sample size : 500
 
<function mmdfuse_test at 0x7f87943fb700>
[0.045      0.085      0.17999999 0.505      0.825      0.97499996]
 
<function mmd_median_test at 0x7f8faceb1820>
[0.04  0.06  0.105 0.29  0.465 0.675]
 
<function mmd_split_test at 0x7f8faceb18b0>
[0.065 0.075 0.125 0.24  0.565 0.825]
 
<function mmdagg_test at 0x7f8faceb1940>
[0.035      0.08       0.21       0.60499996 0.88       0.98499995]
 
<function mmdagginc_test at 0x7f8faceb19d0>
[0.04       0.085      0.175      0.525      0.815      0.98499995]
 
<function deep_mmd_test at 0x7f8faceb1a60>
[0.08  0.06  0.11  0.31  0.515 0.84 ]
 
<function met_test at 0x7f8faceb1af0>
[0.065      0.09999999 0.14999999 0.26       0.51       0.73499995]
 
<function scf_test at 0x7f8faceb1b80>
[0.005 0.005 0.005 0.03  0.03  0.255]
 
<function ctt_test at 0x7f8faceb1c10>
[0.045 0.075 0.14  0.355 0.545 0.81 ]
 
<function actt_test at 0x7f8faceb1ca0>
[0.035      0.03       0.13499999 0.415      0.724

## Vary sample size d = 1

In [1]:
repetitions = 200
scale = 0.2
number_perturbations = 2
sample_sizes = (500, 1000, 1500, 2000, 2500, 3000)
d = 1

tests = (mmdfuse_test, mmd_median_test, mmd_split_test, mmdagg_test, mmdagginc_test, deep_mmd_test, met_test, scf_test, ctt_test, actt_test)
outputs = jnp.zeros((len(tests), len(sample_sizes), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in tqdm(range(len(sample_sizes))):
    sample_size = sample_sizes[s]
    for i in tqdm(range(repetitions)):
        key, subkey = random.split(key)
        X, Y = sampler_perturbations(m=sample_size, n=sample_size, d=d, scale=scale, number_perturbations=number_perturbations, seed=seed)
        key, subkey = random.split(key)
        seed += 1
        for t in range(len(tests)):
            test = tests[t]
            outputs[t][s][i] = test(
                X,
                Y,
                subkey,
                seed,
            )

output = jnp.mean(jnp.array(outputs), -1)

jnp.save("results/perturbations_vary_n_d1.npy", output)
jnp.save("results/perturbations_vary_n_d1_x_axis.npy", sample_sizes)

print("sample sizes :", sample_sizes)
print("scale :", scale)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])


sample sizes : (500, 1000, 1500, 2000, 2500, 3000)
scale : 0.2
 
<function mmdfuse_test at 0x7f00b69550d0>
[0.21499999 0.42999998 0.72499996 0.78999996 0.92499995 0.98499995]
 
<function mmd_median_test at 0x7f09886af430>
[0.16       0.22       0.415      0.45499998 0.59       0.7       ]
 
<function mmd_split_test at 0x7f09886af4c0>
[0.14       0.19999999 0.45       0.49499997 0.69       0.74      ]
 
<function mmdagg_test at 0x7f09886af550>
[0.24       0.47       0.78499997 0.875      0.965      0.98999995]
 
<function mmdagginc_test at 0x7f09886af5e0>
[0.205      0.345      0.44       0.49499997 0.655      0.73499995]
 
<function deep_mmd_test at 0x7f09886af670>
[0.125      0.21       0.45499998 0.45       0.69       0.71999997]
 
<function met_test at 0x7f09886af700>
[0.14       0.24       0.265      0.29999998 0.525      0.675     ]
 
<function scf_test at 0x7f09886af790>
[0.015 0.01  0.02  0.04  0.075 0.145]
 
<function ctt_test at 0x7f09886af820>
[0.17999999 0.295      0.515   

## Vary difficulty d = 2

In [2]:
repetitions = 200
scales = (0, 0.2, 0.4, 0.6, 0.8, 1)
number_perturbations = 2
sample_size = 500
d = 2

tests = (mmdfuse_test, mmd_median_test, mmd_split_test, mmdagg_test, mmdagginc_test, deep_mmd_test, met_test, scf_test, ctt_test, actt_test)
outputs = jnp.zeros((len(tests), len(scales), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in tqdm(range(len(scales))):
    scale = scales[s]
    for i in tqdm(range(repetitions)):
        key, subkey = random.split(key)
        X, Y = sampler_perturbations(m=sample_size, n=sample_size, d=d, scale=scale, number_perturbations=number_perturbations, seed=seed)
        key, subkey = random.split(key)
        seed += 1
        for t in range(len(tests)):
            test = tests[t]
            outputs[t][s][i] = test(
                X,
                Y,
                subkey,
                seed,
            )

output = jnp.mean(jnp.array(outputs), -1)

jnp.save("results/perturbations_vary_dif_d2.npy", output)
jnp.save("results/perturbations_vary_dif_d2_x_axis.npy", scales)

print("scales :", scales)
print("sample size :", sample_size)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])


scales : (0, 0.2, 0.4, 0.6, 0.8, 1)
sample size : 500
 
<function mmdfuse_test at 0x7f00b69550d0>
[0.045      0.055      0.17       0.48499998 0.885      0.995     ]
 
<function mmd_median_test at 0x7f09886af430>
[0.05  0.065 0.065 0.05  0.06  0.075]
 
<function mmd_split_test at 0x7f09886af4c0>
[0.05       0.055      0.09999999 0.205      0.415      0.73499995]
 
<function mmdagg_test at 0x7f09886af550>
[0.045      0.055      0.145      0.39499998 0.83       0.98999995]
 
<function mmdagginc_test at 0x7f09886af5e0>
[0.03       0.055      0.13       0.29999998 0.72999996 0.905     ]
 
<function deep_mmd_test at 0x7f09886af670>
[0.035      0.065      0.12       0.11499999 0.25       0.38      ]
 
<function met_test at 0x7f09886af700>
[0.04       0.08       0.095      0.16       0.22999999 0.42999998]
 
<function scf_test at 0x7f09886af790>
[0.06       0.08       0.09999999 0.16       0.255      0.375     ]
 
<function ctt_test at 0x7f09886af820>
[0.035 0.07  0.06  0.055 0.07  0.075]
 


## Vary sample size d = 2

In [7]:
repetitions = 200
scale = 0.4
number_perturbations = 2
sample_sizes = (500, 1000, 1500, 2000, 2500, 3000)
d = 2

tests = (mmdfuse_test, mmd_median_test, mmd_split_test, mmdagg_test, mmdagginc_test, deep_mmd_test, met_test, scf_test, ctt_test, actt_test)
outputs = jnp.zeros((len(tests), len(sample_sizes), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in (range(len(sample_sizes))):
    sample_size = sample_sizes[s]
    for i in (range(repetitions)):
        key, subkey = random.split(key)
        X, Y = sampler_perturbations(m=sample_size, n=sample_size, d=d, scale=scale, number_perturbations=number_perturbations, seed=seed)
        key, subkey = random.split(key)
        seed += 1
        for t in range(len(tests)):
            test = tests[t]
            outputs[t][s][i] = test(
                X,
                Y,
                subkey,
                seed,
            )

output = jnp.mean(jnp.array(outputs), -1)

jnp.save("results/perturbations_vary_n_d2.npy", output)
jnp.save("results/perturbations_vary_n_d2_x_axis.npy", sample_sizes)

print("sample sizes :", sample_sizes)
print("scale :", scale)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])

sample sizes : (500, 1000, 1500, 2000, 2500, 3000)
scale : 0.4
 
<function mmdfuse_test at 0x7fca5c0e20d0>
[0.16499999 0.42999998 0.69       0.805      0.97499996 0.98999995]
 
<function mmd_median_test at 0x7fd27540f430>
[0.065 0.045 0.055 0.045 0.075 0.03 ]
 
<function mmd_split_test at 0x7fd27540f4c0>
[0.08       0.16499999 0.285      0.42999998 0.63       0.72999996]
 
<function mmdagg_test at 0x7fd27540f550>
[0.155      0.33499998 0.59499997 0.765      0.95       0.97499996]
 
<function mmdagginc_test at 0x7fd27540f5e0>
[0.09       0.19       0.21499999 0.325      0.38       0.35      ]
 
<function deep_mmd_test at 0x7fd27540f670>
[0.03       0.11499999 0.16499999 0.17       0.26999998 0.41      ]
 
<function met_test at 0x7fd27540f700>
[0.07       0.105      0.12       0.22999999 0.25       0.295     ]
 
<function scf_test at 0x7fd27540f790>
[0.095      0.09999999 0.09999999 0.17       0.225      0.29999998]
 
<function ctt_test at 0x7fd27540f820>
[0.07  0.045 0.055 0.045 0.05  0

# Environment autogluon-env

In [7]:
from sampler_perturbations import sampler_perturbations
import jax
import jax.numpy as jnp
from jax import random
from tqdm.auto import tqdm
from pathlib import Path
Path("results").mkdir(exist_ok=True)

In [5]:
import autotst
from utils import HiddenPrints

def autotst_test(X, Y, key, seed, time=60):
    with HiddenPrints():
        tst = autotst.AutoTST(X, Y, split_ratio=0.5, model=autotst.model.AutoGluonTabularPredictor)
        tst.split_data()
        tst.fit_witness(time_limit=time)  # time limit adjustable to your needs (in seconds)
        p_value = tst.p_value_evaluate(permutations=10000)  # control number of permutations in the estimation
    return int(p_value <= 0.05)

## Vary difficulty d = 1

In [10]:
repetitions = 200
scales = (0, 0.1, 0.2, 0.3, 0.4, 0.5)
number_perturbations = 2
sample_size = 500
d = 1

tests = (autotst_test, )
outputs = jnp.zeros((len(tests), len(scales), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in (range(len(scales))):
    scale = scales[s]
    for i in (range(repetitions)):
        key, subkey = random.split(key)
        X, Y = sampler_perturbations(m=sample_size, n=sample_size, d=d, scale=scale, number_perturbations=number_perturbations, seed=seed)
        key, subkey = random.split(key)
        seed += 1
        for t in range(len(tests)):
            test = tests[t]
            outputs[t][s][i] = test(
                X,
                Y,
                subkey,
                seed,
            )

output = jnp.mean(jnp.array(outputs), -1)

jnp.save("results/perturbations_vary_dif_d1_autotst.npy", output)

print("scales :", scales)
print("sample size :", sample_size)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])

scales : (0, 0.1, 0.2, 0.3, 0.4, 0.5)
sample size : 500
 
<function autotst_test at 0x7fae853daaf0>
[0.04       0.08       0.17999999 0.36499998 0.63       0.875     ]


## Vary sample size d = 1

In [3]:
repetitions = 200
scale = 0.2
number_perturbations = 2
sample_sizes = (500, 1000, 1500, 2000, 2500, 3000)
d = 1

tests = (autotst_test, )
outputs = jnp.zeros((len(tests), len(sample_sizes), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in tqdm(range(len(sample_sizes))):
    sample_size = sample_sizes[s]
    for i in tqdm(range(repetitions)):
        key, subkey = random.split(key)
        X, Y = sampler_perturbations(m=sample_size, n=sample_size, d=d, scale=scale, number_perturbations=number_perturbations, seed=seed)
        key, subkey = random.split(key)
        seed += 1
        for t in range(len(tests)):
            test = tests[t]
            outputs[t][s][i] = test(
                X,
                Y,
                subkey,
                seed,
            )

output = jnp.mean(jnp.array(outputs), -1)

jnp.save("results/perturbations_vary_n_d1_autotst.npy", output)

print("sample sizes :", sample_sizes)
print("scale :", scale)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])


sample sizes : (500, 1000, 1500, 2000, 2500, 3000)
scale : 0.2
 
<function autotst_test at 0x7fae853daaf0>
[0.11499999 0.14999999 0.22       0.22999999 0.445      0.59499997]



## Vary difficulty d = 2

In [4]:
repetitions = 200
scales = (0, 0.2, 0.4, 0.6, 0.8, 1)
number_perturbations = 2
sample_size = 500
d = 2

tests = (autotst_test, )
outputs = jnp.zeros((len(tests), len(scales), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in tqdm(range(len(scales))):
    scale = scales[s]
    for i in tqdm(range(repetitions)):
        key, subkey = random.split(key)
        X, Y = sampler_perturbations(m=sample_size, n=sample_size, d=d, scale=scale, number_perturbations=number_perturbations, seed=seed)
        key, subkey = random.split(key)
        seed += 1
        for t in range(len(tests)):
            test = tests[t]
            outputs[t][s][i] = test(
                X,
                Y,
                subkey,
                seed,
            )

output = jnp.mean(jnp.array(outputs), -1)

jnp.save("results/perturbations_vary_dif_d2_autotst.npy", output)

print("scales :", scales)
print("sample size :", sample_size)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])


scales : (0, 0.2, 0.4, 0.6, 0.8, 1)
sample size : 500
 
<function autotst_test at 0x7fae853daaf0>
[0.04       0.05       0.13       0.19999999 0.42499998 0.72999996]



## Vary sample size d = 2

In [14]:
repetitions = 200
scale = 0.4
number_perturbations = 2
sample_sizes = (500, 1000, 1500, 2000, 2500, 3000)
d = 2

tests = (autotst_test, )
outputs = jnp.zeros((len(tests), len(sample_sizes), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in (range(len(sample_sizes))):
    sample_size = sample_sizes[s]
    for i in (range(repetitions)):
        key, subkey = random.split(key)
        X, Y = sampler_perturbations(m=sample_size, n=sample_size, d=d, scale=scale, number_perturbations=number_perturbations, seed=seed)
        key, subkey = random.split(key)
        seed += 1
        for t in range(len(tests)):
            test = tests[t]
            outputs[t][s][i] = test(
                X,
                Y,
                subkey,
                seed,
            )

output = jnp.mean(jnp.array(outputs), -1)

jnp.save("results/perturbations_vary_n_d2_autotst.npy", output)

print("sample sizes :", sample_sizes)
print("scale :", scale)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])

sample sizes : (500, 1000, 1500, 2000, 2500, 3000)
scale : 0.4
 
<function autotst_test at 0x7fae853daaf0>
[0.075      0.09       0.16499999 0.12       0.26999998 0.26999998]
