In [1]:
# Imports
import torch
from dask_jobqueue import SLURMCluster
from distributed import Client
from run_urr import run_experiment
from csuite import SCMS

In [2]:
seeds = 50
nsamples = 1000
corr = 0.0
use_dag = False
scms = list(SCMS.keys())[:1]
noises = ["normal", "rademacher", "cauchy", "gamma", "inversegamma"]
learn_flow = True

In [3]:
cluster = SLURMCluster(
    n_workers=0,
    memory="20GB",
    processes=1,
    cores=1,
    scheduler_options={
        "dashboard_address": ":10092",
        "allowed_failures": 10
    },
    job_cpu=1,
    walltime="24:0:0",
)
cluster.adapt(minimum=0, maximum=200)
client = Client(cluster)

Perhaps you already have a cluster running?
Hosting the HTTP server on port 39919 instead


In [4]:
futures = []
for scm in scms:
    for noise in noises:
        for seed in range(seeds):
            f = client.submit(run_experiment, 
                              sc_name = scm,
                              noise_dist = noise,
                              seed = seed,
                              use_dag = use_dag,
                              corr = corr,
                              N = nsamples,
                              learn_flow = learn_flow)
            futures += [f] 

In [11]:
results = client.gather(futures)

In [8]:
for (i,f) in enumerate(futures):
    if f.status == "error":
        futures[i] = futures[i+1]

In [10]:
futures

[<Future: finished, type: dict, key: run_experiment-6493cdb860ebf82db3b73a6fd32059db>,
 <Future: finished, type: dict, key: run_experiment-23a859f435f8c2ecc031085d4dcf2a2f>,
 <Future: finished, type: dict, key: run_experiment-43c50bde0b562d64320345ca9b786bcd>,
 <Future: finished, type: dict, key: run_experiment-6a693250e22a989a1c43af5410d252db>,
 <Future: finished, type: dict, key: run_experiment-85b1600dec17d912bbe88a790b5b4edf>,
 <Future: finished, type: dict, key: run_experiment-008a88cfeb7b9024959d526a2595bfd3>,
 <Future: finished, type: dict, key: run_experiment-6d8240b637c61bb72207189fc3d84e64>,
 <Future: finished, type: dict, key: run_experiment-97510b075d7b5167656dbce499fc1f56>,
 <Future: finished, type: dict, key: run_experiment-36b25734184c786c1c1f57fb7ecce9b0>,
 <Future: finished, type: dict, key: run_experiment-13e7197fa14a725ea9762bad9f78edd9>,
 <Future: finished, type: dict, key: run_experiment-be632e3058857b440cb18ce8c14d6388>,
 <Future: finished, type: dict, key: run_ex

In [None]:
client.close()
cluster.close()

In [12]:
torch.save(f = "urr_linear_results_n={0}_corr={1}_dag={2}_learnflow={3}_trial={4}.pt".format(nsamples, corr, use_dag,learn_flow,seeds), obj = results)