In [23]:
import sys
from pathlib import Path
from sampler_perturbations import f_theta_sampler
import numpy as np
import jax
import jax.numpy as jnp
from jax import random, jit, vmap
from tqdm.auto import tqdm

Path("results").mkdir(exist_ok=True)
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [27]:
from all_tests import hsicfuse_test, hsic_test
from all_tests import hsicagginc1_test, hsicagginc100_test, hsicagginc200_test, hsicaggincquad_test
from all_tests import nfsic_test

### Vary difficulty d=2

In [32]:
repetitions          = 200
# scales               = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0)
scales = (0.6, )
# sample_size          = 500  
sample_size = 100
d                    = 2
f_theta_seed         = 0
p                    = 1
s                    = 1

tests = (hsicfuse_test, hsic_test, hsicagginc1_test, hsicagginc100_test, hsicagginc200_test, hsicaggincquad_test, nfsic_test)
outputs = jnp.zeros((len(tests), len(scales), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in tqdm(range(len(scales))):
    scale = scales[s]
    for i in tqdm(range(repetitions)):
        key, subkey = random.split(key)
        perturbation_multiplier = np.exp(d) * p ** s * scale 
        Z = f_theta_sampler(f_theta_seed, seed, sample_size, p, s, perturbation_multiplier, d)
        X = np.expand_dims(Z[:, 0], 1)
        Y = np.expand_dims(Z[:, 1], 1)
        key, subkey = random.split(key)
        seed += 1
        for t in range(len(tests)):
            test = tests[t]
            outputs[t][s][i] = test(X, Y, subkey, seed)

output = jnp.mean(jnp.array(outputs), -1) # the last dimension is eliminated

jnp.save("results/perturbations_vary_dif_d2.npy", output)
jnp.save("results/perturbations_vary_dif_d2_x_axis.npy", scales)

print("scales :", scales)
print("sample size :", sample_size)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])
    

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

KeyboardInterrupt: 

## Vary sample size d=2

In [35]:
repetitions          = 200
scale                = 0.6
number_perturbations = 2
# sample_sizes         = (200, 400, 600, 800, 1000)
sample_sizes = (200, )
d                    = 2
f_theta_seed         = 0
p                    = 1
s                    = 1

# tests = (hsicfuse_test, hsic_test, hsicagginc1_test, hsicagginc100_test, hsicagginc200_test, hsicaggincquad_test)
tests = (nfsic_test, )
outputs = jnp.zeros((len(tests), len(sample_sizes), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in tqdm(range(len(sample_sizes))):
    sample_size = sample_sizes[s]
    for i in tqdm(range(repetitions)):
        key, subkey = random.split(key)
        perturbation_multiplier = np.exp(d) * p ** s * scale 
        Z = f_theta_sampler(f_theta_seed, seed, sample_size, p, s, perturbation_multiplier, d)
        X = np.expand_dims(Z[:, 0], 1)
        Y = np.expand_dims(Z[:, 1], 1)
        key, subkey = random.split(key)
        seed += 1
        for t in range(len(tests)):
            test = tests[t]
            outputs[t][s][i] = test(X, Y, subkey, seed)

output = jnp.mean(jnp.array(outputs), -1)

jnp.save("results/perturbations_vary_n_d2.npy", output)
jnp.save("results/perturbations_vary_n_d2_x_axis.npy", sample_sizes)

print("sample sizes :", sample_sizes)
print("scale :", scale)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

Exception ignored in: <function tqdm.__del__ at 0x7facb82815e0>
Traceback (most recent call last):
  File "/nfs/ghome/live/gren/mambaforge/envs/hsicfuse-env/lib/python3.9/site-packages/tqdm/std.py", line 1144, in __del__
    def __del__(self):
KeyboardInterrupt: 


sample sizes : (100,)
scale : 0.6
 
<function nfsic_test at 0x7f9ed99b9e50>
[0.13499999]


### Vary d

In [33]:
repetitions          = 200
scales               = 1.0
number_perturbations = 2
# sample_size          = 500
sample_size = 100
# d_values             = (1, 2, 3, 4)
d_values = (1, )
f_theta_seed         = 0
p                    = 1
s                    = 1

# tests = (hsicfuse_test, hsic_test, hsicagginc1_test, hsicagginc100_test, hsicagginc200_test, hsicaggincquad_test)
tests = (nfsic_test, )
outputs = jnp.zeros((len(tests), len(d_values), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in tqdm(range(len(d_values))):
    d = d_values[s]
    for i in tqdm(range(repetitions)):
        key, subkey = random.split(key)
        perturbation_multiplier = np.exp(d + 1) * p ** s * scale 
        Z = f_theta_sampler(f_theta_seed, seed, sample_size, p, s, perturbation_multiplier, d + 1)
        X = Z[:, :1]
        Y = Z[:, 1:]
        key, subkey = random.split(key)
        seed += 1
        for t in range(len(tests)):
            test = tests[t]
            outputs[t][s][i] = test(X, Y, subkey, seed)

output = jnp.mean(jnp.array(outputs), -1)

jnp.save("results/perturbations_vary_d.npy", output)
jnp.save("results/perturbations_vary_d_x_axis.npy", d_values)

print("d :", d_values)
print("sample size :", sample_size)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

KeyboardInterrupt: 