In [1]:
# Imports
import torch
import os 
import sys
from dask_jobqueue import SLURMCluster
from distributed import Client
from pathlib import Path

import do_ablation_causalklgp as file

In [2]:
# Args setup
ntrial = 50
n = 100
ntest = 100
d = 5
noise = 0.5
niter = 1000
calibrate = [False, True, True]
sample_split = [False, False, True]
marginal_loss = [False, False, False]
retrain_hypers = [False, False, False]
kernel = "gaussian"

In [3]:
# Cluster creation
cluster = SLURMCluster(
    n_workers=0,
    memory="32GB",
    processes=1,
    cores=1,
    scheduler_options={
        "dashboard_address": ":11111",
        "allowed_failures": 10
    },
    job_cpu=1,
    walltime="3:0:0",
    
    job_extra_directives = ["-p medium,fast,cpu"],
)
cluster.adapt(minimum=0, maximum=200)
client = Client(cluster)

In [4]:
# Submitting jobs
futures = []
for i in range(len(calibrate)):           
    for seed in range(ntrial):
            f = client.submit(file.main,
                              seed,
                              n,ntest,d,noise,
                              calibrate = calibrate[i],
                              sample_split = sample_split[i],
                              marginal_loss = marginal_loss[i],
                              retrain_hypers = retrain_hypers[i],
                              kernel = kernel,
                              niter = niter
                             )
            futures += [f]

In [5]:
# Check on futures
futures

[<Future: pending, key: main-9464acb55bc266cff6c1acd2b0c6d6c6>,
 <Future: pending, key: main-cbe87ca18287d269edd05b0afc4850b6>,
 <Future: pending, key: main-0ef7c05aecd6c9bd5d1e86017b299a8b>,
 <Future: pending, key: main-55da9648a51056ff90498ac2493113a0>,
 <Future: pending, key: main-e5a7da35a0b1e28b4f0c4e7998402a95>,
 <Future: pending, key: main-8401d4fa02eaa8d989d4a905fa124a1d>,
 <Future: pending, key: main-bc77cf61576160a9947becf51bc650ab>,
 <Future: pending, key: main-ac63675d8d59ad9c2ab73194c34beefe>,
 <Future: pending, key: main-b3d442e53c24daa8e83cf1ec033c5d99>,
 <Future: pending, key: main-7bdcf057529e18dcabb94b862588335f>,
 <Future: pending, key: main-475ca732a9b4e27fd5efcd37e2fc058d>,
 <Future: pending, key: main-d0516db9c58cad26f0cc9cfc11ff0992>,
 <Future: pending, key: main-c495af123d1c9962b3e58af1a69a7f28>,
 <Future: pending, key: main-b9d951e4e50213659a21d54c6a1a67ce>,
 <Future: pending, key: main-8fc4b11916a8748cb53bc723b6e44313>,
 <Future: pending, key: main-98b59baf2da

In [6]:
# Getting results
results = client.gather(futures)

In [7]:
# Closing client
client.close()
cluster.close()

In [8]:
# Saving results
torch.save(obj = results,
           f = "ablation_causalklgp_ntrial={0}_n={1}_d={2}_noise={3}_kernel={4}.pt".format(ntrial,
                                                                                           n,
                                                                                           d,
                                                                                           noise,
                                                                                           kernel))