# Environment mmdfuse-env

In [None]:
from sampler_galaxy import load_images_list, sampler_galaxy
import jax
import jax.numpy as jnp
from jax import random
from tqdm.auto import tqdm
from pathlib import Path
Path("results").mkdir(exist_ok=True)
%load_ext autoreload
%autoreload 2

In [2]:
from all_tests import mmdfuse_test
from all_tests import mmd_median_test, mmd_split_test
from all_tests import mmdagg_test, mmdagginc_test
from all_tests import deep_mmd_test, deep_mmd_image_test
from all_tests import met_test, scf_test
from all_tests import ctt_test, actt_test

In [29]:
images_list = load_images_list(highres=False)

## Vary difficulty

In [1]:
repetitions = 200
corruptions = (0.1, 0.15, 0.20, 0.25, 0.3, 0.35, 0.4)
sample_size = 500

tests = (mmdfuse_test, mmd_median_test, mmd_split_test, mmdagg_test, mmdagginc_test, deep_mmd_test, scf_test, ctt_test, actt_test, deep_mmd_image_test)
outputs = jnp.zeros((len(tests), len(corruptions), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in tqdm(range(len(corruptions))):
    corruption = corruptions[s]
    for i in tqdm(range(repetitions)):
        key, subkey = random.split(key)
        X, Y = sampler_galaxy(subkey, m=sample_size, n=sample_size, corruption=corruption, images_list=images_list)
        X = jnp.array(X, dtype=jnp.float32).reshape((X.shape[0], -1))
        Y = jnp.array(Y, dtype=jnp.float32).reshape((Y.shape[0], -1))
        key, subkey = random.split(key)
        seed += 1
        for t in range(len(tests)):
            test = tests[t]
            outputs[t][s][i] = test(
                X,
                Y,
                subkey,
                seed,
            )

output = jnp.mean(jnp.array(outputs), -1)

jnp.save("results/galaxy_vary_dif.npy", output)
jnp.save("results/galaxy_vary_dif_x_axis.npy", corruptions)

print("corruptions :", corruptions)
print("sample size :", sample_size)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])


corruptions : (0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4)
sample size : 500
 
<function mmdfuse_test at 0x7f08a5c4f310>
[0.07       0.24499999 0.65999997 0.90999997 0.995      1.
 1.        ]
 
<function mmd_median_test at 0x7f088fb7f310>
[0.13       0.14999999 0.28       0.53499997 0.775      0.85499996
 0.98999995]
 
<function mmd_split_test at 0x7f088fb7f3a0>
[0.075 0.12  0.235 0.475 0.76  0.94  0.965]
 
<function mmdagg_test at 0x7f088fb7f430>
[0.075      0.235      0.51       0.84499997 0.96999997 1.
 1.        ]
 
<function mmdagginc_test at 0x7f088fb7f4c0>
[0.045      0.11       0.265      0.655      0.91499996 0.995
 1.        ]
 
<function deep_mmd_test at 0x7f088fb7f550>
[0.005 0.    0.    0.005 0.01  0.01  0.   ]
 
<function scf_test at 0x7f088fb7f670>
[0.035 0.075 0.07  0.145 0.17  0.225 0.305]
 
<function ctt_test at 0x7f088fb7f700>
[0.09999999 0.14999999 0.32999998 0.53999996 0.78999996 0.875
 0.98499995]
 
