In [1]:
from sampler_perturbations import f_theta_sampler
import numpy as np
import jax
import jax.numpy as jnp
from jax import random, jit, vmap
from tqdm.auto import tqdm
from pathlib import Path
Path("results").mkdir(exist_ok=True)
%load_ext autoreload
%autoreload 2

In [2]:
from all_tests import hsicfuse_test, hsic_test
from all_tests import hsicagginc1_test, hsicagginc100_test, hsicagginc200_test, hsicaggincquad_test

### Vary difficulty d=2

In [3]:
repetitions          = 200
scales               = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0)
sample_size          = 500  
d                    = 2
f_theta_seed         = 0
p                    = 1
s                    = 1

tests = (hsicfuse_test, hsic_test, hsicagginc1_test, hsicagginc100_test, hsicagginc200_test, hsicaggincquad_test)
outputs = jnp.zeros((len(tests), len(scales), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in tqdm(range(len(scales))):
    scale = scales[s]
    for i in tqdm(range(repetitions)):
        key, subkey = random.split(key)
        perturbation_multiplier = np.exp(d) * p ** s * scale 
        Z = f_theta_sampler(f_theta_seed, seed, sample_size, p, s, perturbation_multiplier, d)
        X = jnp.expand_dims(Z[:, 0], 1)
        Y = jnp.expand_dims(Z[:, 1], 1)
        key, subkey = random.split(key)
        seed += 1
        for t in range(len(tests)):
            test = tests[t]
            outputs[t][s][i] = test(X, Y, subkey, seed)

output = jnp.mean(jnp.array(outputs), -1) # the last dimension is eliminated

jnp.save("results/perturbations_vary_dif_d2.npy", output)
jnp.save("results/perturbations_vary_dif_d2_x_axis.npy", scales)

print("scales :", scales)
print("sample size :", sample_size)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])
    

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

scales : (0.0, 0.2, 0.4, 0.6, 0.8, 1.0)
sample size : 500
 
<function hsicfuse_test at 0x7faf143833a0>
[0.055      0.35999998 0.96       1.         1.         1.        ]
 
<function hsic_test at 0x7facb70c8dc0>
[0.03       0.065      0.295      0.615      0.90999997 0.98499995]
 
<function hsicagginc1_test at 0x7facb70c8f70>
[0.03  0.06  0.065 0.075 0.195 0.435]
 
<function hsicagginc100_test at 0x7facb7047430>
[0.01       0.08       0.48999998 0.885      0.98999995 1.        ]
 
<function hsicagginc200_test at 0x7facb70474c0>
[0.025      0.11499999 0.53499997 0.91499996 0.98999995 1.        ]
 
<function hsicaggincquad_test at 0x7facb7047550>
[0.025      0.11       0.55       0.91499996 0.995      1.        ]


## Vary sample size d=2

In [7]:
repetitions          = 200
scale                = 0.6
number_perturbations = 2
sample_sizes         = (200, 400, 600, 800, 1000)
d                    = 2
f_theta_seed         = 0
p                    = 1
s                    = 1

tests = (hsicfuse_test, hsic_test, hsicagginc1_test, hsicagginc100_test, hsicagginc200_test, hsicaggincquad_test)
outputs = jnp.zeros((len(tests), len(sample_sizes), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in tqdm(range(len(sample_sizes))):
    sample_size = sample_sizes[s]
    for i in tqdm(range(repetitions)):
        key, subkey = random.split(key)
        perturbation_multiplier = np.exp(d) * p ** s * scale 
        Z = f_theta_sampler(f_theta_seed, seed, sample_size, p, s, perturbation_multiplier, d)
        X = jnp.expand_dims(Z[:, 0], 1)
        Y = jnp.expand_dims(Z[:, 1], 1)
        key, subkey = random.split(key)
        seed += 1
        for t in range(len(tests)):
            test = tests[t]
            outputs[t][s][i] = test(X, Y, subkey, seed)

output = jnp.mean(jnp.array(outputs), -1)

jnp.save("results/perturbations_vary_n_d2.npy", output)
jnp.save("results/perturbations_vary_n_d2_x_axis.npy", sample_sizes)

print("sample sizes :", sample_sizes)
print("scale :", scale)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

sample sizes : (200, 400, 600, 800, 1000)
scale : 0.6
 
<function hsicfuse_test at 0x7faf143833a0>
[0.94 1.   1.   1.   1.  ]
 
<function hsic_test at 0x7facb70c8dc0>
[0.315      0.5        0.65       0.77       0.79499996]
 
<function hsicagginc1_test at 0x7facb70c8f70>
[0.065      0.09       0.105      0.14       0.16499999]
 
<function hsicagginc100_test at 0x7facb7047430>
[0.42       0.82       0.91499996 0.97999996 0.98499995]
 
<function hsicagginc200_test at 0x7facb70474c0>
[0.42       0.84499997 0.965      0.97999996 0.995     ]
 
<function hsicaggincquad_test at 0x7facb7047550>
[0.42       0.84499997 0.96999997 0.98999995 1.        ]


### Vary d

In [10]:
repetitions          = 200
scales               = 1.0
number_perturbations = 2
sample_size          = 500
d_values             = (1, 2, 3, 4)
f_theta_seed         = 0
p                    = 1
s                    = 1

tests = (hsicfuse_test, hsic_test, hsicagginc1_test, hsicagginc100_test, hsicagginc200_test, hsicaggincquad_test)
outputs = jnp.zeros((len(tests), len(d_values), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in tqdm(range(len(d_values))):
    d = d_values[s]
    for i in tqdm(range(repetitions)):
        key, subkey = random.split(key)
        perturbation_multiplier = np.exp(d + 1) * p ** s * scale 
        Z = f_theta_sampler(f_theta_seed, seed, sample_size, p, s, perturbation_multiplier, d + 1)
        X = Z[:, :1]
        Y = Z[:, 1:]
        key, subkey = random.split(key)
        seed += 1
        for t in range(len(tests)):
            test = tests[t]
            outputs[t][s][i] = test(X, Y, subkey, seed)

output = jnp.mean(jnp.array(outputs), -1)

jnp.save("results/perturbations_vary_d.npy", output)
jnp.save("results/perturbations_vary_d_x_axis.npy", d_values)

print("d :", d_values)
print("sample size :", sample_size)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

d : (1, 2, 3, 4)
sample size : 500
 
<function hsicfuse_test at 0x7faf143833a0>
[1.    0.9   0.065 0.035]
 
<function hsic_test at 0x7facb70c8dc0>
[1.    0.08  0.065 0.04 ]
 
<function hsicagginc1_test at 0x7facb70c8f70>
[0.375 0.065 0.035 0.025]
 
<function hsicagginc100_test at 0x7facb7047430>
[1.         0.72499996 0.11499999 0.045     ]
 
<function hsicagginc200_test at 0x7facb70474c0>
[1.   0.82 0.14 0.05]
 
<function hsicaggincquad_test at 0x7facb7047550>
[1.    0.825 0.13  0.045]
