In [1]:
# Imports
import torch
import os 
import sys
from dask_jobqueue import SLURMCluster
from distributed import Client
from pathlib import Path

import do_ablation_causalklgp as file

In [2]:
# Args setup
ntrial = 50
n = 1000
ntest = 100
d = 5
noise = 0.5
niter = 1000
calibrate = [False, True, True]
sample_split = [False, False, True]
marginal_loss = [False, False, False]
retrain_hypers = [False, False, False]
kernel = "gaussian"

In [3]:
# Cluster creation
cluster = SLURMCluster(
    n_workers=0,
    memory="32GB",
    processes=1,
    cores=1,
    scheduler_options={
        "dashboard_address": ":11111",
        "allowed_failures": 10
    },
    job_cpu=1,
    walltime="3:0:0",
    
    job_extra_directives = ["-p medium,fast,cpu"],
)
cluster.adapt(minimum=0, maximum=200)
client = Client(cluster)

Perhaps you already have a cluster running?
Hosting the HTTP server on port 36075 instead


In [4]:
# Submitting jobs
futures = []
for i in range(len(calibrate)):           
    for seed in range(ntrial):
            f = client.submit(file.main,
                              seed,
                              n,ntest,d,noise,
                              calibrate = calibrate[i],
                              sample_split = sample_split[i],
                              marginal_loss = marginal_loss[i],
                              retrain_hypers = retrain_hypers[i],
                              kernel = kernel,
                              niter = niter
                             )
            futures += [f]

In [9]:
# Check on futures
futures

[<Future: cancelled, type: dict, key: main-6e6ff6a03f2670e1ca90a7bd5c79de48>,
 <Future: cancelled, type: dict, key: main-7c05802e0855e08468b75a03bb60a6cc>,
 <Future: cancelled, type: dict, key: main-26e6bbd39b2f85405f542d9902abbc09>,
 <Future: cancelled, type: dict, key: main-bdd69b812da078c3ec83f68cd2b538cf>,
 <Future: cancelled, type: dict, key: main-b06677b879b22f9a6ccf0c0273fa1336>,
 <Future: cancelled, type: dict, key: main-102732e207371cc9b3eb807ab68a3e79>,
 <Future: cancelled, type: dict, key: main-4b47d0a79805d542966574f02ab55f55>,
 <Future: cancelled, type: dict, key: main-8249d01f222652c93b264773e9103525>,
 <Future: cancelled, type: dict, key: main-123ca2913d93f47e6f14bc3eab82fd2f>,
 <Future: cancelled, type: dict, key: main-16ffa0c0cf21fcfc7b42a265c25f2519>,
 <Future: cancelled, type: dict, key: main-cd71e715a50be5efb5a3a66f0b6a1d3c>,
 <Future: cancelled, type: dict, key: main-80763adbad8953ffebca3cf506d128d8>,
 <Future: cancelled, type: dict, key: main-5da15ca3f94164a402261

In [6]:
# Getting results
results = client.gather(futures)

In [7]:
# Closing client
client.close()
cluster.close()

In [8]:
# Saving results
torch.save(obj = results,
           f = "ablation_causalklgp_ntrial={0}_n={1}_d={2}_noise={3}_kernel={4}.pt".format(ntrial,
                                                                                           n,
                                                                                           d,
                                                                                           noise,
                                                                                           kernel))