In [None]:
from ot.datasets import (
    MnistOT, FashionMnistOT, ImagenetteOT,
    Synthetic1OT, Synthetic2OT
)

mnist_ot = MnistOT(
    source_idx=2, target_idx=54698
)

fashion_mnist_ot = FashionMnistOT(
    source_idx=2, target_idx=54698
)

imagenette_ot = ImagenetteOT(
    source_classname='tench',
    target_classname='cassette player'
)

synthetic1_ot = Synthetic1OT(n=10, m=12)
synthetic2_ot = Synthetic2OT(n=10, m=12)

In [6]:
from ot.experiments import OTsolver
import regot

kwargs = {"reg": 0.001, "max_iter": 1000, "tol": 1e-6}
sinkhorn_bcd = OTsolver(method=regot.sinkhorn_bcd, method_name='BCD', **kwargs)
sinkhorn_apdagd = OTsolver(method=regot.sinkhorn_apdagd, method_name='APDAGD', **kwargs)
sinkhorn_lbfgs_dual = OTsolver(method=regot.sinkhorn_lbfgs_dual, method_name='LBFGS-Dual', **kwargs)
sinkhorn_newton = OTsolver(method=regot.sinkhorn_newton, method_name='Newton', **kwargs)
sinkhorn_ssns = OTsolver(method=regot.sinkhorn_ssns, method_name='SSNS', shift=1e-6, **kwargs)
sinkhorn_sparse_newton = OTsolver(method=regot.sinkhorn_sparse_newton, method_name='Sparse Newton', shift=1e-6, **kwargs)
sinkhorn_splr = OTsolver(method=regot.sinkhorn_sparse_newton_low_rank, method_name='SPLR', density=0.01, shift=1e-6, **kwargs)

In [7]:
solvers_for_all = [sinkhorn_bcd, sinkhorn_apdagd, sinkhorn_lbfgs_dual, sinkhorn_newton, sinkhorn_ssns, sinkhorn_sparse_newton, sinkhorn_splr]
solvers_for_feature = solvers_for_all
solvers_for_large_scale = [sinkhorn_bcd, sinkhorn_ssns, sinkhorn_sparse_newton, sinkhorn_splr]

In [None]:
reg = 0.001
distance = 'l2'

mnist_ot_settings = {
    "source_idx": 2,
    "target_idx": 54698,
    "reg": reg,
    "distance": distance,
}

fashion_mnist_ot_settings = {
    "source_idx": 2,
    "target_idx": 54698,
    "reg": reg,
    "distance": distance,
}

imagenette_ot_settings = {
    "source_classname": 'tench',
    "target_classname": 'cassette player',
    "reg": reg,
    "distance": distance,
    "dim": 30,
}

mnist_ot_problem = MnistOT(**mnist_ot_settings)
fashion_mnist_ot_problem = FashionMnistOT(**fashion_mnist_ot_settings)
imagenette_ot_problem = ImagenetteOT(**imagenette_ot_settings)

In [None]:
from ot.experiments import OTtask

concerned_problems = [
    mnist_ot_problem,
    fashion_mnist_ot_problem,
    imagenette_ot_problem,
]

task = OTtask(problems=concerned_problems, solvers=solvers_for_all)

results = task.run(save_results=True, force_rerun=False)