In [1]:
# Imports
import torch
from torch.distributions import Normal,Laplace,StudentT
from dask_jobqueue import SLURMCluster
from distributed import Client
import BD_SCM

In [2]:
# Args set up
N = 5000
ntrial = 20
batch_size = 64
scheduler = True
flip_prob = 0.05

hypers = ["weight_decay","batch_size","scheduler"]
hyper_vals = [[0,batch_size,scheduler],
              [1e-4,batch_size,scheduler],
              [1e-3,batch_size,scheduler],
              [1e-2,batch_size,scheduler],
              [1e-1,batch_size,scheduler]]

In [3]:
# Cluster creation
cluster = SLURMCluster(
    n_workers=0,
    memory="32GB",
    processes=1,
    cores=1,
    scheduler_options={
        "dashboard_address": ":10095",
        "allowed_failures": 10
    },
    job_cpu=1,
    walltime="24:0:0",
    #job_extra_directives = ["-p medium,fast,cpu"],
)
cluster.adapt(minimum=0, maximum=100)
client = Client(cluster)

In [4]:
# Submitting jobs
futures = []
for seed in range(ntrial):
    for h in hyper_vals:
        f1 = client.submit(BD_SCM.run_experiment,seed,Normal(0,1),False,hypers,h,flip_prob,N)
        f2 = client.submit(BD_SCM.run_experiment,seed,Laplace(0,1),False,hypers,h,flip_prob,N)
        f3 = client.submit(BD_SCM.run_experiment,seed,StudentT(10,0,1),True,hypers,h,flip_prob,N)
        futures += [f1,f2,f3]

In [5]:
futures

[<Future: pending, key: run_experiment-7bef2bf042a89b2b647e1f0c6b826160>,
 <Future: pending, key: run_experiment-d4d6a127d686c228b718517c684fca40>,
 <Future: pending, key: run_experiment-90b724fcd89010714883801217bc7315>,
 <Future: pending, key: run_experiment-e9657a9a59d96a417d228715f4516686>,
 <Future: pending, key: run_experiment-a23bcbae243b39d2dacab70dc21e50bf>,
 <Future: pending, key: run_experiment-956964e488b01846eb4f176421a5d818>,
 <Future: pending, key: run_experiment-3a7ccff4227957b733c67b0c6e550918>,
 <Future: pending, key: run_experiment-915c0deada952956d903b1c87548ad03>,
 <Future: pending, key: run_experiment-3ffcd9840f4b8e6e66e11d21bc645d6c>,
 <Future: pending, key: run_experiment-9469aef0daf4654e3c5f3f9e47b763b3>,
 <Future: pending, key: run_experiment-28de36975089ed102639f6b8f44868b1>,
 <Future: pending, key: run_experiment-2e0cd0545209ec4050aab28769a6d5b9>,
 <Future: pending, key: run_experiment-63eff756f85232c9d7fc7c9251f8ee34>,
 <Future: pending, key: run_experiment

In [6]:
# Getting results
results = client.gather(futures)

In [7]:
# Closing client
client.close()
cluster.close()

2024-03-23 04:50:29,467 - distributed.deploy.adaptive_core - INFO - Adaptive stop


In [8]:
torch.save(f = "BD_SCM_results_new_N={4}_trials={0}_batchsize={1}_scheduler={2}_flip_prob={3}.pt".format(ntrial,batch_size,scheduler,flip_prob,N), obj = results)