In [2]:
import sys
from pathlib import Path
from sampler_gclusters import shuffle, generate_means, sampler_gclusters
import numpy as np
import jax
import jax.numpy as jnp
from jax import random, jit, vmap
from tqdm.auto import tqdm
import time
Path("results").mkdir(exist_ok=True)
%load_ext autoreload
%autoreload 2

In [3]:
from all_tests import hsicfuse_test, hsic_test
from all_tests import hsicagginc_test, hsicagg_test
from all_tests import nfsic_test, nyhsic_test, fhsic_test

### Vary corruption d=2

In [8]:
repetitions = 100
corruptions = (0.2, 0.4, 0.6, 0.8, 1.0) # corruption = 1 is the same as permute = True
sample_size = 500 
L           = 4
d           = 2

tests = (hsicfuse_test, hsic_test, hsicagginc_test, hsicagg_test, nfsic_test, nyhsic_test, fhsic_test)
outputs = jnp.zeros((len(tests), len(corruptions), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in tqdm(range(len(corruptions))):
    C = corruptions[s]
    t0 = time.time()
    for i in tqdm(range(repetitions)):
        key, subkey = random.split(key)
        seed += 1
        X, Y = sampler_gclusters(subkey, L=L, d=d, theta=0, N=sample_size, C=C)
        for t in range(len(tests)):
            test = tests[t]
            outputs[t][s][i] = test(X, Y, subkey, seed)
    print(i + 1, "/", repetitions, "time:", time.time() - t0)  

output = jnp.mean(jnp.array(outputs), -1) # the last dimension is eliminated

jnp.save("results/3gclusters_vary_dif.npy", output)
jnp.save("results/3gclusters_vary_dif_x_axis.npy", corruptions)

print("corruptions :", corruptions)
print("sample size :", sample_size)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])

print(" ")
print(output)

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

100 / 100 time: 737.3986234664917


  0%|          | 0/100 [00:00<?, ?it/s]

100 / 100 time: 623.8985702991486


  0%|          | 0/100 [00:00<?, ?it/s]

100 / 100 time: 628.2195823192596


  0%|          | 0/100 [00:00<?, ?it/s]

100 / 100 time: 627.6710135936737


  0%|          | 0/100 [00:00<?, ?it/s]

100 / 100 time: 607.178471326828
corruptions : (0.2, 0.4, 0.6, 0.8, 1.0)
sample size : 500
 
<function hsicfuse_test at 0x7f4e7055a280>
[1.         1.         0.93       0.26999998 0.03      ]
 
<function hsic_test at 0x7f4e48c32820>
[0.9        0.53999996 0.21       0.13       0.03      ]
 
<function hsicagginc_test at 0x7f4e48c328b0>
[0.97999996 0.95       0.55       0.09       0.04      ]
 
<function hsicagg_test at 0x7f4e48c32940>
[0.98999995 0.95       0.48999998 0.09       0.02      ]
 
<function nfsic_test at 0x7f4e48c329d0>
[0.84999996 0.53       0.13       0.12       0.04      ]
 
<function nyhsic_test at 0x7f4e48c32a60>
[1.   1.   0.95 0.28 0.05]
 
<function fhsic_test at 0x7f4e48c32af0>
[1.   1.   0.9  0.34 0.03]
 
