In [4]:
from sampler_perturbations import f_theta_sampler
import numpy as np
import jax
import jax.numpy as jnp
from jax import random, jit, vmap
from tqdm.auto import tqdm
from pathlib import Path
Path("results").mkdir(exist_ok=True)
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
from all_tests import hsicfuse_test, hsic_test
from all_tests import hsicagginc1_test, hsicagginc100_test, hsicagginc200_test, hsicaggincquad_test

### Vary difficulty d=2

In [13]:
repetitions          = 200
scales               = (0.1, 0.2, 0.3, 0.4, 0.5)
sample_size          = 500  
d                    = 2
f_theta_seed         = 0
p                    = 1
s                    = 1

tests = (hsicfuse_test, hsic_test, hsicagginc1_test, hsicagginc100_test, hsicagginc200_test, hsicaggincquad_test)
outputs = jnp.zeros((len(tests), len(scales), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in tqdm(range(len(scales))):
    scale = scales[s]
    for i in tqdm(range(repetitions)):
        key, subkey = random.split(key)
        perturbation_multiplier = np.exp(d) * p ** s * scale 
        Z = f_theta_sampler(f_theta_seed, seed, sample_size, p, s, perturbation_multiplier, d)
        X = jnp.expand_dims(Z[:, 0], 1)
        Y = jnp.expand_dims(Z[:, 1], 1)
        key, subkey = random.split(key)
        seed += 1
        for t in range(len(tests)):
            test = tests[t]
            outputs[t][s][i] = test(X, Y, subkey, seed)

output = jnp.mean(jnp.array(outputs), -1) # the last dimension is eliminated

jnp.save("results/perturbations_vary_dif_d2.npy", output)
jnp.save("results/perturbations_vary_dif_d2_x_axis.npy", scales)

print("scales :", scales)
print("sample size :", sample_size)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])
    

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

scales : (0.1, 0.2, 0.3, 0.4, 0.5)
sample size : 500
 
<function hsicfuse_test at 0x7f320c2f2430>
[0.405      0.655      0.92499995 0.995      1.        ]
 
<function hsic_test at 0x7f307804fb80>
[0.055 0.065 0.205 0.28  0.38 ]
 
<function hsicagginc1_test at 0x7f307804fd30>
[0.02  0.06  0.03  0.065 0.065]
 
<function hsicagginc100_test at 0x7f3077fcf1f0>
[0.025 0.08  0.28  0.5   0.675]
 
<function hsicagginc200_test at 0x7f3077fcf280>
[0.04       0.11499999 0.29       0.55       0.715     ]
 
<function hsicaggincquad_test at 0x7f3077fcf310>
[0.05       0.11       0.295      0.545      0.71999997]


## Vary sample size d=2

In [6]:
repetitions          = 200
scale                = 0.2
number_perturbations = 2
sample_sizes         = (200, 400, 600, 800, 1000)
d                    = 2
f_theta_seed         = 0
p                    = 1
s                    = 1

