In [10]:
from sampler_perturbations import sampler_perturbations
import numpy as np
import jax
import jax.numpy as jnp
from jax import random
from tqdm.auto import tqdm
from pathlib import Path
Path("results").mkdir(exist_ok=True)
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [15]:
from all_tests import hsicfuse_test, hsic_test
from all_tests import hsicagginc1_test, hsicagginc100_test, hsicagginc200_test, hsicaggincquad_test

### Vary difficulty d=1

In [18]:
# repetitions          = 200
# scales               = (0, 0.1, 0.2, 0.3, 0.4, 0.5)
# number_perturbations = 2
sample_size          = 500  
d                    = 1

repetitions          = 1
scales               = (0, )
number_perturbations = 2

# tests = (mmdfuse_test, mmd_median_test, mmd_split_test, mmdagg_test, mmdagginc_test, deep_mmd_test, met_test, scf_test, ctt_test, actt_test)
# tests = (hsicfuse_test, hsic_test, hsicagginc1_test, hsicagginc100_test, hsicagginc200_test, hsicaggincquad_test, )
tests = (hsicaggincquad_test, )
outputs = jnp.zeros((len(tests), len(scales), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in (range(len(scales))):
    scale = scales[s]
    for i in (range(repetitions)):
        key, subkey = random.split(key)
        X, Y = sampler_perturbations(m=sample_size, n=sample_size, d=d, scale=scale, number_perturbations=number_perturbations, seed=seed)
        key, subkey = random.split(key)
        seed += 1
        for t in range(len(tests)):
            test = tests[t]
            outputs[t][s][i] = test(
                    X,
                    Y,
                    subkey,
                    seed,
                )

output = jnp.mean(jnp.array(outputs), -1) # the last dimension is eliminated

jnp.save("results/perturbations_vary_dif_d1.npy", output)
jnp.save("results/perturbations_vary_dif_d1_x_axis.npy", scales)

print("scales :", scales)
print("sample size :", sample_size)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])

scales : (0,)
sample size : 500
 
<function hsicaggincquad_test at 0x7f70610823a0>
[0.]


### Vary sample size d=1

In [None]:
repetitions = 200
scale = 0.2
number_perturbations = 2
sample_sizes = (500, 1000, 1500, 2000, 2500, 3000)
d = 1

# tests = (mmdfuse_test, mmd_median_test, mmd_split_test, mmdagg_test, mmdagginc_test, deep_mmd_test, met_test, scf_test, ctt_test, actt_test)
tests = (hsicfuse_test, )
outputs = jnp.zeros((len(tests), len(sample_sizes), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in tqdm(range(len(sample_sizes))):
    sample_size = sample_sizes[s]
    for i in tqdm(range(repetitions)):
        key, subkey = random.split(key)
        X, Y = sampler_perturbations(m=sample_size, n=sample_size, d=d, scale=scale, number_perturbations=number_perturbations, seed=seed)
        key, subkey = random.split(key)
        seed += 1
        for t in range(len(tests)):
            test = tests[t]
            outputs[t][s][i] = test(
                X,
                Y,
                subkey,
                seed,
            )

output = jnp.mean(jnp.array(outputs), -1)

jnp.save("results/perturbations_vary_n_d1.npy", output)
jnp.save("results/perturbations_vary_n_d1_x_axis.npy", sample_sizes)

print("sample sizes :", sample_sizes)
print("scale :", scale)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])

### Vary difficulty d=2

In [None]:
repetitions = 200
scales = (0, 0.2, 0.4, 0.6, 0.8, 1)
number_perturbations = 2
sample_size = 500
d = 2

# tests = (mmdfuse_test, mmd_median_test, mmd_split_test, mmdagg_test, mmdagginc_test, deep_mmd_test, met_test, scf_test, ctt_test, actt_test)
tests = (hsicfuse_test, )
outputs = jnp.zeros((len(tests), len(scales), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in tqdm(range(len(scales))):
    scale = scales[s]
    for i in tqdm(range(repetitions)):
        key, subkey = random.split(key)
        X, Y = sampler_perturbations(m=sample_size, n=sample_size, d=d, scale=scale, number_perturbations=number_perturbations, seed=seed)
        key, subkey = random.split(key)
        seed += 1
        for t in range(len(tests)):
            test = tests[t]
            outputs[t][s][i] = test(
                X,
                Y,
                subkey,
                seed,
            )

output = jnp.mean(jnp.array(outputs), -1)

jnp.save("results/perturbations_vary_dif_d2.npy", output)
jnp.save("results/perturbations_vary_dif_d2_x_axis.npy", scales)

print("scales :", scales)
print("sample size :", sample_size)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])

### Vary sample size d=2

In [None]:
repetitions = 200
scale = 0.4
number_perturbations = 2
sample_sizes = (500, 1000, 1500, 2000, 2500, 3000)
d = 2

# tests = (mmdfuse_test, mmd_median_test, mmd_split_test, mmdagg_test, mmdagginc_test, deep_mmd_test, met_test, scf_test, ctt_test, actt_test)
tests = (hsicfuse_test, )
outputs = jnp.zeros((len(tests), len(sample_sizes), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in (range(len(sample_sizes))):
    sample_size = sample_sizes[s]
    for i in (range(repetitions)):
        key, subkey = random.split(key)
        X, Y = sampler_perturbations(m=sample_size, n=sample_size, d=d, scale=scale, number_perturbations=number_perturbations, seed=seed)
        key, subkey = random.split(key)
        seed += 1
        for t in range(len(tests)):
            test = tests[t]
            outputs[t][s][i] = test(
                X,
                Y,
                subkey,
                seed,
            )

output = jnp.mean(jnp.array(outputs), -1)

jnp.save("results/perturbations_vary_n_d2.npy", output)
jnp.save("results/perturbations_vary_n_d2_x_axis.npy", sample_sizes)

print("sample sizes :", sample_sizes)
print("scale :", scale)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])