In [3]:
os.chdir("./causalklgp")

In [4]:
# Imports
import torch
import os 
import sys
from dask_jobqueue import SLURMCluster
from distributed import Client
from pathlib import Path


import do_abelation_causalklgp

In [5]:
# Args setup
ntrial = 50
n = 100
ntest = 1000
d = 5
noise = 1.0
calibrate = [False,True, True, True]
calibrate_latent = [False, False, True, True]
train_calibration_model = [False, False, False, True]

In [6]:
# Cluster creation
cluster = SLURMCluster(
    n_workers=0,
    memory="32GB",
    processes=1,
    cores=1,
    scheduler_options={
        "dashboard_address": ":11111",
        "allowed_failures": 10
    },
    job_cpu=1,
    walltime="6:0:0",
    
    #job_extra_directives = ["-p medium,cpu"],
)
cluster.adapt(minimum=0, maximum=200)
client = Client(cluster)

In [8]:
# Submitting jobs
futures = []
for i in range(len(calibrate)):           
    for seed in range(ntrial):
            f = client.submit(do_abelation_causalklgp.main,seed,n,ntest,d,noise,
                             calibrate = calibrate[i], cal_latent = calibrate_latent[i],
                             traincalmodel = train_calibration_model[i])
            futures += [f]

In [9]:
# Check on futures
futures

[<Future: pending, key: main-a2c99c4fa00cf8243e785ef1cfd196d9>,
 <Future: pending, key: main-5e77c1f1d5855b9acfd19bcaf411c43c>,
 <Future: pending, key: main-3c61e972a259c11efe321226c11da2ac>,
 <Future: pending, key: main-01ce7cb7e527d45fe905c9bf247d950a>,
 <Future: pending, key: main-0475b27f05685befb86656e2d346c6ab>,
 <Future: pending, key: main-869e3e9a17bf4525a4c9f708ebc1243b>,
 <Future: pending, key: main-470ce3599001a21144863825144ce121>,
 <Future: pending, key: main-63f1475225810ca4aaa2fa3a63d98298>,
 <Future: pending, key: main-32e673f2d7dbf7d50e885a0c29ef8a09>,
 <Future: pending, key: main-bb7ecec414d45b21f171dd8abc12c048>,
 <Future: pending, key: main-24d1956603fe9dde1e40302f40d40347>,
 <Future: pending, key: main-cbcfa7d09dc89ec8c9cf0759abce2195>,
 <Future: pending, key: main-04d3e0de6203b40022d8c6f9f34c1540>,
 <Future: pending, key: main-8b00c1c3df063750c2a000963f1962ee>,
 <Future: pending, key: main-f6b542acc6ec874c7a2896f730139b11>,
 <Future: pending, key: main-2ca8548ec39

In [10]:
# Getting results
results = client.gather(futures)

In [11]:
# Closing client
client.close()
cluster.close()

In [12]:
# Saving results
torch.save(obj = results,
           f = "abelation_causalklgp_ntrial={0}_n={1}_d={2}_noise={3}.pt".format(ntrial,n,d,noise))