tests = (hsicfuse_test, hsic_test, hsicagginc1_test, hsicagginc100_test, hsicagginc200_test, hsicaggincquad_test)
outputs = jnp.zeros((len(tests), len(sample_sizes), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in tqdm(range(len(sample_sizes))):
    sample_size = sample_sizes[s]
    for i in tqdm(range(repetitions)):
        key, subkey = random.split(key)
        perturbation_multiplier = np.exp(d) * p ** s * scale 
        Z = f_theta_sampler(f_theta_seed, seed, sample_size, p, s, perturbation_multiplier, d)
        X = jnp.expand_dims(Z[:, 0], 1)
        Y = jnp.expand_dims(Z[:, 1], 1)
        key, subkey = random.split(key)
        seed += 1
        for t in range(len(tests)):
            test = tests[t]
            outputs[t][s][i] = test(X, Y, subkey, seed)

output = jnp.mean(jnp.array(outputs), -1)

jnp.save("results/perturbations_vary_n_d2.npy", output)
jnp.save("results/perturbations_vary_n_d2_x_axis.npy", sample_sizes)

print("sample sizes :", sample_sizes)
print("scale :", scale)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

sample sizes : (200, 400, 600, 800, 1000)
scale : 0.2
 
<function hsicfuse_test at 0x7f320c2f2430>
[0.39499998 0.59499997 0.72999996 0.835      0.865     ]
 
<function hsic_test at 0x7f307804fb80>
[0.055      0.08       0.11499999 0.08       0.125     ]
 
<function hsicagginc1_test at 0x7f307804fd30>
[0.055 0.03  0.04  0.015 0.04 ]
 
<function hsicagginc100_test at 0x7f3077fcf1f0>
[0.045      0.095      0.12       0.185      0.19999999]
 
<function hsicagginc200_test at 0x7f3077fcf280>
[0.045      0.105      0.155      0.235      0.22999999]
 
<function hsicaggincquad_test at 0x7f3077fcf310>
[0.045 0.105 0.145 0.25  0.24 ]


### Vary difficulty d=4

In [9]:
repetitions          = 200
scales               = (0.2, 0.4, 0.6, 0.8, 1.0)
number_perturbations = 2
sample_size          = 500
d                    = 4
f_theta_seed         = 0
p                    = 1
s                    = 1

tests = (hsicfuse_test, hsic_test, hsicagginc1_test, hsicagginc100_test, hsicagginc200_test, hsicaggincquad_test)
outputs = jnp.zeros((len(tests), len(scales), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in tqdm(range(len(scales))):
    scale = scales[s]
    for i in tqdm(range(repetitions)):
        key, subkey = random.split(key)
        perturbation_multiplier = np.exp(d + 1) * p ** s * scale 
        Z = f_theta_sampler(f_theta_seed, seed, sample_size, p, s, perturbation_multiplier, d + 1)
        X = Z[:, :1]
        Y = Z[:, 1:]
        key, subkey = random.split(key)
        seed += 1
        for t in range(len(tests)):
            test = tests[t]
            outputs[t][s][i] = test(X, Y, subkey, seed)

output = jnp.mean(jnp.array(outputs), -1)

jnp.save("results/perturbations_vary_dif_d4.npy", output)
jnp.save("results/perturbations_vary_dif_d4_x_axis.npy", scales)

print("scales :", scales)
print("sample size :", sample_size)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

scales : (0.2, 0.4, 0.6, 0.8, 1.0)
sample size : 500
 
<function hsicfuse_test at 0x7f320c2f2430>
[0.22       0.26999998 0.235      0.285      0.255     ]
 
<function hsic_test at 0x7f307804fb80>
[0.04  0.035 0.035 0.03  0.055]
 
<function hsicagginc1_test at 0x7f307804fd30>
[0.025 0.04  0.02  0.055 0.03 ]
 
<function hsicagginc100_test at 0x7f3077fcf1f0>
[0.03  0.04  0.045 0.025 0.035]
 
<function hsicagginc200_test at 0x7f3077fcf280>
[0.015 0.03  0.02  0.04  0.04 ]
 
<function hsicaggincquad_test at 0x7f3077fcf310>
[0.03  0.035 0.025 0.03  0.05 ]


### Vary sample size d=4

In [12]:
repetitions          = 200
scale                = 0.2
number_perturbations = 2
sample_sizes         = (200, 400, 600, 800, 1000)
d                    = 4
f_theta_seed         = 0
p                    = 1
s                    = 1

tests = (hsicfuse_test, hsic_test, hsicagginc1_test, hsicagginc100_test, hsicagginc200_test, hsicaggincquad_test)
outputs = jnp.zeros((len(tests), len(sample_sizes), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in tqdm(range(len(sample_sizes))):
    sample_size = sample_sizes[s]
    for i in tqdm(range(repetitions)):
        key, subkey = random.split(key)
        perturbation_multiplier = np.exp(d + 1) * p ** s * scale 
        Z = f_theta_sampler(f_theta_seed, seed, sample_size, p, s, perturbation_multiplier, d + 1)
        X = Z[:, :1]
        Y = Z[:, 1:]
        key, subkey = random.split(key)
        seed += 1
        for t in range(len(tests)):
            test = tests[t]
            outputs[t][s][i] = test(X, Y, subkey, seed)

output = jnp.mean(jnp.array(outputs), -1)

jnp.save("results/perturbations_vary_n_d4.npy", output)
jnp.save("results/perturbations_vary_n_d4_x_axis.npy", sample_sizes)

print("sample sizes :", sample_sizes)
print("scale :", scale)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

sample sizes : (200, 400, 600, 800, 1000)
scale : 0.2
 
<function hsicfuse_test at 0x7f320c2f2430>
[0.205 0.275 0.265 0.26  0.255]
 
<function hsic_test at 0x7f307804fb80>
[0.04  0.05  0.055 0.05  0.03 ]
 
<function hsicagginc1_test at 0x7f307804fd30>
[0.04  0.03  0.035 0.04  0.045]
 
<function hsicagginc100_test at 0x7f3077fcf1f0>
[0.035 0.015 0.04  0.025 0.03 ]
 
<function hsicagginc200_test at 0x7f3077fcf280>
[0.035 0.04  0.03  0.025 0.045]
 
<function hsicaggincquad_test at 0x7f3077fcf310>
[0.035 0.04  0.03  0.025 0.04 ]
