In [1]:
from sampler_perturbations import f_theta_sampler
import numpy as np
import jax
import jax.numpy as jnp
from jax import random, jit, vmap
from tqdm.auto import tqdm
import time
from pathlib import Path
Path("results").mkdir(exist_ok=True)
%load_ext autoreload
%autoreload 2

In [2]:
from all_tests import hsicfuse_test, hsic_test
from all_tests import hsicagginc_test, hsicagg_test
from all_tests import nfsic_test, nyhsic_test, fhsic_test

### Level vary N

In [4]:
# Parameters
scale        = 0.5
d            = 2
rep          = 200
true_alpha   = 0.05
f_theta_seed = 0
p            = 2
s            = 1

tests = (hsicfuse_test, hsic_test, hsicagginc_test, hsicagg_test, nfsic_test, nyhsic_test, fhsic_test)
N_values = (200, 400, 600, 800, 1000)
outputs_level_vary_n = np.zeros((len(tests), len(N_values), rep))
rs = np.random.RandomState(0)
key = random.PRNGKey(42)
seed = 0

for r in tqdm(range(len(N_values))):
    N = N_values[r]
    t0 = time.time()
    for i in tqdm(range(rep)):
        seed += 1
        perturbation_multiplier = np.exp(d) * p ** s * scale
        X = rs.uniform(0, 1, (N, 1))
        Y = rs.uniform(0, 1, (N, 1))
        key, subkey = random.split(key)
        for j in range(len(tests)):
            test = tests[j]
            outputs_level_vary_n[j][r][i] = test(X, Y, subkey, seed)
    print(i + 1, "/", rep, "time:", time.time() - t0)    

outputs_level_vary_n = np.mean(outputs_level_vary_n, -1)
np.save("results/level_vary_n.npy", outputs_level_vary_n)
np.save("results/level_vary_n_x_axis.npy", N_values)

print(" ")
print("sample sizes :", N_values)
print("scale :", scale)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(outputs_level_vary_n[t])

print(" ")
print(outputs_level_vary_n)

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

200 / 200 time: 11653.560586214066


  0%|          | 0/200 [00:00<?, ?it/s]

200 / 200 time: 12574.7290391922


  0%|          | 0/200 [00:00<?, ?it/s]

200 / 200 time: 13561.668033599854


  0%|          | 0/200 [00:00<?, ?it/s]

200 / 200 time: 14642.24425816536


  0%|          | 0/200 [00:00<?, ?it/s]

200 / 200 time: 16203.212682962418
 
sample sizes : (200, 400, 600, 800, 1000)
scale : 0.5
 
<function hsicfuse_test at 0x7f208c0d9040>
[0.07  0.035 0.065 0.065 0.055]
 
<function hsic_test at 0x7f20652d15e0>
[0.045 0.03  0.055 0.06  0.065]
 
<function hsicagginc_test at 0x7f20652d1670>
[0.06  0.05  0.055 0.04  0.03 ]
 
<function hsicagg_test at 0x7f20652d1700>
[0.06  0.05  0.04  0.045 0.045]
 
<function nfsic_test at 0x7f20652d1790>
[0.065 0.055 0.055 0.065 0.045]
 
<function nyhsic_test at 0x7f20652d1820>
[0.05  0.055 0.06  0.06  0.05 ]
 
<function fhsic_test at 0x7f20652d18b0>
[0.05  0.035 0.075 0.08  0.055]
 
[[0.07  0.035 0.065 0.065 0.055]
 [0.045 0.03  0.055 0.06  0.065]
 [0.06  0.05  0.055 0.04  0.03 ]
 [0.06  0.05  0.04  0.045 0.045]
 [0.065 0.055 0.055 0.065 0.045]
 [0.05  0.055 0.06  0.06  0.05 ]
 [0.05  0.035 0.075 0.08  0.055]]
