diff --git a/README.md b/README.md index fbad3086e..bc3474212 100644 --- a/README.md +++ b/README.md @@ -370,4 +370,6 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil [68] Chowdhury, S., Miller, D., & Needham, T. (2021). [Quantized gromov-wasserstein](https://link.springer.com/chapter/10.1007/978-3-030-86523-8_49). ECML PKDD 2021. Springer International Publishing. -[69] Delon, J., & Desolneux, A. (2020). [A Wasserstein-type distance in the space of Gaussian mixture models](https://epubs.siam.org/doi/abs/10.1137/19M1301047). SIAM Journal on Imaging Sciences, 13(2), 936-970. \ No newline at end of file +[69] Delon, J., & Desolneux, A. (2020). [A Wasserstein-type distance in the space of Gaussian mixture models](https://epubs.siam.org/doi/abs/10.1137/19M1301047). SIAM Journal on Imaging Sciences, 13(2), 936-970. + +[70] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). [Faster Unbalanced Optimal Transport: Translation Invariant Sinkhorn and 1-D Frank-Wolfe](https://proceedings.mlr.press/v151/sejourne22a.html). In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR. diff --git a/examples/unbalanced-partial/plot_conv_sinkhorn_ti.py b/examples/unbalanced-partial/plot_conv_sinkhorn_ti.py new file mode 100644 index 000000000..887b39bcb --- /dev/null +++ b/examples/unbalanced-partial/plot_conv_sinkhorn_ti.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +""" +=============================================================== +Translation Invariant Sinkhorn for Unbalanced Optimal Transport +=============================================================== + +This examples illustrates the better convergence of the translation +invariance Sinkhorn algorithm proposed in [70] compared to the classical +Sinkhorn algorithm. + +[70] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). +Faster unbalanced optimal transport: Translation invariant sinkhorn and 1-d frank-wolfe. +In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR. + +""" + +# Author: Clément Bonet +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import ot + +############################################################################## +# Setting parameters +# ------------- + +# %% parameters + +n_iter = 50 # nb iters +n = 40 # nb samples + +num_iter_max = 100 +n_noise = 10 + +reg = 0.005 +reg_m_kl = 0.05 + +mu_s = np.array([-1, -1]) +cov_s = np.array([[1, 0], [0, 1]]) + +mu_t = np.array([4, 4]) +cov_t = np.array([[1, -.8], [-.8, 1]]) + + +############################################################################## +# Compute entropic kl-regularized UOT with Sinkhorn and Translation Invariant Sinkhorn +# ----------- + +err_sinkhorn_uot = np.empty((n_iter, num_iter_max)) +err_sinkhorn_uot_ti = np.empty((n_iter, num_iter_max)) + + +for seed in range(n_iter): + np.random.seed(seed) + xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s) + xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t) + + xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) - 4))), axis=0) + xt = np.concatenate((xt, ((np.random.rand(n_noise, 2) + 6))), axis=0) + + n = n + n_noise + + a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples + + # loss matrix + M = ot.dist(xs, xt) + M /= M.max() + + entropic_kl_uot, log_uot = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_kl, reg_type="kl", log=True, numItermax=num_iter_max, stopThr=0) + entropic_kl_uot_ti, log_uot_ti = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_kl, reg_type="kl", + method="sinkhorn_translation_invariant", log=True, + numItermax=num_iter_max, stopThr=0) + + err_sinkhorn_uot[seed] = log_uot["err"] + err_sinkhorn_uot_ti[seed] = log_uot_ti["err"] + +############################################################################## +# Plot the results +# ---------------- + +mean_sinkh = np.mean(err_sinkhorn_uot, axis=0) +std_sinkh = np.std(err_sinkhorn_uot, axis=0) + +mean_sinkh_ti = np.mean(err_sinkhorn_uot_ti, axis=0) +std_sinkh_ti = np.std(err_sinkhorn_uot_ti, axis=0) + +absc = list(range(num_iter_max)) + +pl.plot(absc, mean_sinkh, label="Sinkhorn") +pl.fill_between(absc, mean_sinkh - 2 * std_sinkh, mean_sinkh + 2 * std_sinkh, alpha=0.5) + +pl.plot(absc, mean_sinkh_ti, label="Translation Invariant Sinkhorn") +pl.fill_between(absc, mean_sinkh_ti - 2 * std_sinkh_ti, mean_sinkh_ti + 2 * std_sinkh_ti, alpha=0.5) + +pl.yscale("log") +pl.legend() +pl.xlabel("Number of Iterations") +pl.ylabel(r"$\|u-v\|_\infty$") +pl.grid(True) +pl.show() diff --git a/examples/unbalanced-partial/plot_unbalanced_OT.py b/examples/unbalanced-partial/plot_unbalanced_OT.py index 03487e7e2..6f545c1f7 100644 --- a/examples/unbalanced-partial/plot_unbalanced_OT.py +++ b/examples/unbalanced-partial/plot_unbalanced_OT.py @@ -73,7 +73,7 @@ reg_m_l2 = 5 mass = 0.7 -entropic_kl_uot = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_kl) +entropic_kl_uot = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_kl, reg_type="kl") kl_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_kl, div='kl') l2_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_l2, div='l2') partial_ot = ot.partial.partial_wasserstein(a, b, M, m=mass) diff --git a/ot/unbalanced/__init__.py b/ot/unbalanced/__init__.py index 3a20af30d..7d6294424 100644 --- a/ot/unbalanced/__init__.py +++ b/ot/unbalanced/__init__.py @@ -12,6 +12,7 @@ from ._sinkhorn import (sinkhorn_knopp_unbalanced, sinkhorn_unbalanced, sinkhorn_stabilized_unbalanced, + sinkhorn_unbalanced_translation_invariant, sinkhorn_unbalanced2, barycenter_unbalanced_sinkhorn, barycenter_unbalanced_stabilized, @@ -22,6 +23,7 @@ from ._lbfgs import (lbfgsb_unbalanced, lbfgsb_unbalanced2) __all__ = ['sinkhorn_knopp_unbalanced', 'sinkhorn_unbalanced', 'sinkhorn_stabilized_unbalanced', - 'sinkhorn_unbalanced2', 'barycenter_unbalanced_sinkhorn', 'barycenter_unbalanced_stabilized', + 'sinkhorn_unbalanced_translation_invariant', 'sinkhorn_unbalanced2', + 'barycenter_unbalanced_sinkhorn', 'barycenter_unbalanced_stabilized', 'barycenter_unbalanced', 'mm_unbalanced', 'mm_unbalanced2', '_get_loss_unbalanced', 'lbfgsb_unbalanced', 'lbfgsb_unbalanced2'] diff --git a/ot/unbalanced/_sinkhorn.py b/ot/unbalanced/_sinkhorn.py index 37e85253b..5000f5b61 100644 --- a/ot/unbalanced/_sinkhorn.py +++ b/ot/unbalanced/_sinkhorn.py @@ -140,6 +140,10 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 + .. [70] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). + Faster unbalanced optimal transport: Translation invariant sinkhorn and 1-d frank-wolfe. + In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR. + See Also -------- @@ -148,7 +152,8 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', Unbalanced Stabilized sinkhorn :ref:`[9, 10] ` ot.unbalanced.sinkhorn_reg_scaling_unbalanced: Unbalanced Sinkhorn with epsilon scaling :ref:`[9, 10] ` - + ot.unbalanced.sinkhorn_unbalanced_translation_invariant + Translation invariant Unbalanced Sinkhorn :ref:`[70] ` """ if method.lower() == 'sinkhorn': @@ -163,12 +168,21 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', stopThr=stopThr, verbose=verbose, log=log, **kwargs) + + elif method.lower() == 'sinkhorn_translation_invariant': + return sinkhorn_unbalanced_translation_invariant(a, b, M, reg, reg_m, reg_type, c, + warmstart, numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, **kwargs) + elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp') return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, c, warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) + else: raise ValueError("Unknown method '%s'." % method) @@ -293,12 +307,16 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 + .. [70] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). + Faster unbalanced optimal transport: Translation invariant sinkhorn and 1-d frank-wolfe. + In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR. + See Also -------- ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn :ref:`[10] ` ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn :ref:`[9, 10] ` ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epsilon scaling :ref:`[9, 10] ` - + ot.unbalanced.sinkhorn_unbalanced_translation_invariant: :ref:`[70] ` """ M, a, b = list_to_array(M, a, b) @@ -313,7 +331,14 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', res = sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type, c, warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, - log=True, **kwargs) + log=log, **kwargs) + + elif method.lower() == 'sinkhorn_translation_invariant': + res = sinkhorn_unbalanced_translation_invariant(a, b, M, reg, reg_m, reg_type, c, + warmstart, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp') res = sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, c, @@ -347,6 +372,13 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) + + elif method.lower() == 'sinkhorn_translation_invariant': + return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type, c, + warmstart, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp') return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, c, @@ -865,6 +897,241 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="kl", c=None, return plan +def sinkhorn_unbalanced_translation_invariant(a, b, M, reg, reg_m, reg_type="kl", c=None, + warmstart=None, numItermax=1000, stopThr=1e-6, + verbose=False, log=False, **kwargs): + """ + Solve the entropic regularization unbalanced optimal transport problem and + return the OT plan + + The function solves the following optimization problem: + + .. math:: + W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot \Omega(\gamma) + + \mathrm{reg_{m1}} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_{m2}} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b}) + + s.t. + \gamma \geq 0 + + where : + + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\Omega` is the entropic regularization term,KL divergence + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions + - KL is the Kullback-Leibler divergence + + The algorithm used for solving the problem is the translation invariant Sinkhorn algorithm as proposed in :ref:`[70] ` + + + Parameters + ---------- + a : array-like (dim_a,) + Unnormalized histogram of dimension `dim_a` + b : array-like (dim_b,) or array-like (dim_b, n_hists) + One or multiple unnormalized histograms of dimension `dim_b` + If many, compute all the OT distances (a, b_i) + M : array-like (dim_a, dim_b) + loss matrix + reg : float + Entropy regularization term > 0 + reg_m: float or indexable object of length 1 or 2 + Marginal relaxation term. + If reg_m is a scalar or an indexable object of length 1, + then the same reg_m is applied to both marginal relaxations. + The entropic balanced OT can be recovered using `reg_m=float("inf")`. + For semi-relaxed case, use either + `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. + If reg_m is an array, it must have the same backend as input arrays (a, b, M). + reg_type : string, optional + Regularizer term. Can take two values: + 'entropy' (negative entropy) + :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`, or + 'kl' (Kullback-Leibler) + :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`. + c : array-like (dim_a, dim_b), optional (default=None) + Reference measure for the regularization. + If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`. + If :math:`\texttt{reg_type}='entropy'`, then :math:`\mathbf{c} = 1_{dim_a} 1_{dim_b}^T`. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors). + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (> 0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + if n_hists == 1: + - gamma : (dim_a, dim_b) array-like + Optimal transportation matrix for the given parameters + - log : dict + log dictionary returned only if `log` is `True` + else: + - ot_distance : (n_hists,) array-like + the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` + - log : dict + log dictionary returned only if `log` is `True` + + Examples + -------- + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[0., 1.],[1., 0.]] + >>> ot.unbalanced.sinkhorn_unbalanced_translation_invariant(a, b, M, 1., 1.) + array([[0.32205357, 0.11847689], + [0.11847689, 0.32205357]]) + + .. _references-sinkhorn-unbalanced-translation-invariant: + References + ---------- + .. [70] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). + Faster unbalanced optimal transport: Translation invariant sinkhorn and 1-d frank-wolfe. + In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR. + """ + + M, a, b = list_to_array(M, a, b) + nx = get_backend(M, a, b) + + dim_a, dim_b = M.shape + + if len(a) == 0: + a = nx.ones(dim_a, type_as=M) / dim_a + if len(b) == 0: + b = nx.ones(dim_b, type_as=M) / dim_b + + if len(b.shape) > 1: + n_hists = b.shape[1] + else: + n_hists = 0 + + reg_m1, reg_m2 = get_parameter_pair(reg_m) + + if log: + log = {'err': []} + + # we assume that no distances are null except those of the diagonal of + # distances + if warmstart is None: + if n_hists: + u = nx.ones((dim_a, 1), type_as=M) + v = nx.ones((dim_b, n_hists), type_as=M) + a = a.reshape(dim_a, 1) + else: + u = nx.ones(dim_a, type_as=M) + v = nx.ones(dim_b, type_as=M) + else: + u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) + + u_, v_ = u, v + + if reg_type == "entropy": + warnings.warn('If reg_type = entropy, then the matrix c is overwritten by the one matrix.') + c = nx.ones((dim_a, dim_b), type_as=M) + + if n_hists: + M0 = M + else: + c = a[:, None] * b[None, :] if c is None else c + M0 = M - reg * nx.log(c) + K = nx.exp(-M0 / reg) + + fi_1 = reg_m1 / (reg_m1 + reg) if reg_m1 != float("inf") else 1 + fi_2 = reg_m2 / (reg_m2 + reg) if reg_m2 != float("inf") else 1 + + k1 = reg * reg_m1 / ((reg + reg_m1) * (reg_m1 + reg_m2)) if reg_m1 != float("inf") else 0 + k2 = reg * reg_m2 / ((reg + reg_m2) * (reg_m1 + reg_m2)) if reg_m2 != float("inf") else 0 + + k_rho1 = k1 * reg_m1 / reg if reg_m1 != float("inf") else 0 + k_rho2 = k2 * reg_m2 / reg if reg_m2 != float("inf") else 0 + + if reg_m1 == float("inf") and reg_m2 == float("inf"): + xi1, xi2 = 0, 0 + fi_12 = 1 + elif reg_m1 == float("inf"): + xi1 = 0 + xi2 = reg / reg_m2 + fi_12 = reg_m2 + elif reg_m2 == float("inf"): + xi1 = reg / reg_m1 + xi2 = 0 + fi_12 = reg_m1 + else: + xi1 = (reg_m2 * reg) / (reg_m1 * (reg + reg_m1 + reg_m2)) + xi2 = (reg_m1 * reg) / (reg_m2 * (reg + reg_m1 + reg_m2)) + fi_12 = reg_m1 * reg_m2 / (reg_m1 + reg_m2) + + xi_rho1 = xi1 * reg_m1 / reg if reg_m1 != float("inf") else 0 + xi_rho2 = xi2 * reg_m2 / reg if reg_m2 != float("inf") else 0 + + reg_ratio1 = reg / reg_m1 if reg_m1 != float("inf") else 0 + reg_ratio2 = reg / reg_m2 if reg_m2 != float("inf") else 0 + + err = 1. + + for i in range(numItermax): + uprev = u + vprev = v + + Kv = nx.dot(K, v_) + u_hat = (a / Kv) ** fi_1 * nx.sum(b * v_**reg_ratio2)**k_rho2 + u_ = u_hat * nx.sum(a * u_hat**(-reg_ratio1))**(-xi_rho1) + + Ktu = nx.dot(K.T, u_) + v_hat = (b / Ktu) ** fi_2 * nx.sum(a * u_**(-reg_ratio1))**k_rho1 + v_ = v_hat * nx.sum(b * v_hat**(-reg_ratio2))**(-xi_rho2) + + t = (nx.sum(a * u_**(-reg_ratio1)) / nx.sum(b * v_**(-reg_ratio2)))**(fi_12 / reg) + + u = u_ * t + v = v_ / t + + err_u = nx.max(nx.abs(u - uprev)) / max( + nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1. + ) + err_v = nx.max(nx.abs(v - vprev)) / max( + nx.max(nx.abs(v)), nx.max(nx.abs(vprev)), 1. + ) + err = 0.5 * (err_u + err_v) + if log: + log['err'].append(err) + if verbose: + if i % 50 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(i, err)) + + if err < stopThr: + break + + if log: + log['logu'] = nx.log(u + 1e-300) + log['logv'] = nx.log(v + 1e-300) + + if n_hists: # return only loss + res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M) + if log: + return res, log + else: + return res + + else: # return OT matrix + + if log: + return u[:, None] * K * v[None, :], log + else: + return u[:, None] * K * v[None, :] + + def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, numItermax=1000, stopThr=1e-6, verbose=False, log=False): diff --git a/test/unbalanced/test_sinkhorn.py b/test/unbalanced/test_sinkhorn.py index 595f9ba97..7007e336b 100644 --- a/test/unbalanced/test_sinkhorn.py +++ b/test/unbalanced/test_sinkhorn.py @@ -62,49 +62,20 @@ def test_unbalanced_convergence(nx, method, reg_type): # check if sinkhorn_unbalanced2 returns the correct loss np.testing.assert_allclose(nx.to_numpy(nx.sum(G * M)), loss, atol=1e-5) + # check in case no histogram is provided + M_np = nx.to_numpy(M) + a_np, b_np = np.array([]), np.array([]) + a, b = nx.from_numpy(a_np, b_np) -@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], ["kl", "entropy"])) -def test_unbalanced_marginals(nx, method, reg_type): - # test generalized sinkhorn for unbalanced OT - n = 100 - rng = np.random.RandomState(42) - - x = rng.randn(n, 2) - a = ot.utils.unif(n) - b = ot.utils.unif(n) - M = ot.dist(x, x) - a, b, M = nx.from_numpy(a, b, M) - - epsilon = 1. - reg_m = 1. - - G0, log0 = ot.unbalanced.sinkhorn_unbalanced( - a, b, M, reg=epsilon, reg_m=reg_m, method=method, - reg_type=reg_type, log=True - ) - loss0 = ot.unbalanced.sinkhorn_unbalanced2( - a, b, M, reg=epsilon, reg_m=reg_m, method=method, reg_type=reg_type, - ) - - # check in case no histogram is provided or histogram is None - a_empty, b_empty = np.array([]), np.array([]) - a_empty, b_empty = nx.from_numpy(a_empty, b_empty) - - G_empty, log_empty = ot.unbalanced.sinkhorn_unbalanced( - a_empty, b_empty, M, reg=epsilon, reg_m=reg_m, method=method, - reg_type=reg_type, log=True + G = ot.unbalanced.sinkhorn_unbalanced( + a, b, M, reg=epsilon, reg_m=reg_m, + method=method, reg_type=reg_type, verbose=True ) - loss_empty = ot.unbalanced.sinkhorn_unbalanced2( - a_empty, b_empty, M, reg=epsilon, reg_m=reg_m, method=method, - reg_type=reg_type + G_np = ot.unbalanced.sinkhorn_unbalanced( + a_np, b_np, M_np, reg=epsilon, reg_m=reg_m, + method=method, reg_type=reg_type, verbose=True ) - - np.testing.assert_allclose( - nx.to_numpy(log_empty["logu"]), nx.to_numpy(log0["logu"]), atol=1e-05) - np.testing.assert_allclose( - nx.to_numpy(log_empty["logv"]), nx.to_numpy(log0["logv"]), atol=1e-05) - np.testing.assert_allclose(nx.to_numpy(G_empty), nx.to_numpy(G0), atol=1e-05) - np.testing.assert_allclose(nx.to_numpy(loss_empty), nx.to_numpy(loss0), atol=1e-5) + np.testing.assert_allclose(G_np, nx.to_numpy(G)) @pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], ["kl", "entropy"])) @@ -169,53 +140,8 @@ def test_unbalanced_warmstart(nx, method, reg_type): np.testing.assert_allclose(nx.to_numpy(loss0), nx.to_numpy(loss1), atol=1e-5) -@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], ["kl", "entropy"])) -def test_unbalanced_reference_measure(nx, method, reg_type): - # test generalized sinkhorn for unbalanced OT - n = 100 - rng = np.random.RandomState(42) - - x = rng.randn(n, 2) - a = ot.utils.unif(n) - b = ot.utils.unif(n) - M = ot.dist(x, x) - a, b, M = nx.from_numpy(a, b, M) - - epsilon = 1. - reg_m = 1. - - G0, log0 = ot.unbalanced.sinkhorn_unbalanced( - a, b, M, reg=epsilon, reg_m=reg_m, method=method, - reg_type=reg_type, c=None, log=True - ) - loss0 = ot.unbalanced.sinkhorn_unbalanced2( - a, b, M, reg=epsilon, reg_m=reg_m, method=method, reg_type=reg_type, c=None - ) - - if reg_type == "kl": - c = a[:, None] * b[None, :] - elif reg_type == "entropy": - c = nx.ones(M.shape, type_as=M) - - G, log = ot.unbalanced.sinkhorn_unbalanced( - a, b, M, reg=epsilon, reg_m=reg_m, method=method, - reg_type=reg_type, c=c, log=True - ) - loss = ot.unbalanced.sinkhorn_unbalanced2( - a, b, M, reg=epsilon, reg_m=reg_m, method=method, - reg_type=reg_type, c=c - ) - - np.testing.assert_allclose( - nx.to_numpy(log["logu"]), nx.to_numpy(log0["logu"]), atol=1e-05) - np.testing.assert_allclose( - nx.to_numpy(log["logv"]), nx.to_numpy(log0["logv"]), atol=1e-05) - np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-05) - np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss0), atol=1e-5) - - -@pytest.mark.parametrize("method, log", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], [True, False])) -def test_sinkhorn_unbalanced2(nx, method, log): +@pytest.mark.parametrize("method,reg_type, log", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], ["kl", "entropy"], [True, False])) +def test_sinkhorn_unbalanced2(nx, method, reg_type, log): n = 100 rng = np.random.RandomState(42) @@ -232,12 +158,12 @@ def test_sinkhorn_unbalanced2(nx, method, log): loss = nx.to_numpy(ot.unbalanced.sinkhorn_unbalanced2( a, b, M, reg=epsilon, reg_m=reg_m, method=method, - c=None, log=False, verbose=True + reg_type=reg_type, log=False, verbose=True )) res = ot.unbalanced.sinkhorn_unbalanced2( a, b, M, reg=epsilon, reg_m=reg_m, method=method, - c=None, log=log, verbose=True + reg_type=reg_type, log=log, verbose=True ) loss0 = res[0] if log else res @@ -322,11 +248,6 @@ def test_unbalanced_multiple_inputs(nx, method): v_final = fi * (logb - logKtu) u_final = fi * (loga - logKv) - print("u_final shape = {}".format(u_final.shape)) - print("v_final shape = {}".format(v_final.shape)) - print("logu shape = {}".format(log["logu"].shape)) - print("logv shape = {}".format(log["logv"].shape)) - np.testing.assert_allclose( nx.to_numpy(u_final), nx.to_numpy(log["logu"]), atol=1e-05) np.testing.assert_allclose( @@ -347,20 +268,19 @@ def test_stabilized_vs_sinkhorn(nx): M = ot.utils.dist0(n) M /= np.median(M) - epsilon = 1 + epsilon = 0.1 reg_m = 1. - stopThr = 1e-12 ab, bb, Mb = nx.from_numpy(a, b, M) G, _ = ot.unbalanced.sinkhorn_unbalanced2( - ab, bb, Mb, epsilon, reg_m, method="sinkhorn_stabilized", log=True, stopThr=stopThr, + ab, bb, Mb, epsilon, reg_m, method="sinkhorn_stabilized", log=True ) G2, _ = ot.unbalanced.sinkhorn_unbalanced2( - ab, bb, Mb, epsilon, reg_m, method="sinkhorn", log=True, stopThr=stopThr + ab, bb, Mb, epsilon, reg_m, method="sinkhorn", log=True ) G2_np, _ = ot.unbalanced.sinkhorn_unbalanced2( - a, b, M, epsilon, reg_m, method="sinkhorn", log=True, stopThr=stopThr + a, b, M, epsilon, reg_m, method="sinkhorn", log=True ) G = nx.to_numpy(G) G2 = nx.to_numpy(G2) @@ -369,29 +289,7 @@ def test_stabilized_vs_sinkhorn(nx): np.testing.assert_allclose(G2, G2_np, atol=1e-5) -def test_sinkhorn_wrong_returnCost(nx): - - n = 100 - rng = np.random.RandomState(42) - x = rng.randn(n, 2) - rng = np.random.RandomState(75) - y = rng.randn(n, 2) - a_np = ot.utils.unif(n) - b_np = ot.utils.unif(n) - - M = ot.dist(x, y) - M = M / M.max() - a, b, M = nx.from_numpy(a_np, b_np, M) - epsilon = 1 - reg_m = 1. - - def sinkhorn2(returnCost): - return ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, returnCost=returnCost) - - np.testing.assert_raises(ValueError, sinkhorn2, "invalid_returnCost") - - -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"]) +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) def test_unbalanced_barycenter(nx, method): # test generalized sinkhorn for unbalanced OT barycenter n = 100 @@ -530,3 +428,247 @@ def test_implemented_methods(nx): method=method) barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, method=method) + + +@pytest.mark.parametrize("reg_div,regm_div", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2'])) +def test_lbfgsb_unbalanced(nx, reg_div, regm_div): + + np.random.seed(42) + + xs = np.random.randn(5, 2) + xt = np.random.randn(6, 2) + + M = ot.dist(xs, xt) + + a = ot.unif(5) + b = ot.unif(6) + + G, log = ot.unbalanced.lbfgsb_unbalanced(a, b, M, 1, 10, reg_div=reg_div, regm_div=regm_div, log=True, verbose=False) + + ab, bb, Mb = nx.from_numpy(a, b, M) + + Gb, log = ot.unbalanced.lbfgsb_unbalanced(ab, bb, Mb, 1, 10, reg_div=reg_div, regm_div=regm_div, log=True, verbose=False) + + np.testing.assert_allclose(G, nx.to_numpy(Gb)) + + +@pytest.mark.parametrize("reg_div,regm_div", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2'])) +def test_lbfgsb_unbalanced_relaxation_parameters(nx, reg_div, regm_div): + + np.random.seed(42) + + xs = np.random.randn(5, 2) + xt = np.random.randn(6, 2) + + M = ot.dist(xs, xt) + + a = ot.unif(5) + b = ot.unif(6) + + a, b, M = nx.from_numpy(a, b, M) + + reg_m = 10 + full_list_reg_m = [reg_m, reg_m] + full_tuple_reg_m = (reg_m, reg_m) + tuple_reg_m, list_reg_m = (reg_m), [reg_m] + np1_reg_m = reg_m * np.ones(1) + np2_reg_m = reg_m * np.ones(2) + + list_options = [np1_reg_m, np2_reg_m, full_tuple_reg_m, + tuple_reg_m, full_list_reg_m, list_reg_m] + + G = ot.unbalanced.lbfgsb_unbalanced(a, b, M, 1, reg_m=reg_m, + reg_div=reg_div, regm_div=regm_div, + log=False, verbose=False) + + for opt in list_options: + G0 = ot.unbalanced.lbfgsb_unbalanced( + a, b, M, 1, reg_m=opt, reg_div=reg_div, + regm_div=regm_div, log=False, verbose=False + ) + + np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-06) + + +@pytest.mark.parametrize("reg_div,regm_div", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2'])) +def test_lbfgsb_reference_measure(nx, reg_div, regm_div): + + np.random.seed(42) + + xs = np.random.randn(5, 2) + xt = np.random.randn(6, 2) + M = ot.dist(xs, xt) + a = ot.unif(5) + b = ot.unif(6) + + a, b, M = nx.from_numpy(a, b, M) + c = a[:, None] * b[None, :] + + G, _ = ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=1, reg_m=10, c=None, + reg_div=reg_div, regm_div=regm_div, + log=True, verbose=False) + + G0, _ = ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=1, reg_m=10, c=c, + reg_div=reg_div, regm_div=regm_div, + log=True, verbose=False) + + np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-06) + + +@pytest.mark.parametrize("div", ["kl", "l2"]) +def test_mm_convergence(nx, div): + n = 100 + rng = np.random.RandomState(42) + x = rng.randn(n, 2) + rng = np.random.RandomState(75) + y = rng.randn(n, 2) + a_np = ot.utils.unif(n) + b_np = ot.utils.unif(n) + + M = ot.dist(x, y) + M = M / M.max() + reg_m = 100 + a, b, M = nx.from_numpy(a_np, b_np, M) + + G, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div=div, + verbose=False, log=True) + loss = nx.to_numpy( + ot.unbalanced.mm_unbalanced2(a, b, M, reg_m, div=div, verbose=True) + ) + + # check if the marginals come close to the true ones when large reg + np.testing.assert_allclose(np.sum(nx.to_numpy(G), 1), a_np, atol=1e-03) + np.testing.assert_allclose(np.sum(nx.to_numpy(G), 0), b_np, atol=1e-03) + + # check if mm_unbalanced2 returns the correct loss + np.testing.assert_allclose(nx.to_numpy(nx.sum(G * M)), loss, atol=1e-5) + + # check in case no histogram is provided + a_np, b_np = np.array([]), np.array([]) + a, b = nx.from_numpy(a_np, b_np) + + G_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div=div, verbose=False) + np.testing.assert_allclose(nx.to_numpy(G_null), nx.to_numpy(G)) + + # test when G0 is given + G0 = ot.emd(a, b, M) + G0_np = nx.to_numpy(G0) + reg_m = 10000 + G = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div=div, G0=G0, verbose=False) + np.testing.assert_allclose(G0_np, nx.to_numpy(G), atol=1e-05) + + +@pytest.mark.parametrize("div", ["kl", "l2"]) +def test_mm_relaxation_parameters(nx, div): + n = 100 + rng = np.random.RandomState(42) + x = rng.randn(n, 2) + rng = np.random.RandomState(75) + y = rng.randn(n, 2) + a_np = ot.utils.unif(n) + b_np = ot.utils.unif(n) + + M = ot.dist(x, y) + M = M / M.max() + a, b, M = nx.from_numpy(a_np, b_np, M) + + reg = 1e-2 + + reg_m = 100 + full_list_reg_m = [reg_m, reg_m] + full_tuple_reg_m = (reg_m, reg_m) + tuple_reg_m, list_reg_m = (reg_m), [reg_m] + nx1_reg_m = reg_m * nx.ones(1) + nx2_reg_m = reg_m * nx.ones(2) + + list_options = [nx1_reg_m, nx2_reg_m, full_tuple_reg_m, + tuple_reg_m, full_list_reg_m, list_reg_m] + + G0, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, reg=reg, + div=div, verbose=False, log=True) + loss_0 = nx.to_numpy( + ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=reg_m, reg=reg, + div=div, verbose=True) + ) + + for opt in list_options: + G1, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=opt, + reg=reg, div=div, + verbose=False, log=True) + loss_1 = nx.to_numpy( + ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=opt, + reg=reg, div=div, verbose=True) + ) + + np.testing.assert_allclose(nx.to_numpy(G0), nx.to_numpy(G1), atol=1e-05) + np.testing.assert_allclose(loss_0, loss_1, atol=1e-5) + + +@pytest.mark.parametrize("div", ["kl", "l2"]) +def test_mm_reference_measure(nx, div): + n = 100 + rng = np.random.RandomState(42) + x = rng.randn(n, 2) + rng = np.random.RandomState(75) + y = rng.randn(n, 2) + a_np = ot.utils.unif(n) + b_np = ot.utils.unif(n) + + M = ot.dist(x, y) + M = M / M.max() + a, b, M = nx.from_numpy(a_np, b_np, M) + c = a[:, None] * b[None, :] + + reg = 1e-2 + reg_m = 100 + + G0, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, c=None, reg=reg, + div=div, verbose=False, log=True) + loss_0 = ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=reg_m, c=None, reg=reg, + div=div, verbose=True) + loss_0 = nx.to_numpy(loss_0) + + G1, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, c=c, + reg=reg, div=div, + verbose=False, log=True) + loss_1 = ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=reg_m, c=c, + reg=reg, div=div, verbose=True) + loss_1 = nx.to_numpy(loss_1) + + np.testing.assert_allclose(nx.to_numpy(G0), nx.to_numpy(G1), atol=1e-05) + np.testing.assert_allclose(loss_0, loss_1, atol=1e-5) + + +def test_mm_wrong_divergence(nx): + + n = 100 + rng = np.random.RandomState(42) + x = rng.randn(n, 2) + rng = np.random.RandomState(75) + y = rng.randn(n, 2) + a_np = ot.utils.unif(n) + b_np = ot.utils.unif(n) + + M = ot.dist(x, y) + M = M / M.max() + a, b, M = nx.from_numpy(a_np, b_np, M) + + reg = 1e-2 + reg_m = 100 + + G0, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, reg=reg, + div="kl", verbose=False, log=True) + loss_0 = nx.to_numpy( + ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=reg_m, reg=reg, + div="kl", verbose=True) + ) + + G1, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, reg=reg, + div="wrong_div", verbose=False, log=True) + loss_1 = nx.to_numpy( + ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=reg_m, reg=reg, + div="wrong_div", verbose=True) + ) + + np.testing.assert_allclose(nx.to_numpy(G0), nx.to_numpy(G1), atol=1e-05) + np.testing.assert_allclose(loss_0, loss_1, atol=1e-5)