# Selecting Sinkhorn OT Solutions

In [1]:
from ot.experiments import OTsolver
from typing import List
import regot

# palette: https://colorhunt.co/palettes/random
# color picker: https://www.w3schools.com/colors/colors_picker.asp
#               https://htmlcolorcodes.com/color-picker 

def get_all_solvers(reg, max_iter, tol) -> List:
    sinkhorn_bcd = OTsolver(method=regot.sinkhorn_bcd, method_name='BCD',
                            color='#FDA403', linestyle=(0, (3, 1, 2, 1)),
                            reg=reg, max_iter=max_iter, tol=tol)
    sinkhorn_apdagd = OTsolver(method=regot.sinkhorn_apdagd, method_name='APDAGD',
                            color='#E8751A', linestyle=(0, (2, 2)),
                            reg=reg, max_iter=max_iter, tol=tol)
    sinkhorn_lbfgs_dual = OTsolver(method=regot.sinkhorn_lbfgs_dual, method_name='LBFGS-Dual',
                                color='#898121', linestyle=(0, (3, 1)),
                                reg=reg, max_iter=max_iter, tol=tol)
    sinkhorn_newton = OTsolver(method=regot.sinkhorn_newton, method_name='Newton',
                            color='#E5C287', linestyle=(0, (4, 2, 1, 2)),
                            reg=reg, max_iter=max_iter, tol=tol)
    sinkhorn_ssns = OTsolver(method=regot.sinkhorn_ssns, method_name='SSNS', shift=1e-6,
                            color='#FCE7C8', linestyle=(0, (2, 1)),
                            reg=reg, max_iter=max_iter, tol=tol)
    sinkhorn_sparse_newton = OTsolver(method=regot.sinkhorn_sparse_newton, method_name='Sparse Newton', shift=1e-6,
                                    color='#B1C29E', linestyle=(0, (1, 1)),
                                    reg=reg, max_iter=max_iter, tol=tol)
    sinkhorn_splr = OTsolver(method=regot.sinkhorn_splr, method_name='SPLR', density=0.01, shift=1e-6,
                            color='#DE3163', linestyle='solid',
                            reg=reg, max_iter=max_iter, tol=tol)
    return [
        sinkhorn_bcd, sinkhorn_apdagd,  # first-order methods
        sinkhorn_lbfgs_dual, sinkhorn_newton, # second-order methods
        sinkhorn_ssns, sinkhorn_sparse_newton, # sparse method
        sinkhorn_splr, # new methods
    ]

def get_selected_solvers(reg, max_iter, tol, selected_methods) -> List:
    all_solvers = get_all_solvers(reg, max_iter, tol)
    return [solver for solver in all_solvers if solver.method_name in selected_methods]

mnist_methods = [
    'BCD', 'APDAGD', # first-order methods
    'LBFGS-Dual', 'Newton', # second-order methods
    'SSNS', 'Sparse Newton', # sparse method
    'SPLR' # new methods
]

fashion_mnist_methods = [
    'BCD', 'APDAGD', # first-order methods
    'LBFGS-Dual', 'Newton', # second-order methods
    'SSNS', 'Sparse Newton', # sparse method
    'SPLR' # new methods
]

imagenette_methods = [
    'BCD', 'APDAGD', # first-order methods
    'LBFGS-Dual', 'Newton', # second-order methods
    'SSNS', 'Sparse Newton', # sparse method
    'SPLR' # new methods
]

# Real Image Data Experiments

In [None]:
from ot.datasets import MnistOT, FashionMnistOT, ImagenetteOT
from ot.experiments import OTtask

reg_list = [0.001, 0.01]
norm_list = ['l1', 'l2']

