### Add uot package to path

In [1]:
import sys
import os

sibling_path = os.path.abspath(os.path.join(os.getcwd(), '..', '.'))

if sibling_path not in sys.path:
    sys.path.insert(0, sibling_path)


# OT experiments

Configure jax

In [2]:
import jax
jax.config.update("jax_enable_x64", True)

all necessary imports

In [3]:
from uot.algorithms.sinkhorn import jax_sinkhorn
from uot.algorithms.gradient_ascent import gradient_ascent
from uot.algorithms.lbfgs import lbfgs_ot
from uot.algorithms.lp import pot_lp
from uot.algorithms.pdlp import solve_pdlp
from uot.algorithms.pdhg import solve_pdhg
from uot.core.experiment import run_experiment
from uot.core.suites import time_precision_experiment

Define solvers and their params

In [4]:
epsilon_kwargs = [
    # {'epsilon': 100},
    {'epsilon': 10},
    # {'epsilon': 1},
    # {'epsilon': 1e-1},
    # {'epsilon': 1e-3},
    # {'epsilon': 1e-6},
    # {'epsilon': 1e-9},
]

solvers = {
    # 'pot-lp': (pot_lp, []),
    # 'lbfgs': (lbfgs_ot, epsilon_kwargs),
    'jax-sinkhorn': (jax_sinkhorn, [{'epsilon': 0.1}, {'epsilon': 1}]),
    'grad-ascent': (gradient_ascent, [ # grad ascent works really bad for big regularizations
                                    # {'epsilon': 1},
                                    {'epsilon': 1e-1}
                                    # {'epsilon': 1e-3},
                                    # {'epsilon': 1e-6},
                                    # {'epsilon': 1e-9},
                                    ]),
    # 'pdlp': (solve_pdlp, epsilon_kwargs),
    # 'pdhg': (solve_pdhg, epsilon_kwargs),
}

# algorithms that use jax jit 
# jit_algorithms = [
#     'jax-sinkhorn', 'optax-grad-ascent', 'lbfgs'
# ]

jit_algorithms = [
    # 'pdlp'
]


Define problemset:

In [5]:
problemset_names = [
    # ('distribution', "gamma", 32),
    # ('distribution', "gamma", 64),
    # ('distribution', "gamma", 256),
    # ('distribution', "gamma", 512),
    # ('distribution', "gamma", 1024),
    # ('distribution', "gamma", 2048),

    # ('distribution', "gaussian", 32),
    ('distribution', "gaussian", 64),
    # ('distribution', "gaussian", 256),
    # ('distribution', "gaussian", 512),
    # ('distribution', "gaussian", 1024),
    # ('distribution', "gaussian", 2048),

    # ('distribution', "beta", 32),
    # ('distribution', "beta", 64),
    # ('distribution', "beta", 256),
    # ('distribution', "beta", 512),
    # ('distribution', "beta", 1024),
    # ('distribution', "beta", 2048),
    #
    # ('distribution', "gaussian|gamma|beta|cauchy", 32),
    # ('distribution', "gaussian|gamma|beta|cauchy", 64),
    # ('distribution', "gaussian|gamma|beta|cauchy", 128),
    # ('distribution', "gaussian|gamma|beta|cauchy", 256),
    # ('distribution', "gaussian|gamma|beta|cauchy", 512),
    # ('distribution', "gaussian|gamma|beta|cauchy", 1024),
    # ('distribution', "gaussian|gamma|beta|cauchy", 2048),

]

Run experiment:

In [6]:
df = run_experiment(experiment=time_precision_experiment, 
                    problemsets_names=problemset_names,
                    solvers=solvers,
                    jit_algorithms=jit_algorithms,
                    folds=1)

Solver: grad-ascent({'epsilon': 0.1}): 100%|██████████| 12/12 [00:00<00:00, 14.30it/s]


Save data:

In [7]:
df.to_csv("ot_experiments.csv")

In [8]:
df[["dataset", "time", "converged", "name", "cost_rerr", "coupling_avg_err", "epsilon"]]

Unnamed: 0,dataset,time,converged,name,cost_rerr,coupling_avg_err,epsilon
0,64 1D gaussian,228.454291,True,jax-sinkhorn,0.015694,0.162328,0.1
1,64 1D gaussian,1.904375,True,jax-sinkhorn,0.011282,0.166842,0.1
2,64 1D gaussian,1.529667,True,jax-sinkhorn,0.002774,0.219698,0.1
3,64 1D gaussian,1.369208,True,jax-sinkhorn,0.007794,0.181927,0.1
0,64 1D gaussian,82.088584,True,jax-sinkhorn,0.018971,0.163259,1.0
1,64 1D gaussian,2.76275,True,jax-sinkhorn,0.012775,0.167703,1.0
2,64 1D gaussian,3.008416,True,jax-sinkhorn,0.002851,0.22031,1.0
3,64 1D gaussian,1.646458,True,jax-sinkhorn,0.008451,0.182705,1.0
0,64 1D gaussian,165.074125,True,grad-ascent,0.068103,0.182158,0.1
1,64 1D gaussian,17.806416,True,grad-ascent,0.009636,0.189549,0.1
