# Dask cluster

In [1]:
# !pip install dask-jobqueue

In [1]:
# Parameters for the cluster

from dask_jobqueue import SLURMCluster
from distributed import Client
from os import path
import os
os.environ["DASK_DISTRIBUTED__WORKER__DAEMON"] = "False"

# queue = 'gpu'
queue = 'cpu'
# queue = 'medium' # time limit 12h [both gpus and cpus]
# queue = 'fast'  # time limit 3h  [both gpus and cpus]

# gpu_flag = '--gres=gpu:1'
gpu_flag = ''

memory = '16GB'

# if you increase your cpu per jobs you get more cpus so your individual jobs finish quicker
# but you get allocated fewer ressources
job_cpu = 5

# you need to change this to match a host you have setup the portfowarding 
# on your mac mini in ~/.ssh/config
# in particular the port forwarding might change depending on the 'queue' value
# this host needs to be different from your jupyter notebook host
host = '8883'
# then in your browser type http://localhost:8883/ 
# and you should have access to a dask dashboard which will show the progression on your submitted jobs

# after this time your jobs will be automatically cancelled, just put a high number 
# and kill the cluster before it reaches the end
hours = 80

cluster = SLURMCluster(
    queue=queue,
    memory=memory,
    processes=1, # leave like that
    cores=1, # leave like that
    job_cpu=job_cpu, 
    scheduler_options={'dashboard_address': ':' + host, 'host': ':45353'},
    job_extra_directives=['--output=test.out', '--time=' + str(hours) + ':0:0', gpu_flag],
    # depending on the version of dask you might need to replace the above line with
    #job_extra=['--output=test.out', '--time=' + str(hours) + ':0:0', gpu_flag] 
    death_timeout=60 * 5, # leave like that
    walltime=60 * 3, # leave like that
    )

client = Client(cluster)

In [2]:
cluster

0,1
Dashboard: http://192.168.234.48:8883/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://192.168.234.48:45353,Workers: 0
Dashboard: http://192.168.234.48:8883/status,Total threads: 0
Started: Just now,Total memory: 0 B


In [3]:
client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.SLURMCluster
Dashboard: http://192.168.234.48:8883/status,

0,1
Dashboard: http://192.168.234.48:8883/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://192.168.234.48:45353,Workers: 0
Dashboard: http://192.168.234.48:8883/status,Total threads: 0
Started: Just now,Total memory: 0 B


In [4]:
# in your terminal in tmux you can type
# watch -n0.5 squeue --me
# and you will see all your jobs and whether they have started or not

# now to launch some workers we use the scale function
# Here we request 10 workers, you can play with this number
# If we have 10 workers, they work in parallel so compute time will roughly be divided by 10
cluster.scale(10)

In [5]:
# when you are done you scale the cluster to 0 to kill the jobs
# you can check in your terminal they are not running anymore
cluster.scale(0)

In [6]:
# always check in your terminal whether you don't have some unwanted jobs running
# if so you can kill them by checking the JOBID [e.g. 3813807] and running in the terminal
# scancel 3813807

# Example power

In [5]:
import jax
import jax.numpy as jnp
from jax import random
from tqdm.auto import tqdm
from pathlib import Path
import time
from sampler_perturbations import sampler_perturbations
Path("results").mkdir(exist_ok=True)
%load_ext autoreload
%autoreload 2

In [6]:
import numpy as np
from hsicfuse import hsicfuse
from hsic import hsic, human_readable_dict
from agginc.jax import agginc, human_readable_dict
from wittawatj_tests import nfsic
from nystromhsic import nystromhsic


def test_1(X, Y, key, seed):
    time.sleep(0.2)
    return 0

def test_2(X, Y, key, seed):
    time.sleep(0.2)
    return 0

def test_3(X, Y, key, seed):
    time.sleep(0.2)
    return 0

def test_4(X, Y, key, seed):
    time.sleep(0.2)
    return 0

# the jax tests will use key
# the other tests will use seed

# use something like
#def test_hsicfuse(key, X, Y):
#    return hsic_fuse(
#        key, 
#        X, 
#        Y, 
#        alpha=0.05, 
#        B=2000,
#    )

In [7]:
# as you run this you should see on the dask dashboard the jobs getting completed

use_cluster = True
save = True

# repetitions = 200
repetitions = 10
scales = (0, 0.1, 0.2, 0.3, 0.4, 0.5)
number_perturbations = 2
sample_size = 500
d = 2

f_theta_seed         = 0
p                    = 1
s                    = 1

tests = (test_1, test_2, test_3, test_4)
outputs = jnp.zeros((len(tests), len(scales), repetitions))
outputs = outputs.tolist()
key = random.PRNGKey(42)
seed = 42
for s in tqdm(range(len(scales))):
    scale = scales[s]
    for i in tqdm(range(repetitions)):
        key, subkey = random.split(key)
        X, Y = sampler_perturbations(m=sample_size, n=sample_size, d=d, scale=scale, number_perturbations=number_perturbations, seed=seed)
        key, subkey = random.split(key)
        seed += 1
        for t in range(len(tests)):
            test = tests[t]
            if use_cluster:
                # we submit the jobs to the cluster
                # outputs now contain some objects called futures
                outputs[t][s][i] = client.submit(test, X, Y, subkey, seed)
            else:
                outputs[t][s][i] = test(X, Y, subkey, seed)


  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

In [8]:
# now we ask dask to gather the results once they are completed

if use_cluster:
    print("done here 1")
    results = [client.gather(outputs[t]) for t in range(len(tests))]
    print("done here 2")
    results = jnp.array(results)
    output = jnp.mean(results, -1)
else:
    output = jnp.mean(jnp.array(outputs), -1)

if save:
    jnp.save("results/toy_example.npy", output)
    jnp.save("results/toy_example_x_axis.npy", scales)

print("scales :", scales)
print("sample size :", sample_size)
for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])

done here 1
done here 2
scales : (0, 0.1, 0.2, 0.3, 0.4, 0.5)
sample size : 500
 
<function test_1 at 0x7fbe0c7c0670>
[0. 0. 0. 0. 0. 0.]
 
<function test_2 at 0x7fbd8b004160>
[0. 0. 0. 0. 0. 0.]
 
<function test_3 at 0x7fbd8b0041f0>
[0. 0. 0. 0. 0. 0.]
 
<function test_4 at 0x7fbd8b004280>
[0. 0. 0. 0. 0. 0.]


In [None]:
# you can run both cells together (submit jobs and wait for them to gather)
# but you don t have to you can submit jobs, do other stuff, and only then gather