# Selecting Sinkhorn OT Solutions

In [None]:
from ot.experiments import OTsolver
from typing import List
import regot

FORCE_RERUN = False

# palette: https://colorhunt.co/palettes/random
# color picker: https://www.w3schools.com/colors/colors_picker.asp
#               https://htmlcolorcodes.com/color-picker 

def get_solvers(reg, max_iter, tol, selected: None | List[str] = None) -> List:
    sinkhorn_bcd = OTsolver(method=regot.sinkhorn_bcd, method_name='BCD',
                            color='#c44e52', linestyle=(0, (3, 1, 2, 1)),
                            reg=reg, max_iter=max_iter, tol=tol)
    sinkhorn_apdagd = OTsolver(method=regot.sinkhorn_apdagd, method_name='APDAGD',
                            color='#937860', linestyle=(0, (2, 2)),
                            reg=reg, max_iter=max_iter, tol=tol)
    sinkhorn_lbfgs_dual = OTsolver(method=regot.sinkhorn_lbfgs_dual, method_name='LBFGS-Dual',
                                color='#ff8000', linestyle=(0, (3, 1)),
                                reg=reg, max_iter=max_iter, tol=tol)
    sinkhorn_newton = OTsolver(method=regot.sinkhorn_newton, method_name='Newton',
                            color='#da8bc3', linestyle=(0, (4, 2, 1, 2)),
                            reg=reg, max_iter=max_iter, tol=tol)
    sinkhorn_ssns = OTsolver(method=regot.sinkhorn_ssns, method_name='SSNS', shift=1e-6,
                            color='#55a868', linestyle=(0, (2, 1)),
                            reg=reg, max_iter=max_iter, tol=tol)
    sinkhorn_sparse_newton = OTsolver(method=regot.sinkhorn_sparse_newton, method_name='Sparse Newton', shift=1e-6,
                                    color='#e6a682', linestyle=(0, (1, 1)),
                                    reg=reg, max_iter=max_iter, tol=tol)
    sinkhorn_splr = OTsolver(method=regot.sinkhorn_splr, method_name='SPLR', density=0.01, shift=1e-6,
                            color='#4c72b0', linestyle='solid',
                            reg=reg, max_iter=max_iter, tol=tol)

    all_solvers = {
        'BCD': sinkhorn_bcd,
        'APDAGD': sinkhorn_apdagd,
        'LBFGS-Dual': sinkhorn_lbfgs_dual,
        'Newton': sinkhorn_newton,
        'SSNS': sinkhorn_ssns,
        'Sparse Newton': sinkhorn_sparse_newton,
        'SPLR': sinkhorn_splr,
    }

    if selected is not None:
        return [v for k, v in all_solvers.items() if k in selected]
    else:
        return all_solvers.values()



# MNIST

In [None]:
from ot.datasets import MnistOT
from ot.experiments import OTtask

# reg_list = [0.001, 0.01]
# norm_list = ['l1', 'l2']
reg_list = [0.001]
norm_list = ['l1']
max_iter, tol = 500, 1e-6

mnist_methods = [
    'BCD', 'APDAGD', # first-order methods
    'LBFGS-Dual', 'Newton', # second-order methods
    'SSNS', # sparse method
    'SPLR', # new methods
]

for reg in reg_list:
    for norm in norm_list:
        # Solving the MNIST OT problem
        mnist_ot_problem = MnistOT(reg=reg, distance=norm)
        mnist_solvers = get_solvers(reg=reg, max_iter=max_iter, tol=tol,
                                    selected=mnist_methods)
        mnist_task = OTtask(problem=mnist_ot_problem, solvers=mnist_solvers)
        mnist_task.plot_for_problem(x_key='iterations', x_label='Iteration Number', y_label='Log10 Gradient Norm',
                                    force_rerun=FORCE_RERUN, selected_methods=mnist_methods)
        mnist_task.plot_for_problem(x_key='run_times', x_label='Run time(seconds)', y_label='Log10 Gradient Norm',
                                    force_rerun=FORCE_RERUN, selected_methods=mnist_methods)

# Fashion MNIST

In [None]:
from ot.datasets import FashionMnistOT
from ot.experiments import OTtask

# reg_list = [0.001, 0.01]
# norm_list = ['l1', 'l2']
reg_list = [0.001]
norm_list = ['l1']
max_iter, tol = 500, 1e-6

fashion_mnist_methods = [
    'BCD', 'APDAGD', # first-order methods
    'LBFGS-Dual', 'Newton', # second-order methods
    'SSNS', # sparse method
    'SPLR', # new methods
]