for reg in reg_list:
    for norm in norm_list:
        force_rerun = False
        max_iter, tol = 500, 1e-6

        # Solving the MNIST OT problem
        mnist_ot_problem = MnistOT(reg=reg, distance=norm)
        mnist_solvers = get_selected_solvers(
            reg, max_iter=max_iter, tol=tol,
            selected_methods=mnist_methods
        )
        mnist_task = OTtask(problem=mnist_ot_problem, solvers=mnist_solvers)
        mnist_task.plot_for_problem(x_key='iterations', x_label='Iterations', y_label='Gradient Norm',
                                    force_rerun=force_rerun, selected_methods=mnist_methods)
        mnist_task.plot_for_problem(x_key='run_times', x_label='Time (s)', y_label='Gradient Norm',
                                    force_rerun=force_rerun, selected_methods=mnist_methods)

        # Solving the Fashion MNIST OT problem
        fashion_mnist_ot_problem = FashionMnistOT(reg=reg, distance=norm)
        fashion_mnist_solvers = get_selected_solvers(
            reg, max_iter=max_iter, tol=tol,
            selected_methods=fashion_mnist_methods
        )
        fashion_mnist_task = OTtask(problem=fashion_mnist_ot_problem, solvers=fashion_mnist_solvers)
        fashion_mnist_task.plot_for_problem(x_key='iterations', x_label='Iterations', y_label='Gradient Norm',
                                            force_rerun=force_rerun,
                                            selected_methods=fashion_mnist_methods)
        fashion_mnist_task.plot_for_problem(x_key='run_times', x_label='Time (s)', y_label='Gradient Norm',
                                            force_rerun=force_rerun, selected_methods=fashion_mnist_methods)

        # Solving the Imagenette OT problem
        imagenette_ot_problem = ImagenetteOT(reg=reg, distance=norm)
        imagenette_solvers = get_selected_solvers(
            reg, max_iter=max_iter, tol=tol,
            selected_methods=imagenette_methods
        )
        imagenette_task = OTtask(problem=imagenette_ot_problem, solvers=imagenette_solvers)
        imagenette_task.plot_for_problem(x_key='iterations', x_label='Iterations', y_label='Gradient Norm',
                                         force_rerun=force_rerun, selected_methods=imagenette_methods)
        imagenette_task.plot_for_problem(x_key='run_times', x_label='Time (s)', y_label='Gradient Norm',
                                         force_rerun=force_rerun, selected_methods=imagenette_methods)

# Synthetic Data Experiments

In [None]:
from ot.datasets import Synthetic1OT, Synthetic2OT
from ot.experiments import OTtask

reg_list = [0.001, 0.01]
problem_size_list = [1000, 5000, 10000]

synthetic1_methods = {
    1000: ['BCD', 'APDAGD', 'LBFGS-Dual', 'Newton', 'SSNS', 'Sparse Newton', 'SPLR'],
    5000: ['BCD', 'LBFGS-Dual', 'SSNS', 'Sparse Newton', 'SPLR'],
    10000: ['BCD', 'LBFGS-Dual', 'SSNS', 'Sparse Newton', 'SPLR'],
}
synthetic2_methods = [
    'BCD', 'LBFGS-Dual', # first-order methods
    'SSNS', 'SPLR', # sparse method
]

for reg in reg_list:
    for problem_size in problem_size_list:
        force_rerun = False
        max_iter, tol = 500, 1e-6
        # Solving Synthetic1OT
        synthetic1_ot_problem = Synthetic1OT(n=problem_size, m=problem_size, reg=reg)
        synthetic1_solvers = get_selected_solvers(
            reg=reg, max_iter=max_iter, tol=tol,
            selected_methods=synthetic1_methods[problem_size],
        )
        synthetic1_task = OTtask(problem=synthetic1_ot_problem, solvers=synthetic1_solvers)
        synthetic1_task.plot_for_problem(x_key='iterations', x_label='Iterations', y_label='Gradient Norm',
                                         force_rerun=force_rerun, selected_methods=synthetic1_methods[problem_size])
        synthetic1_task.plot_for_problem(x_key='run_times', x_label='Time (s)', y_label='Gradient Norm',
                                         force_rerun=force_rerun, selected_methods=synthetic1_methods[problem_size])

        # Solving Synthetic2OT
        synthetic2_ot_problem = Synthetic2OT(n=problem_size, m=problem_size, reg=reg)
        synthetic2_solvers = get_selected_solvers(
            reg=reg, max_iter=max_iter, tol=tol,
            selected_methods=synthetic2_methods,
        )
        synthetic2_task = OTtask(problem=synthetic2_ot_problem, solvers=synthetic2_solvers)
        synthetic2_task.plot_for_problem(x_key='iterations', x_label='Iterations', y_label='Gradient Norm',
                                         force_rerun=force_rerun, selected_methods=synthetic2_methods)
        synthetic2_task.plot_for_problem(x_key='run_times', x_label='Time (s)', y_label='Gradient Norm',
                                         force_rerun=force_rerun, selected_methods=synthetic2_methods)
        