<function actt_test at 0x7f088fb7f790>
[0.08       0.145      0.36499

## Vary sample size

In [2]:
repetitions = 200
corruption = 0.15
sample_sizes = (500, 1000, 1500, 2000, 2500)

tests = (mmdfuse_test, mmd_median_test, mmd_split_test, mmdagg_test, mmdagginc_test, deep_mmd_test, scf_test, ctt_test, actt_test, deep_mmd_test_64)
outputs = jnp.zeros((len(tests), len(sample_sizes), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in tqdm(range(len(sample_sizes))):
    sample_size = sample_sizes[s]
    for i in tqdm(range(repetitions)):
        key, subkey = random.split(key)
        X, Y = sampler_galaxy(subkey, m=sample_size, n=sample_size, corruption=corruption, images_list=images_list)
        X = jnp.array(X, dtype=jnp.float32).reshape((X.shape[0], -1))
        Y = jnp.array(Y, dtype=jnp.float32).reshape((Y.shape[0], -1))
        key, subkey = random.split(key)
        seed += 1
        for t in range(len(tests)):
            test = tests[t]
            outputs[t][s][i] = test(
                X,
                Y,
                subkey,
                seed,
            )

output = jnp.mean(jnp.array(outputs), -1)

jnp.save("results/galaxy_vary_n.npy", output)
jnp.save("results/galaxy_vary_n_x_axis.npy", sample_sizes)

print("sample_sizes :", sample_sizes)
print("corruption :", corruption)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])


sample_sizes : (500, 1000, 1500, 2000, 2500)
corruption : 0.15
 
<function mmdfuse_test at 0x7f08a5c4f310>
[0.2769231  0.4507772  0.71727747 0.84020615 0.87434554]
 
<function mmd_median_test at 0x7f088fb7f310>
[0.2244898  0.28350514 0.4789474  0.64248705 0.7291667 ]
 
<function mmd_split_test at 0x7f088fb7f3a0>
[0.13402061 0.19587629 0.375      0.48704663 0.5751295 ]
 
<function mmdagg_test at 0x7f088fb7f430>
[0.2722513  0.45360824 0.7668394  0.89847714 0.9732621 ]
 
<function mmdagginc_test at 0x7f088fb7f4c0>
[0.12565446 0.19796954 0.31794873 0.42211056 0.39037433]
 
<function deep_mmd_test at 0x7f088fb7f550>
[0. 0. 0. 0. 0.]
 
<function scf_test at 0x7f088fb7f670>
[0.08205128 0.07894737 0.04663212 0.08121827 0.06806283]
 
<function ctt_test at 0x7f088fb7f700>
[0.22395834 0.28877005 0.4894737  0.6526316  0.79473686]
 
<function actt_test at 0x7f088fb7f790>
[0.17098445 0.2871795  0.6010638  0.731579   0.8238342 ]

<function deep_mmd_image_test at 0x7fca8873ddc0>
[0.145      0.515    

# Environment autogluon-env

In [None]:
from sampler_galaxy import load_images_list, sampler_galaxy
import jax
import jax.numpy as jnp
from jax import random
from tqdm.auto import tqdm
from pathlib import Path
Path("results").mkdir(exist_ok=True)

In [33]:
import autotst
from utils import HiddenPrints

def autotst_test(X, Y, key, seed, time=60):
    with HiddenPrints():
        tst = autotst.AutoTST(X, Y, split_ratio=0.5, model=autotst.model.AutoGluonTabularPredictor)
        tst.split_data()
        tst.fit_witness(time_limit=time)  # time limit adjustable to your needs (in seconds)
        p_value = tst.p_value_evaluate(permutations=10000)  # control number of permutations in the estimation
    return int(p_value <= 0.05)

In [4]:
images_list = load_images_list(highres=False)

## Vary difficulty

In [24]:
repetitions = 200
corruptions = (0.1, 0.15, 0.20, 0.25, 0.3, 0.35, 0.4)
sample_size = 500

tests = (autotst_test, )
outputs = jnp.zeros((len(tests), len(corruptions), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in tqdm(range(len(corruptions))):
    corruption = corruptions[s]
    for i in tqdm(range(repetitions)):
        key, subkey = random.split(key)
        X, Y = sampler_galaxy(subkey, m=sample_size, n=sample_size, corruption=corruption, images_list=images_list)
        X = jnp.array(X, dtype=jnp.float32).reshape((X.shape[0], -1))
        Y = jnp.array(Y, dtype=jnp.float32).reshape((Y.shape[0], -1))
        key, subkey = random.split(key)
        seed += 1
        for t in range(len(tests)):
            test = tests[t]
            outputs[t][s][i] = test(
                X,
                Y,
                subkey,
                seed,
            )

output = jnp.mean(jnp.array(outputs), -1)

jnp.save("results/galaxy_vary_dif_autotst.npy", output)

print("corruptions :", corruptions)
print("sample size :", sample_size)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])

corruptions : (0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4)
sample size : 500
 
<function autotst_test at 0x7f238fc6f820>
[0.12244898 0.285      0.5929648  0.81       0.95477384 0.995
 1.        ]


## Vary sample size

In [31]:
repetitions = 200
corruption = 0.15
sample_sizes = (500, 1000, 1500, 2000, 2500, 3000)

tests = (autotst_test, )
outputs = jnp.zeros((len(tests), len(sample_sizes), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in tqdm(range(len(sample_sizes))):
    sample_size = sample_sizes[s]
    for i in tqdm(range(repetitions)):
        key, subkey = random.split(key)
        X, Y = sampler_galaxy(subkey, m=sample_size, n=sample_size, corruption=corruption, images_list=images_list)
        X = jnp.array(X, dtype=jnp.float32).reshape((X.shape[0], -1))
        Y = jnp.array(Y, dtype=jnp.float32).reshape((Y.shape[0], -1))
        key, subkey = random.split(key)
        seed += 1
        for t in range(len(tests)):
            test = tests[t]
            outputs[t][s][i] = test(
                X,
                Y,
                subkey,
                seed,
            )

output = jnp.mean(jnp.array(outputs), -1)

jnp.save("results/galaxy_vary_n_autotst.npy", output)

print("sample_sizes :", sample_sizes)
print("corruption :", corruption)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])



sample_sizes : (500, 1000, 1500, 2000, 2500, 3000)
corruption : 0.15
 
<function autotst_test at 0x7f238fc6f820>
[0.3030303  0.36683416 0.3939394  0.35       0.295      0.09999999]


In [4]:
# increase the time limit over the recommended amount
# in order for autotst to have higher power

repetitions = 200
corruption = 0.15
sample_sizes = (500, 1000, 1500, 2000, 2500, 3000)

tests = (autotst_test, )
outputs = jnp.zeros((len(tests), len(sample_sizes), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in tqdm(range(len(sample_sizes))):
    sample_size = sample_sizes[s]
    for i in tqdm(range(repetitions)):
        key, subkey = random.split(key)
        X, Y = sampler_galaxy(subkey, m=sample_size, n=sample_size, corruption=corruption, images_list=images_list)
        X = jnp.array(X, dtype=jnp.float32).reshape((X.shape[0], -1))
        Y = jnp.array(Y, dtype=jnp.float32).reshape((Y.shape[0], -1))
        key, subkey = random.split(key)
        seed += 1
        for t in range(len(tests)):
            test = tests[t]
            outputs[t][s][i] = test(
                X,
                Y,
                subkey,
                seed,
                time=3 * 60,
            )

output = jnp.mean(jnp.array(outputs), -1)

jnp.save("results/galaxy_vary_n_autotst_3min.npy", output)

print("sample_sizes :", sample_sizes)
print("corruption :", corruption)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])



sample_sizes : (500, 1000, 1500, 2000, 2500)
corruption : 0.15
 
<function autotst_test at 0x7f238fc6f820>
[0.14999999 0.35       0.325      0.505      0.37      ]