for reg in reg_list:
    for norm in norm_list:
        # Solving the MNIST OT problem
        fashion_mnist_ot_problem = FashionMnistOT(reg=reg, distance=norm)
        fashion_mnist_solvers = get_solvers(reg=reg, max_iter=max_iter, tol=tol,
                                            selected=fashion_mnist_methods)
        fashion_mnist_task = OTtask(problem=fashion_mnist_ot_problem, solvers=fashion_mnist_solvers)
        fashion_mnist_task.plot_for_problem(x_key='iterations', x_label='Iteration Number', y_label='Log10 Gradient Norm',
                                            force_rerun=FORCE_RERUN, selected_methods=fashion_mnist_methods)
        fashion_mnist_task.plot_for_problem(x_key='run_times', x_label='Run time(seconds)', y_label='Log10 Gradient Norm',
                                            force_rerun=FORCE_RERUN, selected_methods=fashion_mnist_methods)


# Imagenette

In [None]:
from ot.datasets import ImagenetteOT
from ot.experiments import OTtask

# reg_list = [0.001, 0.01]
# norm_list = ['l1', 'l2']
reg_list = [0.001]
norm_list = ['l1']
max_iter, tol = 500, 1e-6

imagenette_methods = [
    'BCD', 'APDAGD', # first-order methods
    'LBFGS-Dual', 'Newton', # second-order methods
    'SSNS', # sparse method
    'SPLR' # new methods
]

for reg in reg_list:
    for norm in norm_list:
        # Solving the Imagenette OT problem
        imagenette_ot_problem = ImagenetteOT(reg=reg, distance=norm)
        imagenette_solvers = get_solvers(reg=reg, max_iter=max_iter, tol=tol,
                                         selected=imagenette_methods)
        imagenette_task = OTtask(problem=imagenette_ot_problem, solvers=imagenette_solvers)
        imagenette_task.plot_for_problem(x_key='iterations', x_label='Iteration Number', y_label='Log10 Gradient Norm',
                                         force_rerun=FORCE_RERUN, selected_methods=imagenette_methods)
        imagenette_task.plot_for_problem(x_key='run_times', x_label='Run time(seconds)', y_label='Log10 Gradient Norm',
                                         force_rerun=FORCE_RERUN, selected_methods=imagenette_methods)


# Synthetic Data Experiments

## Synthetic I

### n = m = 1000

In [None]:
from ot.datasets import Synthetic1OT
from ot.experiments import OTtask

reg_list = [0.001]
problem_size = 1000
max_iter, tol = 500, 1e-6

synthetic1_1000_methods = [
    'BCD', 'APDAGD', # first-order methods
    'LBFGS-Dual', 'Newton', # second-order methods
    'SSNS', # sparse method
    'SPLR' # new methods
]

for reg in reg_list:
    # Solving the Synthetic1 OT problem
    synthetic1_ot_problem = Synthetic1OT(
        n=problem_size,
        m=problem_size,
        reg=reg,
    )
    synthetic1_solvers = get_solvers(reg=reg, max_iter=max_iter, tol=tol,
                                     selected=synthetic1_1000_methods)
    synthetic1_task = OTtask(problem=synthetic1_ot_problem, solvers=synthetic1_solvers)
    synthetic1_task.plot_for_problem(x_key='iterations', x_label='Iteration Number', y_label='Log10 Gradient Norm',
                                     force_rerun=FORCE_RERUN, selected_methods=synthetic1_1000_methods)
    synthetic1_task.plot_for_problem(x_key='run_times', x_label='Run time(seconds)', y_label='Log10 Gradient Norm',
                                     force_rerun=FORCE_RERUN, selected_methods=synthetic1_1000_methods)

### n = m = 5000

In [None]:
from ot.datasets import Synthetic1OT
from ot.experiments import OTtask

reg_list = [0.001]
problem_size = 5000
max_iter, tol = 500, 1e-6

synthetic1_5000_methods = [
    'BCD', # first-order methods
    'LBFGS-Dual', # second-order methods
    'SSNS', # sparse method
    'SPLR' # new methods
]

for reg in reg_list:
    # Solving the Synthetic1 OT problem
    synthetic1_ot_problem = Synthetic1OT(
        n=problem_size,
        m=problem_size,
        reg=reg,
    )
    synthetic1_solvers = get_solvers(reg=reg, max_iter=max_iter, tol=tol,
                                     selected=synthetic1_5000_methods)
    synthetic1_task = OTtask(problem=synthetic1_ot_problem, solvers=synthetic1_solvers)
    synthetic1_task.plot_for_problem(x_key='iterations', x_label='Iteration Number', y_label='Log10 Gradient Norm',
                                     force_rerun=FORCE_RERUN, selected_methods=synthetic1_5000_methods)
    synthetic1_task.plot_for_problem(x_key='run_times', x_label='Run time(seconds)', y_label='Log10 Gradient Norm',
                                     force_rerun=FORCE_RERUN, selected_methods=synthetic1_5000_methods)

### n = m = 10000

In [None]:
from ot.datasets import Synthetic1OT
from ot.experiments import OTtask

reg_list = [0.001]
problem_size = 10000
max_iter, tol = 500, 1e-6

synthetic1_10000_methods = [
    'BCD', # first-order methods
    'LBFGS-Dual', # second-order methods
    'SSNS', # sparse method
    'SPLR' # new methods
]

