In [1]:
# Imports
import torch
import os 
import sys
from dask_jobqueue import SLURMCluster
from distributed import Client
from pathlib import Path

import do_ablation_causalklgp_new as file

In [2]:
# Args setup
ntrial = 50
n = 100
ntest = 100
d = 5
noise = 0.05
calibrate = [False, True, True, True, True]
sample_split = [False, False, True, False, True]
marginal_loss = [False, False, False, False, False]
retrain_hypers = [False, False, False, True, True]

In [3]:
# Cluster creation
cluster = SLURMCluster(
    n_workers=0,
    memory="32GB",
    processes=1,
    cores=1,
    scheduler_options={
        "dashboard_address": ":11111",
        "allowed_failures": 10
    },
    job_cpu=1,
    walltime="3:0:0",
    
    job_extra_directives = ["-p medium,fast,cpu"],
)
cluster.adapt(minimum=0, maximum=200)
client = Client(cluster)

In [4]:
# Submitting jobs
futures = []
for i in range(len(calibrate)):           
    for seed in range(ntrial):
            f = client.submit(file.main,
                              seed,
                              n,ntest,d,noise,
                              calibrate = calibrate[i],
                              sample_split = sample_split[i],
                              marginal_loss = marginal_loss[i],
                              retrain_hypers = retrain_hypers[i]
                             )
            futures += [f]

In [5]:
# Check on futures
futures

[<Future: pending, key: main-afd499a888a78c2e384981ba12a51325>,
 <Future: pending, key: main-5971db769a35411dead64330e7e1c067>,
 <Future: pending, key: main-431a23e3ce26c62efc222649e644db24>,
 <Future: pending, key: main-f838d4fb91ac9fb32676a1c1f4882625>,
 <Future: pending, key: main-d53480e2154f14b4dd2289c4c10030af>,
 <Future: pending, key: main-6722f38692e89904ee646a6da08cf4c4>,
 <Future: pending, key: main-4d95e2ff3753e6b8511bb8a35a77ed12>,
 <Future: pending, key: main-81f5d12f1ae76e655f1c5ddbbad4bda6>,
 <Future: pending, key: main-9afd6dd8c78b1ed7e5cd452c76f5f813>,
 <Future: pending, key: main-527e49245207dd2b5db2ab4cc2726e85>,
 <Future: pending, key: main-d92acde13ebd17e58dbf8add41b787b2>,
 <Future: pending, key: main-5b4eff90db08e9e001d330ce554be241>,
 <Future: pending, key: main-a02c17f9ff8080d2d4efdb7e7928ba27>,
 <Future: pending, key: main-230dc1f173384c4a2207505832874adc>,
 <Future: pending, key: main-1e1e41cee732bd9e4845a324781da576>,
 <Future: pending, key: main-508916d6bb2

In [6]:
# Getting results
results = client.gather(futures)

In [7]:
# Closing client
client.close()
cluster.close()

In [8]:
# Saving results
torch.save(obj = results,
           f = "ablation_causalklgp_ntrial={0}_n={1}_d={2}_noise={3}.pt".format(ntrial,n,d,noise))