# Environment mmdfuse-env

In [None]:
from sampler_mixture import sampler_mixture
import jax
import jax.numpy as jnp
from jax import random
from tqdm.auto import tqdm
from pathlib import Path
Path("results").mkdir(exist_ok=True)
%load_ext autoreload
%autoreload 2

In [2]:
from all_tests import mmdfuse_test, mmd_median_test, mmd_split_test
from all_tests import mmdagg_test, mmdagginc_test, deep_mmd_test
from all_tests import met_test, scf_test
from all_tests import ctt_test, actt_test

## Vary difficulty

In [8]:
repetitions = 200
shifts = (0.5, 0.6, 0.7, 0.8, 0.9, 1, 1.2, 1.4, 1.6, 1.8, 2)
sample_size = 500

tests = (mmdfuse_test, mmd_median_test, mmd_split_test, mmdagg_test, mmdagginc_test, deep_mmd_test, met_test, scf_test, ctt_test, actt_test)
outputs = jnp.zeros((len(tests), len(shifts), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in (range(len(shifts))):
    shift = shifts[s]
    for i in (range(repetitions)):
        key, subkey = random.split(key)
        X, Y = sampler_mixture(subkey, m=sample_size, n=sample_size, d=2, mu=20, std_1=1, std_2=shift)
        key, subkey = random.split(key)
        seed += 1
        for t in range(len(tests)):
            test = tests[t]
            outputs[t][s][i] = test(
                X,
                Y,
                subkey,
                seed,
            )

output = jnp.mean(jnp.array(outputs), -1)

jnp.save("results/mixture_vary_dif.npy", output)
jnp.save("results/mixture_vary_dif_x_axis.npy", shifts)

print("shifts :", shifts)
print("sample size :", sample_size)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])

shifts : (0.5, 0.6, 0.7, 0.8, 0.9, 1, 1.2, 1.4, 1.6, 1.8, 2)
sample size : 500
 
<function mmdfuse_test at 0x7f78f8795b80>
[0.995      0.945      0.51       0.13499999 0.065      0.035
 0.105      0.235      0.635      0.84       0.905     ]
 
<function mmd_median_test at 0x7f78dca704c0>
[0.055 0.05  0.03  0.035 0.065 0.035 0.04  0.04  0.03  0.035 0.05 ]
 
<function mmd_split_test at 0x7f78dca70550>
[0.885      0.615      0.26       0.095      0.09       0.05
 0.085      0.145      0.315      0.49499997 0.555     ]
 
<function mmdagg_test at 0x7f78dca705e0>
[0.995      0.91499996 0.45999998 0.12       0.09       0.05
 0.105      0.25       0.675      0.90999997 0.945     ]
 
<function mmdagginc_test at 0x7f78dca70670>
[0.97999996 0.805      0.355      0.095      0.065      0.02
 0.105      0.205      0.545      0.78       0.90999997]
 