[[1.         1.         0.93       0.26999998 0.03      ]
 [0.9        0.53999996 0.21       0.13       0.03      ]
 [0.97999996 0.95       0.55       0.09       0.04      ]
 [0.98999995 0.95       0.48999998 0.09       0.02      ]
 [0.84999996 0.53       0.13    

### Vary Sample Size

In [None]:
repetitions  = 100
corruption   = 0.6
# sample_sizes = ( 200,  400,  600,  800, 1000, 1500, 2000, 2500, 3000)
sample_sizes = (3500, )
L            = 4
d            = 2

tests = (hsicfuse_test, hsic_test, hsicagginc_test, hsicagg_test, nfsic_test, nyhsic_test, fhsic_test)
outputs = jnp.zeros((len(tests), len(sample_sizes), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in tqdm(range(len(sample_sizes))):
    N = sample_sizes[s]
    t0 = time.time()
    for i in tqdm(range(repetitions)):
        key, subkey = random.split(key)
        seed += 1
        X, Y = sampler_gclusters(subkey, L=L, d=d, theta=0, N=N, C=corruption)
        for t in range(len(tests)):
            test = tests[t]
            outputs[t][s][i] = test(X, Y, subkey, seed)
    print(i + 1, "/", repetitions, "time:", time.time() - t0)  

output = jnp.mean(jnp.array(outputs), -1) # the last dimension is eliminated

# jnp.save("results/3gclusters_vary_n.npy", output)
# jnp.save("results/3gclusters_vary_n_x_axis.npy", sample_sizes)

print("corruptions :", corruption)
print("sample size :", sample_sizes)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])

print(" ")
print(output)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

### Vary d

In [7]:
repetitions  = 100
corruption   = 0.4
sample_size  = 500
L            = 4
dimension    = [2, 15, 30, 45]

tests = (hsicfuse_test, hsic_test, hsicagginc_test, hsicagg_test, nfsic_test, nyhsic_test, fhsic_test)
outputs = jnp.zeros((len(tests), len(dimension), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in tqdm(range(len(dimension))):
    d = dimension[s]
    t0 = time.time()
    for i in tqdm(range(repetitions)):
        key, subkey = random.split(key)
        seed += 1
        X, Y = sampler_gclusters(subkey, L=L, d=d, theta=0, N=sample_size, C=corruption)
        for t in range(len(tests)):
            test = tests[t]
            outputs[t][s][i] = test(X, Y, subkey, seed)
    print(i + 1, "/", repetitions, "time:", time.time() - t0)  

output = jnp.mean(jnp.array(outputs), -1) # the last dimension is eliminated

jnp.save("results/3gclusters_vary_d.npy", output)
jnp.save("results/3gclusters_vary_d_x_axis.npy", dimension)

print("corruptions :", corruption)
print("dimensions :", dimension)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])

print(" ")
print(output)

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

100 / 100 time: 747.3875694274902


  0%|          | 0/100 [00:00<?, ?it/s]

100 / 100 time: 3915.2625262737274


  0%|          | 0/100 [00:00<?, ?it/s]

100 / 100 time: 4368.990199565887


  0%|          | 0/100 [00:00<?, ?it/s]

100 / 100 time: 4650.8327786922455
corruptions : 0.4
dimensions : [2, 15, 30, 45]
 
<function hsicfuse_test at 0x7fb802af2040>
[1.         0.84999996 0.51       0.26      ]
 
<function hsic_test at 0x7fb7d86c04c0>
[0.57 0.13 0.04 0.03]
 
<function hsicagginc_test at 0x7fb7d86c0550>
[0.93 0.32 0.11 0.09]
 
<function hsicagg_test at 0x7fb7d86c05e0>
[0.97999996 0.38       0.09       0.07      ]
 
<function nfsic_test at 0x7fb7d86c0670>
[0.44       0.17999999 0.04       0.05      ]
 
<function nyhsic_test at 0x7fb7d86c0700>
[1.         0.71       0.09999999 0.17      ]
 
<function fhsic_test at 0x7fb7d86c0790>
[1.   0.37 0.16 0.04]
 
[[1.         0.84999996 0.51       0.26      ]
 [0.57       0.13       0.04       0.03      ]
 [0.93       0.32       0.11       0.09      ]
 [0.97999996 0.38       0.09       0.07      ]
 [0.44       0.17999999 0.04       0.05      ]
 [1.         0.71       0.09999999 0.17      ]
 [1.         0.37       0.16       0.04      ]]