for reg in reg_list:
    # Solving the Synthetic1 OT problem
    synthetic1_ot_problem = Synthetic1OT(
        n=problem_size,
        m=problem_size,
        reg=reg,
    )
    synthetic1_solvers = get_solvers(reg=reg, max_iter=max_iter, tol=tol,
                                     selected=synthetic1_10000_methods)
    synthetic1_task = OTtask(problem=synthetic1_ot_problem, solvers=synthetic1_solvers)
    synthetic1_task.plot_for_problem(x_key='iterations', x_label='Iteration Number', y_label='Log10 Gradient Norm',
                                     force_rerun=FORCE_RERUN, selected_methods=synthetic1_10000_methods)
    synthetic1_task.plot_for_problem(x_key='run_times', x_label='Run time(seconds)', y_label='Log10 Gradient Norm',
                                     force_rerun=FORCE_RERUN, selected_methods=synthetic1_10000_methods)

## Synthetic II

### n = m = 1000

In [None]:
from ot.datasets import Synthetic2OT
from ot.experiments import OTtask

reg_list = [0.001]
problem_size = 1000
max_iter, tol = 500, 1e-6

synthetic2_1000_methods = [
    'BCD', 'APDAGD', # first-order methods
    'LBFGS-Dual', 'Newton', # second-order methods
    'SSNS', # sparse method
    'SPLR' # new methods
]

for reg in reg_list:
    # Solving the Synthetic1 OT problem
    synthetic2_ot_problem = Synthetic2OT(
        n=problem_size,
        m=problem_size,
        reg=reg,
    )
    synthetic2_solvers = get_solvers(reg=reg, max_iter=max_iter, tol=tol,
                                     selected=synthetic2_1000_methods)
    synthetic2_task = OTtask(problem=synthetic2_ot_problem, solvers=synthetic2_solvers)
    synthetic2_task.plot_for_problem(x_key='iterations', x_label='Iteration Number', y_label='Log10 Gradient Norm',
                                     force_rerun=FORCE_RERUN, selected_methods=synthetic2_1000_methods)
    synthetic2_task.plot_for_problem(x_key='run_times', x_label='Run time(seconds)', y_label='Log10 Gradient Norm',
                                     force_rerun=FORCE_RERUN, selected_methods=synthetic2_1000_methods)

### n = m = 5000

In [None]:
from ot.datasets import Synthetic2OT
from ot.experiments import OTtask

reg_list = [0.001]
problem_size = 5000
max_iter, tol = 500, 1e-6

synthetic2_5000_methods = [
    'BCD', # first-order methods
    'LBFGS-Dual', # second-order methods
    'SSNS', # sparse method
    'SPLR' # new methods
]

for reg in reg_list:
    # Solving the Synthetic1 OT problem
    synthetic2_ot_problem = Synthetic2OT(
        n=problem_size,
        m=problem_size,
        reg=reg,
    )
    synthetic2_solvers = get_solvers(reg=reg, max_iter=max_iter, tol=tol,
                                     selected=synthetic2_5000_methods)
    synthetic2_task = OTtask(problem=synthetic2_ot_problem, solvers=synthetic2_solvers)
    synthetic2_task.plot_for_problem(x_key='iterations', x_label='Iteration Number', y_label='Log10 Gradient Norm',
                                     force_rerun=FORCE_RERUN, selected_methods=synthetic1_5000_methods)
    synthetic2_task.plot_for_problem(x_key='run_times', x_label='Run time(seconds)', y_label='Log10 Gradient Norm',
                                     force_rerun=FORCE_RERUN, selected_methods=synthetic2_5000_methods)

### n = m = 10000

In [None]:
from ot.datasets import Synthetic2OT
from ot.experiments import OTtask

reg_list = [0.001]
problem_size = 10000
max_iter, tol = 500, 1e-6

synthetic2_10000_methods = [
    'BCD', # first-order methods
    'LBFGS-Dual', # second-order methods
    'SSNS', # sparse method
    'SPLR' # new methods
]

for reg in reg_list:
    # Solving the Synthetic1 OT problem
    synthetic2_ot_problem = Synthetic2OT(
        n=problem_size,
        m=problem_size,
        reg=reg,
    )
    synthetic2_solvers = get_solvers(reg=reg, max_iter=max_iter, tol=tol,
                                     selected=synthetic2_10000_methods)
    synthetic2_task = OTtask(problem=synthetic2_ot_problem, solvers=synthetic2_solvers)
    synthetic2_task.plot_for_problem(x_key='iterations', x_label='Iteration Number', y_label='Log10 Gradient Norm',
                                     force_rerun=FORCE_RERUN, selected_methods=synthetic2_10000_methods)
    synthetic2_task.plot_for_problem(x_key='run_times', x_label='Run time(seconds)', y_label='Log10 Gradient Norm',
                                     force_rerun=FORCE_RERUN, selected_methods=synthetic2_10000_methods)