<function deep_mmd_test at 0x7f78dca70700>
[0.84499997 0.59       0.25       0.11499999 0.075      0.07
 0.105      0.105      0.34       0.545      0.6

## Vary sample size

In [1]:
repetitions = 200
sample_sizes = (500, 1000, 1500, 2000, 2500, 3000)
shift = 1.3

tests = (mmdfuse_test, mmd_median_test, mmd_split_test, mmdagg_test, mmdagginc_test, deep_mmd_test, met_test, scf_test, ctt_test, actt_test)
outputs = jnp.zeros((len(tests), len(sample_sizes), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in tqdm(range(len(sample_sizes))):
    sample_size = sample_sizes[s]
    for i in tqdm(range(repetitions)):
        key, subkey = random.split(key)
        X, Y = sampler_mixture(subkey, m=sample_size, n=sample_size, d=2, mu=20, std_1=1, std_2=shift)
        key, subkey = random.split(key)
        seed += 1
        for t in range(len(tests)):
            test = tests[t]
            outputs[t][s][i] = test(
                X,
                Y,
                subkey,
                seed,
            )

output = jnp.mean(jnp.array(outputs), -1)

jnp.save("results/mixture_vary_n.npy", output)
jnp.save("results/mixture_vary_n_x_axis.npy", sample_sizes)

print("sample_sizes :", sample_sizes)
print("shift :", shift)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])


sample_sizes : (500, 1000, 1500, 2000, 2500, 3000)
shift : 1.3
 
<function mmdfuse_test at 0x7f78f8795b80>
[0.17999999 0.38       0.635      0.765      0.88       0.96999997]
 
<function mmd_median_test at 0x7f78dca704c0>
[0.055 0.065 0.05  0.035 0.05  0.04 ]
 
<function mmd_split_test at 0x7f78dca70550>
[0.09999999 0.14999999 0.26999998 0.42499998 0.515      0.64      ]
 
<function mmdagg_test at 0x7f78dca705e0>
[0.19       0.385      0.66499996 0.78499997 0.89       0.98999995]
 
<function mmdagginc_test at 0x7f78dca70670>
[0.16499999 0.205      0.32999998 0.37       0.35999998 0.48499998]
 
<function deep_mmd_test at 0x7f78dca70700>
[0.13499999 0.195      0.33499998 0.475      0.59999996 0.72999996]
 
<function met_test at 0x7f78dca70790>
[0.095      0.09       0.12       0.17       0.19999999 0.24499999]
 
<function scf_test at 0x7f78dca70820>
[0.22       0.42       0.60499996 0.81       0.82       0.945     ]
 
<function ctt_test at 0x7f78dca708b0>
[0.065 0.055 0.055 0.03  0.03  

# Environment autogluon-env

In [None]:
from sampler_mixture import sampler_mixture
import jax
import jax.numpy as jnp
from jax import random
from tqdm.auto import tqdm
from pathlib import Path
Path("results").mkdir(exist_ok=True)

In [20]:
import autotst
from utils import HiddenPrints

def autotst_test(X, Y, key, seed, time=60):
    with HiddenPrints():
        tst = autotst.AutoTST(X, Y, split_ratio=0.5, model=autotst.model.AutoGluonTabularPredictor)
        tst.split_data()
        tst.fit_witness(time_limit=time)  # time limit adjustable to your needs (in seconds)
        p_value = tst.p_value_evaluate(permutations=10000)  # control number of permutations in the estimation
    return int(p_value <= 0.05)

## Vary difficulty

In [2]:
repetitions = 200
shifts = (0.5, 0.6, 0.7, 0.8, 0.9, 1, 1.2, 1.4, 1.6, 1.8, 2)
sample_size = 500

tests = (autotst_test, )
outputs = jnp.zeros((len(tests), len(shifts), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in tqdm(range(len(shifts))):
    shift = shifts[s]
    for i in tqdm(range(repetitions)):
        key, subkey = random.split(key)
        X, Y = sampler_mixture(subkey, m=sample_size, n=sample_size, d=2, mu=20, std_1=1, std_2=shift)
        key, subkey = random.split(key)
        seed += 1
        for t in range(len(tests)):
            test = tests[t]
            outputs[t][s][i] = test(
                X,
                Y,
                subkey,
                seed,
            )

output = jnp.mean(jnp.array(outputs), -1)

jnp.save("results/mixture_vary_dif_autotst.npy", output)

print("shifts :", shifts)
print("sample size :", sample_size)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])


shifts : (0.5, 0.6, 0.7, 0.8, 0.9, 1, 1.2, 1.4, 1.6, 1.8, 2)
sample size : 500
 
<function autotst_test at 0x7fb0b46971f0>
[0.88       0.59       0.185      0.085      0.04       0.07
 0.075      0.145      0.49499997 0.765      0.90999997]



## Vary sample size

In [3]:
repetitions = 200
sample_sizes = (500, 1000, 1500, 2000, 2500, 3000)
shift = 1.3

tests = (autotst_test, )
outputs = jnp.zeros((len(tests), len(sample_sizes), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in tqdm(range(len(sample_sizes))):
    sample_size = sample_sizes[s]
    for i in tqdm(range(repetitions)):
        key, subkey = random.split(key)
        X, Y = sampler_mixture(subkey, m=sample_size, n=sample_size, d=2, mu=20, std_1=1, std_2=shift)
        key, subkey = random.split(key)
        seed += 1
        for t in range(len(tests)):
            test = tests[t]
            outputs[t][s][i] = test(
                X,
                Y,
                subkey,
                seed,
            )

output = jnp.mean(jnp.array(outputs), -1)

jnp.save("results/mixture_vary_n_autotst.npy", output)

print("sample_sizes :", sample_sizes)
print("shift :", shift)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])


sample_sizes : (500, 1000, 1500, 2000, 2500, 3000)
shift : 1.3
 
<function autotst_test at 0x7fb0b46971f0>
[0.12  0.21  0.415 0.585 0.745 0.87 ]

