In [2]:
# !pip install dask-jobqueue

In [1]:
# Parameters for the cluster

from dask_jobqueue import SLURMCluster
from distributed import Client
from os import path
import os
os.environ["DASK_DISTRIBUTED__WORKER__DAEMON"] = "False"

# queue = 'gpu'
queue = 'cpu'
# queue = 'medium' # time limit 12h [both gpus and cpus]
# queue = 'fast'  # time limit 3h  [both gpus and cpus]

# gpu_flag = '--gres=gpu:1'
gpu_flag = ''

memory = '16GB'

job_cpu = 5

# you need to change this to match a host you have setup the portfowarding 
# on your mac mini in ~/.ssh/config
# in particular the port forwarding might change depending on the 'queue' value
# this host needs to be different from your jupyter notebook host
host = '8882'
# then in your browser type http://localhost:8882/ 
# and you should have access to a dask dashboard which will show the progression on your submitted jobs

# after this time your jobs will be automatically cancelled, just put a high number 
# and kill the cluster before it reaches the end
hours = 80

cluster = SLURMCluster(
    queue=queue,
    memory=memory,
    processes=1, # leave like that
    cores=1, # leave like that
    job_cpu=job_cpu, 
    scheduler_options={'dashboard_address': ':' + host, 'host': ':45353'},
    job_extra_directives=['--output=test.out', '--time=' + str(hours) + ':0:0', gpu_flag],
    # depending on the version of dask you might need to replace the above line with
    #job_extra=['--output=test.out', '--time=' + str(hours) + ':0:0', gpu_flag] 
    death_timeout=60 * 5, # leave like that
    walltime=60 * 3, # leave like that
    )

client = Client(cluster)

In [2]:
cluster

0,1
Dashboard: http://192.168.234.48:8882/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://192.168.234.48:45353,Workers: 0
Dashboard: http://192.168.234.48:8882/status,Total threads: 0
Started: Just now,Total memory: 0 B


In [3]:
client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.SLURMCluster
Dashboard: http://192.168.234.48:8882/status,

0,1
Dashboard: http://192.168.234.48:8882/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://192.168.234.48:45353,Workers: 0
Dashboard: http://192.168.234.48:8882/status,Total threads: 0
Started: Just now,Total memory: 0 B


In [4]:
# in your terminal in tmux you can type
# watch -n0.5 squeue --me
# and you will see all your jobs and whether they have started or not

# now to launch some workers we use the scale function
# Here we request 10 workers, you can play with this number
# If we have 10 workers, they work in parallel so compute time will roughly be divided by 10
cluster.scale(5)

In [None]:
# when you are done you scale the cluster to 0 to kill the jobs
# you can check in your terminal they are not running anymore
# cluster.scale(0)

In [None]:
# always check in your terminal whether you don't have some unwanted jobs running
# if so you can kill them by checking the JOBID [e.g. 3813807] and running in the terminal
# scancel 3813807

In [5]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import random, jit, vmap
from tqdm.auto import tqdm
from pathlib import Path
import sys
from sampler_perturbations import f_theta_sampler
Path("results").mkdir(exist_ok=True)
%load_ext autoreload
%autoreload 2

In [6]:
from all_tests import nfsic_test, nyhsic_test, fhsic_test

In [7]:
use_cluster = True
save        = True

### Vary difficulty d=2

In [10]:
repetitions          = 200
scales               = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0)
sample_size          = 500  

d                    = 2
number_perturbations = 2
f_theta_seed         = 0
p                    = 1
s                    = 1

tests = (nfsic_test, nyhsic_test, fhsic_test)
outputs_numpy = jnp.zeros((len(tests), len(scales), repetitions))
outputs_numpy = outputs_numpy.tolist()
key = random.PRNGKey(42)
seed = 42
for s in tqdm(range(len(scales))):
    scale = scales[s]
    for i in tqdm(range(repetitions)):
        key, subkey = random.split(key)
        perturbation_multiplier = np.exp(d) * p ** s * scale 
        Z = f_theta_sampler(f_theta_seed, seed, sample_size, p, s, perturbation_multiplier, d)
        X = np.expand_dims(Z[:, 0], 1)
        Y = np.expand_dims(Z[:, 1], 1)
        key, subkey = random.split(key)
        seed += 1
        for t in range(len(tests)):
            test = tests[t]
            if use_cluster:
                # we submit the jobs to the cluster
                # outputs now contain some objects called futures
                outputs_numpy[t][s][i] = client.submit(test, X, Y, subkey, seed)
            else:
                outputs_numpy[t][s][i] = test(X, Y, subkey, seed)


In [9]:
# now we ask dask to gather the results once they are completed

if use_cluster:
    results = [client.gather(outputs_numpy[t]) for t in range(len(tests))]
    results = jnp.array(results)
    outputs_numpy = jnp.mean(results, -1)
