In [1]:
# Imports
from dask_jobqueue import SLURMCluster
from distributed import Client
from cocycleKR_sparse import run as run_kr
from cocycleKR_dense import run as run_kr_full
from OT_Sparse import run as run_ot

In [2]:
# Cluster creation
cluster = SLURMCluster(
    n_workers=0,
    memory="32GB",
    processes=1,
    cores=1,
    scheduler_options={
        "dashboard_address": ":10095",
        "allowed_failures": 10
    },
    job_cpu=1,
    walltime="24:0:0",
    #job_extra_directives = ["-p medium,fast,cpu"],
)
cluster.adapt(minimum=0, maximum=100)
client = Client(cluster)

In [3]:
# Submitting jobs
n = 500
m = n
ntrial = 20
affine = True
additive = False
corrs = [0.1,0.3,0.5,0.7,0.9]
multivariate = False
dist = "laplace"
futures = []

if additive:
    run_KR = run_kr
else:
    run_KR = run_kr_full
    
for corr in corrs:
    for seed in range(ntrial):
        f1 = client.submit(run_KR,seed,n,m, affine = affine, corr = corr, additive = additive, multivariate_noise = multivariate, dist = dist)
        f3 = client.submit(run_ot,seed,n,m,"sqeuclidean", affine = affine, corr = corr, additive = additive, multivariate_noise = multivariate, dist = dist)
        f4 = client.submit(run_ot,seed,n,m,"cityblock", affine = affine, corr = corr, additive = additive, multivariate_noise = multivariate, dist = dist)
        f5 = client.submit(run_ot,seed,n,m,"chebychev", affine = affine, corr = corr, additive = additive, multivariate_noise = multivariate, dist = dist)
        futures += [f1,f3,f4,f5]

In [4]:
futures

[<Future: pending, key: run-57d8e76d4cc8aeefc1874fb219e4cf8e>,
 <Future: pending, key: run-df0ada6a66ceaaa608e17908a1b3ab89>,
 <Future: pending, key: run-b44a87fbb6e57cf50323bd7ff86a40b8>,
 <Future: pending, key: run-a67e4fba0db4722608d68df1b966eadd>,
 <Future: pending, key: run-8334b4cbd18658c0d4ff9924890ea91c>,
 <Future: pending, key: run-d39e5d3ec5322ec329efc9930acea241>,
 <Future: pending, key: run-5bd73f06cf4662b3fc9a5d8f2a3f24da>,
 <Future: pending, key: run-54d7594edfe2cab8f7d00daf3d950fda>,
 <Future: pending, key: run-dab0d5a009254701755589fca6499f02>,
 <Future: pending, key: run-8774d8fa2759cffe41e79a9b56dbbeef>,
 <Future: pending, key: run-ba171844b30d5c19bee179ef917b4bae>,
 <Future: pending, key: run-ac2e9808314ffba36e42d827888cccfe>,
 <Future: pending, key: run-cbefba71da99acf0ed4eb4346245384f>,
 <Future: pending, key: run-9266c4e8934aa447b0d943307dffa122>,
 <Future: pending, key: run-e755265e79b5be7d0f6a8f0250c876f4>,
 <Future: pending, key: run-039b2f26155fab0665830ee15ce

In [5]:
# Getting results
results = client.gather(futures)

In [6]:
# Closing client
client.close()
cluster.close()

2025-08-15 00:20:14,573 - distributed.deploy.adaptive_core - INFO - Adaptive stop


In [7]:
import torch
torch.save(f = "OT_affine_results_trials={0}_n={1}_m={2}_additive={3}_multivariate={4}_dist={5}.pt".format(ntrial,n,m,additive,multivariate,dist), obj = results)

In [8]:
results

[{'seed': 0,
  'name': 'KReps0',
  'corr': 0.1,
  'additive': False,
  'wrongorder': False,
  'RMSE10': 0.31308665344966957,
  'RMSE21direct': 0.28305454364436866,
  'RMSE21composite': 0.28435505382649007,
  'RMSE20direct': 0.20211828541054613,
  'RMSE20composite': 0.20393943020791555,
  'RMSEinconsistency': 0.004668014122778629,
  'ATE10': 0.13515640190956427,
  'ATE21direct': 0.10483543930516581,
  'ATE21composite': 0.10613870462661752,
  'ATE20direct': 0.0631554347979506,
  'ATE20composite': 0.06549110620755148},
 {'seed': 0,
  'name': 'OT_dist=sqeuclidean',
  'corr': 0.1,
  'additive': False,
  'RMSE10': 0.32961982788849087,
  'RMSE21direct': 0.3772109084216032,
  'RMSE21composite': 0.3327024464754076,
  'RMSE20direct': 0.28492519093314467,
  'RMSE20composite': 0.32628212605347845,
  'RMSEinconsistency': 0.2301306409290604,
  'ATE10': 0.0881082337917899,
  'ATE21direct': 0.08647038180455927,
  'ATE21composite': 0.0864703818045593,
  'ATE20direct': 0.03650420938889022,
  'ATE20compo

In [9]:
ate20_direct = []
ate20_indirect = []
ate21_direct = []
ate21_indirect = []
rmse_direct = []
rmse_indirect = []
for i in range(len(results)):
            ate20_direct.append(results[i]['ATE20direct'])
            ate20_indirect.append(results[i]['ATE20composite'])
            ate21_direct.append(results[i]['ATE21direct'])
            ate21_indirect.append(results[i]['ATE21composite'])
            rmse_direct.append(results[i]['RMSE20direct'])
            rmse_indirect.append(results[i]['RMSE20composite'])

print(torch.tensor(ate20_direct).mean())
print(torch.tensor(ate20_indirect).mean())
print(torch.tensor(ate21_direct).mean())
print(torch.tensor(ate21_indirect).mean())
print(torch.tensor(rmse_direct).mean())
print(torch.tensor(rmse_indirect).mean())

tensor(0.0727, dtype=torch.float64)
tensor(0.0727, dtype=torch.float64)
tensor(0.0964, dtype=torch.float64)
tensor(0.0962, dtype=torch.float64)
tensor(0.9255, dtype=torch.float64)
tensor(0.9230, dtype=torch.float64)