else:
    outputs_numpy = jnp.mean(jnp.array(outputs_numpy), -1)

if save:
    jnp.save("results/perturbations_vary_dif_d2_cluster_numpy.npy", outputs_numpy)
    jnp.save("results/perturbations_vary_dif_d2_cluster_numpy_x_axis.npy", scales)

print("scales :", scales)
print("sample size :", sample_size)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(outputs_numpy[t])

scales : (0.0, 0.2, 0.4, 0.6, 0.8, 1.0)
sample size : 500
 
<function nfsic_test at 0x7f07abfcbdc0>
[0.035      0.05       0.185      0.48499998 0.87       0.97999996]
 
<function nyhsic_test at 0x7f07abfcbe50>
[0.055      0.36499998 0.97499996 1.         1.         1.        ]
 
<function fhsic_test at 0x7f07abfcbee0>
[0.03       0.35999998 0.92499995 1.         1.         1.        ]


## Vary sample size d=2

In [9]:
repetitions          = 200
scale                = 0.6
number_perturbations = 2
# sample_sizes         = (200, 400, 600, 800, 1000)
sample_sizes = (200, )
d                    = 2
f_theta_seed         = 0
p                    = 1
s                    = 1

# tests = (hsicfuse_test, hsic_test, hsicagginc1_test, hsicagginc100_test, hsicagginc200_test, hsicaggincquad_test)
tests = (nfsic_test, )
outputs = jnp.zeros((len(tests), len(sample_sizes), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in tqdm(range(len(sample_sizes))):
    sample_size = sample_sizes[s]
    for i in tqdm(range(repetitions)):
        key, subkey = random.split(key)
        perturbation_multiplier = np.exp(d) * p ** s * scale 
        Z = f_theta_sampler(f_theta_seed, seed, sample_size, p, s, perturbation_multiplier, d)
        X = np.expand_dims(Z[:, 0], 1)
        Y = np.expand_dims(Z[:, 1], 1)
        key, subkey = random.split(key)
        seed += 1
        for t in range(len(tests)):
            test = tests[t]
            outputs[t][s][i] = test(X, Y, subkey, seed)

output = jnp.mean(jnp.array(outputs), -1)

jnp.save("results/perturbations_vary_n_d2.npy", output)
jnp.save("results/perturbations_vary_n_d2_x_axis.npy", sample_sizes)

print("sample sizes :", sample_sizes)
print("scale :", scale)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

Exception ignored in: <bound method GCDiagnosis._gc_callback of <distributed.utils_perf.GCDiagnosis object at 0x7fc944390c40>>
Traceback (most recent call last):
  File "/nfs/ghome/live/gren/mambaforge/envs/hsicfuse-env/lib/python3.9/site-packages/distributed/utils_perf.py", line 176, in _gc_callback
    def _gc_callback(self, phase, info):
KeyboardInterrupt: 
Exception ignored in: <bound method GCDiagnosis._gc_callback of <distributed.utils_perf.GCDiagnosis object at 0x7fc944390c40>>
Traceback (most recent call last):
  File "/nfs/ghome/live/gren/mambaforge/envs/hsicfuse-env/lib/python3.9/site-packages/distributed/utils_perf.py", line 183, in _gc_callback
    self._fractional_timer.start_timing()
  File "/nfs/ghome/live/gren/mambaforge/envs/hsicfuse-env/lib/python3.9/site-packages/distributed/utils_perf.py", line 118, in start_timing
    assert self._cur_start is None
AssertionError: 


KeyboardInterrupt: 

### Vary d

In [33]:
repetitions          = 200
scales               = 1.0
number_perturbations = 2
# sample_size          = 500
sample_size = 100
# d_values             = (1, 2, 3, 4)
d_values = (1, )
f_theta_seed         = 0
p                    = 1
s                    = 1

# tests = (hsicfuse_test, hsic_test, hsicagginc1_test, hsicagginc100_test, hsicagginc200_test, hsicaggincquad_test)
tests = (nfsic_test, )
outputs = jnp.zeros((len(tests), len(d_values), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in tqdm(range(len(d_values))):
    d = d_values[s]
    for i in tqdm(range(repetitions)):
        key, subkey = random.split(key)
        perturbation_multiplier = np.exp(d + 1) * p ** s * scale 
        Z = f_theta_sampler(f_theta_seed, seed, sample_size, p, s, perturbation_multiplier, d + 1)
        X = Z[:, :1]
        Y = Z[:, 1:]
        key, subkey = random.split(key)
        seed += 1
        for t in range(len(tests)):
            test = tests[t]
            outputs[t][s][i] = test(X, Y, subkey, seed)

output = jnp.mean(jnp.array(outputs), -1)

jnp.save("results/perturbations_vary_d.npy", output)
jnp.save("results/perturbations_vary_d_x_axis.npy", d_values)

print("d :", d_values)
print("sample size :", sample_size)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

KeyboardInterrupt: 